From 47086fa82d24b8d39ba4e4ecdc09927c721055ad Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 13 Mar 2025 22:36:27 +0100 Subject: [PATCH] apply to the rest --- common/common.cpp | 37 ------- common/common.h | 18 +++- examples/llava/gemma3-cli.cpp | 58 +++-------- examples/llava/llava.cpp | 38 +------ examples/llava/qwen2vl-cli.cpp | 1 + examples/lookahead/lookahead.cpp | 21 ++-- examples/parallel/parallel.cpp | 50 +++++----- examples/passkey/passkey.cpp | 32 +++--- examples/perplexity/perplexity.cpp | 98 +++++++------------ examples/retrieval/retrieval.cpp | 61 +++++++----- examples/run/run.cpp | 10 +- examples/save-load-state/save-load-state.cpp | 39 ++++---- examples/simple-chat/simple-chat.cpp | 13 ++- examples/simple/simple.cpp | 14 ++- .../speculative-simple/speculative-simple.cpp | 12 ++- examples/speculative/speculative.cpp | 19 ++-- examples/tts/tts.cpp | 36 +++---- include/llama.h | 8 +- 18 files changed, 242 insertions(+), 323 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 8eb65053c..ec4bf699a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -582,43 +582,6 @@ std::string string_from(const struct llama_context * ctx, const std::vector & values); std::string string_from(const struct llama_context * ctx, const std::vector & tokens); -std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch); // // Filesystem utils @@ -587,10 +586,10 @@ struct common_batch { llama_batch_ext_ptr batch; struct batch_token { llama_token token; - llama_seq_id seq_id; bool logits; }; std::vector tokens; + int n_outputs = 0; common_batch() = default; common_batch(int32_t n_tokens, int32_t n_seq_max) { batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); @@ -602,7 +601,17 @@ struct common_batch { } void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); - tokens.push_back({token, seq_id, logits}); + tokens.push_back({token, logits}); + if (logits) { + n_outputs++; + } + } + void add_text(llama_token token, llama_pos pos, std::vector seq_ids, bool logits) { + llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); + tokens.push_back({token, logits}); + if (logits) { + n_outputs++; + } } void set_logits_last() { if (!tokens.empty()) { @@ -622,6 +631,9 @@ struct common_batch { view.tokens.reserve(n_tokens); for (int32_t i = 0; i < n_tokens; i++) { view.tokens.push_back(tokens[offset + i]); + if (tokens[offset + i].logits) { + view.n_outputs++; + } } return view; } diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index c36bb2eda..9aa710652 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -5,6 +5,7 @@ #include "clip.h" #include "stb_image.h" #include "llama.h" +#include "llama-cpp.h" #include "ggml.h" #include "console.h" @@ -63,7 +64,7 @@ struct gemma3_context { llama_model * model; llama_context * lctx; const llama_vocab * vocab; - llama_batch batch; + llama_batch_ext_ptr batch; int n_threads = 1; llama_pos n_past = 0; @@ -73,7 +74,7 @@ struct gemma3_context { lctx = llama_init.context.get(); vocab = llama_model_get_vocab(model); n_threads = params.cpuparams.n_threads; - batch = llama_batch_init(params.n_batch, 0, 1); + batch.reset(llama_batch_ext_init(params.n_batch, 1)); init_clip_model(params); } @@ -87,50 +88,18 @@ struct gemma3_context { } }; -struct decode_embd_batch { - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_0; - std::vector seq_ids; - std::vector logits; - llama_batch batch; - decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { - pos .resize(n_tokens); - n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); - seq_id_0.resize(1); - seq_id_0[0] = seq_id; - seq_ids [n_tokens] = nullptr; - batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, - /*pos =*/ pos.data(), - /*n_seq_id =*/ n_seq_id.data(), - /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), - }; - for (int i = 0; i < n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; - } - } -}; - static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) { llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); - common_batch_clear(ctx.batch); + llama_batch_ext_clear(ctx.batch.get()); for (llama_token & t : tokens) { - common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(ctx.batch.get(), t, 0, &seq_id, 1, false); } if (logits_last) { - ctx.batch.logits[ctx.batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(ctx.batch.get()); } // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str()); - if (llama_decode(ctx.lctx, ctx.batch)) { + if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { LOG_ERR("Failed to decode text\n"); return 1; } @@ -179,8 +148,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) { int64_t t1 = ggml_time_ms(); eval_text(ctx, ""); llama_set_causal_attn(ctx.lctx, false); - decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0); - if (llama_decode(ctx.lctx, batch_img.batch)) { + llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0)); + if (llama_decode_ext(ctx.lctx, batch_img.get())) { LOG_ERR("failed to decode image\n"); return 1; } @@ -210,9 +179,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_ fflush(stdout); // eval the token - common_batch_clear(ctx.batch); - common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true); - if (llama_decode(ctx.lctx, ctx.batch)) { + llama_batch_ext_clear(ctx.batch.get()); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true); + if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { LOG_ERR("failed to decode token\n"); return 1; } diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 518aad3f1..53ce30215 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -2,6 +2,7 @@ #include "llava.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -438,39 +439,6 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co return true; } -struct llava_embd_batch { - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_0; - std::vector seq_ids; - std::vector logits; - llama_batch batch; - llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { - pos .resize(n_tokens); - n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); - seq_id_0.resize(1); - seq_id_0[0] = seq_id; - seq_ids [n_tokens] = nullptr; - batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, - /*pos =*/ pos.data(), - /*n_seq_id =*/ n_seq_id.data(), - /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), - }; - for (int i = 0; i < n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; - } - } -}; - bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); @@ -480,8 +448,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ n_eval = n_batch; } float * embd = image_embed->embed+i*n_embd; - llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); - if (llama_decode(ctx_llama, llava_batch.batch)) { + llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, 0, 0)); + if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; } diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 132a7da54..d65e88f9d 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -66,6 +66,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); + // TODO: move this to llama_batch_ext API llama_batch batch = { int32_t(n_eval), // n_tokens nullptr, // token diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 7df20aee1..1c2c3ec46 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -115,7 +115,7 @@ int main(int argc, char ** argv) { // seq_id == 0 : the current input token // seq_id [1, W] : tokens from the past N - 1 Jacobi iterations // seq_id [W + 1, W + G] : verification n-grams - llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); + llama_batch_ext * batch = llama_batch_ext_init(params.n_ctx, W + G + 1); // target model sampling context struct common_sampler * smpl = common_sampler_init(model, params.sampling); @@ -204,10 +204,10 @@ int main(int argc, char ** argv) { // V V V V V V // id { - common_batch_clear(batch); + llama_batch_ext_clear(batch); // current token - first token of the first level - common_batch_add(batch, id, n_past, seq_id_all, true); + llama_batch_ext_add_text(batch, id, n_past, seq_id_all.data(), seq_id_all.size(), true); // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation { @@ -230,9 +230,10 @@ int main(int argc, char ** argv) { const llama_token t = ngrams_observed.tokens[idx + j]; ngrams_cur[g].tokens [j + 1] = t; - ngrams_cur[g].i_batch[j + 1] = batch.n_tokens; + ngrams_cur[g].i_batch[j + 1] = llama_batch_ext_get_n_tokens(batch); - common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true); + llama_seq_id seq_id = W + 1 + g; + llama_batch_ext_add_text(batch, t, n_past + j + 1, &seq_id, 1, true); } } } @@ -244,18 +245,20 @@ int main(int argc, char ** argv) { seq_id_look[j] = i + j + 1; } - common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); + llama_batch_ext_add_text(batch, tokens_j[0][i], n_past + i, + seq_id_look.data(), seq_id_look.size(), false); } // fill the rest of the levels for (int j = 1; j < N - 1; j++) { for (int i = 0; i < W; i++) { - common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2); + llama_seq_id seq_id = i + 1; + llama_batch_ext_add_text(batch, tokens_j[j][i], n_past + j + i, &seq_id, 1, j == N - 2); } } } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch) != 0) { LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__); return 1; } @@ -475,7 +478,7 @@ int main(int argc, char ** argv) { llama_kv_cache_view_free(&kvc_view); - llama_batch_free(batch); + llama_batch_ext_free(batch); llama_backend_free(); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 588632f04..1d5f59f7d 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -174,7 +174,7 @@ int main(int argc, char ** argv) { // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time - llama_batch batch = llama_batch_init(n_ctx, 0, 1); + llama_batch_ext * batch = llama_batch_ext_init(n_ctx, 1); int32_t n_total_prompt = 0; int32_t n_total_gen = 0; @@ -192,10 +192,11 @@ int main(int argc, char ** argv) { LOG_INF("%s: Evaluating the system prompt ...\n", __func__); for (int32_t i = 0; i < n_tokens_system; ++i) { - common_batch_add(batch, tokens_system[i], i, { 0 }, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, tokens_system[i], i, &seq_id, 1, false); } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -216,7 +217,7 @@ int main(int argc, char ** argv) { common_kv_cache_dump_view_seqs(kvc_view, 40); } - common_batch_clear(batch); + llama_batch_ext_clear(batch); // decode any currently ongoing sequences for (auto & client : clients) { @@ -224,14 +225,15 @@ int main(int argc, char ** argv) { continue; } - client.i_batch = batch.n_tokens; + client.i_batch = llama_batch_ext_get_n_tokens(batch); - common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true); + llama_seq_id seq_id = client.id + 1; + llama_batch_ext_add_text(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, &seq_id, 1, true); client.n_decoded += 1; } - if (batch.n_tokens == 0) { + if (llama_batch_ext_get_n_tokens(batch) == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { llama_kv_self_seq_rm(ctx, i, -1, -1); @@ -243,7 +245,7 @@ int main(int argc, char ** argv) { } // insert new sequences for decoding - if (cont_batching || batch.n_tokens == 0) { + if (cont_batching || llama_batch_ext_get_n_tokens(batch) == 0) { for (auto & client : clients) { if (client.seq_id == -1 && g_seq_id < n_seq) { client.seq_id = g_seq_id; @@ -262,17 +264,18 @@ int main(int argc, char ** argv) { tokens_prompt = common_tokenize(ctx, client.prompt, false); for (size_t i = 0; i < tokens_prompt.size(); ++i) { - common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false); + llama_seq_id seq_id = client.id + 1; + llama_batch_ext_add_text(batch, tokens_prompt[i], i + n_tokens_system, &seq_id, 1, false); } // extract the logits only for the last token - if (batch.n_tokens > 0) { - batch.logits[batch.n_tokens - 1] = true; + if (llama_batch_ext_get_n_tokens(batch) > 0) { + llama_batch_ext_set_output_last(batch); } client.n_prompt = tokens_prompt.size(); client.n_decoded = 0; - client.i_batch = batch.n_tokens - 1; + client.i_batch = llama_batch_ext_get_n_tokens(batch) - 1; LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); @@ -286,14 +289,15 @@ int main(int argc, char ** argv) { } } - if (batch.n_tokens == 0) { + if (llama_batch_ext_get_n_tokens(batch) == 0) { break; } // process in chunks of params.n_batch int32_t n_batch = params.n_batch; - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch); + for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) { // experiment: process in powers of 2 //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { // n_batch /= 2; @@ -301,19 +305,11 @@ int main(int argc, char ** argv) { // continue; //} - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + const int32_t n_tokens = std::min(n_batch, (int32_t) (n_tokens_in_batch - i)); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; - - const int ret = llama_decode(ctx, batch_view); + llama_batch_ext * batch_view = llama_batch_ext_get_view(batch, i, n_tokens); + const int ret = llama_decode_ext(ctx, batch_view); + llama_batch_ext_free(batch_view); if (ret != 0) { if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size @@ -417,7 +413,7 @@ int main(int argc, char ** argv) { // TODO: print sampling/grammar timings for all clients llama_perf_context_print(ctx); - llama_batch_free(batch); + llama_batch_ext_free(batch); llama_backend_free(); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index ea3a6c1fc..88e6ccdde 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -122,7 +123,7 @@ int main(int argc, char ** argv) { LOG_INF("prompt tokens: %d\n", n_tokens_all); //LOG_INF("prompt: %s\n", params.prompt.c_str()); - llama_batch batch = llama_batch_init(params.n_batch, 0, 1); + llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1)); int n_past = 0; @@ -140,17 +141,18 @@ int main(int argc, char ** argv) { n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; } - common_batch_clear(batch); + llama_batch_ext_clear(batch.get()); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false); } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch.get()); } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch.get()) != 0) { LOG_INF("%s: llama_decode() failed\n", __func__); return 1; } @@ -174,17 +176,18 @@ int main(int argc, char ** argv) { n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; - common_batch_clear(batch); + llama_batch_ext_clear(batch.get()); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false); } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch.get()); } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch.get()) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -223,7 +226,7 @@ int main(int argc, char ** argv) { while (n_cur <= n_len) { // sample the next token { - const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1); // is it an end of generation? if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { @@ -237,16 +240,17 @@ int main(int argc, char ** argv) { n_decode += 1; // prepare the next batch - common_batch_clear(batch); + llama_batch_ext_clear(batch.get()); // push this new token for next evaluation - common_batch_add(batch, new_token_id, n_past++, { 0 }, true); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch.get(), new_token_id, n_past++, &seq_id, 1, true); } n_cur += 1; // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); return 1; } @@ -266,8 +270,6 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl); - llama_batch_free(batch); - llama_free(ctx); llama_model_free(model); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 8c413f7d6..d24fddbf4 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -363,21 +363,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params // clear the KV cache llama_kv_self_clear(ctx); - llama_batch batch = llama_batch_init(n_batch, 0, 1); + common_batch batch(n_batch, 1); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - common_batch_clear(batch); + batch.clear(); for (int i = 0; i < batch_size; i++) { - common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true); } //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { //LOG_ERR("%s : failed to eval\n", __func__); - llama_batch_free(batch); return {tokens, -1, logit_history, prob_history}; } @@ -397,8 +396,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params } } - llama_batch_free(batch); - const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { @@ -504,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); GGML_ASSERT(params.n_ctx == n_seq * n_ctx); - llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); + common_batch batch(std::min(n_batch, n_ctx*n_seq), 1); std::vector logits; if (num_batches > 1) { @@ -555,7 +552,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & int n_outputs = 0; - batch.n_tokens = 0; + batch.clear(); for (int seq = 0; seq < n_seq_batch; seq++) { int seq_start = batch_start + seq*n_ctx; @@ -569,21 +566,18 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & for (int k = 0; k < batch_size; ++k) { const int idx = seq*n_ctx + k; - batch.token [idx] = tokens[seq_start + k]; - batch.pos [idx] = j*n_batch + k; - batch.n_seq_id[idx] = 1; - batch.seq_id [idx][0] = seq; - batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0; + const llama_pos pos = j*n_batch + k; + bool output = pos >= first; + batch.add_text(tokens[seq_start + k], pos, seq, output); - n_outputs += batch.logits[idx] != 0; + n_outputs += output ? 1 : 0; } - batch.n_tokens += batch_size; // restore the original token in case it was set to BOS tokens[seq_start] = token_org; } - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { LOG_INF("%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } @@ -653,36 +647,23 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & LOG_ERR("Unexpected negative standard deviation of log(prob)\n"); } - llama_batch_free(batch); - return {tokens, ppl, logit_history, prob_history}; } -static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector & batch_logits, int n_batch, int n_vocab) { +static bool decode_helper(llama_context * ctx, common_batch & batch, std::vector & batch_logits, int n_batch, int n_vocab) { int prev_outputs = 0; - for (int i = 0; i < (int) batch.n_tokens; i += n_batch) { - const int n_tokens = std::min(n_batch, batch.n_tokens - i); + for (int i = 0; i < (int) batch.get_n_tokens(); i += n_batch) { + const int n_tokens = std::min(n_batch, batch.get_n_tokens() - i); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; + common_batch batch_view = batch.get_view(i, n_tokens); - const int ret = llama_decode(ctx, batch_view); + const int ret = llama_decode_ext(ctx, batch_view.get()); if (ret != 0) { LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); return false; } - int n_outputs = 0; - for (int i = 0; i < n_tokens; ++i) { - n_outputs += batch_view.logits[i] != 0; - } + int n_outputs = batch_view.n_outputs; memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float)); @@ -863,7 +844,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch batch = llama_batch_init(n_ctx, 0, 4); + common_batch batch(n_ctx, 4); std::vector tok_logits(n_vocab); // TODO: this could be made smaller; it's currently the worst-case size @@ -879,7 +860,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { size_t i1 = i0; size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch - common_batch_clear(batch); + batch.clear(); // batch as much tasks as possible into the available context // each task has 4 unique sequence ids - one for each ending @@ -895,9 +876,9 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { } for (size_t i = 0; i < hs_cur.common_prefix; ++i) { - common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); + batch.add_text(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + llama_batch_ext_set_output_last(batch.get()); n_logits += 1; for (int s = 0; s < 4; ++s) { @@ -905,7 +886,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { // TODO: don't evaluate the last token of each sequence for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; - common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); + batch.add_text(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); n_logits += needs_logits; } } @@ -992,8 +973,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { i0 = i1 - 1; } - llama_batch_free(batch); - LOG("\n"); } @@ -1147,7 +1126,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) const int max_tasks_per_batch = 128; const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch batch = llama_batch_init(n_ctx, 0, 2); + common_batch batch(n_ctx, 2); std::vector tok_logits(n_vocab); // TODO: this could be made smaller; it's currently the worst-case size @@ -1166,7 +1145,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) size_t i1 = i0; size_t i_logits = 0; - common_batch_clear(batch); + batch.clear(); while (n_cur + (int) data[i1].required_tokens <= n_ctx) { int n_logits = 0; @@ -1176,15 +1155,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params) } for (size_t i = 0; i < data[i1].common_prefix; ++i) { - common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); + batch.add_text(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); } - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch.get()); n_logits += 1; for (int s = 0; s < 2; ++s) { // TODO: end before the last token, no need to predict past the end of the sequences for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { - common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true); + batch.add_text(data[i1].seq_tokens[s][i], i, { s0 + s }, true); n_logits += 1; } } @@ -1501,7 +1480,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); + common_batch batch(n_ctx, max_seq); std::vector tok_logits(n_vocab); std::vector batch_logits(size_t(n_ctx)*n_vocab); @@ -1521,7 +1500,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par size_t i1 = i0; size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch - common_batch_clear(batch); + batch.clear(); // batch as much tasks as possible into the available context // each task has 4 unique sequence ids - one for each ending @@ -1544,9 +1523,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par for (size_t i = 0; i < cur_task.common_prefix; ++i) { //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); - common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); + batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix n_logits += 1; for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { @@ -1554,7 +1533,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par // TODO: don't evaluate the last token of each sequence for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; - common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); + batch.add_text(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); n_logits += needs_logits; } } @@ -1653,8 +1632,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par i0 = i1 - 1; } - llama_batch_free(batch); - if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return; float p = 1.f*n_correct/n_done; @@ -1767,7 +1744,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { // clear the KV cache llama_kv_self_clear(ctx); - llama_batch batch = llama_batch_init(n_batch, 0, 1); + common_batch batch(n_batch, 1); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -1781,14 +1758,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { tokens[batch_start] = llama_vocab_bos(vocab); } - common_batch_clear(batch); + batch.clear(); for (int i = 0; i < batch_size; i++) { - common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + batch.add_text(tokens[batch_start + i], j*n_batch + i, {0}, true); } - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); - llama_batch_free(batch); return; } @@ -1801,8 +1777,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { } } - llama_batch_free(batch); - const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 0efe20d4b..d43270e85 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -74,40 +74,56 @@ static std::vector chunk_file(const std::string & filename, int chunk_siz return chunks; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { +static void batch_add_seq(common_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - common_batch_add(batch, tokens[i], i, { seq_id }, true); + batch.add_text(tokens[i], i, seq_id, true); } } -static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { +static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + const struct llama_model * model = llama_get_model(ctx); + // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx); // run model - LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); - if (llama_decode(ctx, batch) < 0) { - LOG_ERR("%s : failed to decode\n", __func__); + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); + if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { + // encoder-only model + if (llama_encode_ext(ctx, batch.get()) < 0) { + LOG_ERR("%s : failed to encode\n", __func__); + } + } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { + // decoder-only model + if (llama_decode_ext(ctx, batch.get()) < 0) { + LOG_ERR("%s : failed to decode\n", __func__); + } } - for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { + for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { + if (!batch.tokens[i].logits) { continue; } - // try to get sequence embeddings - supported only when pooling_type is not NONE - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { + const float * embd = nullptr; + int embd_pos = 0; + + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + // try to get token embeddings embd = llama_get_embeddings_ith(ctx, i); - if (embd == NULL) { - LOG_ERR("%s: failed to get embeddings for token %d\n", __func__, i); - continue; - } + embd_pos = i; + GGML_ASSERT(embd != NULL && "failed to get token embeddings"); + } else { + // try to get sequence embeddings - supported only when pooling_type is not NONE + embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); + embd_pos = batch.tokens[i].seq_id; + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); } - float * out = output + batch.seq_id[i][0] * n_embd; - common_embd_normalize(embd, out, n_embd, 2); + float * out = output + embd_pos * n_embd; + common_embd_normalize(embd, out, n_embd, embd_norm); } } @@ -214,7 +230,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_chunks = chunks.size(); - struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + struct common_batch batch = common_batch(n_batch, 1); // allocate output const int n_embd = llama_model_n_embd(model); @@ -231,10 +247,10 @@ int main(int argc, char ** argv) { const uint64_t n_toks = inp.size(); // encode if at capacity - if (batch.n_tokens + n_toks > n_batch) { + if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) { float * out = emb + p * n_embd; batch_decode(ctx, batch, out, s, n_embd); - common_batch_clear(batch); + batch.clear(); p += s; s = 0; } @@ -255,7 +271,7 @@ int main(int argc, char ** argv) { chunks[i].tokens.clear(); } - struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); + struct common_batch query_batch = common_batch(n_batch, 1); // start loop, receive query and return top k similar chunks based on cosine similarity std::string query; @@ -269,7 +285,7 @@ int main(int argc, char ** argv) { std::vector query_emb(n_embd, 0); batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); - common_batch_clear(query_batch); + query_batch.clear(); // compute cosine similarities { @@ -299,6 +315,5 @@ int main(int argc, char ** argv) { llama_perf_context_print(ctx); // clean up - llama_batch_free(query_batch); llama_backend_free(); } diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 437f2533e..02cafa9da 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -905,10 +905,10 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt } // Check if we have enough space in the context to evaluate this batch -static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { +static int check_context_size(const llama_context_ptr & ctx, const llama_batch_ext_ptr & batch) { const int n_ctx = llama_n_ctx(ctx.get()); const int n_ctx_used = llama_kv_self_used_cells(ctx.get()); - if (n_ctx_used + batch.n_tokens > n_ctx) { + if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) { printf(LOG_COL_DEFAULT "\n"); printe("context size exceeded\n"); return 1; @@ -946,11 +946,11 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str } // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); llama_token new_token_id; while (true) { check_context_size(llama_data.context, batch); - if (llama_decode(llama_data.context.get(), batch)) { + if (llama_decode_ext(llama_data.context.get(), batch.get())) { printe("failed to decode\n"); return 1; } @@ -969,7 +969,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str print_word_and_concatenate_to_response(piece, response); // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); + batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0)); } printf(LOG_COL_DEFAULT); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 760ebbbf0..d1cf599b1 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -48,15 +48,11 @@ int main(int argc, char ** argv) { auto tokens = common_tokenize(ctx, params.prompt, true); // prepare the batch - llama_batch batch = llama_batch_init(tokens.size(), 0, 1); - for (size_t i = 0; i < tokens.size(); i++) { - common_batch_add(batch, tokens[i], i, {0}, false); - } - batch.logits[batch.n_tokens - 1] = true; // generate next token + llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0); // evaluate prompt - llama_decode(ctx, batch); - n_past += batch.n_tokens; + llama_decode_ext(ctx, batch); + n_past += llama_batch_ext_get_n_tokens(batch); // save state (rng, logits, embedding and kv_cache) to file { @@ -83,12 +79,13 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result0 += next_token_str; - common_batch_clear(batch); - common_batch_add(batch, next_token, n_past, {0}, true); + llama_batch_ext_clear(batch); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_batch_free(batch); + llama_batch_ext_free(batch); return 1; } n_past += 1; @@ -135,12 +132,13 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result1 += next_token_str; - common_batch_clear(batch); - common_batch_add(batch, next_token, n_past, {0}, true); + llama_batch_ext_clear(batch); + llama_seq_id seq_id = 1; + llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); - if (llama_decode(ctx2, batch)) { + if (llama_decode_ext(ctx2, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_batch_free(batch); + llama_batch_ext_free(batch); return 1; } n_past += 1; @@ -216,12 +214,13 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result2 += next_token_str; - common_batch_clear(batch); - common_batch_add(batch, next_token, n_past, {1}, true); + llama_batch_ext_clear(batch); + llama_seq_id seq_id = 1; + llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); - if (llama_decode(ctx3, batch)) { + if (llama_decode_ext(ctx3, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_batch_free(batch); + llama_batch_ext_free(batch); return 1; } n_past += 1; @@ -233,7 +232,7 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl2); llama_sampler_free(smpl3); - llama_batch_free(batch); + llama_batch_ext_free(batch); if (result0 != result2) { fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 84f415973..cee00ea82 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -108,19 +108,20 @@ int main(int argc, char ** argv) { } // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); + llama_batch_ext_set_output_last(batch); llama_token new_token_id; while (true) { // check if we have enough space in the context to evaluate this batch int n_ctx = llama_n_ctx(ctx); int n_ctx_used = llama_kv_self_used_cells(ctx); - if (n_ctx_used + batch.n_tokens > n_ctx) { + if (n_ctx_used + llama_batch_ext_get_n_tokens(batch) > n_ctx) { printf("\033[0m\n"); fprintf(stderr, "context size exceeded\n"); exit(0); } - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch)) { GGML_ABORT("failed to decode\n"); } @@ -144,9 +145,13 @@ int main(int argc, char ** argv) { response += piece; // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); + llama_batch_ext_clear(batch); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); } + llama_batch_ext_free(batch); + return response; }; diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 10e79a0a6..7b3ba8d81 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -143,7 +143,8 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); + llama_batch_ext_set_output_last(batch); // main loop @@ -151,14 +152,14 @@ int main(int argc, char ** argv) { int n_decode = 0; llama_token new_token_id; - for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) { + for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch) < n_prompt + n_predict; ) { // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch)) { fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); return 1; } - n_pos += batch.n_tokens; + n_pos += llama_batch_ext_get_n_tokens(batch); // sample the next token { @@ -180,7 +181,9 @@ int main(int argc, char ** argv) { fflush(stdout); // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); + llama_batch_ext_clear(batch); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); n_decode += 1; } @@ -198,6 +201,7 @@ int main(int argc, char ** argv) { llama_perf_context_print(ctx); fprintf(stderr, "\n"); + llama_batch_ext_free(batch); llama_sampler_free(smpl); llama_free(ctx); llama_model_free(model); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a5d2bc9d0..e61e863ce 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -132,7 +132,7 @@ int main(int argc, char ** argv) { struct common_speculative * spec = common_speculative_init(ctx_dft); - llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); + llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), 1); const auto t_enc_end = ggml_time_us(); @@ -151,8 +151,9 @@ int main(int argc, char ** argv) { //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); // always have a token to evaluate from before - id_last - common_batch_clear(batch_tgt); - common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); + llama_batch_ext_clear(batch_tgt); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch_tgt, id_last, n_past++, &seq_id, 1, true); // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { @@ -162,12 +163,12 @@ int main(int argc, char ** argv) { } for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true); } //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); - llama_decode(ctx_tgt, batch_tgt); + llama_decode_ext(ctx_tgt, batch_tgt); } // sample from the full target batch and return the accepted tokens based on the target sampler @@ -253,6 +254,7 @@ int main(int argc, char ** argv) { common_sampler_free(smpl); common_speculative_free(spec); + llama_batch_ext_free(batch_tgt); llama_backend_free(); LOG("\n\n"); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index bfddc67e0..1f55db7b6 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -45,7 +45,7 @@ int main(int argc, char ** argv) { } common_init(); - +#ifdef 0 if (params.speculative.model.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; @@ -199,8 +199,8 @@ int main(int argc, char ** argv) { drafts[s].smpl = common_sampler_init(model_dft, params.sampling); } - llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); - llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft); + llama_batch_ext * batch_dft = llama_batch_ext_init(llama_n_batch(ctx_dft), 1); + llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft); const auto t_dec_start = ggml_time_us(); @@ -441,12 +441,13 @@ int main(int argc, char ** argv) { drafts[0].dists.push_back(std::vector()); drafts[0].i_batch_tgt.push_back(0); - common_batch_clear(batch_dft); - common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); + llama_batch_ext_clear(batch_dft); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true); llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); - llama_decode(ctx_dft, batch_dft); + llama_decode_ext(ctx_dft, batch_dft); ++n_past_dft; } @@ -471,8 +472,9 @@ int main(int argc, char ** argv) { drafts[0].drafting = true; drafts[0].i_batch_dft = 0; - common_batch_clear(batch_tgt); - common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); + llama_batch_ext_clear(batch_tgt); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true); // sample n_draft tokens from the draft model using tree-based sampling for (int i = 0; i < n_draft; ++i) { @@ -640,5 +642,6 @@ int main(int argc, char ** argv) { LOG("\n\n"); +#endif return 0; } diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index c658f3182..32f8c43a8 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -817,7 +817,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // create a llama_batch // we use this object to submit token data for decoding - llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel); + llama_batch_ext * batch = llama_batch_ext_init(std::max(prompt_inp.size(), (size_t) n_parallel), n_parallel); std::vector seq_ids(n_parallel, 0); for (int32_t i = 0; i < n_parallel; ++i) { @@ -826,14 +826,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // evaluate the initial prompt for (size_t i = 0; i < prompt_inp.size(); ++i) { - common_batch_add(batch, prompt_inp[i], i, seq_ids, false); + llama_batch_ext_add_text(batch, prompt_inp[i], i, seq_ids.data(), seq_ids.size(), false); } - GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size()); + GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) prompt_inp.size()); // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch); - if (llama_decode(ctx_ttc, batch) != 0) { + if (llama_decode_ext(ctx_ttc, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -852,16 +852,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // remember the batch index of the last token for each parallel sequence // we need this to determine which logits to sample from - std::vector i_batch(n_parallel, batch.n_tokens - 1); + std::vector i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1); - int n_past = batch.n_tokens; + int n_past = llama_batch_ext_get_n_tokens(batch); int n_decode = 0; bool next_token_uses_guide_token = true; while (n_decode <= n_predict) { // prepare the next batch - common_batch_clear(batch); + llama_batch_ext_clear(batch); // sample the next token for each parallel sequence / stream for (int32_t i = 0; i < n_parallel; ++i) { @@ -917,14 +917,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 //LOG_CNT("%d", i); } - i_batch[i] = batch.n_tokens; + i_batch[i] = llama_batch_ext_get_n_tokens(batch); // push this new token for next evaluation - common_batch_add(batch, new_token_id, n_past, { i }, true); + llama_batch_ext_add_text(batch, new_token_id, n_past, &i, 1, false); } // all streams are finished - if (batch.n_tokens == 0) { + if (llama_batch_ext_get_n_tokens(batch) == 0) { break; } @@ -932,13 +932,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 n_past += 1; // evaluate the current batch with the transformer model - if (llama_decode(ctx_ttc, batch)) { + if (llama_decode_ext(ctx_ttc, batch)) { LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); return 1; } } - llama_batch_free(batch); + llama_batch_ext_free(batch); LOG("\n"); LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f); @@ -1007,14 +1007,15 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 const int n_codes = codes.size(); - llama_batch batch = llama_batch_init(n_codes, 0, 1); + llama_batch_ext * batch = llama_batch_ext_init(n_codes, 1); for (size_t i = 0; i < codes.size(); ++i) { - common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits? + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, codes[i], i, &seq_id, 1, true); // TODO: all logits? } - GGML_ASSERT(batch.n_tokens == n_codes); + GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == n_codes); - if (llama_decode(ctx_cts, batch) != 0) { + if (llama_decode_ext(ctx_cts, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -1076,6 +1077,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str()); + llama_batch_ext_free(batch); llama_backend_free(); return 0; diff --git a/include/llama.h b/include/llama.h index ee74d9a8c..4521b3a41 100644 --- a/include/llama.h +++ b/include/llama.h @@ -995,9 +995,9 @@ extern "C" { // Stores the encoder output internally for later use by the decoder cross-attention layers. // 0 - success // < 0 - error. the KV cache state is restored to the state before this call - DEPRECATED(LLAMA_API int32_t llama_encode( + LLAMA_API int32_t llama_encode( struct llama_context * ctx, - struct llama_batch batch), "use llama_batch_ext API instead"); + struct llama_batch batch); LLAMA_API int32_t llama_encode_ext( struct llama_context * ctx, @@ -1007,9 +1007,9 @@ extern "C" { // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // < 0 - error. the KV cache state is restored to the state before this call - DEPRECATED(LLAMA_API int32_t llama_decode( + LLAMA_API int32_t llama_decode( struct llama_context * ctx, - struct llama_batch batch), "use llama_batch_ext API instead"); + struct llama_batch batch); LLAMA_API int32_t llama_decode_ext( struct llama_context * ctx,