llama_batch_ext_ptr::from_text/embd

This commit is contained in:
Xuan Son Nguyen
2025-03-14 17:12:03 +01:00
parent 8e7714fa77
commit eaffba0f2e
12 changed files with 34 additions and 14 deletions

View File

@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true));
auto batch = llama_batch_ext_ptr::from_text(&tokens[i], n_eval, *n_past, 0, true);
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;

View File

@ -448,7 +448,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
n_eval = n_batch;
}
float * embd = image_embed->embed+i*n_embd;
llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, n_embd, 0, 0));
auto batch = llama_batch_ext_ptr::from_embd(embd, n_eval, n_embd, 0, 0);
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;

View File

@ -101,7 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true));
auto batch = llama_batch_ext_ptr::from_text(&tokens[i], n_eval, *n_past, 0, true);
if (llama_decode_ext(ctx_llama, batch.get())) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;

View File

@ -67,7 +67,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos));
float * batch_embd = image_embed->embed+i*n_embd;
llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(batch_embd, n_eval, n_embd, 0, 0));
auto batch = llama_batch_ext_ptr::from_embd(batch_embd, n_eval, n_embd, 0, 0);
llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval);
if (llama_decode_ext(ctx_llama, batch.get())) {