server : avoid common_batch

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-03-20 16:52:24 +02:00
parent 76fd7d6f5b
commit 8a23b4a54a
3 changed files with 91 additions and 137 deletions

View File

@ -565,70 +565,6 @@ std::pair<std::string, std::string> common_get_hf_file(
// clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
//
// Batch utils
//
// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
// this is meant to be temporary
struct common_batch {
llama_batch_ext_ptr batch;
struct batch_token {
llama_token token;
llama_seq_id seq_id; // only support single seq for now
bool logits;
};
std::vector<batch_token> tokens;
int n_outputs = 0;
common_batch() = default;
common_batch(int32_t n_tokens, int32_t n_seq_max) {
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
tokens.reserve(n_tokens);
}
void clear() {
llama_batch_ext_clear(batch.get());
tokens.clear();
}
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
tokens.push_back({token, seq_id, logits});
if (logits) {
n_outputs++;
}
}
void add_text_multi_seq(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
tokens.push_back({token, seq_ids[0], logits});
if (logits) {
n_outputs++;
}
}
void set_logits_last() {
if (!tokens.empty()) {
llama_batch_ext_set_output_last(batch.get());
tokens.back().logits = true;
}
}
int32_t get_n_tokens() const {
return (int32_t)tokens.size();
}
llama_batch_ext * get() {
return batch.get();
}
common_batch get_view(int32_t offset, int32_t n_tokens) {
common_batch view;
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
view.tokens.reserve(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
view.tokens.push_back(tokens[offset + i]);
if (tokens[offset + i].logits) {
view.n_outputs++;
}
}
return view;
}
};
//
// Token utils
//