mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
move to llama_batch_ext
This commit is contained in:
@ -1610,20 +1610,29 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons
|
||||
// Batch utils
|
||||
//
|
||||
|
||||
void common_batch_clear(struct llama_batch * batch) {
|
||||
llama_batch_clear(batch);
|
||||
// DEPRECATED
|
||||
void common_batch_clear(struct llama_batch & batch) {
|
||||
batch.n_tokens = 0;
|
||||
}
|
||||
|
||||
// DEPRECATED
|
||||
void common_batch_add(
|
||||
struct llama_batch * batch,
|
||||
struct llama_batch & batch,
|
||||
llama_token id,
|
||||
llama_pos pos,
|
||||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits) {
|
||||
int32_t res = llama_batch_add_text_token(batch, id, pos, seq_ids.data(), seq_ids.size(), logits);
|
||||
if (res == -1) {
|
||||
LOG_ERR("%s: llama_batch size exceeded\n", __func__);
|
||||
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
|
||||
|
||||
batch.token [batch.n_tokens] = id;
|
||||
batch.pos [batch.n_tokens] = pos;
|
||||
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
|
||||
for (size_t i = 0; i < seq_ids.size(); ++i) {
|
||||
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
|
||||
}
|
||||
batch.logits [batch.n_tokens] = logits;
|
||||
|
||||
batch.n_tokens++;
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -554,10 +554,12 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap
|
||||
// Batch utils
|
||||
//
|
||||
|
||||
void common_batch_clear(struct llama_batch * batch);
|
||||
// DEPRECATED
|
||||
void common_batch_clear(struct llama_batch & batch);
|
||||
|
||||
// DEPRECATED
|
||||
void common_batch_add(
|
||||
struct llama_batch * batch,
|
||||
struct llama_batch & batch,
|
||||
llama_token id,
|
||||
llama_pos pos,
|
||||
const std::vector<llama_seq_id> & seq_ids,
|
||||
|
@ -13,7 +13,7 @@ struct common_speculative {
|
||||
struct llama_context * ctx;
|
||||
struct common_sampler * smpl;
|
||||
|
||||
llama_batch * batch;
|
||||
llama_batch batch;
|
||||
llama_tokens prompt;
|
||||
};
|
||||
|
||||
@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init(
|
||||
auto * result = new common_speculative {
|
||||
/* .ctx = */ ctx_dft,
|
||||
/* .smpl = */ nullptr,
|
||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 1),
|
||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
||||
/* .prompt = */ {},
|
||||
};
|
||||
|
||||
@ -215,7 +215,7 @@ llama_tokens common_speculative_gen_draft(
|
||||
}
|
||||
|
||||
// we should rarely end-up here during normal decoding
|
||||
if (llama_batch_get_n_tokens(batch) > 0) {
|
||||
if (batch.n_tokens > 0) {
|
||||
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
|
||||
|
||||
llama_decode(ctx, batch);
|
||||
|
@ -24,12 +24,12 @@ struct llama_adapter_lora_deleter {
|
||||
void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
|
||||
};
|
||||
|
||||
struct llama_batch_deleter {
|
||||
void operator()(llama_batch * batch) { llama_batch_free(batch); }
|
||||
struct llama_batch_ext_deleter {
|
||||
void operator()(llama_batch_ext * batch) { llama_batch_ext_free(batch); }
|
||||
};
|
||||
|
||||
typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
|
||||
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
|
||||
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
|
||||
typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr;
|
||||
typedef std::unique_ptr<llama_batch, llama_batch_deleter> llama_batch_ptr;
|
||||
typedef std::unique_ptr<llama_batch_ext, llama_batch_ext_deleter> llama_batch_ext_ptr;
|
||||
|
109
include/llama.h
109
include/llama.h
@ -231,9 +231,38 @@ extern "C" {
|
||||
|
||||
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
||||
|
||||
struct llama_batch;
|
||||
// Input data for llama_decode
|
||||
//
|
||||
// WARN: This struct is DEPRECATED and will be removed in the future, use llama_batch_ext instead
|
||||
//
|
||||
// A llama_batch object can contain input about one or many sequences
|
||||
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
|
||||
//
|
||||
// - token : the token ids of the input (used when embd is NULL)
|
||||
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
|
||||
// - pos : the positions of the respective token in the sequence
|
||||
// (if set to NULL, the token position will be tracked automatically by llama_decode)
|
||||
// - seq_id : the sequence to which the respective token belongs
|
||||
// (if set to NULL, the sequence ID will be assumed to be 0)
|
||||
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||
// (if set to NULL, only the logits for last token will be returned)
|
||||
//
|
||||
typedef struct llama_batch {
|
||||
int32_t n_tokens;
|
||||
|
||||
struct llama_batch_token_info {
|
||||
llama_token * token;
|
||||
float * embd;
|
||||
llama_pos * pos;
|
||||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
} llama_batch;
|
||||
|
||||
// Input data for llama_decode / llama_encode
|
||||
// It can contain text tokens and embeddings for one or many sequences
|
||||
struct llama_batch_ext;
|
||||
|
||||
struct llama_batch_ext_token_info {
|
||||
llama_token token;
|
||||
llama_pos pos;
|
||||
int32_t n_seq_id;
|
||||
@ -815,9 +844,9 @@ extern "C" {
|
||||
//
|
||||
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
||||
//
|
||||
LLAMA_API struct llama_batch * llama_batch_get_one(
|
||||
DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens);
|
||||
int32_t n_tokens), "use llama_batch_ext API instead");
|
||||
|
||||
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
||||
// Each token can be assigned up to n_seq_max sequence ids
|
||||
@ -826,30 +855,47 @@ extern "C" {
|
||||
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
|
||||
// The rest of the llama_batch members are allocated with size n_tokens
|
||||
// All members are left uninitialized
|
||||
// LLAMA_API struct llama_batch llama_batch_init(
|
||||
// int32_t n_tokens,
|
||||
// int32_t embd,
|
||||
// int32_t n_seq_max);
|
||||
DEPRECATED(LLAMA_API struct llama_batch llama_batch_init(
|
||||
int32_t n_tokens,
|
||||
int32_t embd,
|
||||
int32_t n_seq_max), "use llama_batch_ext API instead");
|
||||
|
||||
// Frees a batch of tokens allocated with llama_batch_init()
|
||||
DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch),
|
||||
"use llama_batch_ext API instead");
|
||||
|
||||
// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
|
||||
// Each token can be assigned up to n_seq_max sequence ids
|
||||
// The batch has to be freed with llama_batch_free()
|
||||
LLAMA_API struct llama_batch * llama_batch_init(
|
||||
// The batch has to be freed with llama_batch_ext_free()
|
||||
LLAMA_API struct llama_batch_ext * llama_batch_ext_init(
|
||||
int32_t n_tokens,
|
||||
int32_t n_seq_max);
|
||||
|
||||
// Same with llama_batch_init, but initializes the batch with the provided text tokens
|
||||
// First token will be at position pos0
|
||||
// The sequence ID will be fixed to seq_id
|
||||
// The batch has to be freed with llama_batch_ext_free()
|
||||
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t seq_id);
|
||||
|
||||
// Same with llama_batch_init, but initializes the batch with the provided raw embeddings
|
||||
LLAMA_API struct llama_batch * llama_batch_init_from_embd(
|
||||
// First token will be at position pos0
|
||||
// The sequence ID will be fixed to seq_id
|
||||
// The batch has to be freed with llama_batch_ext_free()
|
||||
LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd(
|
||||
float * embd,
|
||||
size_t n_embd,
|
||||
int32_t pos0,
|
||||
int32_t seq_id);
|
||||
|
||||
// Get the number of tokens in the batch
|
||||
LLAMA_API int32_t llama_batch_get_n_tokens(const struct llama_batch * batch);
|
||||
LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch);
|
||||
|
||||
LLAMA_API struct llama_batch_token_info llama_batch_get_token_info(
|
||||
struct llama_batch * batch,
|
||||
LLAMA_API struct llama_batch_ext_token_info llama_batch_ext_get_token_info(
|
||||
struct llama_batch_ext * batch,
|
||||
int32_t i);
|
||||
|
||||
// Add text tokens to the batch
|
||||
@ -857,8 +903,8 @@ extern "C" {
|
||||
// 0 : success
|
||||
// -1 : not enough space in the batch
|
||||
// -2 : embd is already set, cannot add text tokens
|
||||
LLAMA_API int32_t llama_batch_add_text_token(
|
||||
struct llama_batch * batch,
|
||||
LLAMA_API int32_t llama_batch_ext_add_text_token(
|
||||
struct llama_batch_ext * batch,
|
||||
llama_token token,
|
||||
llama_pos pos,
|
||||
const llama_seq_id * seq_ids,
|
||||
@ -868,43 +914,50 @@ extern "C" {
|
||||
// Set logits for the token in the ith sequence
|
||||
// If pos == -1, logits will be set for the all tokens
|
||||
// Returns -1 if the token is not in the batch
|
||||
LLAMA_API int32_t llama_batch_set_logits(
|
||||
struct llama_batch * batch,
|
||||
LLAMA_API int32_t llama_batch_ext_set_logits(
|
||||
struct llama_batch_ext * batch,
|
||||
llama_pos pos,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// Set logits for the last added token
|
||||
// Returns -1 if there is no tokens in the batch
|
||||
LLAMA_API int32_t llama_batch_set_logits_last(struct llama_batch * batch);
|
||||
LLAMA_API int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch);
|
||||
|
||||
// Get a "view" from a number of tokens offset
|
||||
// Return returned batch must be freed with llama_batch_free()
|
||||
LLAMA_API struct llama_batch * llama_batch_get_view(
|
||||
struct llama_batch * batch,
|
||||
LLAMA_API struct llama_batch_ext * llama_batch_ext_get_view(
|
||||
struct llama_batch_ext * batch,
|
||||
int32_t offset,
|
||||
int32_t n_tokens);
|
||||
|
||||
// Remove everything from the batch
|
||||
LLAMA_API void llama_batch_clear(struct llama_batch * batch);
|
||||
LLAMA_API void llama_batch_ext_clear(struct llama_batch_ext * batch);
|
||||
|
||||
// Frees a batch of tokens allocated with llama_batch_init()
|
||||
LLAMA_API void llama_batch_free(struct llama_batch * batch);
|
||||
// Frees a batch of tokens allocated with llama_batch_ext_init()
|
||||
// If this is a view, the original batch is not freed
|
||||
LLAMA_API void llama_batch_ext_free(struct llama_batch_ext * batch);
|
||||
|
||||
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
|
||||
// Stores the encoder output internally for later use by the decoder cross-attention layers.
|
||||
// 0 - success
|
||||
// < 0 - error. the KV cache state is restored to the state before this call
|
||||
LLAMA_API int32_t llama_encode(
|
||||
DEPRECATED(LLAMA_API int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch * batch);
|
||||
struct llama_batch batch), "use llama_batch_ext API instead");
|
||||
LLAMA_API int32_t llama_text_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch_ext * batch);
|
||||
|
||||
// Positive return values does not mean a fatal error, but rather a warning.
|
||||
// 0 - success
|
||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||
// < 0 - error. the KV cache state is restored to the state before this call
|
||||
LLAMA_API int32_t llama_decode(
|
||||
DEPRECATED(LLAMA_API int32_t llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch * batch);
|
||||
struct llama_batch batch), "use llama_batch_ext API instead");
|
||||
LLAMA_API int32_t llama_text_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch_ext * batch);
|
||||
|
||||
// Set the number of threads used for decoding
|
||||
// n_threads is the number of threads used for generation (single token)
|
||||
|
@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
void llama_sbatch::from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
GGML_ASSERT(batch.n_tokens >= 0);
|
||||
this->batch = &batch;
|
||||
this->n_embd = n_embd;
|
||||
@ -273,49 +273,61 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
|
||||
);
|
||||
}
|
||||
|
||||
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
|
||||
batch = in_batch;
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
if (!batch.pos) {
|
||||
pos.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
llama_batch_allocr::llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0) {
|
||||
batch = new llama_batch_ext{
|
||||
/*n_tokens =*/ in_batch.n_tokens,
|
||||
/*max_tokens =*/ in_batch.n_tokens,
|
||||
/*is_view =*/ false,
|
||||
/*tokens =*/ in_batch.token,
|
||||
/*embd =*/ in_batch.embd,
|
||||
/*pos =*/ in_batch.pos,
|
||||
/*n_seq_id =*/ in_batch.n_seq_id,
|
||||
/*seq_id =*/ in_batch.seq_id,
|
||||
/*logits =*/ in_batch.logits,
|
||||
};
|
||||
GGML_ASSERT(batch->n_tokens > 0);
|
||||
if (!in_batch.pos) {
|
||||
pos.resize(batch->n_tokens);
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
pos[i] = i + p0;
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
batch->pos = pos.data();
|
||||
}
|
||||
if (!batch.n_seq_id) {
|
||||
n_seq_id.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch->n_seq_id) {
|
||||
n_seq_id.resize(batch->n_tokens);
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
n_seq_id[i] = seq_id_0.size();
|
||||
}
|
||||
batch.n_seq_id = n_seq_id.data();
|
||||
batch->n_seq_id = n_seq_id.data();
|
||||
}
|
||||
if (!batch.seq_id) {
|
||||
seq_id.resize(batch.n_tokens + 1);
|
||||
seq_id[batch.n_tokens] = NULL;
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch->seq_id) {
|
||||
seq_id.resize(batch->n_tokens + 1);
|
||||
seq_id[batch->n_tokens] = NULL;
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
seq_id[i] = seq_id_0.data();
|
||||
}
|
||||
batch.seq_id = seq_id.data();
|
||||
batch->seq_id = seq_id.data();
|
||||
}
|
||||
if (!batch.logits) {
|
||||
logits.resize(batch.n_tokens);
|
||||
if (!batch->logits) {
|
||||
logits.resize(batch->n_tokens);
|
||||
logits[logits.size() - 1] = true;
|
||||
batch.logits = logits.data();
|
||||
batch->logits = logits.data();
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_allocr::~llama_batch_allocr() {
|
||||
delete batch;
|
||||
}
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
|
||||
struct llama_batch * llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens) {
|
||||
return new llama_batch{
|
||||
struct llama_batch llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens) {
|
||||
return llama_batch{
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*max_tokens =*/ n_tokens,
|
||||
/*is_view =*/ false,
|
||||
/*tokens =*/ tokens,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ nullptr,
|
||||
@ -325,8 +337,20 @@ struct llama_batch * llama_batch_get_one(
|
||||
};
|
||||
}
|
||||
|
||||
static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
||||
llama_batch * batch = new llama_batch{
|
||||
struct llama_batch_ext * llama_batch_ext_init_from_text(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t seq_id) {
|
||||
llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1);
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
llama_batch_ext_add_text_token(batch, tokens[i], pos0 + i, &seq_id, 1, false);
|
||||
}
|
||||
return batch;
|
||||
}
|
||||
|
||||
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
||||
llama_batch_ext * batch = new llama_batch_ext{
|
||||
/*n_tokens =*/ 0,
|
||||
/*max_tokens =*/ n_tokens_alloc,
|
||||
/*is_view =*/ false,
|
||||
@ -357,16 +381,16 @@ static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_
|
||||
return batch;
|
||||
}
|
||||
|
||||
struct llama_batch * llama_batch_init(int32_t n_tokens_alloc, int32_t n_seq_max) {
|
||||
return llama_batch_init_impl(n_tokens_alloc, 0, n_seq_max);
|
||||
struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_seq_max) {
|
||||
return llama_batch_ext_init_impl(n_tokens_alloc, 0, n_seq_max);
|
||||
}
|
||||
|
||||
struct llama_batch * llama_batch_init_from_embd(
|
||||
struct llama_batch_ext * llama_batch_ext_init_from_embd(
|
||||
float * embd,
|
||||
size_t n_embd,
|
||||
int32_t pos0,
|
||||
int32_t seq_id) {
|
||||
struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1);
|
||||
struct llama_batch_ext * batch = llama_batch_ext_init_impl(0, n_embd, 1);
|
||||
memcpy(batch->embd, embd, n_embd * sizeof(float));
|
||||
for (size_t i = 0; i < n_embd; i++) {
|
||||
batch->pos [i] = pos0 + i;
|
||||
@ -376,12 +400,12 @@ struct llama_batch * llama_batch_init_from_embd(
|
||||
return batch;
|
||||
}
|
||||
|
||||
int32_t llama_batch_get_n_tokens(const struct llama_batch * batch) {
|
||||
int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) {
|
||||
return batch->n_tokens;
|
||||
}
|
||||
|
||||
int32_t llama_batch_add_text_token(
|
||||
struct llama_batch * batch,
|
||||
int32_t llama_batch_ext_add_text_token(
|
||||
struct llama_batch_ext * batch,
|
||||
llama_token token,
|
||||
llama_pos pos,
|
||||
const llama_seq_id * seq_ids,
|
||||
@ -404,8 +428,8 @@ int32_t llama_batch_add_text_token(
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t llama_batch_set_logits(
|
||||
struct llama_batch * batch,
|
||||
int32_t llama_batch_ext_set_logits(
|
||||
struct llama_batch_ext * batch,
|
||||
llama_pos pos,
|
||||
llama_seq_id seq_id) {
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
@ -423,7 +447,7 @@ int32_t llama_batch_set_logits(
|
||||
return -1; // not found
|
||||
}
|
||||
|
||||
int32_t llama_batch_set_logits_last(struct llama_batch * batch) {
|
||||
int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch) {
|
||||
if (batch->n_tokens == 0) {
|
||||
return -1;
|
||||
}
|
||||
@ -431,18 +455,18 @@ int32_t llama_batch_set_logits_last(struct llama_batch * batch) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
void llama_batch_clear(struct llama_batch * batch) {
|
||||
void llama_batch_ext_clear(struct llama_batch_ext * batch) {
|
||||
batch->n_tokens = 0;
|
||||
}
|
||||
|
||||
struct llama_batch * llama_batch_get_view(
|
||||
struct llama_batch * batch,
|
||||
struct llama_batch_ext * llama_batch_ext_get_view(
|
||||
struct llama_batch_ext * batch,
|
||||
int32_t offset,
|
||||
int32_t n_tokens) {
|
||||
if (batch->embd) {
|
||||
return nullptr; // not yet supported
|
||||
}
|
||||
llama_batch * batch_view = new llama_batch{
|
||||
llama_batch_ext * batch_view = new llama_batch_ext{
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*max_tokens =*/ n_tokens,
|
||||
/*is_view =*/ true,
|
||||
@ -456,11 +480,11 @@ struct llama_batch * llama_batch_get_view(
|
||||
return batch_view;
|
||||
}
|
||||
|
||||
struct llama_batch_token_info llama_batch_get_token_info(
|
||||
struct llama_batch * batch,
|
||||
struct llama_batch_ext_token_info llama_batch_ext_get_token_info(
|
||||
struct llama_batch_ext * batch,
|
||||
int32_t i) {
|
||||
GGML_ASSERT(i >= 0 && i < batch->n_tokens);
|
||||
return llama_batch_token_info{
|
||||
return llama_batch_ext_token_info{
|
||||
/*token =*/ batch->token [i],
|
||||
/*pos =*/ batch->pos [i],
|
||||
/*n_seq_id =*/ batch->n_seq_id[i],
|
||||
@ -469,7 +493,7 @@ struct llama_batch_token_info llama_batch_get_token_info(
|
||||
};
|
||||
}
|
||||
|
||||
void llama_batch_free(struct llama_batch * batch) {
|
||||
void llama_batch_ext_free(struct llama_batch_ext * batch) {
|
||||
// do not free the members if it's a view
|
||||
if (!batch->is_view) {
|
||||
if (batch->token) free(batch->token);
|
||||
|
@ -5,8 +5,8 @@
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
// Input data for llama_decode
|
||||
// A llama_batch object can contain input about one or many sequences
|
||||
// Input data for llama_decode / llama_encode
|
||||
// A llama_batch_ext object can contain input about one or many sequences
|
||||
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
|
||||
//
|
||||
// - token : the token ids of the input (used when embd is NULL)
|
||||
@ -18,7 +18,7 @@
|
||||
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
|
||||
// (if set to NULL, only the logits for last token will be returned)
|
||||
//
|
||||
struct llama_batch {
|
||||
struct llama_batch_ext {
|
||||
int32_t n_tokens;
|
||||
int32_t max_tokens;
|
||||
bool is_view;
|
||||
@ -73,7 +73,7 @@ struct llama_sbatch {
|
||||
std::vector<size_t> out_ids;
|
||||
std::vector<llama_sbatch_seq> seq;
|
||||
|
||||
const llama_batch * batch = nullptr;
|
||||
const llama_batch_ext * batch = nullptr;
|
||||
|
||||
// buffers for the ubatch
|
||||
std::vector<llama_token> ubatch_token;
|
||||
@ -96,12 +96,12 @@ struct llama_sbatch {
|
||||
// sequence-wise split
|
||||
llama_ubatch split_seq(size_t n_ubatch);
|
||||
|
||||
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
||||
void from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
|
||||
};
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
struct llama_batch_allocr {
|
||||
struct llama_batch batch;
|
||||
struct llama_batch_ext * batch;
|
||||
|
||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
||||
std::vector<llama_pos> pos;
|
||||
@ -110,5 +110,7 @@ struct llama_batch_allocr {
|
||||
std::vector<int8_t> logits;
|
||||
|
||||
// optionally fulfill the batch returned by llama_batch_get_one
|
||||
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
|
||||
llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0);
|
||||
|
||||
~llama_batch_allocr();
|
||||
};
|
||||
|
@ -8445,7 +8445,7 @@ static enum ggml_status llama_graph_compute(
|
||||
|
||||
static int llama_prepare_sbatch(
|
||||
llama_context & lctx,
|
||||
const llama_batch & batch,
|
||||
const llama_batch_ext & batch,
|
||||
uint32_t & n_outputs) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
@ -8585,7 +8585,7 @@ static int llama_prepare_ubatch(
|
||||
//
|
||||
static int llama_decode_impl(
|
||||
llama_context & lctx,
|
||||
llama_batch inp_batch) {
|
||||
llama_batch_ext & inp_batch) {
|
||||
|
||||
lctx.is_encoding = false;
|
||||
|
||||
@ -8594,10 +8594,6 @@ static int llama_decode_impl(
|
||||
return -1;
|
||||
}
|
||||
|
||||
// temporarily allocate memory for the input batch if needed
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
|
||||
const auto & model = lctx.model;
|
||||
const auto & vocab = model.vocab;
|
||||
const auto & hparams = model.hparams;
|
||||
@ -8616,7 +8612,7 @@ static int llama_decode_impl(
|
||||
uint32_t n_outputs_prev = 0;
|
||||
|
||||
{
|
||||
const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
|
||||
const int ret = llama_prepare_sbatch(lctx, inp_batch, n_outputs);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
@ -8625,7 +8621,7 @@ static int llama_decode_impl(
|
||||
while (lctx.sbatch.n_tokens > 0) {
|
||||
llama_ubatch ubatch;
|
||||
{
|
||||
const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
|
||||
const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, inp_batch.n_tokens);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
@ -8832,7 +8828,7 @@ static int llama_decode_impl(
|
||||
//
|
||||
static int llama_encode_impl(
|
||||
llama_context & lctx,
|
||||
llama_batch inp_batch) {
|
||||
llama_batch_ext & inp_batch) {
|
||||
|
||||
lctx.is_encoding = true;
|
||||
|
||||
@ -8841,22 +8837,18 @@ static int llama_encode_impl(
|
||||
return -1;
|
||||
}
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
const uint32_t n_tokens = inp_batch.n_tokens;
|
||||
|
||||
const auto & model = lctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
GGML_ASSERT((!inp_batch.token && inp_batch.embd) || (inp_batch.token && !inp_batch.embd)); // NOLINT
|
||||
|
||||
if (batch.token) {
|
||||
if (inp_batch.token) {
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||
if (inp_batch.token[i] < 0 || (uint32_t) inp_batch.token[i] >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, inp_batch.token[i]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
@ -8873,7 +8865,7 @@ static int llama_encode_impl(
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
||||
lctx.sbatch.from_batch(inp_batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
||||
|
||||
const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
|
||||
|
||||
@ -9976,9 +9968,32 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) {
|
||||
|
||||
///
|
||||
|
||||
|
||||
// DEPRECATED
|
||||
int32_t llama_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch * batch) {
|
||||
struct llama_batch batch) {
|
||||
// temporarily allocate memory for the input batch if needed
|
||||
// also convert llama_batch to llama_batch_ext
|
||||
llama_batch_allocr batch_allocr(batch, batch.pos ? -1 : ctx->kv_self.max_pos() + 1);
|
||||
llama_batch_ext * batch_ext = batch_allocr.batch;
|
||||
return llama_text_encode(ctx, batch_ext);
|
||||
}
|
||||
|
||||
// DEPRECATED
|
||||
int32_t llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch) {
|
||||
// temporarily allocate memory for the input batch if needed
|
||||
// also convert llama_batch to llama_batch_ext
|
||||
llama_batch_allocr batch_allocr(batch, batch.pos ? -1 : ctx->kv_self.max_pos() + 1);
|
||||
llama_batch_ext * batch_ext = batch_allocr.batch;
|
||||
return llama_text_decode(ctx, batch_ext);
|
||||
}
|
||||
|
||||
int32_t llama_text_encode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch_ext * batch) {
|
||||
const int ret = llama_encode_impl(*ctx, *batch);
|
||||
if (ret != 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
|
||||
@ -9987,9 +10002,9 @@ int32_t llama_encode(
|
||||
return ret;
|
||||
}
|
||||
|
||||
int32_t llama_decode(
|
||||
int32_t llama_text_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch * batch) {
|
||||
struct llama_batch_ext * batch) {
|
||||
const int ret = llama_decode_impl(*ctx, *batch);
|
||||
if (ret != 0) {
|
||||
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
|
||||
|
Reference in New Issue
Block a user