context : remove batch_manager

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-14 16:10:55 +02:00
parent 131743ff4f
commit d5e8e1a2ba
4 changed files with 242 additions and 291 deletions

View File

@ -161,7 +161,7 @@ llama_context::llama_context(
// graph outputs buffer
{
// resized during inference when a batch uses more outputs
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
throw std::runtime_error("failed to reserve initial output buffer");
}
@ -747,11 +747,11 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
);
}
size_t llama_context::output_reserve(size_t n_outputs) {
int32_t llama_context::output_reserve(int32_t n_outputs) {
const auto & hparams = model.hparams;
const auto & vocab = model.vocab;
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
const int64_t n_outputs_max = std::max<int64_t>(n_outputs, cparams.n_seq_max);
const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
@ -817,7 +817,7 @@ size_t llama_context::output_reserve(size_t n_outputs) {
}
void llama_context::output_reorder() {
std::vector<size_t> & out_ids = sbatch.out_ids;
auto & out_ids = sbatch.out_ids;
if (!out_ids.empty()) {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint32_t n_embd = model.hparams.n_embd;
@ -1320,8 +1320,8 @@ size_t llama_context::state_get_data(llama_io_write_i & io) {
{
output_reorder();
const uint32_t n_outputs = this->n_outputs;
const auto & output_ids = this->output_ids;
const auto n_outputs = this->n_outputs;
const auto & output_ids = this->output_ids;
std::vector<int32_t> w_output_pos;
@ -1334,7 +1334,7 @@ size_t llama_context::state_get_data(llama_io_write_i & io) {
// map an output id to a position in the batch
int32_t pos = output_ids[i];
if (pos >= 0) {
GGML_ASSERT((uint32_t) pos < n_outputs);
GGML_ASSERT(pos < n_outputs);
w_output_pos[pos] = i;
}
}
@ -1386,15 +1386,15 @@ size_t llama_context::state_set_data(llama_io_read_i & io) {
// read output ids
{
std::vector<int32_t> output_pos;
uint32_t n_outputs;
auto n_outputs = this->n_outputs;
io.read_to(&n_outputs, sizeof(n_outputs));
if (n_outputs > output_reserve(n_outputs)) {
throw std::runtime_error("could not reserve outputs");
}
std::vector<int32_t> output_pos;
if (n_outputs) {
output_pos.resize(n_outputs);
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
@ -1543,228 +1543,6 @@ ggml_context_ptr llama_context_kv_self::graph_init() {
return llama_context::graph_init();
}
struct llama_context_kv_self::batch_manager {
batch_manager(llama_context_kv_self & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
const auto & model = lctx.model;
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
const auto & kv_self = lctx.kv_self;
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (int64_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
throw std::runtime_error("invalid token");
}
}
}
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
}
lctx.n_queued_tokens += n_tokens_all;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
lctx.embd_seq.clear();
// count outputs
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs_all += batch.logits[i] != 0;
}
} else if (lctx.logits_all || embd_pooled) {
n_outputs_all = n_tokens_all;
} else {
// keep last output only
n_outputs_all = 1;
}
const bool logits_all = n_outputs_all == n_tokens_all;
lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ logits_all);
}
~batch_manager() {
}
bool is_done() const {
return lctx.sbatch.n_tokens == 0;
}
llama_ubatch next() {
llama_ubatch ubatch = llama_ubatch();
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
const auto & n_ubatch = cparams.n_ubatch;
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
if (kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = lctx.sbatch.split_equal(n_ubatch);
}
} else {
ubatch = lctx.sbatch.split_simple(n_ubatch);
}
return ubatch;
}
bool prepare(const llama_ubatch & ubatch) {
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
const auto & batch = lctx.sbatch.batch;
const auto n_tokens_all = batch->n_tokens;
auto & kv_self = lctx.kv_self;
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
if (n_outputs_all == n_tokens_all) {
n_outputs_new = ubatch.n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
}
// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
}
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
lctx.kv_self_update();
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
kv_self.head = 0;
}
const auto slot_info = kv_self.find_slot(ubatch);
if (!slot_info) {
return false;
}
kv_slot_restorer.save(slot_info);
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = kv_self.get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
// reserve a worst case graph if needed
if (lctx.need_reserve) {
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
const auto & cparams = lctx.cparams;
const auto & model = lctx.model;
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf = lctx.build_graph(ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(lctx.sched.get());
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
lctx.need_reserve = false;
}
return true;
}
void restore() {
kv_slot_restorer.restore(lctx.kv_self);
}
void update(const llama_ubatch & ubatch) {
auto & kv_self = lctx.kv_self;
// update the kv ring buffer
{
kv_self.head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {
kv_self.head = 0;
}
}
}
void finalize() {
const auto & cparams = lctx.cparams;
auto & kv_self = lctx.kv_self;
// decide if we need to defrag the kv cache
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
// - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens
const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + lctx.get_ctx_padding(cparams))/float(kv_self.n)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
kv_self.defrag();
}
}
}
int64_t n_outputs_all = 0;
llama_context_kv_self & lctx;
const llama_batch & batch;
llama_kv_slot_restorer kv_slot_restorer;
};
std::unique_ptr<llama_context_kv_self::batch_manager> llama_context_kv_self::prepare_batch(const llama_batch & batch) {
return std::make_unique<batch_manager>(*this, batch);
}
int llama_context_kv_self::decode(llama_batch & inp_batch) {
is_encoding = false;
@ -1783,29 +1561,179 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
const auto & hparams = model.hparams;
const int32_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd;
// TODO: try catch
auto bman = prepare_batch(batch);
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
const auto n_outputs_all = bman->n_outputs_all;
// TODO: remove this stuff
class batch_guard {
public:
batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) {
}
~batch_guard() {
if (!is_done) {
kv_slot_restorer.restore();
}
}
void done() {
is_done = true;
}
void save(const llama_kv_cache_slot_info & slot_info) {
kv_slot_restorer.save(slot_info);
}
private:
bool is_done = false;
llama_kv_slot_restorer kv_slot_restorer;
};
batch_guard bg(kv_self);
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (int64_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
throw std::runtime_error("invalid token");
}
}
}
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
if (t_compute_start_us == 0) {
t_compute_start_us = ggml_time_us();
}
n_queued_tokens += n_tokens_all;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
embd_seq.clear();
int64_t n_outputs_all = 0;
// count outputs
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs_all += batch.logits[i] != 0;
}
} else if (logits_all || embd_pooled) {
n_outputs_all = n_tokens_all;
} else {
// keep last output only
n_outputs_all = 1;
}
const bool logits_all = n_outputs_all == n_tokens_all;
sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ logits_all);
// reserve output buffer
// TODO: move to batch manager?
if (output_reserve(bman->n_outputs_all) < (size_t) n_outputs_all) {
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
return -2;
};
int64_t n_outputs_prev = 0;
while (!bman->is_done()) {
llama_ubatch ubatch = bman->next();
while (sbatch.n_tokens > 0) {
llama_ubatch ubatch = llama_ubatch();
if (!bman->prepare(ubatch)) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
bman->restore();
return -3;
const auto & n_ubatch = cparams.n_ubatch;
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
if (kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = sbatch.split_equal(n_ubatch);
}
} else {
ubatch = sbatch.split_simple(n_ubatch);
}
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
if (n_outputs_all == n_tokens_all) {
n_outputs_new = ubatch.n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
}
// needs to happen before the graph is built
n_outputs = n_outputs_new;
}
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
kv_self_update();
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
kv_self.head = 0;
}
const auto slot_info = kv_self.find_slot(ubatch);
if (!slot_info) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
return -3;
}
bg.save(slot_info);
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = kv_self.get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
// reserve a worst case graph if needed
if (need_reserve) {
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf = build_graph(ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(sched.get());
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
need_reserve = false;
}
ggml_backend_sched_reset(sched.get());
@ -1844,7 +1772,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
if (compute_status != GGML_STATUS_SUCCESS) {
bman->restore();
switch (compute_status) {
case GGML_STATUS_ABORTED:
return 2;
@ -1856,7 +1783,15 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
}
}
bman->update(ubatch);
// update the kv ring buffer
{
kv_self.head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {
kv_self.head = 0;
}
}
// plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {
@ -1936,14 +1871,17 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
n_outputs_prev += n_outputs;
}
// finalize the batch processing
bg.done();
// set output mappings
{
bool sorted_output = true;
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
for (size_t i = 0; i < (size_t) n_outputs_all; ++i) {
size_t out_id = sbatch.out_ids[i];
for (int64_t i = 0; i < n_outputs_all; ++i) {
int64_t out_id = sbatch.out_ids[i];
output_ids[out_id] = i;
if (out_id != i) {
sorted_output = false;
@ -1961,7 +1899,19 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
// wait for the computation to finish (automatically done when obtaining the model output)
//synchronize();
bman->finalize();
// decide if we need to defrag the kv cache
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
// - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens
const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + get_ctx_padding(cparams))/float(kv_self.n)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
kv_self.defrag();
}
}
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
@ -1983,14 +1933,14 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens = batch.n_tokens;
const int32_t n_tokens = batch.n_tokens;
const auto & hparams = model.hparams;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (uint32_t i = 0; i < n_tokens; ++i) {
for (int32_t i = 0; i < n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
@ -1999,7 +1949,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
}
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
if (t_compute_start_us == 0) {
t_compute_start_us = ggml_time_us();
@ -2019,7 +1969,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
return -2;
};
for (uint32_t i = 0; i < n_tokens; ++i) {
for (int32_t i = 0; i < n_tokens; ++i) {
output_ids[i] = i;
}
@ -2087,7 +2037,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
// remember the sequence ids used during the encoding - needed for cross attention later
seq_ids_enc.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
for (int32_t i = 0; i < n_tokens; i++) {
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
llama_seq_id seq_id = ubatch.seq_id[i][s];
seq_ids_enc[i].insert(seq_id);
@ -2116,7 +2066,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
for (uint32_t i = 0; i < n_tokens; i++) {
for (int32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = ubatch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
@ -2448,7 +2398,7 @@ void llama_context_kv_self::kv_self_update() {
ggml_backend_sched_reset(sched.get());
auto ctx = graph_init();
auto ctx0 = ctx.get();
auto * ctx0 = ctx.get();
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
@ -2477,7 +2427,7 @@ void llama_context_kv_self::kv_self_update() {
ggml_backend_sched_reset(sched.get());
auto ctx = graph_init();
auto ctx0 = ctx.get();
auto * ctx0 = ctx.get();
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);