kv-cache : separate recurrent vs non-recurrent impl (#12799)

* kv-cache : serparate recurrent vs non-recurrent impl (wip)

ggml-ci

* kv-cache : init -> contructor + add llama_memory_params

ggml-ci

* kv-cache : fix callback reference

ggml-ci

* context : llama_kv_cache -> llama_memory_i

ggml-ci

* context : move memory creation logic to model

ggml-ci

* llama : remove reference of memory during encode

ggml-ci

* kv-cache : hide padding details in the implementation

ggml-ci

* kv-cache : add ubatch_next()

ggml-ci

* context : simplify sbatch logic

ggml-ci

* kv-cache : hide defrag logic in the implementation

ggml-ci

* context : hide kv cache details in implementation

ggml-ci

* build : fix

ggml-ci

* cont : another fix

ggml-ci

* kv-cache : simplify interface (wip)

ggml-ci

* kv-cache : use separate KV cell structs for unified/recurrent

ggml-ci

* kv-cache : clean-up

ggml-ci

* model : better llama_model::create_model() signature

ggml-ci

* kv-cache : fix recurrent seq_rm()

ggml-ci

* kv-cache : replace `struct callbacks` with `llama_model &`

ggml-ci

* kv-cache : replace `struct graph_params` with `llama_context &`

ggml-ci

* kv-cache : fix offload check

ggml-ci

* context : avoid passing unique_ptr

ggml-ci

* kv-cache : avoid using the backends from the llama_context

ref #13113

ggml-ci

* kv-cache : more consistent debug logs [no ci]

* kv-cache : do not pass the full llama_context for kv graphs

ggml-ci

* kv-cache : remove comment

* kv-cache : ggml_rope_ext_inplace -> ggml_rope_ext

ggml-ci

* kv-cache : fix recurrent multi-user case

ggml-ci

* memory : remove comments [no ci]
This commit is contained in:
Georgi Gerganov
2025-05-02 17:48:36 +03:00
committed by GitHub
parent cb06a3c363
commit c642bc014c
11 changed files with 1960 additions and 1048 deletions

View File

@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
return ubatch;
}
void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
GGML_ASSERT(batch.n_tokens >= 0);
this->batch = &batch;
this->n_embd = n_embd;
@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
for (size_t i = 0; i < n_tokens; ++i) {
ids[i] = i;
}
if (simple_split) {
seq.resize(1);
llama_sbatch_seq & s = seq[0];
@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
s.length = n_tokens;
return;
}
std::sort(ids.begin(), ids.end(),
[&batch](size_t a, size_t b) {
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
return n_seq_a > n_seq_b;
}
);
// init seq
llama_sbatch_seq * last_seq = nullptr;
@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim
seq.push_back(new_seq);
last_seq = &seq.back();
}
// keep shared prompts first at the end, then sort by length descending.
std::sort(seq.begin(), seq.end(),
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {

View File

@ -70,7 +70,8 @@ struct llama_sbatch {
// sequence-wise split
llama_ubatch split_seq(size_t n_ubatch);
void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
llama_sbatch() = default;
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
};
// temporary allocate memory for the input batch if needed

View File

@ -6,11 +6,9 @@
#include "llama-model.h"
#include "llama-kv-cache.h"
#include <cassert>
#include <cstring>
#include <stdexcept>
#include <cinttypes>
#include <cmath>
//
// llama_context
@ -177,44 +175,13 @@ llama_context::llama_context(
}
// init the memory module
// TODO: for now, always create a unified KV cache
if (!hparams.vocab_only) {
kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
llama_memory_params params_mem = {
/*.type_k =*/ params.type_k,
/*.type_v =*/ params.type_v,
};
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
uint32_t kv_size = cparams.n_ctx;
ggml_type type_k = params.type_k;
ggml_type type_v = params.type_v;
if (llama_model_is_recurrent(&model)) {
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
}
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
throw std::runtime_error("failed to initialize self-attention cache");
}
{
const size_t memory_size_k = kv_self->size_k_bytes();
const size_t memory_size_v = kv_self->size_v_bytes();
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
memory.reset(model.create_memory(params_mem, cparams));
}
// init backends
@ -305,7 +272,9 @@ llama_context::llama_context(
int n_nodes_tg = -1;
// simulate full KV cache
kv_self->n = kv_self->size;
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->set_full();
cross.v_embd.clear();
@ -427,6 +396,18 @@ const llama_model & llama_context::get_model() const {
return model;
}
const llama_cparams & llama_context::get_cparams() const {
return cparams;
}
ggml_backend_sched_t llama_context::get_sched() const {
return sched.get();
}
ggml_context * llama_context::get_ctx_compute() const {
return ctx_compute.get();
}
uint32_t llama_context::n_ctx() const {
return cparams.n_ctx;
}
@ -456,337 +437,21 @@ uint32_t llama_context::n_threads_batch() const {
}
llama_kv_cache * llama_context::get_kv_self() {
return kv_self.get();
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
return kv_self;
}
const llama_kv_cache * llama_context::get_kv_self() const {
return kv_self.get();
}
ggml_tensor * llama_context::build_rope_shift(
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const {
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
const auto & hparams = model.hparams;
const auto & n_rot = hparams.n_rot;
const auto & rope_type = hparams.rope_type;
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
ggml_tensor * tmp;
if (ggml_is_quantized(cur->type)) {
// dequantize to f32 -> RoPE -> quantize back
tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
tmp = ggml_rope_ext(ctx0, tmp,
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
tmp = ggml_cpy(ctx0, tmp, cur);
} else {
// we rotate only the first n_rot dimensions
tmp = ggml_rope_ext_inplace(ctx0, cur,
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
}
return tmp;
}
class llm_graph_input_k_shift : public llm_graph_input_i {
public:
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_k_shift() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * k_shift; // I32 [kv_size]
const llama_kv_cache_unified * kv_self;
};
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
if (k_shift) {
assert(ggml_backend_buffer_is_host(k_shift->buffer));
int32_t * data = (int32_t *) k_shift->data;
for (uint32_t i = 0; i < kv_self->size; ++i) {
data[i] = kv_self->cells[i].delta;
}
}
}
llm_graph_result_ptr llama_context::build_kv_self_shift(
ggml_context * ctx0,
ggml_cgraph * gf) const {
auto res = std::make_unique<llm_graph_result>();
const auto & hparams = model.hparams;
const auto & n_layer = hparams.n_layer;
const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;
//GGML_ASSERT(kv_self->size == n_ctx);
auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
ggml_set_input(inp->k_shift);
for (uint32_t il = 0; il < n_layer; ++il) {
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const bool is_swa = hparams.is_swa(il);
// note: the swa rope params could become part of the cparams in the future
// if we decide to make them configurable, like the non-sliding ones
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
ggml_tensor * k =
ggml_view_3d(ctx0, kv_self->k_l[il],
n_embd_head_k, n_head_kv, kv_self->size,
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
0);
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
ggml_build_forward_expand(gf, cur);
}
res->add_input(std::move(inp));
return res;
}
llm_graph_result_ptr llama_context::build_kv_self_defrag(
ggml_context * ctx0,
ggml_cgraph * gf) const {
auto res = std::make_unique<llm_graph_result>();
const auto & hparams = model.hparams;
const auto & ids = kv_self->defrag_info.ids;
#if 0
// CPU defrag
//
// TODO: optimizations are possible:
// - multiple threads
// - avoid copying to the host memory when already there
//
// likely not worth the effort, as we have ggml_graph based defrag
//
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const uint32_t kv_size = size;
std::vector<uint8_t> buf_k;
std::vector<uint8_t> buf_v;
for (uint32_t il = 0; il < n_layer; ++il) {
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
const size_t v_size_el = ggml_type_size(v_l[il]->type);
const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
buf_k.resize(k_size);
buf_v.resize(v_size);
ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
// batch move [i, i+nm) to [id, id+nm)
// note: cells can move only to a lower index
for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t id = ids[i];
if (i == id || id == n_kv) {
continue;
}
uint32_t nm = 1;
while (i + nm < n_kv && ids[i + nm] == id + nm) {
nm++;
}
// move keys
{
const int64_t os = i*k_size_row;
const int64_t od = id*k_size_row;
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
}
// move values (note: they are transposed)
{
const int64_t os = i;
const int64_t od = id;
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
}
}
i += nm - 1;
}
ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
}
#else
for (uint32_t i = 0; i < ids.size(); ++i) {
const uint32_t id = ids[i];
if (i == id || id == ids.size()) {
continue;
}
uint32_t nm = 1;
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
nm++;
}
for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
n_embd_k_gqa, nm,
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
n_embd_k_gqa, nm,
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
ggml_tensor * view_v_src;
ggml_tensor * view_v_dst;
if (cparams.flash_attn) {
// NOTE: the V cache is not transposed when using flash attention
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
n_embd_v_gqa, nm,
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
n_embd_v_gqa, nm,
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
} else {
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
nm, n_embd_v_gqa,
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
ggml_row_size(kv_self->v_l[il]->type, i));
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
nm, n_embd_v_gqa,
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
ggml_row_size(kv_self->v_l[il]->type, id));
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
}
i += nm - 1;
}
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
#endif
return res;
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
return kv_self;
}
void llama_context::kv_self_update() {
auto & kv = kv_self;
bool need_reserve = false;
if (kv->has_shift) {
if (!kv->get_can_shift()) {
GGML_ABORT("The current context does not support K-shift");
}
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
// apply K-shift if needed
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
ggml_backend_sched_reset(sched.get());
auto * gf = graph_init();
auto res = build_kv_self_shift(ctx_compute.get(), gf);
ggml_backend_sched_alloc_graph(sched.get(), gf);
res->set_inputs(nullptr);
graph_compute(gf, false);
need_reserve = true;
}
{
kv->has_shift = false;
for (uint32_t i = 0; i < kv->size; ++i) {
kv->cells[i].delta = 0;
}
}
}
// defragment the KV cache if needed
if (kv->do_defrag) {
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
if (kv->defrag_prepare(graph_max_nodes())) {
ggml_backend_sched_reset(sched.get());
auto * gf = graph_init();
auto res = build_kv_self_defrag(ctx_compute.get(), gf);
ggml_backend_sched_alloc_graph(sched.get(), gf);
res->set_inputs(nullptr);
graph_compute(gf, false);
need_reserve = true;
}
kv->do_defrag = false;
}
need_reserve = kv_self->update(*this);
// reserve a worst case graph if needed
if (need_reserve) {
@ -797,7 +462,7 @@ void llama_context::kv_self_update() {
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
// simulate full KV cache
kv_self->n = kv_self->size;
kv_self->set_full();
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
@ -818,9 +483,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
}
float * llama_context::get_logits() {
// reorder logits for backward compatibility
output_reorder();
return logits;
}
@ -863,9 +525,6 @@ float * llama_context::get_logits_ith(int32_t i) {
}
float * llama_context::get_embeddings() {
// reorder embeddings for backward compatibility
output_reorder();
return embd;
}
@ -1017,8 +676,8 @@ int llama_context::encode(llama_batch & inp_batch) {
}
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
// note: during encode, we always pass the full sequence starting from pos = 0
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
const llama_batch & batch = batch_allocr.batch;
const int32_t n_tokens = batch.n_tokens;
@ -1047,7 +706,7 @@ int llama_context::encode(llama_batch & inp_batch) {
const int64_t n_embd = hparams.n_embd;
sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
@ -1181,9 +840,11 @@ int llama_context::decode(llama_batch & inp_batch) {
return -1;
}
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
@ -1195,7 +856,7 @@ int llama_context::decode(llama_batch & inp_batch) {
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
llama_kv_cache_guard kv_guard(kv_self.get());
llama_kv_cache_guard kv_guard(kv_self);
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
@ -1236,11 +897,7 @@ int llama_context::decode(llama_batch & inp_batch) {
n_outputs_all = 1;
}
const bool logits_all = n_outputs_all == n_tokens_all;
sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self->recurrent,
/* logits_all */ logits_all);
llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
@ -1254,22 +911,7 @@ int llama_context::decode(llama_batch & inp_batch) {
int64_t n_outputs_prev = 0;
while (sbatch.n_tokens > 0) {
llama_ubatch ubatch = llama_ubatch();
const auto & n_ubatch = cparams.n_ubatch;
if (kv_self->recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = sbatch.split_seq(cparams.n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = sbatch.split_equal(cparams.n_ubatch);
}
} else {
ubatch = sbatch.split_simple(n_ubatch);
}
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
// count the outputs in this u_batch
{
@ -1289,24 +931,12 @@ int llama_context::decode(llama_batch & inp_batch) {
}
// find KV slot
{
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
if (!kv_self->find_slot(ubatch)) {
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
return 1;
}
if (!kv_self->recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = kv_self->get_padding(cparams);
kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
}
return 1;
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
@ -1424,18 +1054,52 @@ int llama_context::decode(llama_batch & inp_batch) {
{
bool sorted_output = true;
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
auto & out_ids = sbatch.out_ids;
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
for (int64_t i = 0; i < n_outputs_all; ++i) {
int64_t out_id = sbatch.out_ids[i];
int64_t out_id = out_ids[i];
output_ids[out_id] = i;
if (out_id != i) {
sorted_output = false;
}
}
if (sorted_output) {
sbatch.out_ids.clear();
// make the outputs have the same order they had in the user-provided batch
// note: this is mostly relevant for recurrent models atm
if (!sorted_output) {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint32_t n_embd = model.hparams.n_embd;
GGML_ASSERT((size_t) n_outputs == out_ids.size());
// TODO: is there something more efficient which also minimizes swaps?
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
for (int32_t i = 0; i < n_outputs - 1; ++i) {
int32_t j_min = i;
for (int32_t j = i + 1; j < n_outputs; ++j) {
if (out_ids[j] < out_ids[j_min]) {
j_min = j;
}
}
if (j_min == i) { continue; }
std::swap(out_ids[i], out_ids[j_min]);
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
}
}
if (embd_size > 0) {
for (uint32_t k = 0; k < n_embd; k++) {
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
}
}
}
std::fill(output_ids.begin(), output_ids.end(), -1);
for (int32_t i = 0; i < n_outputs; ++i) {
output_ids[out_ids[i]] = i;
}
}
}
@ -1446,17 +1110,8 @@ int llama_context::decode(llama_batch & inp_batch) {
//synchronize();
// decide if we need to defrag the kv cache
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
// - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens
const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
kv_self->defrag();
}
if (cparams.defrag_thold > 0.0f) {
kv_self->defrag_sched(cparams.defrag_thold);
}
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
@ -1542,44 +1197,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
return n_outputs_max;
}
void llama_context::output_reorder() {
auto & out_ids = sbatch.out_ids;
if (!out_ids.empty()) {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint32_t n_embd = model.hparams.n_embd;
GGML_ASSERT((size_t) n_outputs == out_ids.size());
// TODO: is there something more efficient which also minimizes swaps?
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
for (int32_t i = 0; i < n_outputs - 1; ++i) {
int32_t j_min = i;
for (int32_t j = i + 1; j < n_outputs; ++j) {
if (out_ids[j] < out_ids[j_min]) {
j_min = j;
}
}
if (j_min == i) { continue; }
std::swap(out_ids[i], out_ids[j_min]);
if (logits_size > 0) {
for (uint32_t k = 0; k < n_vocab; k++) {
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
}
}
if (embd_size > 0) {
for (uint32_t k = 0; k < n_embd; k++) {
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
}
}
}
std::fill(output_ids.begin(), output_ids.end(), -1);
for (int32_t i = 0; i < n_outputs; ++i) {
output_ids[out_ids[i]] = i;
}
out_ids.clear();
}
}
//
// graph
//
@ -1616,7 +1233,7 @@ llm_graph_result_ptr llama_context::graph_build(
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.memory =*/ kv_self.get(),
/*.memory =*/ memory.get(),
/*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(),
@ -2020,8 +1637,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
{
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
output_reorder();
const auto n_outputs = this->n_outputs;
const auto & output_ids = this->output_ids;
@ -2075,6 +1690,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
}
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_write(io);
return io.n_bytes();
@ -2159,6 +1776,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
}
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_read(io);
return io.n_bytes();
@ -2167,6 +1786,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
GGML_UNUSED(seq_id);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_write(io, seq_id);
return io.n_bytes();
@ -2175,6 +1796,8 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
GGML_UNUSED(seq_id);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
kv_self->state_read(io, seq_id);
return io.n_bytes();
@ -2530,7 +2153,7 @@ void llama_kv_cache_seq_cp(
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_self_seq_cp(
@ -2544,14 +2167,14 @@ void llama_kv_self_seq_cp(
return;
}
return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
// deprecated
void llama_kv_cache_seq_keep(
llama_context * ctx,
llama_seq_id seq_id) {
return llama_kv_self_seq_keep(ctx, seq_id);
llama_kv_self_seq_keep(ctx, seq_id);
}
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
@ -2560,7 +2183,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
return;
}
return kv->seq_keep(seq_id);
kv->seq_keep(seq_id);
}
// deprecated
@ -2570,7 +2193,7 @@ void llama_kv_cache_seq_add(
llama_pos p0,
llama_pos p1,
llama_pos delta) {
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
}
void llama_kv_self_seq_add(
@ -2584,7 +2207,7 @@ void llama_kv_self_seq_add(
return;
}
return kv->seq_add(seq_id, p0, p1, delta);
kv->seq_add(seq_id, p0, p1, delta);
}
// deprecated
@ -2594,7 +2217,7 @@ void llama_kv_cache_seq_div(
llama_pos p0,
llama_pos p1,
int d) {
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
}
void llama_kv_self_seq_div(
@ -2608,7 +2231,7 @@ void llama_kv_self_seq_div(
return;
}
return kv->seq_div(seq_id, p0, p1, d);
kv->seq_div(seq_id, p0, p1, d);
}
// deprecated
@ -2627,7 +2250,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
// deprecated
void llama_kv_cache_defrag(llama_context * ctx) {
return llama_kv_self_defrag(ctx);
llama_kv_self_defrag(ctx);
}
void llama_kv_self_defrag(llama_context * ctx) {
@ -2636,7 +2259,8 @@ void llama_kv_self_defrag(llama_context * ctx) {
return;
}
return kv->defrag();
// force defrag
kv->defrag_sched(-1.0f);
}
// deprecated

View File

@ -27,7 +27,12 @@ struct llama_context {
void synchronize();
const llama_model & get_model() const;
const llama_model & get_model() const;
const llama_cparams & get_cparams() const;
ggml_backend_sched_t get_sched() const;
ggml_context * get_ctx_compute() const;
uint32_t n_ctx() const;
uint32_t n_ctx_per_seq() const;
@ -137,49 +142,30 @@ private:
// Returns max number of outputs for which space was reserved.
int32_t output_reserve(int32_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
// TODO: maybe remove this
void output_reorder();
//
// graph
//
public:
int32_t graph_max_nodes() const;
// zero-out inputs and create the ctx_compute for the compute graph
ggml_cgraph * graph_init();
llm_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype);
// returns the result of ggml_backend_sched_graph_compute_async execution
ggml_status graph_compute(
ggml_cgraph * gf,
bool batched);
private:
llm_graph_result_ptr graph_build(
ggml_context * ctx,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
llm_graph_type gtype);
llm_graph_cb graph_get_cb() const;
// used by kv_self_update()
ggml_tensor * build_rope_shift(
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const;
llm_graph_result_ptr build_kv_self_shift(
ggml_context * ctx0,
ggml_cgraph * gf) const;
llm_graph_result_ptr build_kv_self_defrag(
ggml_context * ctx0,
ggml_cgraph * gf) const;
// TODO: read/write lora adapters and cvec
size_t state_write_data(llama_io_write_i & io);
size_t state_read_data (llama_io_read_i & io);
@ -196,11 +182,10 @@ private:
llama_cparams cparams;
llama_adapter_cvec cvec;
llama_adapter_loras loras;
llama_sbatch sbatch;
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
std::unique_ptr<llama_kv_cache_unified> kv_self;
std::unique_ptr<llama_memory_i> memory;
// TODO: remove
bool logits_all = false;

View File

@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self->head;
//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
// prevent out-of-bound sources
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
kv_cell.src = cell_id;
}
data[i] = kv_cell.src;
// TODO: do not mutate the KV cache
// ensure copy only happens once
if (kv_cell.src != (int32_t) cell_id) {
kv_cell.src = cell_id;
}
data[i] = kv_self->s_copy(i);
}
}
}
@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
// clear unused states
for (int i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self->head;
//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
data[i] = (float) (kv_cell.src >= 0);
// only clear once
if (kv_cell.src < 0) {
kv_cell.src = cell_id;
}
data[i] = kv_self->s_mask(i);
}
}
}
@ -1105,7 +1077,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
}
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
@ -1122,7 +1094,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
}
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
@ -1436,8 +1408,6 @@ ggml_tensor * llm_graph_context::build_attn(
// store to KV cache
{
GGML_ASSERT(!kv_self->recurrent);
const auto kv_head = kv_self->head;
GGML_ASSERT(kv_self->size == n_ctx);
@ -1587,7 +1557,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto n_kv = kv_self->n;
const auto kv_head = kv_self->head;
@ -1619,7 +1589,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto token_shift_count = hparams.token_shift_count;
@ -1640,7 +1610,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto token_shift_count = hparams.token_shift_count;
const auto n_embd = hparams.n_embd;

View File

@ -19,6 +19,7 @@ struct llama_cparams;
class llama_memory_i;
class llama_kv_cache_unified;
class llama_kv_cache_recurrent;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
@ -186,26 +187,26 @@ public:
class llm_graph_input_s_copy : public llm_graph_input_i {
public:
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_copy() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size]
const llama_kv_cache_unified * kv_self;
const llama_kv_cache_recurrent * kv_self;
};
class llm_graph_input_s_mask : public llm_graph_input_i {
public:
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
virtual ~llm_graph_input_s_mask() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_mask; // F32 [1, n_kv]
const llama_kv_cache_unified * kv_self;
const llama_kv_cache_recurrent * kv_self;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
@ -350,8 +351,8 @@ struct llm_graph_params {
const llama_cparams & cparams;
const llama_ubatch & ubatch;
ggml_backend_sched * sched;
ggml_backend * backend_cpu;
ggml_backend_sched_t sched;
ggml_backend_t backend_cpu;
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;
@ -402,9 +403,9 @@ struct llm_graph_context {
ggml_context * ctx0 = nullptr;
ggml_backend_sched * sched;
ggml_backend_sched_t sched;
ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras;

File diff suppressed because it is too large Load Diff

View File

@ -2,32 +2,72 @@
#include "llama.h"
#include "llama-io.h"
#include "llama-graph.h"
#include "llama-memory.h"
#include "ggml-cpp.h"
#include <functional>
#include <set>
#include <vector>
struct llama_cparams;
struct llama_hparams;
struct llama_ubatch;
struct llama_sbatch;
struct llama_model;
struct llama_context;
struct llama_kv_cache : public llama_memory_i {
using llama_memory_i::llama_memory_i;
virtual ~llama_kv_cache() = default;
virtual void restore() = 0; // call if batch processing fails - restores the cache state
virtual void commit() = 0; // call after successful batch processing - clears any pending state
// call if batch processing fails - restores the cache state
virtual void restore() = 0;
virtual int32_t get_n_tokens() const = 0;
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
// call after successful batch processing - clears any pending state
virtual void commit() = 0;
virtual bool get_can_shift() const = 0;
// process any pending defrag/shift/etc. operations
// optionally call once before processing a new batch
virtual bool update(llama_context & lctx) = 0;
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
virtual void defrag_sched(float thold) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual void set_full() = 0;
//
// batch processing
//
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
// different KV caches require different batch splitting strategies
virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0;
// find an empty slot of size "n_tokens" in the cache
virtual bool find_slot(const llama_ubatch & batch) = 0;
// getters
virtual int32_t get_n_tokens() const = 0;
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual llama_pos get_pos_max() const = 0;
virtual bool get_can_shift() const = 0;
bool get_can_edit() const override { return get_can_shift(); }
//
// state write/read
//
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
};
//
// llama_kv_cache_guard
//
struct llama_kv_cache_guard {
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
@ -43,65 +83,50 @@ private:
llama_kv_cache * kv;
};
struct llama_kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
int32_t src = -1; // used by recurrent state models to copy states
int32_t tail = -1;
//
// llama_kv_cache_unified
//
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const llama_kv_cell & other) const {
return seq_id == other.seq_id;
}
};
// ring-buffer of cached KV data
// TODO: pimpl
// TODO: add notion of max sequences
class llama_kv_cache_unified : public llama_kv_cache {
public:
// can be used to query data from the model if needed
struct callbacks {
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
struct kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};
static uint32_t get_padding(const llama_cparams & cparams);
llama_kv_cache_unified(
const llama_hparams & hparams,
callbacks cbs);
virtual ~llama_kv_cache_unified() = default;
// TODO: become constructor
bool init(
const llama_model & model, // TODO: do not reference the model
const llama_cparams & cparams,
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
uint32_t kv_size,
bool offload);
uint32_t padding);
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
~llama_kv_cache_unified() = default;
size_t total_size() const;
// TODO: better data structures to reduce the cost of this operation
llama_pos pos_max() const;
//
// llama_memory_i
//
void clear() override;
void defrag() override;
virtual void restore() override;
virtual void commit() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
@ -111,63 +136,40 @@ public:
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
bool get_can_shift() const override;
//
// llama_kv_cache
//
void restore() override;
void commit() override;
bool update(llama_context & ctx) override;
void defrag_sched(float thold) override;
void set_full() override;
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
// find an empty slot of size "n_tokens" in the cache
// updates the cache head
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
bool find_slot(const llama_ubatch & batch);
bool find_slot(const llama_ubatch & batch) override;
// TODO: maybe not needed
uint32_t get_padding(const llama_cparams & cparams) const;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
// find how many cells are currently in use
uint32_t cell_max() const;
// TODO: better data structures to reduce the cost of this operation
llama_pos get_pos_max() const override;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
// defrag
struct {
std::vector<uint32_t> ids;
} defrag_info;
// return true if cells have been moved
bool defrag_prepare(int32_t n_max_nodes);
// commit/restore cache
struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0;
};
// pending cell updates that are not yet committed
struct {
std::vector<slot_range> ranges;
} pending;
bool get_can_shift() const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
// members
const llama_hparams & hparams;
callbacks cbs;
bool has_shift = false;
bool do_defrag = false;
// TODO: remove this and implement llama_kv_cache_recurrent instead
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool v_trans = true; // the value tensor is transposed
bool can_shift = false;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
@ -179,18 +181,213 @@ public:
// computed before each graph build
uint32_t n = 0;
std::vector<llama_kv_cell> cells;
std::vector<kv_cell> cells;
std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l;
private:
const llama_model & model;
const llama_hparams & hparams;
bool has_shift = false;
bool do_defrag = false;
bool v_trans = true; // the value tensor is transposed
bool can_shift = false;
// required padding
uint32_t padding = 1;
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
// defrag
struct {
std::vector<uint32_t> ids;
} defrag_info;
// return true if cells have been moved
bool defrag_prepare(int32_t n_max_nodes);
// commit/restore cache
struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0;
};
// pending cell updates that are not yet committed
struct {
std::vector<slot_range> ranges;
} pending;
// find how many cells are currently in use
uint32_t cell_max() const;
size_t total_size() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
ggml_tensor * build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale) const;
llm_graph_result_ptr build_graph_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_cgraph * gf) const;
llm_graph_result_ptr build_graph_defrag(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_cgraph * gf) const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
//
// llama_kv_cache_recurrent
//
class llama_kv_cache_recurrent : public llama_kv_cache {
public:
struct kv_cell {
llama_pos pos = -1;
int32_t src = -1; // used to copy states
int32_t tail = -1;
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};
llama_kv_cache_recurrent(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size);
~llama_kv_cache_recurrent() = default;
//
// llama_memory_i
//
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
void restore() override;
void commit() override;
bool update(llama_context & lctx) override;
void defrag_sched(float thold) override;
void set_full() override;
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
bool find_slot(const llama_ubatch & batch) override;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
// TODO: better data structures to reduce the cost of this operation
llama_pos get_pos_max() const override;
bool get_can_shift() const override;
// TODO: temporary methods - they are not really const as they do const_cast<>, fix this
int32_t s_copy(int i) const;
float s_mask(int i) const;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_impl also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build
uint32_t n = 0;
std::vector<kv_cell> cells;
std::vector<ggml_tensor *> k_l; // per layer
std::vector<ggml_tensor *> v_l;
private:
//const llama_model & model;
const llama_hparams & hparams;
// commit/restore cache
// TODO: rework for recurrent cache
struct slot_range {
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
uint32_t c1 = 0;
};
// pending cell updates that are not yet committed
struct {
std::vector<slot_range> ranges;
} pending;
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
// find how many cells are currently in use
uint32_t cell_max() const;
size_t total_size() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
@ -198,11 +395,6 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
};
// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified
//class llama_kv_cache_recurrent : public llama_kv_cache_unified {
//public:
// using llama_kv_cache_unified::llama_kv_cache_unified;
//};
//
// kv cache view

View File

@ -2,12 +2,22 @@
#include "llama.h"
struct llama_memory_params {
// kv cache
ggml_type type_k;
ggml_type type_v;
// parameters for other types of memory
// ...
};
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
class llama_memory_i {
public:
virtual ~llama_memory_i() = default;
virtual void clear() = 0;
virtual void defrag() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;

View File

@ -4445,6 +4445,19 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
return it->second;
}
ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const {
// choose long/short freq factors based on the context size
if (layers[il].rope_freqs != nullptr) {
return layers[il].rope_freqs;
}
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
return layers[il].rope_long;
}
return layers[il].rope_short;
}
struct llm_build_llama : public llm_graph_context {
llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
@ -4485,7 +4498,7 @@ struct llm_build_llama : public llm_graph_context {
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -4710,7 +4723,7 @@ struct llm_build_deci : public llm_graph_context {
} else if (n_head > 0) {
// self-attention
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -7192,7 +7205,7 @@ struct llm_build_phi3 : public llm_graph_context {
// self-attention
{
// rope freq factors for 128k context
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor* attn_norm_output = build_norm(inpL,
model.layers[il].attn_norm,
@ -7944,7 +7957,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// norm
cur = build_norm(inpL,
@ -8711,7 +8724,7 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto kv_head = kv_self->head;
@ -9012,7 +9025,7 @@ struct llm_build_cohere2 : public llm_graph_context {
// self-attention
{
// rope freq factors for 128k context
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -9950,7 +9963,7 @@ struct llm_build_deepseek : public llm_graph_context {
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -11314,7 +11327,7 @@ struct llm_build_exaone : public llm_graph_context {
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -11459,7 +11472,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
@ -11855,7 +11868,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
ggml_tensor *& first_layer_value,
const llama_ubatch & ubatch,
int il) const {
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
@ -12695,7 +12708,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@ -12815,7 +12828,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
}
};
llama_memory_i * llama_model::create_memory() const {
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
llama_memory_i * res;
switch (arch) {
@ -12825,26 +12838,29 @@ llama_memory_i * llama_model::create_memory() const {
case LLM_ARCH_RWKV7:
case LLM_ARCH_ARWKV7:
{
res = new llama_kv_cache_unified(hparams, {
/*.get_rope_factors =*/ nullptr
});
res = new llama_kv_cache_recurrent(
*this,
GGML_TYPE_F32,
GGML_TYPE_F32,
cparams.offload_kqv,
std::max((uint32_t) 1, cparams.n_seq_max));
} break;
default:
{
res = new llama_kv_cache_unified(hparams, {
/*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) {
// choose long/short freq factors based on the context size
if (layers[il].rope_freqs != nullptr) {
return layers[il].rope_freqs;
}
const auto padding = llama_kv_cache_unified::get_padding(cparams);
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
return layers[il].rope_long;
}
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
return layers[il].rope_short;
}
});
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
res = new llama_kv_cache_unified(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.n_ctx,
padding);
}
}

View File

@ -395,8 +395,11 @@ struct llama_model {
const struct ggml_tensor * get_tensor(const char * name) const;
ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const;
// note: can mutate `cparams`
// TODO: move this to new llm_arch_model_i interface
llama_memory_i * create_memory() const; // TODO: params
llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const;
// TODO: move this to new llm_arch_model_i interface
llm_graph_result_ptr build_graph(