apply to the rest

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:36:27 +01:00
parent 4aabf4e8f4
commit 47086fa82d
18 changed files with 242 additions and 323 deletions

View File

@ -582,43 +582,6 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam
return buf.str(); return buf.str();
} }
/*
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
std::stringstream buf;
buf << "[ ";
bool first = true;
for (int i = 0; i < batch.n_tokens; ++i) {
if (!first) {
buf << ", ";
} else {
first = false;
}
auto detokenized = common_token_to_piece(ctx, batch.token[i]);
detokenized.erase(
std::remove_if(
detokenized.begin(),
detokenized.end(),
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());
buf << "\n" << std::to_string(i)
<< ", token '" << detokenized << "'"
<< ", pos " << std::to_string(batch.pos[i])
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
<< ", logits " << std::to_string(batch.logits[i]);
}
buf << " ]";
return buf.str();
}
*/
void string_process_escapes(std::string & input) { void string_process_escapes(std::string & input) {
std::size_t input_len = input.length(); std::size_t input_len = input.length();
std::size_t output_idx = 0; std::size_t output_idx = 0;

View File

@ -516,7 +516,6 @@ void string_process_escapes(std::string & input);
std::string string_from(bool value); std::string string_from(bool value);
std::string string_from(const std::vector<int> & values); std::string string_from(const std::vector<int> & values);
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens); std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);
// //
// Filesystem utils // Filesystem utils
@ -587,10 +586,10 @@ struct common_batch {
llama_batch_ext_ptr batch; llama_batch_ext_ptr batch;
struct batch_token { struct batch_token {
llama_token token; llama_token token;
llama_seq_id seq_id;
bool logits; bool logits;
}; };
std::vector<batch_token> tokens; std::vector<batch_token> tokens;
int n_outputs = 0;
common_batch() = default; common_batch() = default;
common_batch(int32_t n_tokens, int32_t n_seq_max) { common_batch(int32_t n_tokens, int32_t n_seq_max) {
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
@ -602,7 +601,17 @@ struct common_batch {
} }
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
tokens.push_back({token, seq_id, logits}); tokens.push_back({token, logits});
if (logits) {
n_outputs++;
}
}
void add_text(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
tokens.push_back({token, logits});
if (logits) {
n_outputs++;
}
} }
void set_logits_last() { void set_logits_last() {
if (!tokens.empty()) { if (!tokens.empty()) {
@ -622,6 +631,9 @@ struct common_batch {
view.tokens.reserve(n_tokens); view.tokens.reserve(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) { for (int32_t i = 0; i < n_tokens; i++) {
view.tokens.push_back(tokens[offset + i]); view.tokens.push_back(tokens[offset + i]);
if (tokens[offset + i].logits) {
view.n_outputs++;
}
} }
return view; return view;
} }

View File

@ -5,6 +5,7 @@
#include "clip.h" #include "clip.h"
#include "stb_image.h" #include "stb_image.h"
#include "llama.h" #include "llama.h"
#include "llama-cpp.h"
#include "ggml.h" #include "ggml.h"
#include "console.h" #include "console.h"
@ -63,7 +64,7 @@ struct gemma3_context {
llama_model * model; llama_model * model;
llama_context * lctx; llama_context * lctx;
const llama_vocab * vocab; const llama_vocab * vocab;
llama_batch batch; llama_batch_ext_ptr batch;
int n_threads = 1; int n_threads = 1;
llama_pos n_past = 0; llama_pos n_past = 0;
@ -73,7 +74,7 @@ struct gemma3_context {
lctx = llama_init.context.get(); lctx = llama_init.context.get();
vocab = llama_model_get_vocab(model); vocab = llama_model_get_vocab(model);
n_threads = params.cpuparams.n_threads; n_threads = params.cpuparams.n_threads;
batch = llama_batch_init(params.n_batch, 0, 1); batch.reset(llama_batch_ext_init(params.n_batch, 1));
init_clip_model(params); init_clip_model(params);
} }
@ -87,50 +88,18 @@ struct gemma3_context {
} }
}; };
struct decode_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) { static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
common_batch_clear(ctx.batch); llama_batch_ext_clear(ctx.batch.get());
for (llama_token & t : tokens) { for (llama_token & t : tokens) {
common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false); llama_seq_id seq_id = 0;
llama_batch_ext_add_text(ctx.batch.get(), t, 0, &seq_id, 1, false);
} }
if (logits_last) { if (logits_last) {
ctx.batch.logits[ctx.batch.n_tokens - 1] = true; llama_batch_ext_set_output_last(ctx.batch.get());
} }
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str()); // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
if (llama_decode(ctx.lctx, ctx.batch)) { if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
LOG_ERR("Failed to decode text\n"); LOG_ERR("Failed to decode text\n");
return 1; return 1;
} }
@ -179,8 +148,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
eval_text(ctx, "<start_of_image>"); eval_text(ctx, "<start_of_image>");
llama_set_causal_attn(ctx.lctx, false); llama_set_causal_attn(ctx.lctx, false);
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0); llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0));
if (llama_decode(ctx.lctx, batch_img.batch)) { if (llama_decode_ext(ctx.lctx, batch_img.get())) {
LOG_ERR("failed to decode image\n"); LOG_ERR("failed to decode image\n");
return 1; return 1;
} }
@ -210,9 +179,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
fflush(stdout); fflush(stdout);
// eval the token // eval the token
common_batch_clear(ctx.batch); llama_batch_ext_clear(ctx.batch.get());
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true); llama_seq_id seq_id = 0;
if (llama_decode(ctx.lctx, ctx.batch)) { llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true);
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
LOG_ERR("failed to decode token\n"); LOG_ERR("failed to decode token\n");
return 1; return 1;
} }

View File

