From a363251fac38be94d27f4c63c1765efa65d9d0d2 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 14 Mar 2025 11:25:36 +0100 Subject: [PATCH] qwen2vl: use llama_batch_ext_set_pos --- examples/llava/qwen2vl-cli.cpp | 15 ++++----------- include/llama.h | 6 ++++++ src/llama-batch.cpp | 8 ++++++++ 3 files changed, 18 insertions(+), 11 deletions(-) diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 1f4242580..a702ab46a 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -66,18 +66,11 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); - // TODO: move this to llama_batch_ext API - llama_batch batch = { - int32_t(n_eval), // n_tokens - nullptr, // token - (image_embed->embed+i*n_embd), // embed - batch_mrope_pos.data(), // pos - nullptr, // n_seq_id - nullptr, // seq_id - nullptr, // logits - }; + float * batch_embd = image_embed->embed+i*n_embd; + llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(batch_embd, n_eval, n_embd, 0, 0)); + llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval); - if (llama_decode(ctx_llama, batch)) { + if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; } diff --git a/include/llama.h b/include/llama.h index 2f58085fc..28fb82606 100644 --- a/include/llama.h +++ b/include/llama.h @@ -950,6 +950,12 @@ extern "C" { int32_t pos0, int32_t seq_id); + // Set arbitrary token to the embeddings batch + // Note: this is only to be used in conjunction with llama_batch_ext_init_from_embd() + // n_pos must match the n_tokens of the batch + // Returns -1 if n_pos does not match the n_tokens of the batch + LLAMA_API int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos); + // Get the number of tokens in the batch LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index a7f2717f1..f56b3b03b 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -405,6 +405,14 @@ struct llama_batch_ext * llama_batch_ext_init_from_embd( return batch; } +int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos) { + if (batch->n_tokens != n_pos) { + return -1; + } + memcpy(batch->pos, pos, n_pos * sizeof(llama_pos)); + return 0; +} + int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) { return batch->n_tokens; }