first proposal for private llama_batch

This commit is contained in:
Xuan Son Nguyen
2025-02-14 00:48:12 +01:00
parent 04045bb842
commit 4ed4fe75ed
3 changed files with 173 additions and 51 deletions

View File

@ -309,10 +309,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
// interface implementation
//
struct llama_batch llama_batch_get_one(
struct llama_batch * llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens) {
return {
return new llama_batch{
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens,
/*embd =*/ nullptr,
@ -323,8 +323,8 @@ struct llama_batch llama_batch_get_one(
};
}
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
llama_batch batch = {
static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
llama_batch * batch = new llama_batch{
/*n_tokens =*/ 0,
/*tokens =*/ nullptr,
/*embd =*/ nullptr,
@ -335,34 +335,108 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
};
if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
} else {
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
}
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
for (int i = 0; i < n_tokens_alloc; ++i) {
batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
batch.seq_id[n_tokens_alloc] = nullptr;
batch->seq_id[n_tokens_alloc] = nullptr;
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
return batch;
}
void llama_batch_free(struct llama_batch batch) {
if (batch.token) free(batch.token);
if (batch.embd) free(batch.embd);
if (batch.pos) free(batch.pos);
if (batch.n_seq_id) free(batch.n_seq_id);
if (batch.seq_id) {
for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
free(batch.seq_id[i]);
}
free(batch.seq_id);
}
if (batch.logits) free(batch.logits);
struct llama_batch * llama_batch_init(int32_t n_tokens_alloc, int32_t n_seq_max) {
return llama_batch_init_impl(n_tokens_alloc, 0, n_seq_max);
}
struct llama_batch * llama_batch_init_from_embd(
float * embd,
size_t n_embd,
int32_t pos0,
int32_t seq_id) {
struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1);
memcpy(batch->embd, embd, n_embd * sizeof(float));
for (int32_t i = 0; i < n_embd; i++) {
batch->pos [i] = pos0 + i;
batch->n_seq_id[i] = 1;
batch->seq_id [i][0] = seq_id;
}
}
int32_t llama_batch_add_text(
struct llama_batch * batch,
llama_token * tokens,
size_t n_tokens,
int32_t pos0,
int32_t * seq_ids,
size_t n_seq_ids) {
if (batch->n_tokens + n_tokens > batch->n_tokens) {
return -1;
}
if (batch->embd) {
return -2;
}
for (int32_t i = 0; i < n_tokens; i++) {
batch->token [batch->n_tokens + i] = tokens[i];
batch->pos [batch->n_tokens + i] = pos0 + i;
batch->n_seq_id[batch->n_tokens + i] = n_seq_ids;
for (int32_t j = 0; j < n_seq_ids; j++) {
batch->seq_id[batch->n_tokens + i][j] = seq_ids[j];
}
}
}
int32_t llama_batch_add_text(
struct llama_batch * batch,
llama_token * tokens,
size_t n_tokens,
int32_t pos0,
int32_t seq_id) {
std::array<int32_t, 1> seq_ids = { seq_id };
return llama_batch_add_text(batch, tokens, n_tokens, pos0, seq_ids.data(), seq_ids.size());
}
int32_t llama_batch_set_logits(
struct llama_batch * batch,
int32_t pos,
int32_t seq_id) {
for (int32_t i = 0; i < batch->n_tokens; i++) {
// find the token having seq_id
for (int32_t j = 0; j < batch->n_seq_id[i]; j++) {
if (batch->seq_id[i][j] == seq_id) {
// found the sequence
if (pos == -1 || pos == batch->pos[i]) {
batch->logits[i] = true;
break;
}
}
}
}
}
void llama_batch_clear(struct llama_batch * batch) {
batch->n_tokens = 0;
}
void llama_batch_free(struct llama_batch * batch) {
if (batch->token) free(batch->token);
if (batch->embd) free(batch->embd);
if (batch->pos) free(batch->pos);
if (batch->n_seq_id) free(batch->n_seq_id);
if (batch->seq_id) {
for (int i = 0; batch->seq_id[i] != nullptr; ++i) {
free(batch->seq_id[i]);
}
free(batch->seq_id);
}
if (batch->logits) free(batch->logits);
delete batch;
}