@ -2,6 +2,7 @@
#include "llava.h" #include "llava.h"
#include "llama.h" #include "llama.h"
#include "llama-cpp.h"
#include <algorithm> #include <algorithm>
#include <cerrno> #include <cerrno>
@ -438,39 +439,6 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
return true; return true;
} }
struct llava_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
@ -480,8 +448,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
n_eval = n_batch; n_eval = n_batch;
} }
float * embd = image_embed->embed+i*n_embd; float * embd = image_embed->embed+i*n_embd;
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, 0, 0));
if (llama_decode(ctx_llama, llava_batch.batch)) { if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__); LOG_ERR("%s : failed to eval\n", __func__);
return false; return false;
} }

View File

@ -66,6 +66,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
// TODO: move this to llama_batch_ext API
llama_batch batch = { llama_batch batch = {
int32_t(n_eval), // n_tokens int32_t(n_eval), // n_tokens
nullptr, // token nullptr, // token

View File

@ -115,7 +115,7 @@ int main(int argc, char ** argv) {
// seq_id == 0 : the current input token // seq_id == 0 : the current input token
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations // seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
// seq_id [W + 1, W + G] : verification n-grams // seq_id [W + 1, W + G] : verification n-grams
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); llama_batch_ext * batch = llama_batch_ext_init(params.n_ctx, W + G + 1);
// target model sampling context // target model sampling context
struct common_sampler * smpl = common_sampler_init(model, params.sampling); struct common_sampler * smpl = common_sampler_init(model, params.sampling);
@ -204,10 +204,10 @@ int main(int argc, char ** argv) {
// V V V V V V // V V V V V V
// id // id
{ {
common_batch_clear(batch); llama_batch_ext_clear(batch);
// current token - first token of the first level // current token - first token of the first level
common_batch_add(batch, id, n_past, seq_id_all, true); llama_batch_ext_add_text(batch, id, n_past, seq_id_all.data(), seq_id_all.size(), true);
// verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
{ {
@ -230,9 +230,10 @@ int main(int argc, char ** argv) {
const llama_token t = ngrams_observed.tokens[idx + j]; const llama_token t = ngrams_observed.tokens[idx + j];
ngrams_cur[g].tokens [j + 1] = t; ngrams_cur[g].tokens [j + 1] = t;
ngrams_cur[g].i_batch[j + 1] = batch.n_tokens; ngrams_cur[g].i_batch[j + 1] = llama_batch_ext_get_n_tokens(batch);
common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true); llama_seq_id seq_id = W + 1 + g;
llama_batch_ext_add_text(batch, t, n_past + j + 1, &seq_id, 1, true);
} }
} }
} }
@ -244,18 +245,20 @@ int main(int argc, char ** argv) {
seq_id_look[j] = i + j + 1; seq_id_look[j] = i + j + 1;
} }
common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); llama_batch_ext_add_text(batch, tokens_j[0][i], n_past + i,
seq_id_look.data(), seq_id_look.size(), false);
} }
// fill the rest of the levels // fill the rest of the levels
for (int j = 1; j < N - 1; j++) { for (int j = 1; j < N - 1; j++) {
for (int i = 0; i < W; i++) { for (int i = 0; i < W; i++) {
common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2); llama_seq_id seq_id = i + 1;
llama_batch_ext_add_text(batch, tokens_j[j][i], n_past + j + i, &seq_id, 1, j == N - 2);
} }
} }
} }
if (llama_decode(ctx, batch) != 0) { if (llama_decode_ext(ctx, batch) != 0) {
LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__); LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__);
return 1; return 1;
} }
@ -475,7 +478,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_view_free(&kvc_view); llama_kv_cache_view_free(&kvc_view);
llama_batch_free(batch); llama_batch_ext_free(batch);
llama_backend_free(); llama_backend_free();

View File

