return output ID from llama_batch_ext_add/set

This commit is contained in:
Xuan Son Nguyen
2025-03-13 17:47:07 +01:00
parent 86973cb14a
commit 4aabf4e8f4
5 changed files with 31 additions and 25 deletions

View File

@@ -410,25 +410,26 @@ int32_t llama_batch_ext_add_text(
llama_pos pos,
const llama_seq_id * seq_ids,
size_t n_seq_ids,
float logits) {
bool output) {
if (batch->n_tokens + 1 > batch->max_tokens) {
return -1; // llama_batch size exceeded
}
if (batch->embd) {
return -2; // embd is already set, cannot add text tokens
}
batch->token [batch->n_tokens] = token;
batch->pos [batch->n_tokens] = pos;
batch->n_seq_id[batch->n_tokens] = n_seq_ids;
const int32_t output_id = batch->n_tokens;
batch->token [output_id] = token;
batch->pos [output_id] = pos;
batch->n_seq_id[output_id] = n_seq_ids;
for (size_t j = 0; j < n_seq_ids; j++) {
batch->seq_id[batch->n_tokens][j] = seq_ids[j];
}
batch->logits [batch->n_tokens] = logits;
batch->logits [output_id] = output;
batch->n_tokens++;
return 0;
return output_id;
}
int32_t llama_batch_ext_set_logits(
int32_t llama_batch_ext_set_output(
struct llama_batch_ext * batch,
llama_pos pos,
llama_seq_id seq_id) {
@@ -439,7 +440,7 @@ int32_t llama_batch_ext_set_logits(
// found the sequence
if (pos == -1 || pos == batch->pos[i]) {
batch->logits[i] = true;
return 0;
return i;
}
}
}
@@ -447,12 +448,13 @@ int32_t llama_batch_ext_set_logits(
return -1; // not found
}
int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch) {
int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch) {
if (batch->n_tokens == 0) {
return -1;
}
batch->logits[batch->n_tokens - 1] = true;
return 0;
const int32_t output_id = batch->n_tokens - 1;
batch->logits[output_id] = true;
return output_id;
}
void llama_batch_ext_clear(struct llama_batch_ext * batch) {