rework, targeting llama-server

This commit is contained in:
Xuan Son Nguyen
2025-02-14 18:16:49 +01:00
parent 4ed4fe75ed
commit f2e59a8eb9
10 changed files with 191 additions and 136 deletions

View File

@ -1215,7 +1215,7 @@ struct server_slot {
// only used for completion/embedding/infill/rerank
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
llama_batch batch_spec = {};
llama_batch_ptr batch_spec;
llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;
@ -1787,7 +1787,7 @@ struct server_context {
llama_context_params cparams_dft;
llama_batch batch = {};
llama_batch_ptr batch;
bool clean_kv_cache = true;
bool add_bos_token = true;
@ -1820,11 +1820,7 @@ struct server_context {
common_speculative_free(slot.spec);
slot.spec = nullptr;
llama_batch_free(slot.batch_spec);
}
llama_batch_free(batch);
}
bool load_model(const common_params & params) {
@ -1944,7 +1940,7 @@ struct server_context {
slot.n_predict = params_base.n_predict;
if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
slot.batch_spec.reset(llama_batch_init(params_base.speculative.n_max + 1, 1));
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) {
@ -1969,7 +1965,7 @@ struct server_context {
slot.reset();
slots.push_back(slot);
slots.push_back(std::move(slot));
}
default_generation_settings_for_props = slots[0].to_json();
@ -1980,7 +1976,7 @@ struct server_context {
const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
batch.reset(llama_batch_init(std::max(n_batch, params_base.n_parallel), 1));
}
metrics.init();
@ -2098,9 +2094,7 @@ struct server_context {
}
if (slot.ctx_dft) {
llama_batch_free(slot.batch_spec);
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
slot.batch_spec.reset(llama_batch_init(slot.params.speculative.n_max + 1, 1));
}
slot.state = SLOT_STATE_STARTED;
@ -2408,7 +2402,7 @@ struct server_context {
queue_results.send(std::move(res));
}
void send_embedding(const server_slot & slot, const llama_batch & batch) {
void send_embedding(const server_slot & slot, llama_batch_ptr & batch) {
auto res = std::make_unique<server_task_result_embd>();
res->id = slot.id_task;
res->index = slot.index;
@ -2419,18 +2413,19 @@ struct server_context {
std::vector<float> embd_res(n_embd, 0.0f);
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) {
llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i);
if (!tok.logits || tok.seq_id[0] != slot.id) {
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]);
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
continue;
@ -2451,24 +2446,25 @@ struct server_context {
queue_results.send(std::move(res));
}
void send_rerank(const server_slot & slot, const llama_batch & batch) {
void send_rerank(const server_slot & slot, llama_batch_ptr & batch) {
auto res = std::make_unique<server_task_result_rerank>();
res->id = slot.id_task;
res->index = slot.index;
res->n_tokens = slot.n_prompt_tokens;
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) {
llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i);
if (!tok.logits || tok.seq_id[0] != slot.id) {
continue;
}
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
}
if (embd == NULL) {
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]);
res->score = -1e6;
continue;
@ -2859,7 +2855,7 @@ struct server_context {
}
// start populating the batch for this iteration
common_batch_clear(batch);
common_batch_clear(batch.get());
// track if given slot can be batched with slots already in the batch
server_slot * slot_batched = nullptr;
@ -2881,9 +2877,9 @@ struct server_context {
continue;
}
slot.i_batch = batch.n_tokens;
slot.i_batch = llama_batch_get_n_tokens(batch.get());
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
common_batch_add(batch.get(), slot.sampled, slot.n_past, { slot.id }, true);
slot.n_past += 1;
@ -2900,7 +2896,7 @@ struct server_context {
int32_t n_ubatch = llama_n_ubatch(ctx);
// next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) {
if (params_base.cont_batching || llama_batch_get_n_tokens(batch.get()) == 0) {
for (auto & slot : slots) {
// check if we can batch this slot with the previous one
if (slot.is_processing()) {
@ -3066,7 +3062,7 @@ struct server_context {
// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.is_non_causal()) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
if (llama_batch_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) {
continue;
}
}
@ -3086,11 +3082,11 @@ struct server_context {
slot.cache_tokens.resize(slot.n_past);
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
while (slot.n_past < slot.n_prompt_tokens && llama_batch_get_n_tokens(batch.get()) < n_batch) {
// without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
common_batch_add(batch.get(), prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@ -3100,13 +3096,13 @@ struct server_context {
slot.n_past++;
}
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
// entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_DONE_PROMPT;
GGML_ASSERT(batch.n_tokens > 0);
GGML_ASSERT(llama_batch_get_n_tokens(batch.get()) > 0);
common_sampler_reset(slot.smpl);
@ -3116,27 +3112,27 @@ struct server_context {
}
// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
llama_batch_set_logits_last(batch.get());
slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
slot.i_batch = llama_batch_get_n_tokens(batch.get()) - 1;
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_get_n_tokens(batch.get()));
}
}
if (batch.n_tokens >= n_batch) {
if (llama_batch_get_n_tokens(batch.get()) >= n_batch) {
break;
}
}
}
if (batch.n_tokens == 0) {
if (llama_batch_get_n_tokens(batch.get()) == 0) {
SRV_WRN("%s", "no tokens to decode\n");
return;
}
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_get_n_tokens(batch.get()));
if (slot_batched) {
// make sure we're in the right embedding mode
@ -3146,20 +3142,12 @@ struct server_context {
}
// process the created batch of tokens
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
for (int32_t i = 0; i < llama_batch_get_n_tokens(batch.get()); i += n_batch) {
const int32_t n_tokens = std::min(n_batch, llama_batch_get_n_tokens(batch.get()) - i);
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
llama_batch_ptr batch_view(llama_batch_get_view(batch.get(), i, n_tokens));
const int ret = llama_decode(ctx, batch_view);
const int ret = llama_decode(ctx, batch_view.get());
metrics.on_decoded(slots);
if (ret != 0) {
@ -3294,16 +3282,16 @@ struct server_context {
}
// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
common_batch_clear(slot.batch_spec.get());
common_batch_add (slot.batch_spec.get(), id, slot.n_past, { slot.id }, true);
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
common_batch_add(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, { slot.id }, true);
}
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_get_n_tokens(slot.batch_spec.get()));
llama_decode(ctx, slot.batch_spec);
llama_decode(ctx, slot.batch_spec.get());
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);