mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 20:45:04 +00:00
fix llama_batch_ext_init_from_embd
This commit is contained in:
@ -148,7 +148,7 @@ static int eval_image(gemma3_context & ctx, std::string & fname) {
|
|||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
eval_text(ctx, "<start_of_image>");
|
eval_text(ctx, "<start_of_image>");
|
||||||
llama_set_causal_attn(ctx.lctx, false);
|
llama_set_causal_attn(ctx.lctx, false);
|
||||||
llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0));
|
llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0));
|
||||||
if (llama_decode_ext(ctx.lctx, batch_img.get())) {
|
if (llama_decode_ext(ctx.lctx, batch_img.get())) {
|
||||||
LOG_ERR("failed to decode image\n");
|
LOG_ERR("failed to decode image\n");
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -448,7 +448,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
|
|||||||
n_eval = n_batch;
|
n_eval = n_batch;
|
||||||
}
|
}
|
||||||
float * embd = image_embed->embed+i*n_embd;
|
float * embd = image_embed->embed+i*n_embd;
|
||||||
llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, 0, 0));
|
llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, n_embd, 0, 0));
|
||||||
if (llama_decode_ext(ctx_llama, batch.get())) {
|
if (llama_decode_ext(ctx_llama, batch.get())) {
|
||||||
LOG_ERR("%s : failed to eval\n", __func__);
|
LOG_ERR("%s : failed to eval\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
|
@ -938,11 +938,14 @@ extern "C" {
|
|||||||
bool output_last);
|
bool output_last);
|
||||||
|
|
||||||
// Same with llama_batch_init, but initializes the batch with the provided raw embeddings
|
// Same with llama_batch_init, but initializes the batch with the provided raw embeddings
|
||||||
|
// Size of embd should be n_tokens * n_embd
|
||||||
|
// n_embd is the number of embeddings per token, can be obtained from llama_model_n_embd()
|
||||||
// First token will be at position pos0
|
// First token will be at position pos0
|
||||||
// The sequence ID will be fixed to seq_id
|
// The sequence ID will be fixed to seq_id
|
||||||
// The batch has to be freed with llama_batch_ext_free()
|
// The batch has to be freed with llama_batch_ext_free()
|
||||||
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd(
|
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd(
|
||||||
float * embd,
|
float * embd,
|
||||||
|
size_t n_tokens,
|
||||||
size_t n_embd,
|
size_t n_embd,
|
||||||
int32_t pos0,
|
int32_t pos0,
|
||||||
int32_t seq_id);
|
int32_t seq_id);
|
||||||
|
@ -353,7 +353,7 @@ struct llama_batch_ext * llama_batch_ext_init_from_text(
|
|||||||
return batch;
|
return batch;
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) {
|
||||||
llama_batch_ext * batch = new llama_batch_ext{
|
llama_batch_ext * batch = new llama_batch_ext{
|
||||||
/*n_tokens =*/ 0,
|
/*n_tokens =*/ 0,
|
||||||
/*max_tokens =*/ n_tokens_alloc,
|
/*max_tokens =*/ n_tokens_alloc,
|
||||||
@ -366,8 +366,8 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc
|
|||||||
/*logits =*/ nullptr,
|
/*logits =*/ nullptr,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (embd) {
|
if (n_embd) {
|
||||||
batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
|
batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * n_embd);
|
||||||
} else {
|
} else {
|
||||||
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
||||||
}
|
}
|
||||||
@ -391,14 +391,15 @@ struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_
|
|||||||
|
|
||||||
struct llama_batch_ext * llama_batch_ext_init_from_embd(
|
struct llama_batch_ext * llama_batch_ext_init_from_embd(
|
||||||
float * embd,
|
float * embd,
|
||||||
|
size_t n_tokens,
|
||||||
size_t n_embd,
|
size_t n_embd,
|
||||||
int32_t pos0,
|
int32_t pos0,
|
||||||
int32_t seq_id) {
|
int32_t seq_id) {
|
||||||
struct llama_batch_ext * batch = llama_batch_ext_init_impl(0, n_embd, 1);
|
struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1);
|
||||||
memcpy(batch->embd, embd, n_embd * sizeof(float));
|
memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float));
|
||||||
for (size_t i = 0; i < n_embd; i++) {
|
for (size_t i = 0; i < n_tokens; i++) {
|
||||||
batch->pos [i] = pos0 + i;
|
batch->pos [i] = pos0 + i;
|
||||||
batch->n_seq_id[i] = 1;
|
batch->n_seq_id[i] = 1;
|
||||||
batch->seq_id [i][0] = seq_id;
|
batch->seq_id [i][0] = seq_id;
|
||||||
}
|
}
|
||||||
return batch;
|
return batch;
|
||||||
|
Reference in New Issue
Block a user