batch : rework llama_batch_allocr (#14153)

* batch : rework llama_batch_allocr

ggml-ci

* cont : move validation inside class

ggml-ci

* cont : move output counting to class

ggml-ci

* cont : minor

ggml-ci

* batch : add TODOs

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-13 13:47:55 +03:00
committed by GitHub
parent b7cc7745e3
commit 60c666347b
7 changed files with 162 additions and 106 deletions

View File

@@ -1,6 +1,7 @@
#include "llama-context.h"
#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-io.h"
#include "llama-memory.h"
#include "llama-mmap.h"
@@ -18,7 +19,8 @@
llama_context::llama_context(
const llama_model & model,
llama_context_params params) :
model(model) {
model(model),
batch_allocr(std::make_unique<llama_batch_allocr>()) {
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
t_start_us = model.t_start_us;
@@ -494,7 +496,7 @@ float * llama_context::get_logits() {
}
float * llama_context::get_logits_ith(int32_t i) {
int32_t j = -1;
int64_t j = -1;
try {
if (logits == nullptr) {
@@ -517,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) {
}
if (j >= n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}
return logits + j*model.vocab.n_tokens();
@@ -536,7 +538,7 @@ float * llama_context::get_embeddings() {
}
float * llama_context::get_embeddings_ith(int32_t i) {
int32_t j = -1;
int64_t j = -1;
try {
if (embd == nullptr) {
@@ -559,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
}
if (j >= n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs));
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}
return embd + j*model.hparams.n_embd;
@@ -719,40 +721,27 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
return res;
}
int llama_context::encode(llama_batch & inp_batch) {
if (inp_batch.n_tokens == 0) {
int llama_context::encode(const llama_batch & batch_inp) {
if (batch_inp.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
// note: during encode, we always pass the full sequence starting from pos = 0
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : 0)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
const llama_batch & batch = batch_allocr.batch;
const int32_t n_tokens = batch.n_tokens;
const llama_batch & batch = batch_allocr->get_batch();
const auto & hparams = model.hparams;
const uint32_t n_tokens = batch.n_tokens;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
// TODO: move the validation to the llama_batch_allocr
if (batch.token) {
for (int32_t i = 0; i < n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
}
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d] = %d > %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
throw -1;
}
}
}
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
if (t_compute_start_us == 0) {
t_compute_start_us = ggml_time_us();
@@ -763,6 +752,8 @@ int llama_context::encode(llama_batch & inp_batch) {
n_queued_tokens += n_tokens;
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
@@ -775,7 +766,7 @@ int llama_context::encode(llama_batch & inp_batch) {
return -2;
};
for (int32_t i = 0; i < n_tokens; ++i) {
for (uint32_t i = 0; i < n_tokens; ++i) {
output_ids[i] = i;
}
@@ -831,7 +822,8 @@ int llama_context::encode(llama_batch & inp_batch) {
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
for (int32_t i = 0; i < n_tokens; i++) {
// TODO: fix indexing [UBATCH_IDX]
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = ubatch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
@@ -846,6 +838,7 @@ int llama_context::encode(llama_batch & inp_batch) {
auto & embd_seq_out = embd_seq;
const uint32_t n_cls_out = hparams.n_cls_out;
// TODO: fix indexing [UBATCH_IDX]
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
@@ -878,13 +871,11 @@ int llama_context::encode(llama_batch & inp_batch) {
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
// remember the sequence ids used during the encoding - needed for cross attention later
// TODO: the seuqence indexing here is likely not correct in the general case
// probably works only for split_simple
cross.seq_ids_enc.resize(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
for (uint32_t i = 0; i < n_tokens; i++) {
cross.seq_ids_enc[i].clear();
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
llama_seq_id seq_id = ubatch.seq_id[i][s];
for (int s = 0; s < batch.n_seq_id[i]; s++) {
llama_seq_id seq_id = batch.seq_id[i][s];
cross.seq_ids_enc[i].insert(seq_id);
}
}
@@ -893,68 +884,44 @@ int llama_context::encode(llama_batch & inp_batch) {
return 0;
}
int llama_context::decode(llama_batch & inp_batch) {
int llama_context::decode(const llama_batch & batch_inp) {
if (!memory) {
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
return encode(inp_batch);
return encode(batch_inp);
}
if (inp_batch.n_tokens == 0) {
if (batch_inp.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
if (!inp_batch.pos) {
if (inp_batch.seq_id) {
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
if (!batch_allocr->init(batch_inp, model.vocab, batch_inp.pos ? -1 : memory->seq_pos_max(0) + 1)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
const llama_batch & batch = batch_allocr.batch;
const llama_batch & batch = batch_allocr->get_batch();
const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
const int32_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd;
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
const uint32_t n_tokens_all = batch.n_tokens;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
// TODO: move the validation to the llama_batch_allocr
if (batch.token) {
for (int64_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
return -1;
}
if (batch.seq_id && (batch.seq_id[i][0] < 0 || batch.seq_id[i][0] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%" PRId64 "] = %d >= %d\n", __func__, i, batch.seq_id[i][0], LLAMA_MAX_PARALLEL_SEQUENCES);
return -1;
}
}
}
// this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
int64_t n_outputs_all = 0;
// count outputs
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs_all += batch.logits[i] != 0;
}
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
if (embd_pooled) {
// require that all tokens are output
if (n_outputs_all != n_tokens_all) {
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n",
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
__func__, n_outputs_all, n_tokens_all);
return -1;
}
@@ -1024,7 +991,7 @@ int llama_context::decode(llama_batch & inp_batch) {
// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
return -2;
};
@@ -1063,6 +1030,7 @@ int llama_context::decode(llama_batch & inp_batch) {
pos_min[s] = std::numeric_limits<llama_pos>::max();
}
// TODO: fix sequence indexing
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
const auto & seq_id = ubatch.seq_id[i][0];
@@ -1176,14 +1144,14 @@ int llama_context::decode(llama_batch & inp_batch) {
n_outputs = n_outputs_all;
// set output mappings
{
if (n_outputs > 0) {
bool sorted_output = true;
auto & out_ids = mstate->out_ids();
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
for (int64_t i = 0; i < n_outputs_all; ++i) {
for (int64_t i = 0; i < n_outputs; ++i) {
int64_t out_id = out_ids[i];
output_ids[out_id] = i;
if (out_id != i) {
@@ -1195,20 +1163,22 @@ int llama_context::decode(llama_batch & inp_batch) {
// note: this is mostly relevant for recurrent models atm
if (!sorted_output) {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint32_t n_embd = model.hparams.n_embd;
const uint64_t n_embd = model.hparams.n_embd;
GGML_ASSERT((size_t) n_outputs == out_ids.size());
// TODO: is there something more efficient which also minimizes swaps?
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
for (int32_t i = 0; i < n_outputs - 1; ++i) {
int32_t j_min = i;
for (int32_t j = i + 1; j < n_outputs; ++j) {
for (uint32_t i = 0; i < n_outputs - 1; ++i) {
uint32_t j_min = i;
for (uint32_t j = i + 1; j < n_outputs; ++j) {
if (out_ids[j] < out_ids[j_min]) {
j_min = j;
}
}
if (j_min == i) { continue; }
if (j_min == i) {
continue;
}
std::swap(out_ids[i], out_ids[j_min]);
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
@@ -1221,8 +1191,10 @@ int llama_context::decode(llama_batch & inp_batch) {
}
}
}
std::fill(output_ids.begin(), output_ids.end(), -1);
for (int32_t i = 0; i < n_outputs; ++i) {
for (uint32_t i = 0; i < n_outputs; ++i) {
output_ids[out_ids[i]] = i;
}
}
@@ -1242,7 +1214,7 @@ int llama_context::decode(llama_batch & inp_batch) {
// output
//
int32_t llama_context::output_reserve(int32_t n_outputs) {
uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto & hparams = model.hparams;
const auto & vocab = model.vocab;
@@ -1308,8 +1280,7 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
// set all ids as invalid (negative)
std::fill(output_ids.begin(), output_ids.end(), -1);
this->n_outputs = 0;
this->n_outputs_max = n_outputs_max;
this->n_outputs = 0;
return n_outputs_max;
}
@@ -1800,14 +1771,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
std::vector<int32_t> w_output_pos;
GGML_ASSERT(n_outputs <= n_outputs_max);
w_output_pos.resize(n_outputs);
// build a more compact representation of the output ids
for (size_t i = 0; i < n_batch(); ++i) {
// map an output id to a position in the batch
int32_t pos = output_ids[i];
int64_t pos = output_ids[i];
if (pos >= 0) {
GGML_ASSERT(pos < n_outputs);
w_output_pos[pos] = i;
@@ -2082,7 +2051,7 @@ void llama_context::opt_epoch_iter(
embd_seq.clear();
int64_t n_outputs_all = n_tokens_all;
uint32_t n_outputs_all = n_tokens_all;
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
@@ -2092,7 +2061,7 @@ void llama_context::opt_epoch_iter(
// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
GGML_ABORT("TODO: handle this error");
};