mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 12:35:16 +00:00
first proposal for private llama_batch
This commit is contained in:
@ -231,29 +231,7 @@ extern "C" {
|
||||
|
||||
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
||||
|
||||
// Input data for llama_decode
|
||||
// A llama_batch object can contain input about one or many sequences
|
||||
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
|
||||
//
|
||||
// - token : the token ids of the input (used when embd is NULL)
|
||||
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||
// - pos : the positions of the respective token in the sequence
|
||||
// (if set to NULL, the token position will be tracked automatically by llama_decode)
|
||||
// - seq_id : the sequence to which the respective token belongs
|
||||
// (if set to NULL, the sequence ID will be assumed to be 0)
|
||||
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||
// (if set to NULL, only the logits for last token will be returned)
|
||||
//
|
||||
typedef struct llama_batch {
|
||||
int32_t n_tokens;
|
||||
|
||||
llama_token * token;
|
||||
float * embd;
|
||||
llama_pos * pos;
|
||||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
} llama_batch;
|
||||
struct llama_batch;
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
LLAMA_KV_OVERRIDE_TYPE_INT,
|
||||
@ -829,7 +807,7 @@ extern "C" {
|
||||
//
|
||||
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
||||
//
|
||||
LLAMA_API struct llama_batch llama_batch_get_one(
|
||||
LLAMA_API struct llama_batch * llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens);
|
||||
|
||||
@ -840,13 +818,59 @@ extern "C" {
|
||||
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
|
||||
// The rest of the llama_batch members are allocated with size n_tokens
|
||||
// All members are left uninitialized
|
||||
LLAMA_API struct llama_batch llama_batch_init(
|
||||
// LLAMA_API struct llama_batch llama_batch_init(
|
||||
// int32_t n_tokens,
|
||||
// int32_t embd,
|
||||
// int32_t n_seq_max);
|
||||
|
||||
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
||||
// Each token can be assigned up to n_seq_max sequence ids
|
||||
// The batch has to be freed with llama_batch_free()
|
||||
LLAMA_API struct llama_batch * llama_batch_init(
|
||||
int32_t n_tokens,
|
||||
int32_t embd,
|
||||
int32_t n_seq_max);
|
||||
|
||||
// Same with llama_batch_init, but initializes the batch with the provided raw embeddings
|
||||
LLAMA_API struct llama_batch * llama_batch_init_from_embd(
|
||||
float * embd,
|
||||
size_t n_embd,
|
||||
int32_t pos0,
|
||||
int32_t seq_id);
|
||||
|
||||
// Add text tokens to the batch
|
||||
// First token in the list starts at position pos0
|
||||
// Return values:
|
||||
// 0 : success
|
||||
// -1 : not enough space in the batch
|
||||
// -2 : embd is already set, cannot add text tokens
|
||||
LLAMA_API int32_t llama_batch_add_text(
|
||||
struct llama_batch * batch,
|
||||
llama_token * tokens,
|
||||
size_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t seq_id);
|
||||
|
||||
// Same as llama_batch_add_text, but accepts multiple sequences
|
||||
LLAMA_API int32_t llama_batch_add_text(
|
||||
struct llama_batch * batch,
|
||||
llama_token * tokens,
|
||||
size_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t * seq_ids,
|
||||
size_t n_seq_ids);
|
||||
|
||||
// Set logits for the token in the ith sequence
|
||||
// If pos == -1, logits will be set for the all tokens
|
||||
LLAMA_API int32_t llama_batch_set_logits(
|
||||
struct llama_batch * batch,
|
||||
int32_t pos,
|
||||
int32_t seq_id);
|
||||
|
||||
// Remove everything from the batch
|
||||
LLAMA_API void llama_batch_clear(struct llama_batch * batch);
|
||||
|
||||
// Frees a batch of tokens allocated with llama_batch_init()
|
||||
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
||||
LLAMA_API void llama_batch_free(struct llama_batch * batch);
|
||||
|
||||
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
|
||||
// Stores the encoder output internally for later use by the decoder cross-attention layers.
|
||||
|
@ -309,10 +309,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
|
||||
// interface implementation
|
||||
//
|
||||
|
||||
struct llama_batch llama_batch_get_one(
|
||||
struct llama_batch * llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens) {
|
||||
return {
|
||||
return new llama_batch{
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ tokens,
|
||||
/*embd =*/ nullptr,
|
||||
@ -323,8 +323,8 @@ struct llama_batch llama_batch_get_one(
|
||||
};
|
||||
}
|
||||
|
||||
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
||||
llama_batch batch = {
|
||||
static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
||||
llama_batch * batch = new llama_batch{
|
||||
/*n_tokens =*/ 0,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ nullptr,
|
||||
@ -335,34 +335,108 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
||||
};
|
||||
|
||||
if (embd) {
|
||||
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
|
||||
batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
|
||||
} else {
|
||||
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
||||
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
|
||||
}
|
||||
|
||||
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
|
||||
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
|
||||
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
|
||||
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
|
||||
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
|
||||
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
|
||||
for (int i = 0; i < n_tokens_alloc; ++i) {
|
||||
batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
||||
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
||||
}
|
||||
batch.seq_id[n_tokens_alloc] = nullptr;
|
||||
batch->seq_id[n_tokens_alloc] = nullptr;
|
||||
|
||||
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
|
||||
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
|
||||
|
||||
return batch;
|
||||
}
|
||||
|
||||
void llama_batch_free(struct llama_batch batch) {
|
||||
if (batch.token) free(batch.token);
|
||||
if (batch.embd) free(batch.embd);
|
||||
if (batch.pos) free(batch.pos);
|
||||
if (batch.n_seq_id) free(batch.n_seq_id);
|
||||
if (batch.seq_id) {
|
||||
for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
|
||||
free(batch.seq_id[i]);
|
||||
}
|
||||
free(batch.seq_id);
|
||||
}
|
||||
if (batch.logits) free(batch.logits);
|
||||
struct llama_batch * llama_batch_init(int32_t n_tokens_alloc, int32_t n_seq_max) {
|
||||
return llama_batch_init_impl(n_tokens_alloc, 0, n_seq_max);
|
||||
}
|
||||
|
||||
struct llama_batch * llama_batch_init_from_embd(
|
||||
float * embd,
|
||||
size_t n_embd,
|
||||
int32_t pos0,
|
||||
int32_t seq_id) {
|
||||
struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1);
|
||||
memcpy(batch->embd, embd, n_embd * sizeof(float));
|
||||
for (int32_t i = 0; i < n_embd; i++) {
|
||||
batch->pos [i] = pos0 + i;
|
||||
batch->n_seq_id[i] = 1;
|
||||
batch->seq_id [i][0] = seq_id;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t llama_batch_add_text(
|
||||
struct llama_batch * batch,
|
||||
llama_token * tokens,
|
||||
size_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t * seq_ids,
|
||||
size_t n_seq_ids) {
|
||||
if (batch->n_tokens + n_tokens > batch->n_tokens) {
|
||||
return -1;
|
||||
}
|
||||
if (batch->embd) {
|
||||
return -2;
|
||||
}
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
batch->token [batch->n_tokens + i] = tokens[i];
|
||||
batch->pos [batch->n_tokens + i] = pos0 + i;
|
||||
batch->n_seq_id[batch->n_tokens + i] = n_seq_ids;
|
||||
for (int32_t j = 0; j < n_seq_ids; j++) {
|
||||
batch->seq_id[batch->n_tokens + i][j] = seq_ids[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int32_t llama_batch_add_text(
|
||||
struct llama_batch * batch,
|
||||
llama_token * tokens,
|
||||
size_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t seq_id) {
|
||||
std::array<int32_t, 1> seq_ids = { seq_id };
|
||||
return llama_batch_add_text(batch, tokens, n_tokens, pos0, seq_ids.data(), seq_ids.size());
|
||||
}
|
||||
|
||||
int32_t llama_batch_set_logits(
|
||||
struct llama_batch * batch,
|
||||
int32_t pos,
|
||||
int32_t seq_id) {
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
// find the token having seq_id
|
||||
for (int32_t j = 0; j < batch->n_seq_id[i]; j++) {
|
||||
if (batch->seq_id[i][j] == seq_id) {
|
||||
// found the sequence
|
||||
if (pos == -1 || pos == batch->pos[i]) {
|
||||
batch->logits[i] = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_batch_clear(struct llama_batch * batch) {
|
||||
batch->n_tokens = 0;
|
||||
}
|
||||
|
||||
void llama_batch_free(struct llama_batch * batch) {
|
||||
if (batch->token) free(batch->token);
|
||||
if (batch->embd) free(batch->embd);
|
||||
if (batch->pos) free(batch->pos);
|
||||
if (batch->n_seq_id) free(batch->n_seq_id);
|
||||
if (batch->seq_id) {
|
||||
for (int i = 0; batch->seq_id[i] != nullptr; ++i) {
|
||||
free(batch->seq_id[i]);
|
||||
}
|
||||
free(batch->seq_id);
|
||||
}
|
||||
if (batch->logits) free(batch->logits);
|
||||
delete batch;
|
||||
}
|
||||
|
@ -5,6 +5,30 @@
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
// Input data for llama_decode
|
||||
// A llama_batch object can contain input about one or many sequences
|
||||
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
|
||||
//
|
||||
// - token : the token ids of the input (used when embd is NULL)
|
||||
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||
// - pos : the positions of the respective token in the sequence
|
||||
// (if set to NULL, the token position will be tracked automatically by llama_decode)
|
||||
// - seq_id : the sequence to which the respective token belongs
|
||||
// (if set to NULL, the sequence ID will be assumed to be 0)
|
||||
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||
// (if set to NULL, only the logits for last token will be returned)
|
||||
//
|
||||
struct llama_batch {
|
||||
int32_t n_tokens;
|
||||
|
||||
llama_token * token;
|
||||
float * embd;
|
||||
llama_pos * pos;
|
||||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
};
|
||||
|
||||
// very similar to llama_batch,
|
||||
// but has more metadata about sequences
|
||||
struct llama_ubatch {
|
||||
|
Reference in New Issue
Block a user