diff --git a/include/llama.h b/include/llama.h index 1f5f3a09b..2c7569f8a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -231,29 +231,7 @@ extern "C" { typedef bool (*llama_progress_callback)(float progress, void * user_data); - // Input data for llama_decode - // A llama_batch object can contain input about one or many sequences - // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens - // - // - token : the token ids of the input (used when embd is NULL) - // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) - // - pos : the positions of the respective token in the sequence - // (if set to NULL, the token position will be tracked automatically by llama_decode) - // - seq_id : the sequence to which the respective token belongs - // (if set to NULL, the sequence ID will be assumed to be 0) - // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output - // (if set to NULL, only the logits for last token will be returned) - // - typedef struct llama_batch { - int32_t n_tokens; - - llama_token * token; - float * embd; - llama_pos * pos; - int32_t * n_seq_id; - llama_seq_id ** seq_id; - int8_t * logits; // TODO: rename this to "output" - } llama_batch; + struct llama_batch; enum llama_model_kv_override_type { LLAMA_KV_OVERRIDE_TYPE_INT, @@ -829,7 +807,7 @@ extern "C" { // // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it // - LLAMA_API struct llama_batch llama_batch_get_one( + LLAMA_API struct llama_batch * llama_batch_get_one( llama_token * tokens, int32_t n_tokens); @@ -840,13 +818,59 @@ extern "C" { // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token // The rest of the llama_batch members are allocated with size n_tokens // All members are left uninitialized - LLAMA_API struct llama_batch llama_batch_init( + // LLAMA_API struct llama_batch llama_batch_init( + // int32_t n_tokens, + // int32_t embd, + // int32_t n_seq_max); + + // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens + // Each token can be assigned up to n_seq_max sequence ids + // The batch has to be freed with llama_batch_free() + LLAMA_API struct llama_batch * llama_batch_init( int32_t n_tokens, - int32_t embd, int32_t n_seq_max); + // Same with llama_batch_init, but initializes the batch with the provided raw embeddings + LLAMA_API struct llama_batch * llama_batch_init_from_embd( + float * embd, + size_t n_embd, + int32_t pos0, + int32_t seq_id); + + // Add text tokens to the batch + // First token in the list starts at position pos0 + // Return values: + // 0 : success + // -1 : not enough space in the batch + // -2 : embd is already set, cannot add text tokens + LLAMA_API int32_t llama_batch_add_text( + struct llama_batch * batch, + llama_token * tokens, + size_t n_tokens, + int32_t pos0, + int32_t seq_id); + + // Same as llama_batch_add_text, but accepts multiple sequences + LLAMA_API int32_t llama_batch_add_text( + struct llama_batch * batch, + llama_token * tokens, + size_t n_tokens, + int32_t pos0, + int32_t * seq_ids, + size_t n_seq_ids); + + // Set logits for the token in the ith sequence + // If pos == -1, logits will be set for the all tokens + LLAMA_API int32_t llama_batch_set_logits( + struct llama_batch * batch, + int32_t pos, + int32_t seq_id); + + // Remove everything from the batch + LLAMA_API void llama_batch_clear(struct llama_batch * batch); + // Frees a batch of tokens allocated with llama_batch_init() - LLAMA_API void llama_batch_free(struct llama_batch batch); + LLAMA_API void llama_batch_free(struct llama_batch * batch); // Processes a batch of tokens with the ecoder part of the encoder-decoder model. // Stores the encoder output internally for later use by the decoder cross-attention layers. diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57f..027ac2413 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -309,10 +309,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 // interface implementation // -struct llama_batch llama_batch_get_one( +struct llama_batch * llama_batch_get_one( llama_token * tokens, int32_t n_tokens) { - return { + return new llama_batch{ /*n_tokens =*/ n_tokens, /*tokens =*/ tokens, /*embd =*/ nullptr, @@ -323,8 +323,8 @@ struct llama_batch llama_batch_get_one( }; } -struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { - llama_batch batch = { +static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + llama_batch * batch = new llama_batch{ /*n_tokens =*/ 0, /*tokens =*/ nullptr, /*embd =*/ nullptr, @@ -335,34 +335,108 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ }; if (embd) { - batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); + batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); } else { - batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); + batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); } - batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); - batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); - batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); + batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); + batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); + batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); for (int i = 0; i < n_tokens_alloc; ++i) { - batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); } - batch.seq_id[n_tokens_alloc] = nullptr; + batch->seq_id[n_tokens_alloc] = nullptr; - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); return batch; } -void llama_batch_free(struct llama_batch batch) { - if (batch.token) free(batch.token); - if (batch.embd) free(batch.embd); - if (batch.pos) free(batch.pos); - if (batch.n_seq_id) free(batch.n_seq_id); - if (batch.seq_id) { - for (int i = 0; batch.seq_id[i] != nullptr; ++i) { - free(batch.seq_id[i]); - } - free(batch.seq_id); - } - if (batch.logits) free(batch.logits); +struct llama_batch * llama_batch_init(int32_t n_tokens_alloc, int32_t n_seq_max) { + return llama_batch_init_impl(n_tokens_alloc, 0, n_seq_max); +} + +struct llama_batch * llama_batch_init_from_embd( + float * embd, + size_t n_embd, + int32_t pos0, + int32_t seq_id) { + struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1); + memcpy(batch->embd, embd, n_embd * sizeof(float)); + for (int32_t i = 0; i < n_embd; i++) { + batch->pos [i] = pos0 + i; + batch->n_seq_id[i] = 1; + batch->seq_id [i][0] = seq_id; + } +} + +int32_t llama_batch_add_text( + struct llama_batch * batch, + llama_token * tokens, + size_t n_tokens, + int32_t pos0, + int32_t * seq_ids, + size_t n_seq_ids) { + if (batch->n_tokens + n_tokens > batch->n_tokens) { + return -1; + } + if (batch->embd) { + return -2; + } + for (int32_t i = 0; i < n_tokens; i++) { + batch->token [batch->n_tokens + i] = tokens[i]; + batch->pos [batch->n_tokens + i] = pos0 + i; + batch->n_seq_id[batch->n_tokens + i] = n_seq_ids; + for (int32_t j = 0; j < n_seq_ids; j++) { + batch->seq_id[batch->n_tokens + i][j] = seq_ids[j]; + } + } +} + +int32_t llama_batch_add_text( + struct llama_batch * batch, + llama_token * tokens, + size_t n_tokens, + int32_t pos0, + int32_t seq_id) { + std::array seq_ids = { seq_id }; + return llama_batch_add_text(batch, tokens, n_tokens, pos0, seq_ids.data(), seq_ids.size()); +} + +int32_t llama_batch_set_logits( + struct llama_batch * batch, + int32_t pos, + int32_t seq_id) { + for (int32_t i = 0; i < batch->n_tokens; i++) { + // find the token having seq_id + for (int32_t j = 0; j < batch->n_seq_id[i]; j++) { + if (batch->seq_id[i][j] == seq_id) { + // found the sequence + if (pos == -1 || pos == batch->pos[i]) { + batch->logits[i] = true; + break; + } + } + } + } +} + +void llama_batch_clear(struct llama_batch * batch) { + batch->n_tokens = 0; +} + +void llama_batch_free(struct llama_batch * batch) { + if (batch->token) free(batch->token); + if (batch->embd) free(batch->embd); + if (batch->pos) free(batch->pos); + if (batch->n_seq_id) free(batch->n_seq_id); + if (batch->seq_id) { + for (int i = 0; batch->seq_id[i] != nullptr; ++i) { + free(batch->seq_id[i]); + } + free(batch->seq_id); + } + if (batch->logits) free(batch->logits); + delete batch; } diff --git a/src/llama-batch.h b/src/llama-batch.h index 773c3808b..de702da76 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -5,6 +5,30 @@ #include #include +// Input data for llama_decode +// A llama_batch object can contain input about one or many sequences +// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens +// +// - token : the token ids of the input (used when embd is NULL) +// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) +// - pos : the positions of the respective token in the sequence +// (if set to NULL, the token position will be tracked automatically by llama_decode) +// - seq_id : the sequence to which the respective token belongs +// (if set to NULL, the sequence ID will be assumed to be 0) +// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output +// (if set to NULL, only the logits for last token will be returned) +// +struct llama_batch { + int32_t n_tokens; + + llama_token * token; + float * embd; + llama_pos * pos; + int32_t * n_seq_id; + llama_seq_id ** seq_id; + int8_t * logits; // TODO: rename this to "output" +}; + // very similar to llama_batch, // but has more metadata about sequences struct llama_ubatch {