@ -174,7 +174,7 @@ int main(int argc, char ** argv) {
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
llama_batch batch = llama_batch_init(n_ctx, 0, 1); llama_batch_ext * batch = llama_batch_ext_init(n_ctx, 1);
int32_t n_total_prompt = 0; int32_t n_total_prompt = 0;
int32_t n_total_gen = 0; int32_t n_total_gen = 0;
@ -192,10 +192,11 @@ int main(int argc, char ** argv) {
LOG_INF("%s: Evaluating the system prompt ...\n", __func__); LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
for (int32_t i = 0; i < n_tokens_system; ++i) { for (int32_t i = 0; i < n_tokens_system; ++i) {
common_batch_add(batch, tokens_system[i], i, { 0 }, false); llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, tokens_system[i], i, &seq_id, 1, false);
} }
if (llama_decode(ctx, batch) != 0) { if (llama_decode_ext(ctx, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__); LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1; return 1;
} }
@ -216,7 +217,7 @@ int main(int argc, char ** argv) {
common_kv_cache_dump_view_seqs(kvc_view, 40); common_kv_cache_dump_view_seqs(kvc_view, 40);
} }
common_batch_clear(batch); llama_batch_ext_clear(batch);
// decode any currently ongoing sequences // decode any currently ongoing sequences
for (auto & client : clients) { for (auto & client : clients) {
@ -224,14 +225,15 @@ int main(int argc, char ** argv) {
continue; continue;
} }
client.i_batch = batch.n_tokens; client.i_batch = llama_batch_ext_get_n_tokens(batch);
common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true); llama_seq_id seq_id = client.id + 1;
llama_batch_ext_add_text(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, &seq_id, 1, true);
client.n_decoded += 1; client.n_decoded += 1;
} }
if (batch.n_tokens == 0) { if (llama_batch_ext_get_n_tokens(batch) == 0) {
// all sequences have ended - clear the entire KV cache // all sequences have ended - clear the entire KV cache
for (int i = 1; i <= n_clients; ++i) { for (int i = 1; i <= n_clients; ++i) {
llama_kv_self_seq_rm(ctx, i, -1, -1); llama_kv_self_seq_rm(ctx, i, -1, -1);
@ -243,7 +245,7 @@ int main(int argc, char ** argv) {
} }
// insert new sequences for decoding // insert new sequences for decoding
if (cont_batching || batch.n_tokens == 0) { if (cont_batching || llama_batch_ext_get_n_tokens(batch) == 0) {
for (auto & client : clients) { for (auto & client : clients) {
if (client.seq_id == -1 && g_seq_id < n_seq) { if (client.seq_id == -1 && g_seq_id < n_seq) {
client.seq_id = g_seq_id; client.seq_id = g_seq_id;
@ -262,17 +264,18 @@ int main(int argc, char ** argv) {
tokens_prompt = common_tokenize(ctx, client.prompt, false); tokens_prompt = common_tokenize(ctx, client.prompt, false);
for (size_t i = 0; i < tokens_prompt.size(); ++i) { for (size_t i = 0; i < tokens_prompt.size(); ++i) {
common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false); llama_seq_id seq_id = client.id + 1;
llama_batch_ext_add_text(batch, tokens_prompt[i], i + n_tokens_system, &seq_id, 1, false);
} }
// extract the logits only for the last token // extract the logits only for the last token
if (batch.n_tokens > 0) { if (llama_batch_ext_get_n_tokens(batch) > 0) {
batch.logits[batch.n_tokens - 1] = true; llama_batch_ext_set_output_last(batch);
} }
client.n_prompt = tokens_prompt.size(); client.n_prompt = tokens_prompt.size();
client.n_decoded = 0; client.n_decoded = 0;
client.i_batch = batch.n_tokens - 1; client.i_batch = llama_batch_ext_get_n_tokens(batch) - 1;
LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
@ -286,14 +289,15 @@ int main(int argc, char ** argv) {
} }
} }
if (batch.n_tokens == 0) { if (llama_batch_ext_get_n_tokens(batch) == 0) {
break; break;
} }
// process in chunks of params.n_batch // process in chunks of params.n_batch
int32_t n_batch = params.n_batch; int32_t n_batch = params.n_batch;
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch);
for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) {
// experiment: process in powers of 2 // experiment: process in powers of 2
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
// n_batch /= 2; // n_batch /= 2;
@ -301,19 +305,11 @@ int main(int argc, char ** argv) {
// continue; // continue;
//} //}
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); const int32_t n_tokens = std::min(n_batch, (int32_t) (n_tokens_in_batch - i));
llama_batch batch_view = { llama_batch_ext * batch_view = llama_batch_ext_get_view(batch, i, n_tokens);
n_tokens, const int ret = llama_decode_ext(ctx, batch_view);
batch.token + i, llama_batch_ext_free(batch_view);
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
const int ret = llama_decode(ctx, batch_view);
if (ret != 0) { if (ret != 0) {
if (n_batch == 1 || ret < 0) { if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try increasing it via the context size // if you get here, it means the KV cache is full - try increasing it via the context size
@ -417,7 +413,7 @@ int main(int argc, char ** argv) {
// TODO: print sampling/grammar timings for all clients // TODO: print sampling/grammar timings for all clients
llama_perf_context_print(ctx); llama_perf_context_print(ctx);
llama_batch_free(batch); llama_batch_ext_free(batch);
llama_backend_free(); llama_backend_free();

View File

@ -2,6 +2,7 @@
#include "common.h" #include "common.h"
#include "log.h" #include "log.h"
#include "llama.h" #include "llama.h"
#include "llama-cpp.h"
#include <cmath> #include <cmath>
#include <cstdio> #include <cstdio>
@ -122,7 +123,7 @@ int main(int argc, char ** argv) {
LOG_INF("prompt tokens: %d\n", n_tokens_all); LOG_INF("prompt tokens: %d\n", n_tokens_all);
//LOG_INF("prompt: %s\n", params.prompt.c_str()); //LOG_INF("prompt: %s\n", params.prompt.c_str());
llama_batch batch = llama_batch_init(params.n_batch, 0, 1); llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1));
int n_past = 0; int n_past = 0;
@ -140,17 +141,18 @@ int main(int argc, char ** argv) {
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
} }
common_batch_clear(batch); llama_batch_ext_clear(batch.get());
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
} }
if (i + n_batch >= n_tokens_all) { if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true; llama_batch_ext_set_output_last(batch.get());
} }
if (llama_decode(ctx, batch) != 0) { if (llama_decode_ext(ctx, batch.get()) != 0) {
LOG_INF("%s: llama_decode() failed\n", __func__); LOG_INF("%s: llama_decode() failed\n", __func__);
return 1; return 1;
} }
@ -174,17 +176,18 @@ int main(int argc, char ** argv) {
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
common_batch_clear(batch); llama_batch_ext_clear(batch.get());
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
} }
if (i + n_batch >= n_tokens_all) { if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true; llama_batch_ext_set_output_last(batch.get());
} }
if (llama_decode(ctx, batch) != 0) { if (llama_decode_ext(ctx, batch.get()) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__); LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1; return 1;
} }
@ -223,7 +226,7 @@ int main(int argc, char ** argv) {
while (n_cur <= n_len) { while (n_cur <= n_len) {
// sample the next token // sample the next token
{ {
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); const llama_token new_token_id = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1);
// is it an end of generation? // is it an end of generation?
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
@ -237,16 +240,17 @@ int main(int argc, char ** argv) {
n_decode += 1; n_decode += 1;
// prepare the next batch // prepare the next batch
common_batch_clear(batch); llama_batch_ext_clear(batch.get());
// push this new token for next evaluation // push this new token for next evaluation
common_batch_add(batch, new_token_id, n_past++, { 0 }, true); llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), new_token_id, n_past++, &seq_id, 1, true);
} }
n_cur += 1; n_cur += 1;
// evaluate the current batch with the transformer model // evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) { if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
return 1; return 1;
} }
@ -266,8 +270,6 @@ int main(int argc, char ** argv) {
llama_sampler_free(smpl); llama_sampler_free(smpl);
llama_batch_free(batch);
llama_free(ctx); llama_free(ctx);
llama_model_free(model); llama_model_free(model);

View File

@ -363,21 +363,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
// clear the KV cache // clear the KV cache
llama_kv_self_clear(ctx); llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1); common_batch batch(n_batch, 1);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch); const int batch_size = std::min(end - batch_start, n_batch);
common_batch_clear(batch); batch.clear();
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true);
} }
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
if (llama_decode(ctx, batch)) { if (llama_decode_ext(ctx, batch.get())) {
//LOG_ERR("%s : failed to eval\n", __func__); //LOG_ERR("%s : failed to eval\n", __func__);
llama_batch_free(batch);
return {tokens, -1, logit_history, prob_history}; return {tokens, -1, logit_history, prob_history};
} }
@ -397,8 +396,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
} }
} }
llama_batch_free(batch);
const auto t_end = std::chrono::high_resolution_clock::now(); const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) { if (i == 0) {
@ -504,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
GGML_ASSERT(params.n_ctx == n_seq * n_ctx); GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); common_batch batch(std::min(n_batch, n_ctx*n_seq), 1);
std::vector<float> logits; std::vector<float> logits;
if (num_batches > 1) { if (num_batches > 1) {
@ -555,7 +552,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
int n_outputs = 0; int n_outputs = 0;
batch.n_tokens = 0; batch.clear();
for (int seq = 0; seq < n_seq_batch; seq++) { for (int seq = 0; seq < n_seq_batch; seq++) {
int seq_start = batch_start + seq*n_ctx; int seq_start = batch_start + seq*n_ctx;
@ -569,21 +566,18 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
for (int k = 0; k < batch_size; ++k) { for (int k = 0; k < batch_size; ++k) {
const int idx = seq*n_ctx + k; const int idx = seq*n_ctx + k;
batch.token [idx] = tokens[seq_start + k]; const llama_pos pos = j*n_batch + k;
batch.pos [idx] = j*n_batch + k; bool output = pos >= first;
batch.n_seq_id[idx] = 1; batch.add_text(tokens[seq_start + k], pos, seq, output);
batch.seq_id [idx][0] = seq;
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
n_outputs += batch.logits[idx] != 0; n_outputs += output ? 1 : 0;
} }
batch.n_tokens += batch_size;
// restore the original token in case it was set to BOS // restore the original token in case it was set to BOS
tokens[seq_start] = token_org; tokens[seq_start] = token_org;
} }
if (llama_decode(ctx, batch)) { if (llama_decode_ext(ctx, batch.get())) {
LOG_INF("%s : failed to eval\n", __func__); LOG_INF("%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history}; return {tokens, -1, logit_history, prob_history};
} }
@ -653,36 +647,23 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
LOG_ERR("Unexpected negative standard deviation of log(prob)\n"); LOG_ERR("Unexpected negative standard deviation of log(prob)\n");
} }
llama_batch_free(batch);
return {tokens, ppl, logit_history, prob_history}; return {tokens, ppl, logit_history, prob_history};
} }
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) { static bool decode_helper(llama_context * ctx, common_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
int prev_outputs = 0; int prev_outputs = 0;
for (int i = 0; i < (int) batch.n_tokens; i += n_batch) { for (int i = 0; i < (int) batch.get_n_tokens(); i += n_batch) {
const int n_tokens = std::min<int>(n_batch, batch.n_tokens - i); const int n_tokens = std::min<int>(n_batch, batch.get_n_tokens() - i);
llama_batch batch_view = { common_batch batch_view = batch.get_view(i, n_tokens);
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};
const int ret = llama_decode(ctx, batch_view); const int ret = llama_decode_ext(ctx, batch_view.get());
if (ret != 0) { if (ret != 0) {
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false; return false;
} }
int n_outputs = 0; int n_outputs = batch_view.n_outputs;
for (int i = 0; i < n_tokens; ++i) {
n_outputs += batch_view.logits[i] != 0;
}
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float)); memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
@ -863,7 +844,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
const int max_tasks_per_batch = 32; const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, 4); common_batch batch(n_ctx, 4);
std::vector<float> tok_logits(n_vocab); std::vector<float> tok_logits(n_vocab);
// TODO: this could be made smaller; it's currently the worst-case size // TODO: this could be made smaller; it's currently the worst-case size
@ -879,7 +860,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
size_t i1 = i0; size_t i1 = i0;
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
common_batch_clear(batch); batch.clear();
// batch as much tasks as possible into the available context // batch as much tasks as possible into the available context
// each task has 4 unique sequence ids - one for each ending // each task has 4 unique sequence ids - one for each ending
@ -895,9 +876,9 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
} }
for (size_t i = 0; i < hs_cur.common_prefix; ++i) { for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); batch.add_text(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
} }
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix llama_batch_ext_set_output_last(batch.get());
n_logits += 1; n_logits += 1;
for (int s = 0; s < 4; ++s) { for (int s = 0; s < 4; ++s) {
@ -905,7 +886,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
// TODO: don't evaluate the last token of each sequence // TODO: don't evaluate the last token of each sequence
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
const bool needs_logits = i < seq_tokens_size - 1; const bool needs_logits = i < seq_tokens_size - 1;
common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); batch.add_text(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
n_logits += needs_logits; n_logits += needs_logits;
} }
} }
@ -992,8 +973,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
i0 = i1 - 1; i0 = i1 - 1;
} }
llama_batch_free(batch);
LOG("\n"); LOG("\n");
} }
@ -1147,7 +1126,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
const int max_tasks_per_batch = 128; const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, 2); common_batch batch(n_ctx, 2);
std::vector<float> tok_logits(n_vocab); std::vector<float> tok_logits(n_vocab);
// TODO: this could be made smaller; it's currently the worst-case size // TODO: this could be made smaller; it's currently the worst-case size
@ -1166,7 +1145,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
size_t i1 = i0; size_t i1 = i0;
size_t i_logits = 0; size_t i_logits = 0;
common_batch_clear(batch); batch.clear();
while (n_cur + (int) data[i1].required_tokens <= n_ctx) { while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
int n_logits = 0; int n_logits = 0;
@ -1176,15 +1155,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
} }
for (size_t i = 0; i < data[i1].common_prefix; ++i) { for (size_t i = 0; i < data[i1].common_prefix; ++i) {
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); batch.add_text(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
} }
batch.logits[batch.n_tokens - 1] = true; llama_batch_ext_set_output_last(batch.get());
n_logits += 1; n_logits += 1;
for (int s = 0; s < 2; ++s) { for (int s = 0; s < 2; ++s) {
// TODO: end before the last token, no need to predict past the end of the sequences // TODO: end before the last token, no need to predict past the end of the sequences
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true); batch.add_text(data[i1].seq_tokens[s][i], i, { s0 + s }, true);
n_logits += 1; n_logits += 1;
} }
} }
@ -1501,7 +1480,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
const int max_tasks_per_batch = 32; const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); common_batch batch(n_ctx, max_seq);
std::vector<float> tok_logits(n_vocab); std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab); std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
@ -1521,7 +1500,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
size_t i1 = i0; size_t i1 = i0;
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
common_batch_clear(batch); batch.clear();
// batch as much tasks as possible into the available context // batch as much tasks as possible into the available context
// each task has 4 unique sequence ids - one for each ending // each task has 4 unique sequence ids - one for each ending
@ -1544,9 +1523,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
for (size_t i = 0; i < cur_task.common_prefix; ++i) { for (size_t i = 0; i < cur_task.common_prefix; ++i) {
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false);
} }
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix
n_logits += 1; n_logits += 1;
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
@ -1554,7 +1533,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
// TODO: don't evaluate the last token of each sequence // TODO: don't evaluate the last token of each sequence
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
const bool needs_logits = i < seq_tokens_size - 1; const bool needs_logits = i < seq_tokens_size - 1;
common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); batch.add_text(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
n_logits += needs_logits; n_logits += needs_logits;
} }
} }
@ -1653,8 +1632,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
i0 = i1 - 1; i0 = i1 - 1;
} }
llama_batch_free(batch);
if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return; if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return;
float p = 1.f*n_correct/n_done; float p = 1.f*n_correct/n_done;
@ -1767,7 +1744,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
// clear the KV cache // clear the KV cache
llama_kv_self_clear(ctx); llama_kv_self_clear(ctx);
llama_batch batch = llama_batch_init(n_batch, 0, 1); common_batch batch(n_batch, 1);
for (int j = 0; j < num_batches; ++j) { for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;
@ -1781,14 +1758,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
tokens[batch_start] = llama_vocab_bos(vocab); tokens[batch_start] = llama_vocab_bos(vocab);
} }
common_batch_clear(batch); batch.clear();
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); batch.add_text(tokens[batch_start + i], j*n_batch + i, {0}, true);
} }
if (llama_decode(ctx, batch)) { if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__); LOG_ERR("%s : failed to eval\n", __func__);
llama_batch_free(batch);
return; return;
} }
@ -1801,8 +1777,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
} }
} }
llama_batch_free(batch);
const auto t_end = std::chrono::high_resolution_clock::now(); const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) { if (i == 0) {

View File

@ -74,40 +74,56 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
return chunks; return chunks;
} }
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { static void batch_add_seq(common_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
size_t n_tokens = tokens.size(); size_t n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) { for (size_t i = 0; i < n_tokens; i++) {
common_batch_add(batch, tokens[i], i, { seq_id }, true); batch.add_text(tokens[i], i, seq_id, true);
} }
} }
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
const struct llama_model * model = llama_get_model(ctx);
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
llama_kv_self_clear(ctx); llama_kv_self_clear(ctx);
// run model // run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq);
if (llama_decode(ctx, batch) < 0) { if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
// encoder-only model
if (llama_encode_ext(ctx, batch.get()) < 0) {
LOG_ERR("%s : failed to encode\n", __func__);
}
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
// decoder-only model
if (llama_decode_ext(ctx, batch.get()) < 0) {
LOG_ERR("%s : failed to decode\n", __func__); LOG_ERR("%s : failed to decode\n", __func__);
} }
}
for (int i = 0; i < batch.n_tokens; i++) { for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) {
if (!batch.logits[i]) { if (!batch.tokens[i].logits) {
continue; continue;
} }
// try to get sequence embeddings - supported only when pooling_type is not NONE const float * embd = nullptr;
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); int embd_pos = 0;
if (embd == NULL) {
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
// try to get token embeddings
embd = llama_get_embeddings_ith(ctx, i); embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) { embd_pos = i;
LOG_ERR("%s: failed to get embeddings for token %d\n", __func__, i); GGML_ASSERT(embd != NULL && "failed to get token embeddings");
continue; } else {
} // try to get sequence embeddings - supported only when pooling_type is not NONE
embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id);
embd_pos = batch.tokens[i].seq_id;
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
} }
float * out = output + batch.seq_id[i][0] * n_embd; float * out = output + embd_pos * n_embd;
common_embd_normalize(embd, out, n_embd, 2); common_embd_normalize(embd, out, n_embd, embd_norm);
} }
} }
@ -214,7 +230,7 @@ int main(int argc, char ** argv) {
// initialize batch // initialize batch
const int n_chunks = chunks.size(); const int n_chunks = chunks.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, 1); struct common_batch batch = common_batch(n_batch, 1);
// allocate output // allocate output
const int n_embd = llama_model_n_embd(model); const int n_embd = llama_model_n_embd(model);
@ -231,10 +247,10 @@ int main(int argc, char ** argv) {
const uint64_t n_toks = inp.size(); const uint64_t n_toks = inp.size();
// encode if at capacity // encode if at capacity
if (batch.n_tokens + n_toks > n_batch) { if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) {
float * out = emb + p * n_embd; float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd); batch_decode(ctx, batch, out, s, n_embd);
common_batch_clear(batch); batch.clear();
p += s; p += s;
s = 0; s = 0;
} }
@ -255,7 +271,7 @@ int main(int argc, char ** argv) {
chunks[i].tokens.clear(); chunks[i].tokens.clear();
} }
struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); struct common_batch query_batch = common_batch(n_batch, 1);
// start loop, receive query and return top k similar chunks based on cosine similarity // start loop, receive query and return top k similar chunks based on cosine similarity
std::string query; std::string query;
@ -269,7 +285,7 @@ int main(int argc, char ** argv) {
std::vector<float> query_emb(n_embd, 0); std::vector<float> query_emb(n_embd, 0);
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
common_batch_clear(query_batch); query_batch.clear();
// compute cosine similarities // compute cosine similarities
{ {
@ -299,6 +315,5 @@ int main(int argc, char ** argv) {
llama_perf_context_print(ctx); llama_perf_context_print(ctx);
// clean up // clean up
llama_batch_free(query_batch);
llama_backend_free(); llama_backend_free();
} }

View File

@ -905,10 +905,10 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
} }
// Check if we have enough space in the context to evaluate this batch // Check if we have enough space in the context to evaluate this batch
static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { static int check_context_size(const llama_context_ptr & ctx, const llama_batch_ext_ptr & batch) {
const int n_ctx = llama_n_ctx(ctx.get()); const int n_ctx = llama_n_ctx(ctx.get());
const int n_ctx_used = llama_kv_self_used_cells(ctx.get()); const int n_ctx_used = llama_kv_self_used_cells(ctx.get());
if (n_ctx_used + batch.n_tokens > n_ctx) { if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) {
printf(LOG_COL_DEFAULT "\n"); printf(LOG_COL_DEFAULT "\n");
printe("context size exceeded\n"); printe("context size exceeded\n");
return 1; return 1;
@ -946,11 +946,11 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
} }
// prepare a batch for the prompt // prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
llama_token new_token_id; llama_token new_token_id;
while (true) { while (true) {
check_context_size(llama_data.context, batch); check_context_size(llama_data.context, batch);
if (llama_decode(llama_data.context.get(), batch)) { if (llama_decode_ext(llama_data.context.get(), batch.get())) {
printe("failed to decode\n"); printe("failed to decode\n");
return 1; return 1;
} }
@ -969,7 +969,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
print_word_and_concatenate_to_response(piece, response); print_word_and_concatenate_to_response(piece, response);
// prepare the next batch with the sampled token // prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1); batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0));
} }
printf(LOG_COL_DEFAULT); printf(LOG_COL_DEFAULT);

View File

@ -48,15 +48,11 @@ int main(int argc, char ** argv) {
auto tokens = common_tokenize(ctx, params.prompt, true); auto tokens = common_tokenize(ctx, params.prompt, true);
// prepare the batch // prepare the batch
llama_batch batch = llama_batch_init(tokens.size(), 0, 1); llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0);
for (size_t i = 0; i < tokens.size(); i++) {
common_batch_add(batch, tokens[i], i, {0}, false);
}
batch.logits[batch.n_tokens - 1] = true; // generate next token
// evaluate prompt // evaluate prompt
llama_decode(ctx, batch); llama_decode_ext(ctx, batch);
n_past += batch.n_tokens; n_past += llama_batch_ext_get_n_tokens(batch);
// save state (rng, logits, embedding and kv_cache) to file // save state (rng, logits, embedding and kv_cache) to file
{ {
@ -83,12 +79,13 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str()); printf("%s", next_token_str.c_str());
result0 += next_token_str; result0 += next_token_str;
common_batch_clear(batch); llama_batch_ext_clear(batch);
common_batch_add(batch, next_token, n_past, {0}, true); llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode(ctx, batch)) { if (llama_decode_ext(ctx, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__); fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch); llama_batch_ext_free(batch);
return 1; return 1;
} }
n_past += 1; n_past += 1;
@ -135,12 +132,13 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str()); printf("%s", next_token_str.c_str());
result1 += next_token_str; result1 += next_token_str;
common_batch_clear(batch); llama_batch_ext_clear(batch);
common_batch_add(batch, next_token, n_past, {0}, true); llama_seq_id seq_id = 1;
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode(ctx2, batch)) { if (llama_decode_ext(ctx2, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__); fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch); llama_batch_ext_free(batch);
return 1; return 1;
} }
n_past += 1; n_past += 1;
@ -216,12 +214,13 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str()); printf("%s", next_token_str.c_str());
result2 += next_token_str; result2 += next_token_str;
common_batch_clear(batch); llama_batch_ext_clear(batch);
common_batch_add(batch, next_token, n_past, {1}, true); llama_seq_id seq_id = 1;
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode(ctx3, batch)) { if (llama_decode_ext(ctx3, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__); fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch); llama_batch_ext_free(batch);
return 1; return 1;
} }
n_past += 1; n_past += 1;
@ -233,7 +232,7 @@ int main(int argc, char ** argv) {
llama_sampler_free(smpl2); llama_sampler_free(smpl2);
llama_sampler_free(smpl3); llama_sampler_free(smpl3);
llama_batch_free(batch); llama_batch_ext_free(batch);
if (result0 != result2) { if (result0 != result2) {
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);

View File

@ -108,19 +108,20 @@ int main(int argc, char ** argv) {
} }
// prepare a batch for the prompt // prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0);
llama_batch_ext_set_output_last(batch);
llama_token new_token_id; llama_token new_token_id;
while (true) { while (true) {
// check if we have enough space in the context to evaluate this batch // check if we have enough space in the context to evaluate this batch
int n_ctx = llama_n_ctx(ctx); int n_ctx = llama_n_ctx(ctx);
int n_ctx_used = llama_kv_self_used_cells(ctx); int n_ctx_used = llama_kv_self_used_cells(ctx);
if (n_ctx_used + batch.n_tokens > n_ctx) { if (n_ctx_used + llama_batch_ext_get_n_tokens(batch) > n_ctx) {
printf("\033[0m\n"); printf("\033[0m\n");
fprintf(stderr, "context size exceeded\n"); fprintf(stderr, "context size exceeded\n");
exit(0); exit(0);
} }
if (llama_decode(ctx, batch)) { if (llama_decode_ext(ctx, batch)) {
GGML_ABORT("failed to decode\n"); GGML_ABORT("failed to decode\n");
} }
@ -144,9 +145,13 @@ int main(int argc, char ** argv) {
response += piece; response += piece;
// prepare the next batch with the sampled token // prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1); llama_batch_ext_clear(batch);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true);
} }
llama_batch_ext_free(batch);
return response; return response;
}; };

