first proposal for private llama_batch

This commit is contained in:
Xuan Son Nguyen
2025-02-14 00:48:12 +01:00
parent 04045bb842
commit 4ed4fe75ed
3 changed files with 173 additions and 51 deletions

View File

@ -231,29 +231,7 @@ extern "C" {
typedef bool (*llama_progress_callback)(float progress, void * user_data);
// Input data for llama_decode
// A llama_batch object can contain input about one or many sequences
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
//
// - token : the token ids of the input (used when embd is NULL)
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence
// (if set to NULL, the token position will be tracked automatically by llama_decode)
// - seq_id : the sequence to which the respective token belongs
// (if set to NULL, the sequence ID will be assumed to be 0)
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
// (if set to NULL, only the logits for last token will be returned)
//
typedef struct llama_batch {
int32_t n_tokens;
llama_token * token;
float * embd;
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
} llama_batch;
struct llama_batch;
enum llama_model_kv_override_type {
LLAMA_KV_OVERRIDE_TYPE_INT,
@ -829,7 +807,7 @@ extern "C" {
//
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
//
LLAMA_API struct llama_batch llama_batch_get_one(
LLAMA_API struct llama_batch * llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens);
@ -840,13 +818,59 @@ extern "C" {
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
// The rest of the llama_batch members are allocated with size n_tokens
// All members are left uninitialized
LLAMA_API struct llama_batch llama_batch_init(
// LLAMA_API struct llama_batch llama_batch_init(
// int32_t n_tokens,
// int32_t embd,
// int32_t n_seq_max);
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
// Each token can be assigned up to n_seq_max sequence ids
// The batch has to be freed with llama_batch_free()
LLAMA_API struct llama_batch * llama_batch_init(
int32_t n_tokens,
int32_t embd,
int32_t n_seq_max);
// Same with llama_batch_init, but initializes the batch with the provided raw embeddings
LLAMA_API struct llama_batch * llama_batch_init_from_embd(
float * embd,
size_t n_embd,
int32_t pos0,
int32_t seq_id);
// Add text tokens to the batch
// First token in the list starts at position pos0
// Return values:
// 0 : success
// -1 : not enough space in the batch
// -2 : embd is already set, cannot add text tokens
LLAMA_API int32_t llama_batch_add_text(
struct llama_batch * batch,
llama_token * tokens,
size_t n_tokens,
int32_t pos0,
int32_t seq_id);
// Same as llama_batch_add_text, but accepts multiple sequences
LLAMA_API int32_t llama_batch_add_text(
struct llama_batch * batch,
llama_token * tokens,
size_t n_tokens,
int32_t pos0,
int32_t * seq_ids,
size_t n_seq_ids);
// Set logits for the token in the ith sequence
// If pos == -1, logits will be set for the all tokens
LLAMA_API int32_t llama_batch_set_logits(
struct llama_batch * batch,
int32_t pos,
int32_t seq_id);
// Remove everything from the batch
LLAMA_API void llama_batch_clear(struct llama_batch * batch);
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);
LLAMA_API void llama_batch_free(struct llama_batch * batch);
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
// Stores the encoder output internally for later use by the decoder cross-attention layers.