From 17d3658b5f0a595cb6e0c56fa04dc00f8a6ab58d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 16 Feb 2025 00:02:53 +0100 Subject: [PATCH] move to llama_batch_ext --- common/common.cpp | 21 +++++--- common/common.h | 6 ++- common/speculative.cpp | 6 +-- include/llama-cpp.h | 6 +-- include/llama.h | 109 ++++++++++++++++++++++++++++---------- src/llama-batch.cpp | 116 +++++++++++++++++++++++++---------------- src/llama-batch.h | 16 +++--- src/llama.cpp | 59 +++++++++++++-------- 8 files changed, 222 insertions(+), 117 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index c79f1e736..b54e546f9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1610,20 +1610,29 @@ std::pair common_get_hf_file(const std::string &, cons // Batch utils // -void common_batch_clear(struct llama_batch * batch) { - llama_batch_clear(batch); +// DEPRECATED +void common_batch_clear(struct llama_batch & batch) { + batch.n_tokens = 0; } +// DEPRECATED void common_batch_add( - struct llama_batch * batch, + struct llama_batch & batch, llama_token id, llama_pos pos, const std::vector & seq_ids, bool logits) { - int32_t res = llama_batch_add_text_token(batch, id, pos, seq_ids.data(), seq_ids.size(), logits); - if (res == -1) { - LOG_ERR("%s: llama_batch size exceeded\n", __func__); + GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); + + batch.token [batch.n_tokens] = id; + batch.pos [batch.n_tokens] = pos; + batch.n_seq_id[batch.n_tokens] = seq_ids.size(); + for (size_t i = 0; i < seq_ids.size(); ++i) { + batch.seq_id[batch.n_tokens][i] = seq_ids[i]; } + batch.logits [batch.n_tokens] = logits; + + batch.n_tokens++; } // diff --git a/common/common.h b/common/common.h index 8ce6f2f12..524559de4 100644 --- a/common/common.h +++ b/common/common.h @@ -554,10 +554,12 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector & seq_ids, diff --git a/common/speculative.cpp b/common/speculative.cpp index 0836845ec..318e96ea3 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -13,7 +13,7 @@ struct common_speculative { struct llama_context * ctx; struct common_sampler * smpl; - llama_batch * batch; + llama_batch batch; llama_tokens prompt; }; @@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init( auto * result = new common_speculative { /* .ctx = */ ctx_dft, /* .smpl = */ nullptr, - /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 1), + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), /* .prompt = */ {}, }; @@ -215,7 +215,7 @@ llama_tokens common_speculative_gen_draft( } // we should rarely end-up here during normal decoding - if (llama_batch_get_n_tokens(batch) > 0) { + if (batch.n_tokens > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); llama_decode(ctx, batch); diff --git a/include/llama-cpp.h b/include/llama-cpp.h index 80c726e30..880a6a5fa 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -24,12 +24,12 @@ struct llama_adapter_lora_deleter { void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } }; -struct llama_batch_deleter { - void operator()(llama_batch * batch) { llama_batch_free(batch); } +struct llama_batch_ext_deleter { + void operator()(llama_batch_ext * batch) { llama_batch_ext_free(batch); } }; typedef std::unique_ptr llama_model_ptr; typedef std::unique_ptr llama_context_ptr; typedef std::unique_ptr llama_sampler_ptr; typedef std::unique_ptr llama_adapter_lora_ptr; -typedef std::unique_ptr llama_batch_ptr; +typedef std::unique_ptr llama_batch_ext_ptr; diff --git a/include/llama.h b/include/llama.h index 79dc8604d..32b4cdbe1 100644 --- a/include/llama.h +++ b/include/llama.h @@ -231,9 +231,38 @@ extern "C" { typedef bool (*llama_progress_callback)(float progress, void * user_data); - struct llama_batch; + // Input data for llama_decode + // + // WARN: This struct is DEPRECATED and will be removed in the future, use llama_batch_ext instead + // + // A llama_batch object can contain input about one or many sequences + // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens + // + // - token : the token ids of the input (used when embd is NULL) + // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) + // - pos : the positions of the respective token in the sequence + // (if set to NULL, the token position will be tracked automatically by llama_decode) + // - seq_id : the sequence to which the respective token belongs + // (if set to NULL, the sequence ID will be assumed to be 0) + // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output + // (if set to NULL, only the logits for last token will be returned) + // + typedef struct llama_batch { + int32_t n_tokens; - struct llama_batch_token_info { + llama_token * token; + float * embd; + llama_pos * pos; + int32_t * n_seq_id; + llama_seq_id ** seq_id; + int8_t * logits; // TODO: rename this to "output" + } llama_batch; + + // Input data for llama_decode / llama_encode + // It can contain text tokens and embeddings for one or many sequences + struct llama_batch_ext; + + struct llama_batch_ext_token_info { llama_token token; llama_pos pos; int32_t n_seq_id; @@ -815,9 +844,9 @@ extern "C" { // // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it // - LLAMA_API struct llama_batch * llama_batch_get_one( + DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one( llama_token * tokens, - int32_t n_tokens); + int32_t n_tokens), "use llama_batch_ext API instead"); // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Each token can be assigned up to n_seq_max sequence ids @@ -826,30 +855,47 @@ extern "C" { // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token // The rest of the llama_batch members are allocated with size n_tokens // All members are left uninitialized - // LLAMA_API struct llama_batch llama_batch_init( - // int32_t n_tokens, - // int32_t embd, - // int32_t n_seq_max); + DEPRECATED(LLAMA_API struct llama_batch llama_batch_init( + int32_t n_tokens, + int32_t embd, + int32_t n_seq_max), "use llama_batch_ext API instead"); + + // Frees a batch of tokens allocated with llama_batch_init() + DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch), + "use llama_batch_ext API instead"); // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Each token can be assigned up to n_seq_max sequence ids - // The batch has to be freed with llama_batch_free() - LLAMA_API struct llama_batch * llama_batch_init( + // The batch has to be freed with llama_batch_ext_free() + LLAMA_API struct llama_batch_ext * llama_batch_ext_init( int32_t n_tokens, int32_t n_seq_max); + // Same with llama_batch_init, but initializes the batch with the provided text tokens + // First token will be at position pos0 + // The sequence ID will be fixed to seq_id + // The batch has to be freed with llama_batch_ext_free() + LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text( + llama_token * tokens, + int32_t n_tokens, + int32_t pos0, + int32_t seq_id); + // Same with llama_batch_init, but initializes the batch with the provided raw embeddings - LLAMA_API struct llama_batch * llama_batch_init_from_embd( + // First token will be at position pos0 + // The sequence ID will be fixed to seq_id + // The batch has to be freed with llama_batch_ext_free() + LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd( float * embd, size_t n_embd, int32_t pos0, int32_t seq_id); // Get the number of tokens in the batch - LLAMA_API int32_t llama_batch_get_n_tokens(const struct llama_batch * batch); + LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch); - LLAMA_API struct llama_batch_token_info llama_batch_get_token_info( - struct llama_batch * batch, + LLAMA_API struct llama_batch_ext_token_info llama_batch_ext_get_token_info( + struct llama_batch_ext * batch, int32_t i); // Add text tokens to the batch @@ -857,8 +903,8 @@ extern "C" { // 0 : success // -1 : not enough space in the batch // -2 : embd is already set, cannot add text tokens - LLAMA_API int32_t llama_batch_add_text_token( - struct llama_batch * batch, + LLAMA_API int32_t llama_batch_ext_add_text_token( + struct llama_batch_ext * batch, llama_token token, llama_pos pos, const llama_seq_id * seq_ids, @@ -868,43 +914,50 @@ extern "C" { // Set logits for the token in the ith sequence // If pos == -1, logits will be set for the all tokens // Returns -1 if the token is not in the batch - LLAMA_API int32_t llama_batch_set_logits( - struct llama_batch * batch, + LLAMA_API int32_t llama_batch_ext_set_logits( + struct llama_batch_ext * batch, llama_pos pos, llama_seq_id seq_id); // Set logits for the last added token // Returns -1 if there is no tokens in the batch - LLAMA_API int32_t llama_batch_set_logits_last(struct llama_batch * batch); + LLAMA_API int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch); // Get a "view" from a number of tokens offset // Return returned batch must be freed with llama_batch_free() - LLAMA_API struct llama_batch * llama_batch_get_view( - struct llama_batch * batch, + LLAMA_API struct llama_batch_ext * llama_batch_ext_get_view( + struct llama_batch_ext * batch, int32_t offset, int32_t n_tokens); // Remove everything from the batch - LLAMA_API void llama_batch_clear(struct llama_batch * batch); + LLAMA_API void llama_batch_ext_clear(struct llama_batch_ext * batch); - // Frees a batch of tokens allocated with llama_batch_init() - LLAMA_API void llama_batch_free(struct llama_batch * batch); + // Frees a batch of tokens allocated with llama_batch_ext_init() + // If this is a view, the original batch is not freed + LLAMA_API void llama_batch_ext_free(struct llama_batch_ext * batch); // Processes a batch of tokens with the ecoder part of the encoder-decoder model. // Stores the encoder output internally for later use by the decoder cross-attention layers. // 0 - success // < 0 - error. the KV cache state is restored to the state before this call - LLAMA_API int32_t llama_encode( + DEPRECATED(LLAMA_API int32_t llama_encode( struct llama_context * ctx, - struct llama_batch * batch); + struct llama_batch batch), "use llama_batch_ext API instead"); + LLAMA_API int32_t llama_text_encode( + struct llama_context * ctx, + struct llama_batch_ext * batch); // Positive return values does not mean a fatal error, but rather a warning. // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // < 0 - error. the KV cache state is restored to the state before this call - LLAMA_API int32_t llama_decode( + DEPRECATED(LLAMA_API int32_t llama_decode( struct llama_context * ctx, - struct llama_batch * batch); + struct llama_batch batch), "use llama_batch_ext API instead"); + LLAMA_API int32_t llama_text_decode( + struct llama_context * ctx, + struct llama_batch_ext * batch); // Set the number of threads used for decoding // n_threads is the number of threads used for generation (single token) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index c9b6a97f7..36a3d00be 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { return ubatch; } -void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { +void llama_sbatch::from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split, bool logits_all) { GGML_ASSERT(batch.n_tokens >= 0); this->batch = &batch; this->n_embd = n_embd; @@ -273,49 +273,61 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim ); } -llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) { - batch = in_batch; - GGML_ASSERT(batch.n_tokens > 0); - if (!batch.pos) { - pos.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { +llama_batch_allocr::llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0) { + batch = new llama_batch_ext{ + /*n_tokens =*/ in_batch.n_tokens, + /*max_tokens =*/ in_batch.n_tokens, + /*is_view =*/ false, + /*tokens =*/ in_batch.token, + /*embd =*/ in_batch.embd, + /*pos =*/ in_batch.pos, + /*n_seq_id =*/ in_batch.n_seq_id, + /*seq_id =*/ in_batch.seq_id, + /*logits =*/ in_batch.logits, + }; + GGML_ASSERT(batch->n_tokens > 0); + if (!in_batch.pos) { + pos.resize(batch->n_tokens); + for (int32_t i = 0; i < batch->n_tokens; i++) { pos[i] = i + p0; } - batch.pos = pos.data(); + batch->pos = pos.data(); } - if (!batch.n_seq_id) { - n_seq_id.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { + if (!batch->n_seq_id) { + n_seq_id.resize(batch->n_tokens); + for (int32_t i = 0; i < batch->n_tokens; i++) { n_seq_id[i] = seq_id_0.size(); } - batch.n_seq_id = n_seq_id.data(); + batch->n_seq_id = n_seq_id.data(); } - if (!batch.seq_id) { - seq_id.resize(batch.n_tokens + 1); - seq_id[batch.n_tokens] = NULL; - for (int32_t i = 0; i < batch.n_tokens; i++) { + if (!batch->seq_id) { + seq_id.resize(batch->n_tokens + 1); + seq_id[batch->n_tokens] = NULL; + for (int32_t i = 0; i < batch->n_tokens; i++) { seq_id[i] = seq_id_0.data(); } - batch.seq_id = seq_id.data(); + batch->seq_id = seq_id.data(); } - if (!batch.logits) { - logits.resize(batch.n_tokens); + if (!batch->logits) { + logits.resize(batch->n_tokens); logits[logits.size() - 1] = true; - batch.logits = logits.data(); + batch->logits = logits.data(); } } +llama_batch_allocr::~llama_batch_allocr() { + delete batch; +} + // // interface implementation // -struct llama_batch * llama_batch_get_one( - llama_token * tokens, - int32_t n_tokens) { - return new llama_batch{ +struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens) { + return llama_batch{ /*n_tokens =*/ n_tokens, - /*max_tokens =*/ n_tokens, - /*is_view =*/ false, /*tokens =*/ tokens, /*embd =*/ nullptr, /*pos =*/ nullptr, @@ -325,8 +337,20 @@ struct llama_batch * llama_batch_get_one( }; } -static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { - llama_batch * batch = new llama_batch{ +struct llama_batch_ext * llama_batch_ext_init_from_text( + llama_token * tokens, + int32_t n_tokens, + int32_t pos0, + int32_t seq_id) { + llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1); + for (int32_t i = 0; i < n_tokens; i++) { + llama_batch_ext_add_text_token(batch, tokens[i], pos0 + i, &seq_id, 1, false); + } + return batch; +} + +static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + llama_batch_ext * batch = new llama_batch_ext{ /*n_tokens =*/ 0, /*max_tokens =*/ n_tokens_alloc, /*is_view =*/ false, @@ -357,16 +381,16 @@ static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_ return batch; } -struct llama_batch * llama_batch_init(int32_t n_tokens_alloc, int32_t n_seq_max) { - return llama_batch_init_impl(n_tokens_alloc, 0, n_seq_max); +struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_seq_max) { + return llama_batch_ext_init_impl(n_tokens_alloc, 0, n_seq_max); } -struct llama_batch * llama_batch_init_from_embd( +struct llama_batch_ext * llama_batch_ext_init_from_embd( float * embd, size_t n_embd, int32_t pos0, int32_t seq_id) { - struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1); + struct llama_batch_ext * batch = llama_batch_ext_init_impl(0, n_embd, 1); memcpy(batch->embd, embd, n_embd * sizeof(float)); for (size_t i = 0; i < n_embd; i++) { batch->pos [i] = pos0 + i; @@ -376,12 +400,12 @@ struct llama_batch * llama_batch_init_from_embd( return batch; } -int32_t llama_batch_get_n_tokens(const struct llama_batch * batch) { +int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) { return batch->n_tokens; } -int32_t llama_batch_add_text_token( - struct llama_batch * batch, +int32_t llama_batch_ext_add_text_token( + struct llama_batch_ext * batch, llama_token token, llama_pos pos, const llama_seq_id * seq_ids, @@ -404,8 +428,8 @@ int32_t llama_batch_add_text_token( return 0; } -int32_t llama_batch_set_logits( - struct llama_batch * batch, +int32_t llama_batch_ext_set_logits( + struct llama_batch_ext * batch, llama_pos pos, llama_seq_id seq_id) { for (int32_t i = 0; i < batch->n_tokens; i++) { @@ -423,7 +447,7 @@ int32_t llama_batch_set_logits( return -1; // not found } -int32_t llama_batch_set_logits_last(struct llama_batch * batch) { +int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch) { if (batch->n_tokens == 0) { return -1; } @@ -431,18 +455,18 @@ int32_t llama_batch_set_logits_last(struct llama_batch * batch) { return 0; } -void llama_batch_clear(struct llama_batch * batch) { +void llama_batch_ext_clear(struct llama_batch_ext * batch) { batch->n_tokens = 0; } -struct llama_batch * llama_batch_get_view( - struct llama_batch * batch, +struct llama_batch_ext * llama_batch_ext_get_view( + struct llama_batch_ext * batch, int32_t offset, int32_t n_tokens) { if (batch->embd) { return nullptr; // not yet supported } - llama_batch * batch_view = new llama_batch{ + llama_batch_ext * batch_view = new llama_batch_ext{ /*n_tokens =*/ n_tokens, /*max_tokens =*/ n_tokens, /*is_view =*/ true, @@ -456,11 +480,11 @@ struct llama_batch * llama_batch_get_view( return batch_view; } -struct llama_batch_token_info llama_batch_get_token_info( - struct llama_batch * batch, +struct llama_batch_ext_token_info llama_batch_ext_get_token_info( + struct llama_batch_ext * batch, int32_t i) { GGML_ASSERT(i >= 0 && i < batch->n_tokens); - return llama_batch_token_info{ + return llama_batch_ext_token_info{ /*token =*/ batch->token [i], /*pos =*/ batch->pos [i], /*n_seq_id =*/ batch->n_seq_id[i], @@ -469,7 +493,7 @@ struct llama_batch_token_info llama_batch_get_token_info( }; } -void llama_batch_free(struct llama_batch * batch) { +void llama_batch_ext_free(struct llama_batch_ext * batch) { // do not free the members if it's a view if (!batch->is_view) { if (batch->token) free(batch->token); diff --git a/src/llama-batch.h b/src/llama-batch.h index 70bc6d405..bbd2205b3 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -5,8 +5,8 @@ #include #include -// Input data for llama_decode -// A llama_batch object can contain input about one or many sequences +// Input data for llama_decode / llama_encode +// A llama_batch_ext object can contain input about one or many sequences // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens // // - token : the token ids of the input (used when embd is NULL) @@ -18,7 +18,7 @@ // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output // (if set to NULL, only the logits for last token will be returned) // -struct llama_batch { +struct llama_batch_ext { int32_t n_tokens; int32_t max_tokens; bool is_view; @@ -73,7 +73,7 @@ struct llama_sbatch { std::vector out_ids; std::vector seq; - const llama_batch * batch = nullptr; + const llama_batch_ext * batch = nullptr; // buffers for the ubatch std::vector ubatch_token; @@ -96,12 +96,12 @@ struct llama_sbatch { // sequence-wise split llama_ubatch split_seq(size_t n_ubatch); - void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); + void from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); }; // temporary allocate memory for the input batch if needed struct llama_batch_allocr { - struct llama_batch batch; + struct llama_batch_ext * batch; std::array seq_id_0 = { 0 }; // default sequence id std::vector pos; @@ -110,5 +110,7 @@ struct llama_batch_allocr { std::vector logits; // optionally fulfill the batch returned by llama_batch_get_one - llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); + llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0); + + ~llama_batch_allocr(); }; diff --git a/src/llama.cpp b/src/llama.cpp index 978ce0dd7..a3dc7824a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8445,7 +8445,7 @@ static enum ggml_status llama_graph_compute( static int llama_prepare_sbatch( llama_context & lctx, - const llama_batch & batch, + const llama_batch_ext & batch, uint32_t & n_outputs) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -8585,7 +8585,7 @@ static int llama_prepare_ubatch( // static int llama_decode_impl( llama_context & lctx, - llama_batch inp_batch) { + llama_batch_ext & inp_batch) { lctx.is_encoding = false; @@ -8594,10 +8594,6 @@ static int llama_decode_impl( return -1; } - // temporarily allocate memory for the input batch if needed - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1); - const llama_batch & batch = batch_allocr.batch; - const auto & model = lctx.model; const auto & vocab = model.vocab; const auto & hparams = model.hparams; @@ -8616,7 +8612,7 @@ static int llama_decode_impl( uint32_t n_outputs_prev = 0; { - const int ret = llama_prepare_sbatch(lctx, batch, n_outputs); + const int ret = llama_prepare_sbatch(lctx, inp_batch, n_outputs); if (ret != 0) { return ret; } @@ -8625,7 +8621,7 @@ static int llama_decode_impl( while (lctx.sbatch.n_tokens > 0) { llama_ubatch ubatch; { - const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens); + const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, inp_batch.n_tokens); if (ret != 0) { return ret; } @@ -8832,7 +8828,7 @@ static int llama_decode_impl( // static int llama_encode_impl( llama_context & lctx, - llama_batch inp_batch) { + llama_batch_ext & inp_batch) { lctx.is_encoding = true; @@ -8841,22 +8837,18 @@ static int llama_encode_impl( return -1; } - // temporary allocate memory for the input batch if needed - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1); - - const llama_batch & batch = batch_allocr.batch; - const uint32_t n_tokens = batch.n_tokens; + const uint32_t n_tokens = inp_batch.n_tokens; const auto & model = lctx.model; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + GGML_ASSERT((!inp_batch.token && inp_batch.embd) || (inp_batch.token && !inp_batch.embd)); // NOLINT - if (batch.token) { + if (inp_batch.token) { for (uint32_t i = 0; i < n_tokens; ++i) { - if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { - LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); + if (inp_batch.token[i] < 0 || (uint32_t) inp_batch.token[i] >= model.vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, inp_batch.token[i]); return -1; } } @@ -8873,7 +8865,7 @@ static int llama_encode_impl( const int64_t n_embd = hparams.n_embd; - lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + lctx.sbatch.from_batch(inp_batch, n_embd, /* simple_split */ true, /* logits_all */ true); const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens); @@ -9976,9 +9968,32 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) { /// + +// DEPRECATED int32_t llama_encode( struct llama_context * ctx, - struct llama_batch * batch) { + struct llama_batch batch) { + // temporarily allocate memory for the input batch if needed + // also convert llama_batch to llama_batch_ext + llama_batch_allocr batch_allocr(batch, batch.pos ? -1 : ctx->kv_self.max_pos() + 1); + llama_batch_ext * batch_ext = batch_allocr.batch; + return llama_text_encode(ctx, batch_ext); +} + +// DEPRECATED +int32_t llama_decode( + struct llama_context * ctx, + struct llama_batch batch) { + // temporarily allocate memory for the input batch if needed + // also convert llama_batch to llama_batch_ext + llama_batch_allocr batch_allocr(batch, batch.pos ? -1 : ctx->kv_self.max_pos() + 1); + llama_batch_ext * batch_ext = batch_allocr.batch; + return llama_text_decode(ctx, batch_ext); +} + +int32_t llama_text_encode( + struct llama_context * ctx, + struct llama_batch_ext * batch) { const int ret = llama_encode_impl(*ctx, *batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); @@ -9987,9 +10002,9 @@ int32_t llama_encode( return ret; } -int32_t llama_decode( +int32_t llama_text_decode( struct llama_context * ctx, - struct llama_batch * batch) { + struct llama_batch_ext * batch) { const int ret = llama_decode_impl(*ctx, *batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);