mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 20:45:04 +00:00
return output ID from llama_batch_ext_add/set
This commit is contained in:
@ -606,7 +606,7 @@ struct common_batch {
|
|||||||
}
|
}
|
||||||
void set_logits_last() {
|
void set_logits_last() {
|
||||||
if (!tokens.empty()) {
|
if (!tokens.empty()) {
|
||||||
llama_batch_ext_set_logits_last(batch.get());
|
llama_batch_ext_set_output_last(batch.get());
|
||||||
tokens.back().logits = true;
|
tokens.back().logits = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -122,7 +122,7 @@ int main(int argc, char ** argv) {
|
|||||||
llama_batch_ext_add_text(batch, 0, i, &j, 1, false);
|
llama_batch_ext_add_text(batch, 0, i, &j, 1, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
llama_batch_ext_set_logits_last(batch);
|
llama_batch_ext_set_output_last(batch);
|
||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
const auto t_pp_start = ggml_time_us();
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// llama_decode will output logits only for the last token of the prompt
|
// llama_decode will output logits only for the last token of the prompt
|
||||||
llama_batch_ext_set_logits_last(batch);
|
llama_batch_ext_set_output_last(batch);
|
||||||
|
|
||||||
if (llama_decode_ext(ctx, batch) != 0) {
|
if (llama_decode_ext(ctx, batch) != 0) {
|
||||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||||
|
@ -900,7 +900,7 @@ extern "C" {
|
|||||||
//
|
//
|
||||||
DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one(
|
DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one(
|
||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int32_t n_tokens), "use llama_batch_ext API instead");
|
int32_t n_tokens), "use llama_batch_ext_init_from_text instead");
|
||||||
|
|
||||||
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
||||||
// Each token can be assigned up to n_seq_max sequence ids
|
// Each token can be assigned up to n_seq_max sequence ids
|
||||||
@ -912,7 +912,7 @@ extern "C" {
|
|||||||
DEPRECATED(LLAMA_API struct llama_batch llama_batch_init(
|
DEPRECATED(LLAMA_API struct llama_batch llama_batch_init(
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
int32_t embd,
|
int32_t embd,
|
||||||
int32_t n_seq_max), "use llama_batch_ext API instead");
|
int32_t n_seq_max), "use llama_batch_ext_init instead");
|
||||||
|
|
||||||
// Frees a batch of tokens allocated with llama_batch_init()
|
// Frees a batch of tokens allocated with llama_batch_init()
|
||||||
DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch),
|
DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch),
|
||||||
@ -950,28 +950,32 @@ extern "C" {
|
|||||||
|
|
||||||
// Add text tokens to the batch
|
// Add text tokens to the batch
|
||||||
// Return values:
|
// Return values:
|
||||||
// 0 : success
|
|
||||||
// -1 : not enough space in the batch
|
// -1 : not enough space in the batch
|
||||||
// -2 : embd is already set, cannot add text tokens
|
// -2 : embd is already set, cannot add text tokens
|
||||||
|
// otherwise, returns the output ID
|
||||||
LLAMA_API int32_t llama_batch_ext_add_text(
|
LLAMA_API int32_t llama_batch_ext_add_text(
|
||||||
struct llama_batch_ext * batch,
|
struct llama_batch_ext * batch,
|
||||||
llama_token token,
|
llama_token token,
|
||||||
llama_pos pos,
|
llama_pos pos,
|
||||||
const llama_seq_id * seq_ids,
|
const llama_seq_id * seq_ids,
|
||||||
size_t n_seq_ids,
|
size_t n_seq_ids,
|
||||||
float logits);
|
bool output);
|
||||||
|
|
||||||
// Set logits for the token in the ith sequence
|
// Set output (logits/embeddings) for the token in the ith sequence
|
||||||
// If pos == -1, logits will be set for the all tokens
|
// If pos == -1, output will be set for the all tokens
|
||||||
// Returns -1 if the token is not in the batch
|
// Return values:
|
||||||
LLAMA_API int32_t llama_batch_ext_set_logits(
|
// -1 : the token is not in the batch
|
||||||
|
// otherwise, returns the output ID
|
||||||
|
LLAMA_API int32_t llama_batch_ext_set_output(
|
||||||
struct llama_batch_ext * batch,
|
struct llama_batch_ext * batch,
|
||||||
llama_pos pos,
|
llama_pos pos,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
// Set logits for the last added token
|
// Set output (logits/embeddings) for the last added token
|
||||||
// Returns -1 if there is no tokens in the batch
|
// Return values:
|
||||||
LLAMA_API int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch);
|
// -1 : the batch is empty
|
||||||
|
// otherwise, returns the output ID
|
||||||
|
LLAMA_API int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch);
|
||||||
|
|
||||||
// Get a "view" from a number of tokens offset
|
// Get a "view" from a number of tokens offset
|
||||||
// Return returned batch must be freed with llama_batch_free()
|
// Return returned batch must be freed with llama_batch_free()
|
||||||
|
@ -410,25 +410,26 @@ int32_t llama_batch_ext_add_text(
|
|||||||
llama_pos pos,
|
llama_pos pos,
|
||||||
const llama_seq_id * seq_ids,
|
const llama_seq_id * seq_ids,
|
||||||
size_t n_seq_ids,
|
size_t n_seq_ids,
|
||||||
float logits) {
|
bool output) {
|
||||||
if (batch->n_tokens + 1 > batch->max_tokens) {
|
if (batch->n_tokens + 1 > batch->max_tokens) {
|
||||||
return -1; // llama_batch size exceeded
|
return -1; // llama_batch size exceeded
|
||||||
}
|
}
|
||||||
if (batch->embd) {
|
if (batch->embd) {
|
||||||
return -2; // embd is already set, cannot add text tokens
|
return -2; // embd is already set, cannot add text tokens
|
||||||
}
|
}
|
||||||
batch->token [batch->n_tokens] = token;
|
const int32_t output_id = batch->n_tokens;
|
||||||
batch->pos [batch->n_tokens] = pos;
|
batch->token [output_id] = token;
|
||||||
batch->n_seq_id[batch->n_tokens] = n_seq_ids;
|
batch->pos [output_id] = pos;
|
||||||
|
batch->n_seq_id[output_id] = n_seq_ids;
|
||||||
for (size_t j = 0; j < n_seq_ids; j++) {
|
for (size_t j = 0; j < n_seq_ids; j++) {
|
||||||
batch->seq_id[batch->n_tokens][j] = seq_ids[j];
|
batch->seq_id[batch->n_tokens][j] = seq_ids[j];
|
||||||
}
|
}
|
||||||
batch->logits [batch->n_tokens] = logits;
|
batch->logits [output_id] = output;
|
||||||
batch->n_tokens++;
|
batch->n_tokens++;
|
||||||
return 0;
|
return output_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_batch_ext_set_logits(
|
int32_t llama_batch_ext_set_output(
|
||||||
struct llama_batch_ext * batch,
|
struct llama_batch_ext * batch,
|
||||||
llama_pos pos,
|
llama_pos pos,
|
||||||
llama_seq_id seq_id) {
|
llama_seq_id seq_id) {
|
||||||
@ -439,7 +440,7 @@ int32_t llama_batch_ext_set_logits(
|
|||||||
// found the sequence
|
// found the sequence
|
||||||
if (pos == -1 || pos == batch->pos[i]) {
|
if (pos == -1 || pos == batch->pos[i]) {
|
||||||
batch->logits[i] = true;
|
batch->logits[i] = true;
|
||||||
return 0;
|
return i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -447,12 +448,13 @@ int32_t llama_batch_ext_set_logits(
|
|||||||
return -1; // not found
|
return -1; // not found
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch) {
|
int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch) {
|
||||||
if (batch->n_tokens == 0) {
|
if (batch->n_tokens == 0) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
batch->logits[batch->n_tokens - 1] = true;
|
const int32_t output_id = batch->n_tokens - 1;
|
||||||
return 0;
|
batch->logits[output_id] = true;
|
||||||
|
return output_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_batch_ext_clear(struct llama_batch_ext * batch) {
|
void llama_batch_ext_clear(struct llama_batch_ext * batch) {
|
||||||
|
Reference in New Issue
Block a user