diff --git a/common/common.h b/common/common.h index c7e71bb29..86cad8655 100644 --- a/common/common.h +++ b/common/common.h @@ -565,6 +565,52 @@ void common_batch_add( const std::vector & seq_ids, bool logits); +// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions +// this is meant to be temporary +struct common_batch { + llama_batch_ext_ptr batch; + struct batch_token { + llama_token token; + llama_seq_id seq_id; + bool logits; + }; + std::vector tokens; + common_batch() = default; + common_batch(int32_t n_tokens, int32_t n_seq_max) { + batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); + tokens.reserve(n_tokens); + } + void clear() { + llama_batch_ext_clear(batch.get()); + tokens.clear(); + } + void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { + llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); + tokens.push_back({token, seq_id, logits}); + } + void set_logits_last() { + if (!tokens.empty()) { + llama_batch_ext_set_logits_last(batch.get()); + tokens.back().logits = true; + } + } + int32_t get_n_tokens() const { + return (int32_t)tokens.size(); + } + llama_batch_ext * get() { + return batch.get(); + } + common_batch get_view(int32_t offset, int32_t n_tokens) { + common_batch view; + view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens)); + view.tokens.reserve(n_tokens); + for (int32_t i = 0; i < n_tokens; i++) { + view.tokens.push_back(tokens[offset + i]); + } + return view; + } +}; + // // Token utils // diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 0659ab6f1..829bf7f94 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -59,24 +59,17 @@ int main(int argc, char ** argv) { const int32_t n_kv_max = llama_n_ctx(ctx); - llama_batch batch = llama_batch_init(n_kv_max, 0, 1); + llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1); // decode in batches of ctx_params.n_batch tokens - auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) { - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) { + const int32_t n_batch_tokens = llama_batch_ext_get_n_tokens(batch); + for (int32_t i = 0; i < (int32_t) n_batch_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (n_batch_tokens - i)); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; + llama_batch_ext_ptr batch_view = llama_batch_ext_ptr(llama_batch_ext_get_view(batch, i, n_tokens)); - const int ret = llama_decode(ctx, batch_view); + const int ret = llama_decode_ext(ctx, batch_view.get()); if (ret != 0) { LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); return false; @@ -91,7 +84,8 @@ int main(int argc, char ** argv) { // warm up { for (int i = 0; i < 16; ++i) { - common_batch_add(batch, 0, i, { 0 }, false); + const llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false); } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { @@ -121,14 +115,14 @@ int main(int argc, char ** argv) { continue; } - common_batch_clear(batch); + llama_batch_ext_clear(batch); for (int i = 0; i < pp; ++i) { for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { - common_batch_add(batch, 0, i, { j }, false); + llama_batch_ext_add_text(batch, 0, i, &j, 1, false); } } - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_logits_last(batch); const auto t_pp_start = ggml_time_us(); @@ -150,10 +144,10 @@ int main(int argc, char ** argv) { const auto t_tg_start = ggml_time_us(); for (int i = 0; i < tg; ++i) { - common_batch_clear(batch); + llama_batch_ext_clear(batch); for (int j = 0; j < pl; ++j) { - common_batch_add(batch, 0, pp + i, { j }, true); + llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, false); } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { @@ -191,7 +185,7 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); - llama_batch_free(batch); + llama_batch_ext_free(batch); llama_free(ctx); llama_model_free(model); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 21b95ef5e..858053a88 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -102,7 +102,7 @@ int main(int argc, char ** argv) { // create a llama_batch // we use this object to submit token data for decoding - llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel); + llama_batch_ext * batch = llama_batch_ext_init(std::max(tokens_list.size(), (size_t) n_parallel), n_parallel); std::vector seq_ids(n_parallel, 0); for (int32_t i = 0; i < n_parallel; ++i) { @@ -111,12 +111,12 @@ int main(int argc, char ** argv) { // evaluate the initial prompt for (size_t i = 0; i < tokens_list.size(); ++i) { - common_batch_add(batch, tokens_list[i], i, seq_ids, false); + llama_batch_ext_add_text(batch, tokens_list[i], i, seq_ids.data(), seq_ids.size(), false); } - GGML_ASSERT(batch.n_tokens == (int) tokens_list.size()); + GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) tokens_list.size()); if (llama_model_has_encoder(model)) { - if (llama_encode(ctx, batch)) { + if (llama_encode_ext(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } @@ -126,14 +126,14 @@ int main(int argc, char ** argv) { decoder_start_token_id = llama_vocab_bos(vocab); } - common_batch_clear(batch); - common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false); + llama_batch_ext_clear(batch); + llama_batch_ext_add_text(batch, decoder_start_token_id, 0, seq_ids.data(), seq_ids.size(), false); } // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_logits_last(batch); - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -155,16 +155,16 @@ int main(int argc, char ** argv) { // remember the batch index of the last token for each parallel sequence // we need this to determine which logits to sample from - std::vector i_batch(n_parallel, batch.n_tokens - 1); + std::vector i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1); - int n_cur = batch.n_tokens; + int n_cur = llama_batch_ext_get_n_tokens(batch); int n_decode = 0; const auto t_main_start = ggml_time_us(); while (n_cur <= n_predict) { // prepare the next batch - common_batch_clear(batch); + llama_batch_ext_clear(batch); // sample the next token for each parallel sequence / stream for (int32_t i = 0; i < n_parallel; ++i) { @@ -193,23 +193,23 @@ int main(int argc, char ** argv) { streams[i] += common_token_to_piece(ctx, new_token_id); - i_batch[i] = batch.n_tokens; + i_batch[i] = llama_batch_ext_get_n_tokens(batch); // push this new token for next evaluation - common_batch_add(batch, new_token_id, n_cur, { i }, true); + llama_batch_ext_add_text(batch, new_token_id, n_cur, &i, 1, false); n_decode += 1; } // all streams are finished - if (batch.n_tokens == 0) { + if (llama_batch_ext_get_n_tokens(batch) == 0) { break; } n_cur += 1; // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch)) { LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); return 1; } @@ -234,7 +234,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); - llama_batch_free(batch); + llama_batch_ext_free(batch); llama_sampler_free(smpl); llama_free(ctx); diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 413b71d34..689e3e539 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { llama_kv_cache_clear(ctx); - if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); + if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 38d22c90f..c71200958 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -25,14 +25,14 @@ static std::vector split_lines(const std::string & s, const std::st return lines; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { +static void batch_add_seq(common_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - common_batch_add(batch, tokens[i], i, { seq_id }, true); + batch.add_text(tokens[i], i, seq_id, true); } } -static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { +static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); const struct llama_model * model = llama_get_model(ctx); @@ -40,21 +40,21 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu llama_kv_cache_clear(ctx); // run model - LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { // encoder-only model - if (llama_encode(ctx, batch) < 0) { + if (llama_encode_ext(ctx, batch.get()) < 0) { LOG_ERR("%s : failed to encode\n", __func__); } } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { // decoder-only model - if (llama_decode(ctx, batch) < 0) { + if (llama_decode_ext(ctx, batch.get()) < 0) { LOG_ERR("%s : failed to decode\n", __func__); } } - for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { + for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { + if (!batch.tokens[i].logits) { continue; } @@ -68,8 +68,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu GGML_ASSERT(embd != NULL && "failed to get token embeddings"); } else { // try to get sequence embeddings - supported only when pooling_type is not NONE - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - embd_pos = batch.seq_id[i][0]; + embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); + embd_pos = batch.tokens[i].seq_id; GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); } @@ -170,7 +170,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_prompts = prompts.size(); - struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + struct common_batch batch = common_batch(n_batch, 1); // count number of embeddings int n_embd_count = 0; @@ -197,12 +197,12 @@ int main(int argc, char ** argv) { const uint64_t n_toks = inp.size(); // encode if at capacity - if (batch.n_tokens + n_toks > n_batch) { + if (batch.get_n_tokens() + n_toks > n_batch) { float * out = emb + e * n_embd; batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); - e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; + e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.get_n_tokens() : s; s = 0; - common_batch_clear(batch); + batch.clear(); } // add to batch @@ -318,7 +318,6 @@ int main(int argc, char ** argv) { llama_perf_context_print(ctx); // clean up - llama_batch_free(batch); llama_backend_free(); return 0; diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index fb188f5a9..7e600440d 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -134,7 +134,8 @@ static bool run(llama_context * ctx, const common_params & params) { std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); - if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; } diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 72eb46257..aa87c3a27 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -13,10 +13,10 @@ static std::vector> encode(llama_context * ctx, const std::ve const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); - llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); + llama_batch_ext * batch = llama_batch_ext_init(llama_n_batch(ctx), 1); for (uint64_t i = 0; i < sentences.size(); i++) { - common_batch_clear(batch); + llama_batch_ext_clear(batch); const std::string input_string = instruction + sentences[i]; @@ -41,7 +41,8 @@ static std::vector> encode(llama_context * ctx, const std::ve // add input to batch (this increments n_tokens) for (int32_t j = 0; j < n_toks; j++) { - common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst); + const llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, inputs[j], j, &seq_id, 1 , j >= n_inst); } // clear previous kv_cache values (irrelevant for embeddings) @@ -50,7 +51,7 @@ static std::vector> encode(llama_context * ctx, const std::ve llama_set_causal_attn(ctx, false); // run model - llama_decode(ctx, batch); + llama_decode_ext(ctx, batch); // get embedding dimensions uint64_t n_embd = llama_model_n_embd(model); @@ -89,7 +90,7 @@ static std::vector> encode(llama_context * ctx, const std::ve #endif } - llama_batch_free(batch); + llama_batch_ext_free(batch); return result; } @@ -106,25 +107,26 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); - llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); + llama_batch_ext * bat = llama_batch_ext_init(llama_n_batch(ctx), 1); std::vector inputs = common_tokenize(vocab, prompt, false, true); int32_t i_current_token = 0; while (true) { - common_batch_clear(bat); + llama_batch_ext_clear(bat); { const int32_t n_inputs = inputs.size(); for (int32_t i = 0; i < n_inputs; i++) { - common_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); + const llama_seq_id seq_id = 0; + llama_batch_ext_add_text(bat, inputs[i], i_current_token++, &seq_id, 1, i == n_inputs - 1); } } inputs.clear(); - llama_decode(ctx, bat); + llama_decode_ext(ctx, bat); - llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1); + llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(bat) - 1); if (token == eos_token) { break; @@ -145,7 +147,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std std::printf("\n"); } - llama_batch_free(bat); + llama_batch_ext_free(bat); return result; } diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 4edc0bfac..86f7ccbc3 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -500,7 +500,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { // clear the KV cache llama_kv_cache_clear(ctx); - llama_batch batch = llama_batch_init(n_batch, 0, 1); + llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -514,14 +514,15 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { tokens[batch_start] = llama_vocab_bos(vocab); } - common_batch_clear(batch); + llama_batch_ext_clear(batch); for (int i = 0; i < batch_size; i++) { - common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + const llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true); } - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); - llama_batch_free(batch); + llama_batch_ext_free(batch); return false; } @@ -534,7 +535,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { } } - llama_batch_free(batch); + llama_batch_ext_free(batch); const auto t_end = std::chrono::high_resolution_clock::now(); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 489a208b6..738fd6e11 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -353,7 +353,8 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0)); + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index f518d02d3..f270cce69 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1444,7 +1444,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0)); + llama_decode_ext(ctx, batch.get()); n_processed += n_tokens; } @@ -1461,7 +1462,8 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) { llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - llama_decode(ctx, llama_batch_get_one(&token, 1)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0)); + llama_decode_ext(ctx, batch.get()); llama_synchronize(ctx); token = std::rand() % n_vocab; } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index dbd0444ec..fee09adcd 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -91,8 +91,10 @@ int main(int argc, char ** argv){ const auto t_enc_start = ggml_time_us(); - llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); - llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0)); + llama_decode_ext(ctx, batch0.get()); + llama_decode_ext(ctx, batch1.get()); const auto t_enc_end = ggml_time_us(); @@ -108,7 +110,7 @@ int main(int argc, char ** argv){ std::vector draft; - llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); + llama_batch_ext * batch_tgt = llama_batch_ext_init(params.n_ctx, 1); // debug struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1); @@ -194,8 +196,9 @@ int main(int argc, char ** argv){ // clean the cache of draft tokens that weren't accepted llama_kv_cache_seq_rm(ctx, 0, n_past, -1); - common_batch_clear(batch_tgt); - common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); + const llama_seq_id seq_id = 0; + llama_batch_ext_clear(batch_tgt); + llama_batch_ext_add_text(batch_tgt, draft[0], n_past, &seq_id, 1, true); // Draft already contains a single token sampled from the model: GGML_ASSERT(draft.size() == 1); @@ -205,13 +208,13 @@ int main(int argc, char ** argv){ common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); for (size_t i = 1; i < draft.size(); ++i) { - common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true); } t_draft_us += ggml_time_us() - t_start_draft_us; n_drafted += draft.size() - 1; - llama_decode(ctx, batch_tgt); + llama_decode_ext(ctx, batch_tgt); ++n_past; draft.erase(draft.begin()); @@ -243,7 +246,7 @@ int main(int argc, char ** argv){ common_sampler_free(smpl); - llama_batch_free(batch_tgt); + llama_batch_ext_free(batch_tgt); llama_backend_free(); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 057184764..22d2c6e92 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1205,47 +1205,6 @@ struct server_task_result_apply_lora : server_task_result { } }; -struct server_batch { - llama_batch_ext_ptr batch; - struct batch_token { - llama_token token; - llama_seq_id seq_id; - bool logits; - }; - std::vector tokens; - server_batch() = default; - server_batch(int32_t n_tokens, int32_t n_seq_max) { - batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); - tokens.reserve(n_tokens); - } - void clear() { - llama_batch_ext_clear(batch.get()); - tokens.clear(); - } - void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { - llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); - tokens.push_back({token, seq_id, logits}); - } - void set_logits_last() { - if (!tokens.empty()) { - llama_batch_ext_set_logits_last(batch.get()); - tokens.back().logits = true; - } - } - int32_t get_n_tokens() const { - return (int32_t)tokens.size(); - } - server_batch get_view(int32_t offset, int32_t n_tokens) { - server_batch view; - view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens)); - view.tokens.reserve(n_tokens); - for (int32_t i = 0; i < n_tokens; i++) { - view.tokens.push_back(tokens[offset + i]); - } - return view; - } -}; - struct server_slot { int id; int id_task = -1; @@ -1253,7 +1212,7 @@ struct server_slot { // only used for completion/embedding/infill/rerank server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; - server_batch batch_spec; + common_batch batch_spec; llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1825,7 +1784,7 @@ struct server_context { llama_context_params cparams_dft; - server_batch batch; + common_batch batch; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1950,7 +1909,7 @@ struct server_context { slot.n_predict = params_base.n_predict; if (model_dft) { - slot.batch_spec = server_batch(params_base.speculative.n_max + 1, 1); + slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { @@ -1986,7 +1945,7 @@ struct server_context { const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch = server_batch(std::max(n_batch, params_base.n_parallel), 1); + batch = common_batch(std::max(n_batch, params_base.n_parallel), 1); } metrics.init(); @@ -2104,7 +2063,7 @@ struct server_context { } if (slot.ctx_dft) { - slot.batch_spec = server_batch(slot.params.speculative.n_max + 1, 1); + slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1); } slot.state = SLOT_STATE_STARTED; @@ -2412,7 +2371,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_embedding(const server_slot & slot, server_batch & batch) { + void send_embedding(const server_slot & slot, common_batch & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -2456,7 +2415,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_rerank(const server_slot & slot, server_batch & batch) { + void send_rerank(const server_slot & slot, common_batch & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -3155,9 +3114,9 @@ struct server_context { for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i); - server_batch batch_view = batch.get_view(i, n_tokens); + common_batch batch_view = batch.get_view(i, n_tokens); - const int ret = llama_decode_ext(ctx, batch_view.batch.get()); + const int ret = llama_decode_ext(ctx, batch_view.get()); metrics.on_decoded(slots); if (ret != 0) { @@ -3301,7 +3260,7 @@ struct server_context { SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens()); - llama_decode_ext(ctx, slot.batch_spec.batch.get()); + llama_decode_ext(ctx, slot.batch_spec.get()); // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);