diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b745dd044..057184764 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1205,6 +1205,47 @@ struct server_task_result_apply_lora : server_task_result { } }; +struct server_batch { + llama_batch_ext_ptr batch; + struct batch_token { + llama_token token; + llama_seq_id seq_id; + bool logits; + }; + std::vector tokens; + server_batch() = default; + server_batch(int32_t n_tokens, int32_t n_seq_max) { + batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); + tokens.reserve(n_tokens); + } + void clear() { + llama_batch_ext_clear(batch.get()); + tokens.clear(); + } + void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { + llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); + tokens.push_back({token, seq_id, logits}); + } + void set_logits_last() { + if (!tokens.empty()) { + llama_batch_ext_set_logits_last(batch.get()); + tokens.back().logits = true; + } + } + int32_t get_n_tokens() const { + return (int32_t)tokens.size(); + } + server_batch get_view(int32_t offset, int32_t n_tokens) { + server_batch view; + view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens)); + view.tokens.reserve(n_tokens); + for (int32_t i = 0; i < n_tokens; i++) { + view.tokens.push_back(tokens[offset + i]); + } + return view; + } +}; + struct server_slot { int id; int id_task = -1; @@ -1212,7 +1253,7 @@ struct server_slot { // only used for completion/embedding/infill/rerank server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; - llama_batch_ext_ptr batch_spec; + server_batch batch_spec; llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1784,7 +1825,7 @@ struct server_context { llama_context_params cparams_dft; - llama_batch_ext_ptr batch; + server_batch batch; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1909,7 +1950,7 @@ struct server_context { slot.n_predict = params_base.n_predict; if (model_dft) { - slot.batch_spec.reset(llama_batch_ext_init(params_base.speculative.n_max + 1, 1)); + slot.batch_spec = server_batch(params_base.speculative.n_max + 1, 1); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { @@ -1945,7 +1986,7 @@ struct server_context { const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch.reset(llama_batch_ext_init(std::max(n_batch, params_base.n_parallel), 1)); + batch = server_batch(std::max(n_batch, params_base.n_parallel), 1); } metrics.init(); @@ -2063,7 +2104,7 @@ struct server_context { } if (slot.ctx_dft) { - slot.batch_spec.reset(llama_batch_ext_init(slot.params.speculative.n_max + 1, 1)); + slot.batch_spec = server_batch(slot.params.speculative.n_max + 1, 1); } slot.state = SLOT_STATE_STARTED; @@ -2371,7 +2412,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_embedding(const server_slot & slot, llama_batch_ext_ptr & batch) { + void send_embedding(const server_slot & slot, server_batch & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -2382,19 +2423,19 @@ struct server_context { std::vector embd_res(n_embd, 0.0f); - for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) { - llama_batch_ext_token_info tok = llama_batch_ext_get_token_info(batch.get(), i); - if (!tok.logits || tok.seq_id[0] != slot.id) { + for (int i = 0; i < batch.get_n_tokens(); ++i) { + auto tok = batch.tokens[i]; + if (!tok.logits || tok.seq_id != slot.id) { continue; } - const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]); + const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id); if (embd == NULL) { embd = llama_get_embeddings_ith(ctx, i); } if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]); + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); res->embedding.push_back(std::vector(n_embd, 0.0f)); continue; @@ -2415,25 +2456,25 @@ struct server_context { queue_results.send(std::move(res)); } - void send_rerank(const server_slot & slot, llama_batch_ext_ptr & batch) { + void send_rerank(const server_slot & slot, server_batch & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; res->n_tokens = slot.n_prompt_tokens; - for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) { - llama_batch_ext_token_info tok = llama_batch_ext_get_token_info(batch.get(), i); - if (!tok.logits || tok.seq_id[0] != slot.id) { + for (int i = 0; i < batch.get_n_tokens(); ++i) { + auto tok = batch.tokens[i]; + if (!tok.logits || tok.seq_id != slot.id) { continue; } - const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]); + const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id); if (embd == NULL) { embd = llama_get_embeddings_ith(ctx, i); } if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]); + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); res->score = -1e6; continue; @@ -2824,7 +2865,7 @@ struct server_context { } // start populating the batch for this iteration - llama_batch_ext_clear(batch.get()); + batch.clear(); // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; @@ -2846,10 +2887,9 @@ struct server_context { continue; } - slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()); + slot.i_batch = batch.get_n_tokens(); - std::array seq_id = { slot.id }; - llama_batch_ext_add_text(batch.get(), slot.sampled, slot.n_past, seq_id.data(), seq_id.size(), true); + batch.add_text(slot.sampled, slot.n_past, slot.id, true); slot.n_past += 1; @@ -2866,7 +2906,7 @@ struct server_context { int32_t n_ubatch = llama_n_ubatch(ctx); // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) { + if (params_base.cont_batching || batch.get_n_tokens() == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one if (slot.is_processing()) { @@ -3032,7 +3072,7 @@ struct server_context { // non-causal tasks require to fit the entire prompt in the physical batch if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (llama_batch_ext_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) { + if (batch.get_n_tokens() + slot.n_prompt_tokens > n_batch) { continue; } } @@ -3052,12 +3092,11 @@ struct server_context { slot.cache_tokens.resize(slot.n_past); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens(batch.get()) < n_batch) { + while (slot.n_past < slot.n_prompt_tokens && batch.get_n_tokens() < n_batch) { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - std::array seq_id = { slot.id }; - llama_batch_ext_add_text(batch.get(), prompt_tokens[slot.n_past], slot.n_past, seq_id.data(), seq_id.size(), need_embd); + batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -3067,13 +3106,13 @@ struct server_context { slot.n_past++; } - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.get_n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) > 0); + GGML_ASSERT(batch.get_n_tokens() > 0); common_sampler_reset(slot.smpl); @@ -3083,27 +3122,27 @@ struct server_context { } // extract the logits only for the last token - llama_batch_ext_set_logits_last(batch.get()); + batch.set_logits_last(); slot.n_decoded = 0; - slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1; + slot.i_batch = batch.get_n_tokens() - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get())); + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.get_n_tokens()); } } - if (llama_batch_ext_get_n_tokens(batch.get()) >= n_batch) { + if (batch.get_n_tokens() >= n_batch) { break; } } } - if (llama_batch_ext_get_n_tokens(batch.get()) == 0) { + if (batch.get_n_tokens() == 0) { SRV_WRN("%s", "no tokens to decode\n"); return; } - SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_ext_get_n_tokens(batch.get())); + SRV_DBG("decoding batch, n_tokens = %d\n", batch.get_n_tokens()); if (slot_batched) { // make sure we're in the right embedding mode @@ -3113,12 +3152,12 @@ struct server_context { } // process the created batch of tokens - for (int32_t i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i += n_batch) { - const int32_t n_tokens = std::min(n_batch, llama_batch_ext_get_n_tokens(batch.get()) - i); + for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) { + const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i); - llama_batch_ext_ptr batch_view(llama_batch_ext_get_view(batch.get(), i, n_tokens)); + server_batch batch_view = batch.get_view(i, n_tokens); - const int ret = llama_decode_ext(ctx, batch_view.get()); + const int ret = llama_decode_ext(ctx, batch_view.batch.get()); metrics.on_decoded(slots); if (ret != 0) { @@ -3253,17 +3292,16 @@ struct server_context { } // construct the speculation batch - llama_batch_ext_clear(slot.batch_spec.get()); - std::array seq_id = { slot.id }; - llama_batch_ext_add_text(slot.batch_spec.get(), id, slot.n_past, seq_id.data(), seq_id.size(), true); + slot.batch_spec.clear(); + slot.batch_spec.add_text(id, slot.n_past, slot.id, true); for (size_t i = 0; i < draft.size(); ++i) { - llama_batch_ext_add_text(slot.batch_spec.get(), draft[i], slot.n_past + 1, seq_id.data(), seq_id.size(), true); + slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true); } - SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get())); + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens()); - llama_decode_ext(ctx, slot.batch_spec.get()); + llama_decode_ext(ctx, slot.batch_spec.batch.get()); // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); diff --git a/include/llama.h b/include/llama.h index dab1aea2b..d370e0523 100644 --- a/include/llama.h +++ b/include/llama.h @@ -263,14 +263,6 @@ extern "C" { // It can contain text tokens and embeddings for one or many sequences struct llama_batch_ext; - struct llama_batch_ext_token_info { - llama_token token; - llama_pos pos; - int32_t n_seq_id; - llama_seq_id * seq_id; - int8_t logits; - }; - enum llama_model_kv_override_type { LLAMA_KV_OVERRIDE_TYPE_INT, LLAMA_KV_OVERRIDE_TYPE_FLOAT, @@ -896,10 +888,6 @@ extern "C" { // Get the number of tokens in the batch LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch); - LLAMA_API struct llama_batch_ext_token_info llama_batch_ext_get_token_info( - struct llama_batch_ext * batch, - int32_t i); - // Add text tokens to the batch // Return values: // 0 : success diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index b63d4ec7f..d8117c3f0 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -480,19 +480,6 @@ struct llama_batch_ext * llama_batch_ext_get_view( return batch_view; } -struct llama_batch_ext_token_info llama_batch_ext_get_token_info( - struct llama_batch_ext * batch, - int32_t i) { - GGML_ASSERT(i >= 0 && i < batch->n_tokens); - return llama_batch_ext_token_info{ - /*token =*/ batch->token [i], - /*pos =*/ batch->pos [i], - /*n_seq_id =*/ batch->n_seq_id[i], - /*seq_id =*/ batch->seq_id [i], - /*logits =*/ batch->logits [i], - }; -} - void llama_batch_ext_free(struct llama_batch_ext * batch) { // do not free the members if it's a view if (!batch->is_view) {