View File

@ -143,7 +143,8 @@ int main(int argc, char ** argv) {
// prepare a batch for the prompt // prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0);
llama_batch_ext_set_output_last(batch);
// main loop // main loop
@ -151,14 +152,14 @@ int main(int argc, char ** argv) {
int n_decode = 0; int n_decode = 0;
llama_token new_token_id; llama_token new_token_id;
for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) { for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch) < n_prompt + n_predict; ) {
// evaluate the current batch with the transformer model // evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) { if (llama_decode_ext(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1; return 1;
} }
n_pos += batch.n_tokens; n_pos += llama_batch_ext_get_n_tokens(batch);
// sample the next token // sample the next token
{ {
@ -180,7 +181,9 @@ int main(int argc, char ** argv) {
fflush(stdout); fflush(stdout);
// prepare the next batch with the sampled token // prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1); llama_batch_ext_clear(batch);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true);
n_decode += 1; n_decode += 1;
} }
@ -198,6 +201,7 @@ int main(int argc, char ** argv) {
llama_perf_context_print(ctx); llama_perf_context_print(ctx);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
llama_batch_ext_free(batch);
llama_sampler_free(smpl); llama_sampler_free(smpl);
llama_free(ctx); llama_free(ctx);
llama_model_free(model); llama_model_free(model);

View File

@ -132,7 +132,7 @@ int main(int argc, char ** argv) {
struct common_speculative * spec = common_speculative_init(ctx_dft); struct common_speculative * spec = common_speculative_init(ctx_dft);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), 1);
const auto t_enc_end = ggml_time_us(); const auto t_enc_end = ggml_time_us();
@ -151,8 +151,9 @@ int main(int argc, char ** argv) {
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
// always have a token to evaluate from before - id_last // always have a token to evaluate from before - id_last
common_batch_clear(batch_tgt); llama_batch_ext_clear(batch_tgt);
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_tgt, id_last, n_past++, &seq_id, 1, true);
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
{ {
@ -162,12 +163,12 @@ int main(int argc, char ** argv) {
} }
for (size_t i = 0; i < draft.size(); ++i) { for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true);
} }
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt); llama_decode_ext(ctx_tgt, batch_tgt);
} }
// sample from the full target batch and return the accepted tokens based on the target sampler // sample from the full target batch and return the accepted tokens based on the target sampler
@ -253,6 +254,7 @@ int main(int argc, char ** argv) {
common_sampler_free(smpl); common_sampler_free(smpl);
common_speculative_free(spec); common_speculative_free(spec);
llama_batch_ext_free(batch_tgt);
llama_backend_free(); llama_backend_free();
LOG("\n\n"); LOG("\n\n");

View File

@ -45,7 +45,7 @@ int main(int argc, char ** argv) {
} }
common_init(); common_init();
#ifdef 0
if (params.speculative.model.empty()) { if (params.speculative.model.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__); LOG_ERR("%s: --model-draft is required\n", __func__);
return 1; return 1;
@ -199,8 +199,8 @@ int main(int argc, char ** argv) {
drafts[s].smpl = common_sampler_init(model_dft, params.sampling); drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
} }
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); llama_batch_ext * batch_dft = llama_batch_ext_init(llama_n_batch(ctx_dft), 1);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft); llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft);
const auto t_dec_start = ggml_time_us(); const auto t_dec_start = ggml_time_us();
@ -441,12 +441,13 @@ int main(int argc, char ** argv) {
drafts[0].dists.push_back(std::vector<llama_token_data>()); drafts[0].dists.push_back(std::vector<llama_token_data>());
drafts[0].i_batch_tgt.push_back(0); drafts[0].i_batch_tgt.push_back(0);
common_batch_clear(batch_dft); llama_batch_ext_clear(batch_dft);
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true);
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1); llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft); llama_decode_ext(ctx_dft, batch_dft);
++n_past_dft; ++n_past_dft;
} }
@ -471,8 +472,9 @@ int main(int argc, char ** argv) {
drafts[0].drafting = true; drafts[0].drafting = true;
drafts[0].i_batch_dft = 0; drafts[0].i_batch_dft = 0;
common_batch_clear(batch_tgt); llama_batch_ext_clear(batch_tgt);
common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true);
// sample n_draft tokens from the draft model using tree-based sampling // sample n_draft tokens from the draft model using tree-based sampling
for (int i = 0; i < n_draft; ++i) { for (int i = 0; i < n_draft; ++i) {
@ -640,5 +642,6 @@ int main(int argc, char ** argv) {
LOG("\n\n"); LOG("\n\n");
#endif
return 0; return 0;
} }

View File

@ -817,7 +817,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
// create a llama_batch // create a llama_batch
// we use this object to submit token data for decoding // we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel); llama_batch_ext * batch = llama_batch_ext_init(std::max(prompt_inp.size(), (size_t) n_parallel), n_parallel);
std::vector<llama_seq_id> seq_ids(n_parallel, 0); std::vector<llama_seq_id> seq_ids(n_parallel, 0);
for (int32_t i = 0; i < n_parallel; ++i) { for (int32_t i = 0; i < n_parallel; ++i) {
@ -826,14 +826,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
// evaluate the initial prompt // evaluate the initial prompt
for (size_t i = 0; i < prompt_inp.size(); ++i) { for (size_t i = 0; i < prompt_inp.size(); ++i) {
common_batch_add(batch, prompt_inp[i], i, seq_ids, false); llama_batch_ext_add_text(batch, prompt_inp[i], i, seq_ids.data(), seq_ids.size(), false);
} }
GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size()); GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) prompt_inp.size());
// llama_decode will output logits only for the last token of the prompt // llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true; llama_batch_ext_set_output_last(batch);
if (llama_decode(ctx_ttc, batch) != 0) { if (llama_decode_ext(ctx_ttc, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__); LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1; return 1;
} }
@ -852,16 +852,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
// remember the batch index of the last token for each parallel sequence // remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from // we need this to determine which logits to sample from
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1); std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1);
int n_past = batch.n_tokens; int n_past = llama_batch_ext_get_n_tokens(batch);
int n_decode = 0; int n_decode = 0;
bool next_token_uses_guide_token = true; bool next_token_uses_guide_token = true;
while (n_decode <= n_predict) { while (n_decode <= n_predict) {
// prepare the next batch // prepare the next batch
common_batch_clear(batch); llama_batch_ext_clear(batch);
// sample the next token for each parallel sequence / stream // sample the next token for each parallel sequence / stream
for (int32_t i = 0; i < n_parallel; ++i) { for (int32_t i = 0; i < n_parallel; ++i) {
@ -917,14 +917,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
//LOG_CNT("%d", i); //LOG_CNT("%d", i);
} }
i_batch[i] = batch.n_tokens; i_batch[i] = llama_batch_ext_get_n_tokens(batch);
// push this new token for next evaluation // push this new token for next evaluation
common_batch_add(batch, new_token_id, n_past, { i }, true); llama_batch_ext_add_text(batch, new_token_id, n_past, &i, 1, false);
} }
// all streams are finished // all streams are finished
if (batch.n_tokens == 0) { if (llama_batch_ext_get_n_tokens(batch) == 0) {
break; break;
} }
@ -932,13 +932,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
n_past += 1; n_past += 1;
// evaluate the current batch with the transformer model // evaluate the current batch with the transformer model
if (llama_decode(ctx_ttc, batch)) { if (llama_decode_ext(ctx_ttc, batch)) {
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
return 1; return 1;
} }
} }
llama_batch_free(batch); llama_batch_ext_free(batch);
LOG("\n"); LOG("\n");
LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f); LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f);
@ -1007,14 +1007,15 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
const int n_codes = codes.size(); const int n_codes = codes.size();
llama_batch batch = llama_batch_init(n_codes, 0, 1); llama_batch_ext * batch = llama_batch_ext_init(n_codes, 1);
for (size_t i = 0; i < codes.size(); ++i) { for (size_t i = 0; i < codes.size(); ++i) {
common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits? llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, codes[i], i, &seq_id, 1, true); // TODO: all logits?
} }
GGML_ASSERT(batch.n_tokens == n_codes); GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == n_codes);
if (llama_decode(ctx_cts, batch) != 0) { if (llama_decode_ext(ctx_cts, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__); LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1; return 1;
} }
@ -1076,6 +1077,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str()); LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str());
llama_batch_ext_free(batch);
llama_backend_free(); llama_backend_free();
return 0; return 0;

View File

@ -995,9 +995,9 @@ extern "C" {
// Stores the encoder output internally for later use by the decoder cross-attention layers. // Stores the encoder output internally for later use by the decoder cross-attention layers.
// 0 - success // 0 - success
// < 0 - error. the KV cache state is restored to the state before this call // < 0 - error. the KV cache state is restored to the state before this call
DEPRECATED(LLAMA_API int32_t llama_encode( LLAMA_API int32_t llama_encode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch), "use llama_batch_ext API instead"); struct llama_batch batch);
LLAMA_API int32_t llama_encode_ext( LLAMA_API int32_t llama_encode_ext(
struct llama_context * ctx, struct llama_context * ctx,
@ -1007,9 +1007,9 @@ extern "C" {
// 0 - success // 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// < 0 - error. the KV cache state is restored to the state before this call // < 0 - error. the KV cache state is restored to the state before this call
DEPRECATED(LLAMA_API int32_t llama_decode( LLAMA_API int32_t llama_decode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch), "use llama_batch_ext API instead"); struct llama_batch batch);
LLAMA_API int32_t llama_decode_ext( LLAMA_API int32_t llama_decode_ext(
struct llama_context * ctx, struct llama_context * ctx,