mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 04:35:05 +00:00
move to llama_batch_ext
This commit is contained in:
@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
void llama_sbatch::from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split, bool logits_all) {
|
||||
GGML_ASSERT(batch.n_tokens >= 0);
|
||||
this->batch = &batch;
|
||||
this->n_embd = n_embd;
|
||||
@ -273,49 +273,61 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
|
||||
);
|
||||
}
|
||||
|
||||
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
|
||||
batch = in_batch;
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
if (!batch.pos) {
|
||||
pos.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
llama_batch_allocr::llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0) {
|
||||
batch = new llama_batch_ext{
|
||||
/*n_tokens =*/ in_batch.n_tokens,
|
||||
/*max_tokens =*/ in_batch.n_tokens,
|
||||
/*is_view =*/ false,
|
||||
/*tokens =*/ in_batch.token,
|
||||
/*embd =*/ in_batch.embd,
|
||||
/*pos =*/ in_batch.pos,
|
||||
/*n_seq_id =*/ in_batch.n_seq_id,
|
||||
/*seq_id =*/ in_batch.seq_id,
|
||||
/*logits =*/ in_batch.logits,
|
||||
};
|
||||
GGML_ASSERT(batch->n_tokens > 0);
|
||||
if (!in_batch.pos) {
|
||||
pos.resize(batch->n_tokens);
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
pos[i] = i + p0;
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
batch->pos = pos.data();
|
||||
}
|
||||
if (!batch.n_seq_id) {
|
||||
n_seq_id.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch->n_seq_id) {
|
||||
n_seq_id.resize(batch->n_tokens);
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
n_seq_id[i] = seq_id_0.size();
|
||||
}
|
||||
batch.n_seq_id = n_seq_id.data();
|
||||
batch->n_seq_id = n_seq_id.data();
|
||||
}
|
||||
if (!batch.seq_id) {
|
||||
seq_id.resize(batch.n_tokens + 1);
|
||||
seq_id[batch.n_tokens] = NULL;
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch->seq_id) {
|
||||
seq_id.resize(batch->n_tokens + 1);
|
||||
seq_id[batch->n_tokens] = NULL;
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
seq_id[i] = seq_id_0.data();
|
||||
}
|
||||
batch.seq_id = seq_id.data();
|
||||
batch->seq_id = seq_id.data();
|
||||
}
|
||||
if (!batch.logits) {
|
||||
logits.resize(batch.n_tokens);
|
||||
if (!batch->logits) {
|
||||
logits.resize(batch->n_tokens);
|
||||
logits[logits.size() - 1] = true;
|
||||
batch.logits = logits.data();
|
||||
batch->logits = logits.data();
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_allocr::~llama_batch_allocr() {
|
||||
delete batch;
|
||||
}
|
||||
|
||||
//
|
||||
// interface implementation
|
||||
//
|
||||
|
||||
struct llama_batch * llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens) {
|
||||
return new llama_batch{
|
||||
struct llama_batch llama_batch_get_one(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens) {
|
||||
return llama_batch{
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*max_tokens =*/ n_tokens,
|
||||
/*is_view =*/ false,
|
||||
/*tokens =*/ tokens,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ nullptr,
|
||||
@ -325,8 +337,20 @@ struct llama_batch * llama_batch_get_one(
|
||||
};
|
||||
}
|
||||
|
||||
static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
||||
llama_batch * batch = new llama_batch{
|
||||
struct llama_batch_ext * llama_batch_ext_init_from_text(
|
||||
llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t seq_id) {
|
||||
llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1);
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
llama_batch_ext_add_text_token(batch, tokens[i], pos0 + i, &seq_id, 1, false);
|
||||
}
|
||||
return batch;
|
||||
}
|
||||
|
||||
static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
||||
llama_batch_ext * batch = new llama_batch_ext{
|
||||
/*n_tokens =*/ 0,
|
||||
/*max_tokens =*/ n_tokens_alloc,
|
||||
/*is_view =*/ false,
|
||||
@ -357,16 +381,16 @@ static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_
|
||||
return batch;
|
||||
}
|
||||
|
||||
struct llama_batch * llama_batch_init(int32_t n_tokens_alloc, int32_t n_seq_max) {
|
||||
return llama_batch_init_impl(n_tokens_alloc, 0, n_seq_max);
|
||||
struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_seq_max) {
|
||||
return llama_batch_ext_init_impl(n_tokens_alloc, 0, n_seq_max);
|
||||
}
|
||||
|
||||
struct llama_batch * llama_batch_init_from_embd(
|
||||
struct llama_batch_ext * llama_batch_ext_init_from_embd(
|
||||
float * embd,
|
||||
size_t n_embd,
|
||||
int32_t pos0,
|
||||
int32_t seq_id) {
|
||||
struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1);
|
||||
struct llama_batch_ext * batch = llama_batch_ext_init_impl(0, n_embd, 1);
|
||||
memcpy(batch->embd, embd, n_embd * sizeof(float));
|
||||
for (size_t i = 0; i < n_embd; i++) {
|
||||
batch->pos [i] = pos0 + i;
|
||||
@ -376,12 +400,12 @@ struct llama_batch * llama_batch_init_from_embd(
|
||||
return batch;
|
||||
}
|
||||
|
||||
int32_t llama_batch_get_n_tokens(const struct llama_batch * batch) {
|
||||
int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) {
|
||||
return batch->n_tokens;
|
||||
}
|
||||
|
||||
int32_t llama_batch_add_text_token(
|
||||
struct llama_batch * batch,
|
||||
int32_t llama_batch_ext_add_text_token(
|
||||
struct llama_batch_ext * batch,
|
||||
llama_token token,
|
||||
llama_pos pos,
|
||||
const llama_seq_id * seq_ids,
|
||||
@ -404,8 +428,8 @@ int32_t llama_batch_add_text_token(
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t llama_batch_set_logits(
|
||||
struct llama_batch * batch,
|
||||
int32_t llama_batch_ext_set_logits(
|
||||
struct llama_batch_ext * batch,
|
||||
llama_pos pos,
|
||||
llama_seq_id seq_id) {
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
@ -423,7 +447,7 @@ int32_t llama_batch_set_logits(
|
||||
return -1; // not found
|
||||
}
|
||||
|
||||
int32_t llama_batch_set_logits_last(struct llama_batch * batch) {
|
||||
int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch) {
|
||||
if (batch->n_tokens == 0) {
|
||||
return -1;
|
||||
}
|
||||
@ -431,18 +455,18 @@ int32_t llama_batch_set_logits_last(struct llama_batch * batch) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
void llama_batch_clear(struct llama_batch * batch) {
|
||||
void llama_batch_ext_clear(struct llama_batch_ext * batch) {
|
||||
batch->n_tokens = 0;
|
||||
}
|
||||
|
||||
struct llama_batch * llama_batch_get_view(
|
||||
struct llama_batch * batch,
|
||||
struct llama_batch_ext * llama_batch_ext_get_view(
|
||||
struct llama_batch_ext * batch,
|
||||
int32_t offset,
|
||||
int32_t n_tokens) {
|
||||
if (batch->embd) {
|
||||
return nullptr; // not yet supported
|
||||
}
|
||||
llama_batch * batch_view = new llama_batch{
|
||||
llama_batch_ext * batch_view = new llama_batch_ext{
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*max_tokens =*/ n_tokens,
|
||||
/*is_view =*/ true,
|
||||
@ -456,11 +480,11 @@ struct llama_batch * llama_batch_get_view(
|
||||
return batch_view;
|
||||
}
|
||||
|
||||
struct llama_batch_token_info llama_batch_get_token_info(
|
||||
struct llama_batch * batch,
|
||||
struct llama_batch_ext_token_info llama_batch_ext_get_token_info(
|
||||
struct llama_batch_ext * batch,
|
||||
int32_t i) {
|
||||
GGML_ASSERT(i >= 0 && i < batch->n_tokens);
|
||||
return llama_batch_token_info{
|
||||
return llama_batch_ext_token_info{
|
||||
/*token =*/ batch->token [i],
|
||||
/*pos =*/ batch->pos [i],
|
||||
/*n_seq_id =*/ batch->n_seq_id[i],
|
||||
@ -469,7 +493,7 @@ struct llama_batch_token_info llama_batch_get_token_info(
|
||||
};
|
||||
}
|
||||
|
||||
void llama_batch_free(struct llama_batch * batch) {
|
||||
void llama_batch_ext_free(struct llama_batch_ext * batch) {
|
||||
// do not free the members if it's a view
|
||||
if (!batch->is_view) {
|
||||
if (batch->token) free(batch->token);
|
||||
|
Reference in New Issue
Block a user