mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-18 08:37:43 +00:00
context : introduce llama_batch_manager
ggml-ci
This commit is contained in:
@ -32,6 +32,132 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
|
||||
return relative_bucket;
|
||||
}
|
||||
|
||||
struct llama_batch_manager : public llama_batch_manager_i {
|
||||
llama_batch_manager(llama_context & lctx, const llama_batch & batch, bool logits_all) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
const auto & n_embd = hparams.n_embd;
|
||||
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
lctx.sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !kv_self.recurrent,
|
||||
/* logits_all */ logits_all);
|
||||
}
|
||||
|
||||
~llama_batch_manager() override {
|
||||
}
|
||||
|
||||
virtual llama_ubatch next() override {
|
||||
ubatch = llama_ubatch();
|
||||
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
const auto & n_ubatch = cparams.n_ubatch;
|
||||
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
if (kv_self.recurrent) {
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
||||
} else {
|
||||
// recurrent model architectures are easier to implement
|
||||
// with equal-length sequences
|
||||
ubatch = lctx.sbatch.split_equal(n_ubatch);
|
||||
}
|
||||
} else {
|
||||
ubatch = lctx.sbatch.split_simple(n_ubatch);
|
||||
}
|
||||
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
virtual bool prepare() override {
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
llama_kv_self_update(&lctx);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
const auto slot_info = kv_self.find_slot(ubatch);
|
||||
if (!slot_info) {
|
||||
return false;
|
||||
}
|
||||
|
||||
kv_slot_restorer.save(slot_info);
|
||||
|
||||
if (!kv_self.recurrent) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
const uint32_t pad = kv_self.get_padding(cparams);
|
||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual void restore() override {
|
||||
kv_slot_restorer.restore(lctx.kv_self);
|
||||
}
|
||||
|
||||
virtual void update() override {
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
// update the kv ring buffer
|
||||
{
|
||||
kv_self.head += ubatch.n_tokens;
|
||||
|
||||
// Ensure kv cache head points to a valid index.
|
||||
if (kv_self.head >= kv_self.size) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual void finalize() override {
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
// decide if we need to defrag the kv cache
|
||||
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
|
||||
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > cparams.defrag_thold) {
|
||||
//LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
|
||||
|
||||
kv_self.defrag();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_context & lctx;
|
||||
|
||||
const llama_batch & batch;
|
||||
|
||||
llama_ubatch ubatch;
|
||||
|
||||
llama_kv_slot_restorer kv_slot_restorer;
|
||||
};
|
||||
|
||||
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch, bool logits_all) {
|
||||
return std::make_unique<llama_batch_manager>(*this, batch, logits_all);
|
||||
}
|
||||
|
||||
enum ggml_status llama_context::compute_graph(
|
||||
ggml_cgraph * graph,
|
||||
bool batched) {
|
||||
@ -59,7 +185,6 @@ enum ggml_status llama_context::compute_graph(
|
||||
return status;
|
||||
}
|
||||
|
||||
|
||||
llama_pos llama_context::pos_max() const {
|
||||
return kv_self.pos_max();
|
||||
}
|
||||
@ -94,9 +219,6 @@ void llama_context::prepare_k_shift() {
|
||||
void llama_context::prepare_defrag() {
|
||||
}
|
||||
|
||||
void llama_context::prepare_decode(const llama_ubatch & /*ubatch*/) {
|
||||
}
|
||||
|
||||
// llama input
|
||||
|
||||
void llama_context::set_inputs(const llama_ubatch & ubatch) {
|
||||
|
Reference in New Issue
Block a user