fix llama_batch_ext_init_from_embd

This commit is contained in:
Xuan Son Nguyen
2025-03-14 11:17:22 +01:00
parent 07d84fa3c2
commit ba79369615
4 changed files with 14 additions and 10 deletions

View File

@@ -353,7 +353,7 @@ struct llama_batch_ext * llama_batch_ext_init_from_text(
return batch;
}
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) {
llama_batch_ext * batch = new llama_batch_ext{
/*n_tokens =*/ 0,
/*max_tokens =*/ n_tokens_alloc,
@@ -366,8 +366,8 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc
/*logits =*/ nullptr,
};
if (embd) {
batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
if (n_embd) {
batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * n_embd);
} else {
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
}
@@ -391,14 +391,15 @@ struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_
struct llama_batch_ext * llama_batch_ext_init_from_embd(
float * embd,
size_t n_tokens,
size_t n_embd,
int32_t pos0,
int32_t seq_id) {
struct llama_batch_ext * batch = llama_batch_ext_init_impl(0, n_embd, 1);
memcpy(batch->embd, embd, n_embd * sizeof(float));
for (size_t i = 0; i < n_embd; i++) {
batch->pos [i] = pos0 + i;
batch->n_seq_id[i] = 1;
struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1);
memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float));
for (size_t i = 0; i < n_tokens; i++) {
batch->pos [i] = pos0 + i;
batch->n_seq_id[i] = 1;
batch->seq_id [i][0] = seq_id;
}
return batch;