mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-13 20:07:41 -04:00
fix llama_batch_ext_init_from_embd
This commit is contained in:
@@ -353,7 +353,7 @@ struct llama_batch_ext * llama_batch_ext_init_from_text(
|
||||
return batch;
|
||||
}
|
||||
|
||||
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
||||
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) {
|
||||
llama_batch_ext * batch = new llama_batch_ext{
|
||||
/*n_tokens =*/ 0,
|
||||
/*max_tokens =*/ n_tokens_alloc,
|
||||
@@ -366,8 +366,8 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc
|
||||
/*logits =*/ nullptr,
|
||||
};
|
||||
|
||||
if (embd) {
|
||||
batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
|
||||
if (n_embd) {
|
||||
batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * n_embd);
|
||||
} else {
|
||||
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
||||
}
|
||||
@@ -391,14 +391,15 @@ struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_
|
||||
|
||||
struct llama_batch_ext * llama_batch_ext_init_from_embd(
|
||||
float * embd,
|
||||
size_t n_tokens,
|
||||
size_t n_embd,
|
||||
int32_t pos0,
|
||||
int32_t seq_id) {
|
||||
struct llama_batch_ext * batch = llama_batch_ext_init_impl(0, n_embd, 1);
|
||||
memcpy(batch->embd, embd, n_embd * sizeof(float));
|
||||
for (size_t i = 0; i < n_embd; i++) {
|
||||
batch->pos [i] = pos0 + i;
|
||||
batch->n_seq_id[i] = 1;
|
||||
struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1);
|
||||
memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float));
|
||||
for (size_t i = 0; i < n_tokens; i++) {
|
||||
batch->pos [i] = pos0 + i;
|
||||
batch->n_seq_id[i] = 1;
|
||||
batch->seq_id [i][0] = seq_id;
|
||||
}
|
||||
return batch;
|
||||
|
Reference in New Issue
Block a user