diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index e2fdfcc28..3efa604b9 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -148,7 +148,7 @@ static int eval_image(gemma3_context & ctx, std::string & fname) { int64_t t1 = ggml_time_ms(); eval_text(ctx, ""); llama_set_causal_attn(ctx.lctx, false); - llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0)); + llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0)); if (llama_decode_ext(ctx.lctx, batch_img.get())) { LOG_ERR("failed to decode image\n"); return 1; diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 53ce30215..de967e069 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -448,7 +448,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ n_eval = n_batch; } float * embd = image_embed->embed+i*n_embd; - llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, n_embd, 0, 0)); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/include/llama.h b/include/llama.h index fb6edda8e..2f58085fc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -938,11 +938,14 @@ extern "C" { bool output_last); // Same with llama_batch_init, but initializes the batch with the provided raw embeddings + // Size of embd should be n_tokens * n_embd + // n_embd is the number of embeddings per token, can be obtained from llama_model_n_embd() // First token will be at position pos0 // The sequence ID will be fixed to seq_id // The batch has to be freed with llama_batch_ext_free() LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd( float * embd, + size_t n_tokens, size_t n_embd, int32_t pos0, int32_t seq_id); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index d55625da1..a7f2717f1 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -353,7 +353,7 @@ struct llama_batch_ext * llama_batch_ext_init_from_text( return batch; } -static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { +static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) { llama_batch_ext * batch = new llama_batch_ext{ /*n_tokens =*/ 0, /*max_tokens =*/ n_tokens_alloc, @@ -366,8 +366,8 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc /*logits =*/ nullptr, }; - if (embd) { - batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); + if (n_embd) { + batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * n_embd); } else { batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); } @@ -391,14 +391,15 @@ struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_ struct llama_batch_ext * llama_batch_ext_init_from_embd( float * embd, + size_t n_tokens, size_t n_embd, int32_t pos0, int32_t seq_id) { - struct llama_batch_ext * batch = llama_batch_ext_init_impl(0, n_embd, 1); - memcpy(batch->embd, embd, n_embd * sizeof(float)); - for (size_t i = 0; i < n_embd; i++) { - batch->pos [i] = pos0 + i; - batch->n_seq_id[i] = 1; + struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1); + memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float)); + for (size_t i = 0; i < n_tokens; i++) { + batch->pos [i] = pos0 + i; + batch->n_seq_id[i] = 1; batch->seq_id [i][0] = seq_id; } return batch;