mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 04:35:05 +00:00
rework, targeting llama-server
This commit is contained in:
@ -314,6 +314,8 @@ struct llama_batch * llama_batch_get_one(
|
||||
int32_t n_tokens) {
|
||||
return new llama_batch{
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*max_tokens =*/ n_tokens,
|
||||
/*is_view =*/ false,
|
||||
/*tokens =*/ tokens,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ nullptr,
|
||||
@ -326,6 +328,8 @@ struct llama_batch * llama_batch_get_one(
|
||||
static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
||||
llama_batch * batch = new llama_batch{
|
||||
/*n_tokens =*/ 0,
|
||||
/*max_tokens =*/ n_tokens_alloc,
|
||||
/*is_view =*/ false,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ nullptr,
|
||||
@ -364,50 +368,46 @@ struct llama_batch * llama_batch_init_from_embd(
|
||||
int32_t seq_id) {
|
||||
struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1);
|
||||
memcpy(batch->embd, embd, n_embd * sizeof(float));
|
||||
for (int32_t i = 0; i < n_embd; i++) {
|
||||
for (size_t i = 0; i < n_embd; i++) {
|
||||
batch->pos [i] = pos0 + i;
|
||||
batch->n_seq_id[i] = 1;
|
||||
batch->seq_id [i][0] = seq_id;
|
||||
}
|
||||
return batch;
|
||||
}
|
||||
|
||||
int32_t llama_batch_add_text(
|
||||
int32_t llama_batch_get_n_tokens(const struct llama_batch * batch) {
|
||||
return batch->n_tokens;
|
||||
}
|
||||
|
||||
int32_t llama_batch_add_text_token(
|
||||
struct llama_batch * batch,
|
||||
llama_token * tokens,
|
||||
size_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t * seq_ids,
|
||||
size_t n_seq_ids) {
|
||||
if (batch->n_tokens + n_tokens > batch->n_tokens) {
|
||||
return -1;
|
||||
llama_token token,
|
||||
llama_pos pos,
|
||||
const llama_seq_id * seq_ids,
|
||||
size_t n_seq_ids,
|
||||
float logits) {
|
||||
if (batch->n_tokens + 1 > batch->max_tokens) {
|
||||
return -1; // llama_batch size exceeded
|
||||
}
|
||||
if (batch->embd) {
|
||||
return -2;
|
||||
return -2; // embd is already set, cannot add text tokens
|
||||
}
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
batch->token [batch->n_tokens + i] = tokens[i];
|
||||
batch->pos [batch->n_tokens + i] = pos0 + i;
|
||||
batch->n_seq_id[batch->n_tokens + i] = n_seq_ids;
|
||||
for (int32_t j = 0; j < n_seq_ids; j++) {
|
||||
batch->seq_id[batch->n_tokens + i][j] = seq_ids[j];
|
||||
}
|
||||
batch->token [batch->n_tokens] = token;
|
||||
batch->pos [batch->n_tokens] = pos;
|
||||
batch->n_seq_id[batch->n_tokens] = n_seq_ids;
|
||||
for (size_t j = 0; j < n_seq_ids; j++) {
|
||||
batch->seq_id[batch->n_tokens][j] = seq_ids[j];
|
||||
}
|
||||
}
|
||||
|
||||
int32_t llama_batch_add_text(
|
||||
struct llama_batch * batch,
|
||||
llama_token * tokens,
|
||||
size_t n_tokens,
|
||||
int32_t pos0,
|
||||
int32_t seq_id) {
|
||||
std::array<int32_t, 1> seq_ids = { seq_id };
|
||||
return llama_batch_add_text(batch, tokens, n_tokens, pos0, seq_ids.data(), seq_ids.size());
|
||||
batch->logits [batch->n_tokens] = logits;
|
||||
batch->n_tokens++;
|
||||
return 0;
|
||||
}
|
||||
|
||||
int32_t llama_batch_set_logits(
|
||||
struct llama_batch * batch,
|
||||
int32_t pos,
|
||||
int32_t seq_id) {
|
||||
llama_pos pos,
|
||||
llama_seq_id seq_id) {
|
||||
for (int32_t i = 0; i < batch->n_tokens; i++) {
|
||||
// find the token having seq_id
|
||||
for (int32_t j = 0; j < batch->n_seq_id[i]; j++) {
|
||||
@ -415,28 +415,74 @@ int32_t llama_batch_set_logits(
|
||||
// found the sequence
|
||||
if (pos == -1 || pos == batch->pos[i]) {
|
||||
batch->logits[i] = true;
|
||||
break;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1; // not found
|
||||
}
|
||||
|
||||
int32_t llama_batch_set_logits_last(struct llama_batch * batch) {
|
||||
if (batch->n_tokens == 0) {
|
||||
return -1;
|
||||
}
|
||||
batch->logits[batch->n_tokens - 1] = true;
|
||||
return 0;
|
||||
}
|
||||
|
||||
void llama_batch_clear(struct llama_batch * batch) {
|
||||
batch->n_tokens = 0;
|
||||
}
|
||||
|
||||
void llama_batch_free(struct llama_batch * batch) {
|
||||
if (batch->token) free(batch->token);
|
||||
if (batch->embd) free(batch->embd);
|
||||
if (batch->pos) free(batch->pos);
|
||||
if (batch->n_seq_id) free(batch->n_seq_id);
|
||||
if (batch->seq_id) {
|
||||
for (int i = 0; batch->seq_id[i] != nullptr; ++i) {
|
||||
free(batch->seq_id[i]);
|
||||
}
|
||||
free(batch->seq_id);
|
||||
struct llama_batch * llama_batch_get_view(
|
||||
struct llama_batch * batch,
|
||||
int32_t offset,
|
||||
int32_t n_tokens) {
|
||||
if (batch->embd) {
|
||||
return nullptr; // not yet supported
|
||||
}
|
||||
llama_batch * batch_view = new llama_batch{
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*max_tokens =*/ n_tokens,
|
||||
/*is_view =*/ true,
|
||||
/*tokens =*/ batch->token + offset,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ batch->pos + offset,
|
||||
/*n_seq_id =*/ batch->n_seq_id + offset,
|
||||
/*seq_id =*/ batch->seq_id + offset,
|
||||
/*logits =*/ batch->logits + offset,
|
||||
};
|
||||
return batch_view;
|
||||
}
|
||||
|
||||
struct llama_batch_token_info llama_batch_get_token_info(
|
||||
struct llama_batch * batch,
|
||||
int32_t i) {
|
||||
GGML_ASSERT(i >= 0 && i < batch->n_tokens);
|
||||
return llama_batch_token_info{
|
||||
/*token =*/ batch->token [i],
|
||||
/*pos =*/ batch->pos [i],
|
||||
/*n_seq_id =*/ batch->n_seq_id[i],
|
||||
/*seq_id =*/ batch->seq_id [i],
|
||||
/*logits =*/ batch->logits [i],
|
||||
};
|
||||
}
|
||||
|
||||
void llama_batch_free(struct llama_batch * batch) {
|
||||
// do not free the members if it's a view
|
||||
if (!batch->is_view) {
|
||||
if (batch->token) free(batch->token);
|
||||
if (batch->embd) free(batch->embd);
|
||||
if (batch->pos) free(batch->pos);
|
||||
if (batch->n_seq_id) free(batch->n_seq_id);
|
||||
if (batch->seq_id) {
|
||||
for (int i = 0; batch->seq_id[i] != nullptr; ++i) {
|
||||
free(batch->seq_id[i]);
|
||||
}
|
||||
free(batch->seq_id);
|
||||
}
|
||||
if (batch->logits) free(batch->logits);
|
||||
}
|
||||
if (batch->logits) free(batch->logits);
|
||||
delete batch;
|
||||
}
|
||||
|
Reference in New Issue
Block a user