context : introduce llama_batch_manager

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-01-17 20:30:16 +02:00
parent cb8f2095c6
commit 99422dfa3f
3 changed files with 162 additions and 73 deletions

View File

@ -32,6 +32,132 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
return relative_bucket;
}
struct llama_batch_manager : public llama_batch_manager_i {
llama_batch_manager(llama_context & lctx, const llama_batch & batch, bool logits_all) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
const auto & hparams = lctx.model.hparams;
const auto & n_embd = hparams.n_embd;
const auto & kv_self = lctx.kv_self;
lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ logits_all);
}
~llama_batch_manager() override {
}
virtual llama_ubatch next() override {
ubatch = llama_ubatch();
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
const auto & n_ubatch = cparams.n_ubatch;
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
if (kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = lctx.sbatch.split_equal(n_ubatch);
}
} else {
ubatch = lctx.sbatch.split_simple(n_ubatch);
}
return ubatch;
}
virtual bool prepare() override {
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
auto & kv_self = lctx.kv_self;
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
llama_kv_self_update(&lctx);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
kv_self.head = 0;
}
const auto slot_info = kv_self.find_slot(ubatch);
if (!slot_info) {
return false;
}
kv_slot_restorer.save(slot_info);
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = kv_self.get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
return true;
}
virtual void restore() override {
kv_slot_restorer.restore(lctx.kv_self);
}
virtual void update() override {
auto & kv_self = lctx.kv_self;
// update the kv ring buffer
{
kv_self.head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {
kv_self.head = 0;
}
}
}
virtual void finalize() override {
const auto & cparams = lctx.cparams;
auto & kv_self = lctx.kv_self;
// decide if we need to defrag the kv cache
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
kv_self.defrag();
}
}
}
llama_context & lctx;
const llama_batch & batch;
llama_ubatch ubatch;
llama_kv_slot_restorer kv_slot_restorer;
};
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch, bool logits_all) {
return std::make_unique<llama_batch_manager>(*this, batch, logits_all);
}
enum ggml_status llama_context::compute_graph(
ggml_cgraph * graph,
bool batched) {
@ -59,7 +185,6 @@ enum ggml_status llama_context::compute_graph(
return status;
}
llama_pos llama_context::pos_max() const {
return kv_self.pos_max();
}
@ -94,9 +219,6 @@ void llama_context::prepare_k_shift() {
void llama_context::prepare_defrag() {
}
void llama_context::prepare_decode(const llama_ubatch & /*ubatch*/) {
}
// llama input
void llama_context::set_inputs(const llama_ubatch & ubatch) {