diff --git a/.gitignore b/.gitignore index 694f36e04..56b5ac2c1 100644 --- a/.gitignore +++ b/.gitignore @@ -98,6 +98,7 @@ examples/server/*.css.hpp examples/server/*.html.hpp examples/server/*.js.hpp examples/server/*.mjs.hpp +examples/server/*.gz.hpp !build_64.sh !examples/*.bat !examples/*/*.kts diff --git a/common/common.cpp b/common/common.cpp index 8661e164a..c79f1e736 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -580,6 +580,7 @@ std::string string_from(const struct llama_context * ctx, const std::vector common_get_hf_file(const std::string &, cons // Batch utils // -void common_batch_clear(struct llama_batch & batch) { - batch.n_tokens = 0; +void common_batch_clear(struct llama_batch * batch) { + llama_batch_clear(batch); } void common_batch_add( - struct llama_batch & batch, + struct llama_batch * batch, llama_token id, llama_pos pos, const std::vector & seq_ids, bool logits) { - GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); - - batch.token [batch.n_tokens] = id; - batch.pos [batch.n_tokens] = pos; - batch.n_seq_id[batch.n_tokens] = seq_ids.size(); - for (size_t i = 0; i < seq_ids.size(); ++i) { - batch.seq_id[batch.n_tokens][i] = seq_ids[i]; + int32_t res = llama_batch_add_text_token(batch, id, pos, seq_ids.data(), seq_ids.size(), logits); + if (res == -1) { + LOG_ERR("%s: llama_batch size exceeded\n", __func__); } - batch.logits [batch.n_tokens] = logits; - - batch.n_tokens++; } // diff --git a/common/common.h b/common/common.h index 98b9a4464..8ce6f2f12 100644 --- a/common/common.h +++ b/common/common.h @@ -554,10 +554,10 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector & seq_ids, diff --git a/common/speculative.cpp b/common/speculative.cpp index 318e96ea3..0836845ec 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -13,7 +13,7 @@ struct common_speculative { struct llama_context * ctx; struct common_sampler * smpl; - llama_batch batch; + llama_batch * batch; llama_tokens prompt; }; @@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init( auto * result = new common_speculative { /* .ctx = */ ctx_dft, /* .smpl = */ nullptr, - /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 1), /* .prompt = */ {}, }; @@ -215,7 +215,7 @@ llama_tokens common_speculative_gen_draft( } // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { + if (llama_batch_get_n_tokens(batch) > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); llama_decode(ctx, batch); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 71151183b..41f8dc505 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1215,7 +1215,7 @@ struct server_slot { // only used for completion/embedding/infill/rerank server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; - llama_batch batch_spec = {}; + llama_batch_ptr batch_spec; llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1787,7 +1787,7 @@ struct server_context { llama_context_params cparams_dft; - llama_batch batch = {}; + llama_batch_ptr batch; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1820,11 +1820,7 @@ struct server_context { common_speculative_free(slot.spec); slot.spec = nullptr; - - llama_batch_free(slot.batch_spec); } - - llama_batch_free(batch); } bool load_model(const common_params & params) { @@ -1944,7 +1940,7 @@ struct server_context { slot.n_predict = params_base.n_predict; if (model_dft) { - slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); + slot.batch_spec.reset(llama_batch_init(params_base.speculative.n_max + 1, 1)); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { @@ -1969,7 +1965,7 @@ struct server_context { slot.reset(); - slots.push_back(slot); + slots.push_back(std::move(slot)); } default_generation_settings_for_props = slots[0].to_json(); @@ -1980,7 +1976,7 @@ struct server_context { const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + batch.reset(llama_batch_init(std::max(n_batch, params_base.n_parallel), 1)); } metrics.init(); @@ -2098,9 +2094,7 @@ struct server_context { } if (slot.ctx_dft) { - llama_batch_free(slot.batch_spec); - - slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + slot.batch_spec.reset(llama_batch_init(slot.params.speculative.n_max + 1, 1)); } slot.state = SLOT_STATE_STARTED; @@ -2408,7 +2402,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_embedding(const server_slot & slot, const llama_batch & batch) { + void send_embedding(const server_slot & slot, llama_batch_ptr & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -2419,18 +2413,19 @@ struct server_context { std::vector embd_res(n_embd, 0.0f); - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) { + llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i); + if (!tok.logits || tok.seq_id[0] != slot.id) { continue; } - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]); if (embd == NULL) { embd = llama_get_embeddings_ith(ctx, i); } if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]); res->embedding.push_back(std::vector(n_embd, 0.0f)); continue; @@ -2451,24 +2446,25 @@ struct server_context { queue_results.send(std::move(res)); } - void send_rerank(const server_slot & slot, const llama_batch & batch) { + void send_rerank(const server_slot & slot, llama_batch_ptr & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; res->n_tokens = slot.n_prompt_tokens; - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) { + llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i); + if (!tok.logits || tok.seq_id[0] != slot.id) { continue; } - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]); if (embd == NULL) { embd = llama_get_embeddings_ith(ctx, i); } if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]); res->score = -1e6; continue; @@ -2859,7 +2855,7 @@ struct server_context { } // start populating the batch for this iteration - common_batch_clear(batch); + common_batch_clear(batch.get()); // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; @@ -2881,9 +2877,9 @@ struct server_context { continue; } - slot.i_batch = batch.n_tokens; + slot.i_batch = llama_batch_get_n_tokens(batch.get()); - common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); + common_batch_add(batch.get(), slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; @@ -2900,7 +2896,7 @@ struct server_context { int32_t n_ubatch = llama_n_ubatch(ctx); // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens == 0) { + if (params_base.cont_batching || llama_batch_get_n_tokens(batch.get()) == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one if (slot.is_processing()) { @@ -3066,7 +3062,7 @@ struct server_context { // non-causal tasks require to fit the entire prompt in the physical batch if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + if (llama_batch_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) { continue; } } @@ -3086,11 +3082,11 @@ struct server_context { slot.cache_tokens.resize(slot.n_past); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + while (slot.n_past < slot.n_prompt_tokens && llama_batch_get_n_tokens(batch.get()) < n_batch) { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); + common_batch_add(batch.get(), prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -3100,13 +3096,13 @@ struct server_context { slot.n_past++; } - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT(llama_batch_get_n_tokens(batch.get()) > 0); common_sampler_reset(slot.smpl); @@ -3116,27 +3112,27 @@ struct server_context { } // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; + llama_batch_set_logits_last(batch.get()); slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + slot.i_batch = llama_batch_get_n_tokens(batch.get()) - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_get_n_tokens(batch.get())); } } - if (batch.n_tokens >= n_batch) { + if (llama_batch_get_n_tokens(batch.get()) >= n_batch) { break; } } } - if (batch.n_tokens == 0) { + if (llama_batch_get_n_tokens(batch.get()) == 0) { SRV_WRN("%s", "no tokens to decode\n"); return; } - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_get_n_tokens(batch.get())); if (slot_batched) { // make sure we're in the right embedding mode @@ -3146,20 +3142,12 @@ struct server_context { } // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + for (int32_t i = 0; i < llama_batch_get_n_tokens(batch.get()); i += n_batch) { + const int32_t n_tokens = std::min(n_batch, llama_batch_get_n_tokens(batch.get()) - i); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; + llama_batch_ptr batch_view(llama_batch_get_view(batch.get(), i, n_tokens)); - const int ret = llama_decode(ctx, batch_view); + const int ret = llama_decode(ctx, batch_view.get()); metrics.on_decoded(slots); if (ret != 0) { @@ -3294,16 +3282,16 @@ struct server_context { } // construct the speculation batch - common_batch_clear(slot.batch_spec); - common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + common_batch_clear(slot.batch_spec.get()); + common_batch_add (slot.batch_spec.get(), id, slot.n_past, { slot.id }, true); for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + common_batch_add(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, { slot.id }, true); } - SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_get_n_tokens(slot.batch_spec.get())); - llama_decode(ctx, slot.batch_spec); + llama_decode(ctx, slot.batch_spec.get()); // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); diff --git a/include/llama-cpp.h b/include/llama-cpp.h index 8f6368177..80c726e30 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -24,7 +24,12 @@ struct llama_adapter_lora_deleter { void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } }; +struct llama_batch_deleter { + void operator()(llama_batch * batch) { llama_batch_free(batch); } +}; + typedef std::unique_ptr llama_model_ptr; typedef std::unique_ptr llama_context_ptr; typedef std::unique_ptr llama_sampler_ptr; typedef std::unique_ptr llama_adapter_lora_ptr; +typedef std::unique_ptr llama_batch_ptr; diff --git a/include/llama.h b/include/llama.h index 2c7569f8a..79dc8604d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -233,6 +233,14 @@ extern "C" { struct llama_batch; + struct llama_batch_token_info { + llama_token token; + llama_pos pos; + int32_t n_seq_id; + llama_seq_id * seq_id; + int8_t logits; + }; + enum llama_model_kv_override_type { LLAMA_KV_OVERRIDE_TYPE_INT, LLAMA_KV_OVERRIDE_TYPE_FLOAT, @@ -837,34 +845,44 @@ extern "C" { int32_t pos0, int32_t seq_id); + // Get the number of tokens in the batch + LLAMA_API int32_t llama_batch_get_n_tokens(const struct llama_batch * batch); + + LLAMA_API struct llama_batch_token_info llama_batch_get_token_info( + struct llama_batch * batch, + int32_t i); + // Add text tokens to the batch - // First token in the list starts at position pos0 // Return values: // 0 : success // -1 : not enough space in the batch // -2 : embd is already set, cannot add text tokens - LLAMA_API int32_t llama_batch_add_text( + LLAMA_API int32_t llama_batch_add_text_token( struct llama_batch * batch, - llama_token * tokens, - size_t n_tokens, - int32_t pos0, - int32_t seq_id); - - // Same as llama_batch_add_text, but accepts multiple sequences - LLAMA_API int32_t llama_batch_add_text( - struct llama_batch * batch, - llama_token * tokens, - size_t n_tokens, - int32_t pos0, - int32_t * seq_ids, - size_t n_seq_ids); + llama_token token, + llama_pos pos, + const llama_seq_id * seq_ids, + size_t n_seq_ids, + float logits); // Set logits for the token in the ith sequence // If pos == -1, logits will be set for the all tokens + // Returns -1 if the token is not in the batch LLAMA_API int32_t llama_batch_set_logits( struct llama_batch * batch, - int32_t pos, - int32_t seq_id); + llama_pos pos, + llama_seq_id seq_id); + + // Set logits for the last added token + // Returns -1 if there is no tokens in the batch + LLAMA_API int32_t llama_batch_set_logits_last(struct llama_batch * batch); + + // Get a "view" from a number of tokens offset + // Return returned batch must be freed with llama_batch_free() + LLAMA_API struct llama_batch * llama_batch_get_view( + struct llama_batch * batch, + int32_t offset, + int32_t n_tokens); // Remove everything from the batch LLAMA_API void llama_batch_clear(struct llama_batch * batch); @@ -878,7 +896,7 @@ extern "C" { // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_encode( struct llama_context * ctx, - struct llama_batch batch); + struct llama_batch * batch); // Positive return values does not mean a fatal error, but rather a warning. // 0 - success @@ -886,7 +904,7 @@ extern "C" { // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_decode( struct llama_context * ctx, - struct llama_batch batch); + struct llama_batch * batch); // Set the number of threads used for decoding // n_threads is the number of threads used for generation (single token) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 027ac2413..c9b6a97f7 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -314,6 +314,8 @@ struct llama_batch * llama_batch_get_one( int32_t n_tokens) { return new llama_batch{ /*n_tokens =*/ n_tokens, + /*max_tokens =*/ n_tokens, + /*is_view =*/ false, /*tokens =*/ tokens, /*embd =*/ nullptr, /*pos =*/ nullptr, @@ -326,6 +328,8 @@ struct llama_batch * llama_batch_get_one( static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { llama_batch * batch = new llama_batch{ /*n_tokens =*/ 0, + /*max_tokens =*/ n_tokens_alloc, + /*is_view =*/ false, /*tokens =*/ nullptr, /*embd =*/ nullptr, /*pos =*/ nullptr, @@ -364,50 +368,46 @@ struct llama_batch * llama_batch_init_from_embd( int32_t seq_id) { struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1); memcpy(batch->embd, embd, n_embd * sizeof(float)); - for (int32_t i = 0; i < n_embd; i++) { + for (size_t i = 0; i < n_embd; i++) { batch->pos [i] = pos0 + i; batch->n_seq_id[i] = 1; batch->seq_id [i][0] = seq_id; } + return batch; } -int32_t llama_batch_add_text( +int32_t llama_batch_get_n_tokens(const struct llama_batch * batch) { + return batch->n_tokens; +} + +int32_t llama_batch_add_text_token( struct llama_batch * batch, - llama_token * tokens, - size_t n_tokens, - int32_t pos0, - int32_t * seq_ids, - size_t n_seq_ids) { - if (batch->n_tokens + n_tokens > batch->n_tokens) { - return -1; + llama_token token, + llama_pos pos, + const llama_seq_id * seq_ids, + size_t n_seq_ids, + float logits) { + if (batch->n_tokens + 1 > batch->max_tokens) { + return -1; // llama_batch size exceeded } if (batch->embd) { - return -2; + return -2; // embd is already set, cannot add text tokens } - for (int32_t i = 0; i < n_tokens; i++) { - batch->token [batch->n_tokens + i] = tokens[i]; - batch->pos [batch->n_tokens + i] = pos0 + i; - batch->n_seq_id[batch->n_tokens + i] = n_seq_ids; - for (int32_t j = 0; j < n_seq_ids; j++) { - batch->seq_id[batch->n_tokens + i][j] = seq_ids[j]; - } + batch->token [batch->n_tokens] = token; + batch->pos [batch->n_tokens] = pos; + batch->n_seq_id[batch->n_tokens] = n_seq_ids; + for (size_t j = 0; j < n_seq_ids; j++) { + batch->seq_id[batch->n_tokens][j] = seq_ids[j]; } -} - -int32_t llama_batch_add_text( - struct llama_batch * batch, - llama_token * tokens, - size_t n_tokens, - int32_t pos0, - int32_t seq_id) { - std::array seq_ids = { seq_id }; - return llama_batch_add_text(batch, tokens, n_tokens, pos0, seq_ids.data(), seq_ids.size()); + batch->logits [batch->n_tokens] = logits; + batch->n_tokens++; + return 0; } int32_t llama_batch_set_logits( struct llama_batch * batch, - int32_t pos, - int32_t seq_id) { + llama_pos pos, + llama_seq_id seq_id) { for (int32_t i = 0; i < batch->n_tokens; i++) { // find the token having seq_id for (int32_t j = 0; j < batch->n_seq_id[i]; j++) { @@ -415,28 +415,74 @@ int32_t llama_batch_set_logits( // found the sequence if (pos == -1 || pos == batch->pos[i]) { batch->logits[i] = true; - break; + return 0; } } } } + return -1; // not found +} + +int32_t llama_batch_set_logits_last(struct llama_batch * batch) { + if (batch->n_tokens == 0) { + return -1; + } + batch->logits[batch->n_tokens - 1] = true; + return 0; } void llama_batch_clear(struct llama_batch * batch) { batch->n_tokens = 0; } -void llama_batch_free(struct llama_batch * batch) { - if (batch->token) free(batch->token); - if (batch->embd) free(batch->embd); - if (batch->pos) free(batch->pos); - if (batch->n_seq_id) free(batch->n_seq_id); - if (batch->seq_id) { - for (int i = 0; batch->seq_id[i] != nullptr; ++i) { - free(batch->seq_id[i]); - } - free(batch->seq_id); +struct llama_batch * llama_batch_get_view( + struct llama_batch * batch, + int32_t offset, + int32_t n_tokens) { + if (batch->embd) { + return nullptr; // not yet supported + } + llama_batch * batch_view = new llama_batch{ + /*n_tokens =*/ n_tokens, + /*max_tokens =*/ n_tokens, + /*is_view =*/ true, + /*tokens =*/ batch->token + offset, + /*embd =*/ nullptr, + /*pos =*/ batch->pos + offset, + /*n_seq_id =*/ batch->n_seq_id + offset, + /*seq_id =*/ batch->seq_id + offset, + /*logits =*/ batch->logits + offset, + }; + return batch_view; +} + +struct llama_batch_token_info llama_batch_get_token_info( + struct llama_batch * batch, + int32_t i) { + GGML_ASSERT(i >= 0 && i < batch->n_tokens); + return llama_batch_token_info{ + /*token =*/ batch->token [i], + /*pos =*/ batch->pos [i], + /*n_seq_id =*/ batch->n_seq_id[i], + /*seq_id =*/ batch->seq_id [i], + /*logits =*/ batch->logits [i], + }; +} + +void llama_batch_free(struct llama_batch * batch) { + // do not free the members if it's a view + if (!batch->is_view) { + if (batch->token) free(batch->token); + if (batch->embd) free(batch->embd); + if (batch->pos) free(batch->pos); + if (batch->n_seq_id) free(batch->n_seq_id); + if (batch->seq_id) { + for (int i = 0; batch->seq_id[i] != nullptr; ++i) { + free(batch->seq_id[i]); + } + free(batch->seq_id); + } + if (batch->logits) free(batch->logits); } - if (batch->logits) free(batch->logits); delete batch; } diff --git a/src/llama-batch.h b/src/llama-batch.h index de702da76..70bc6d405 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -20,6 +20,8 @@ // struct llama_batch { int32_t n_tokens; + int32_t max_tokens; + bool is_view; llama_token * token; float * embd; diff --git a/src/llama.cpp b/src/llama.cpp index 607f27861..978ce0dd7 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9978,8 +9978,8 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) { int32_t llama_encode( struct llama_context * ctx, - struct llama_batch batch) { - const int ret = llama_encode_impl(*ctx, batch); + struct llama_batch * batch) { + const int ret = llama_encode_impl(*ctx, *batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); } @@ -9989,8 +9989,8 @@ int32_t llama_encode( int32_t llama_decode( struct llama_context * ctx, - struct llama_batch batch) { - const int ret = llama_decode_impl(*ctx, batch); + struct llama_batch * batch) { + const int ret = llama_decode_impl(*ctx, *batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); }