From 4aabf4e8f4b88e96c6c98a504b2c8cbe0d815e46 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 13 Mar 2025 17:47:07 +0100 Subject: [PATCH] return output ID from llama_batch_ext_add/set --- common/common.h | 2 +- examples/batched-bench/batched-bench.cpp | 2 +- examples/batched/batched.cpp | 2 +- include/llama.h | 26 ++++++++++++++---------- src/llama-batch.cpp | 24 ++++++++++++---------- 5 files changed, 31 insertions(+), 25 deletions(-) diff --git a/common/common.h b/common/common.h index c7dbcc202..afede57bb 100644 --- a/common/common.h +++ b/common/common.h @@ -606,7 +606,7 @@ struct common_batch { } void set_logits_last() { if (!tokens.empty()) { - llama_batch_ext_set_logits_last(batch.get()); + llama_batch_ext_set_output_last(batch.get()); tokens.back().logits = true; } } diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 1eb0ede77..8f7c2c94b 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -122,7 +122,7 @@ int main(int argc, char ** argv) { llama_batch_ext_add_text(batch, 0, i, &j, 1, false); } } - llama_batch_ext_set_logits_last(batch); + llama_batch_ext_set_output_last(batch); const auto t_pp_start = ggml_time_us(); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 858053a88..1ed189859 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -131,7 +131,7 @@ int main(int argc, char ** argv) { } // llama_decode will output logits only for the last token of the prompt - llama_batch_ext_set_logits_last(batch); + llama_batch_ext_set_output_last(batch); if (llama_decode_ext(ctx, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); diff --git a/include/llama.h b/include/llama.h index 564ffe1aa..ee74d9a8c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -900,7 +900,7 @@ extern "C" { // DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one( llama_token * tokens, - int32_t n_tokens), "use llama_batch_ext API instead"); + int32_t n_tokens), "use llama_batch_ext_init_from_text instead"); // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Each token can be assigned up to n_seq_max sequence ids @@ -912,7 +912,7 @@ extern "C" { DEPRECATED(LLAMA_API struct llama_batch llama_batch_init( int32_t n_tokens, int32_t embd, - int32_t n_seq_max), "use llama_batch_ext API instead"); + int32_t n_seq_max), "use llama_batch_ext_init instead"); // Frees a batch of tokens allocated with llama_batch_init() DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch), @@ -950,28 +950,32 @@ extern "C" { // Add text tokens to the batch // Return values: - // 0 : success // -1 : not enough space in the batch // -2 : embd is already set, cannot add text tokens + // otherwise, returns the output ID LLAMA_API int32_t llama_batch_ext_add_text( struct llama_batch_ext * batch, llama_token token, llama_pos pos, const llama_seq_id * seq_ids, size_t n_seq_ids, - float logits); + bool output); - // Set logits for the token in the ith sequence - // If pos == -1, logits will be set for the all tokens - // Returns -1 if the token is not in the batch - LLAMA_API int32_t llama_batch_ext_set_logits( + // Set output (logits/embeddings) for the token in the ith sequence + // If pos == -1, output will be set for the all tokens + // Return values: + // -1 : the token is not in the batch + // otherwise, returns the output ID + LLAMA_API int32_t llama_batch_ext_set_output( struct llama_batch_ext * batch, llama_pos pos, llama_seq_id seq_id); - // Set logits for the last added token - // Returns -1 if there is no tokens in the batch - LLAMA_API int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch); + // Set output (logits/embeddings) for the last added token + // Return values: + // -1 : the batch is empty + // otherwise, returns the output ID + LLAMA_API int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch); // Get a "view" from a number of tokens offset // Return returned batch must be freed with llama_batch_free() diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index d8117c3f0..bae8b37b3 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -410,25 +410,26 @@ int32_t llama_batch_ext_add_text( llama_pos pos, const llama_seq_id * seq_ids, size_t n_seq_ids, - float logits) { + bool output) { if (batch->n_tokens + 1 > batch->max_tokens) { return -1; // llama_batch size exceeded } if (batch->embd) { return -2; // embd is already set, cannot add text tokens } - batch->token [batch->n_tokens] = token; - batch->pos [batch->n_tokens] = pos; - batch->n_seq_id[batch->n_tokens] = n_seq_ids; + const int32_t output_id = batch->n_tokens; + batch->token [output_id] = token; + batch->pos [output_id] = pos; + batch->n_seq_id[output_id] = n_seq_ids; for (size_t j = 0; j < n_seq_ids; j++) { batch->seq_id[batch->n_tokens][j] = seq_ids[j]; } - batch->logits [batch->n_tokens] = logits; + batch->logits [output_id] = output; batch->n_tokens++; - return 0; + return output_id; } -int32_t llama_batch_ext_set_logits( +int32_t llama_batch_ext_set_output( struct llama_batch_ext * batch, llama_pos pos, llama_seq_id seq_id) { @@ -439,7 +440,7 @@ int32_t llama_batch_ext_set_logits( // found the sequence if (pos == -1 || pos == batch->pos[i]) { batch->logits[i] = true; - return 0; + return i; } } } @@ -447,12 +448,13 @@ int32_t llama_batch_ext_set_logits( return -1; // not found } -int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch) { +int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch) { if (batch->n_tokens == 0) { return -1; } - batch->logits[batch->n_tokens - 1] = true; - return 0; + const int32_t output_id = batch->n_tokens - 1; + batch->logits[output_id] = true; + return output_id; } void llama_batch_ext_clear(struct llama_batch_ext * batch) {