mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 12:35:16 +00:00
qwen2vl: use llama_batch_ext_set_pos
This commit is contained in:
@ -66,18 +66,11 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
|
|||||||
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
|
memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos));
|
||||||
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
|
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
|
||||||
|
|
||||||
// TODO: move this to llama_batch_ext API
|
float * batch_embd = image_embed->embed+i*n_embd;
|
||||||
llama_batch batch = {
|
llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(batch_embd, n_eval, n_embd, 0, 0));
|
||||||
int32_t(n_eval), // n_tokens
|
llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval);
|
||||||
nullptr, // token
|
|
||||||
(image_embed->embed+i*n_embd), // embed
|
|
||||||
batch_mrope_pos.data(), // pos
|
|
||||||
nullptr, // n_seq_id
|
|
||||||
nullptr, // seq_id
|
|
||||||
nullptr, // logits
|
|
||||||
};
|
|
||||||
|
|
||||||
if (llama_decode(ctx_llama, batch)) {
|
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -950,6 +950,12 @@ extern "C" {
|
|||||||
int32_t pos0,
|
int32_t pos0,
|
||||||
int32_t seq_id);
|
int32_t seq_id);
|
||||||
|
|
||||||
|
// Set arbitrary token to the embeddings batch
|
||||||
|
// Note: this is only to be used in conjunction with llama_batch_ext_init_from_embd()
|
||||||
|
// n_pos must match the n_tokens of the batch
|
||||||
|
// Returns -1 if n_pos does not match the n_tokens of the batch
|
||||||
|
LLAMA_API int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos);
|
||||||
|
|
||||||
// Get the number of tokens in the batch
|
// Get the number of tokens in the batch
|
||||||
LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch);
|
LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch);
|
||||||
|
|
||||||
|
@ -405,6 +405,14 @@ struct llama_batch_ext * llama_batch_ext_init_from_embd(
|
|||||||
return batch;
|
return batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos) {
|
||||||
|
if (batch->n_tokens != n_pos) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
memcpy(batch->pos, pos, n_pos * sizeof(llama_pos));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) {
|
int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) {
|
||||||
return batch->n_tokens;
|
return batch->n_tokens;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user