From 79dac3c8614013bb7c90ac5d95bcf0f828f2f9f7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 19 Jun 2025 19:26:47 +0300 Subject: [PATCH] kv-cache : use ggml_set_rows ggml-ci --- src/llama-graph.cpp | 37 ++++++++++++++++++++---- src/llama-graph.h | 9 ++++++ src/llama-kv-cache-unified.cpp | 51 ++++++++++++++++++++++++++++------ src/llama-kv-cache-unified.h | 10 ++++--- 4 files changed, 89 insertions(+), 18 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 48589a50a..ca7e75ede 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -281,12 +281,24 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { + if (self_kv_idxs) { + mctx->set_input_kv_idxs(self_kv_idxs, ubatch); + } + if (self_kq_mask) { mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } } void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { + if (self_kv_idxs) { + mctx->get_base()->set_input_kv_idxs(self_kv_idxs, ubatch); + } + + if (self_kv_idxs_swa) { + mctx->get_swa()->set_input_kv_idxs(self_kv_idxs_swa, ubatch); + } + if (self_kq_mask) { mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } @@ -1192,6 +1204,9 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const auto n_kv = mctx_cur->get_n_kv(); + inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens); + ggml_set_input(inp->self_kv_idxs); + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); ggml_set_input(inp->self_kq_mask); @@ -1224,8 +1239,10 @@ ggml_tensor * llm_graph_context::build_attn( // store to KV cache { - ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); + const auto & kv_idxs = inp->get_kv_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il)); } const auto & kq_mask = inp->get_kq_mask(); @@ -1278,8 +1295,10 @@ ggml_tensor * llm_graph_context::build_attn( // store to KV cache { - ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); + const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il)); } const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); @@ -1383,8 +1402,8 @@ ggml_tensor * llm_graph_context::build_attn( // store to KV cache { - ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, nullptr, il)); + ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, nullptr, il)); } const auto & kq_mask = inp->get_kq_mask(); @@ -1419,6 +1438,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif { const auto n_kv = mctx_cur->get_base()->get_n_kv(); + inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens); + ggml_set_input(inp->self_kv_idxs); + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); ggml_set_input(inp->self_kq_mask); @@ -1431,6 +1453,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif const auto n_kv = mctx_cur->get_swa()->get_n_kv(); + inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens); + ggml_set_input(inp->self_kv_idxs_swa); + inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); ggml_set_input(inp->self_kq_mask_swa); diff --git a/src/llama-graph.h b/src/llama-graph.h index b433f266d..4cf48a110 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -248,8 +248,12 @@ public: void set_input(const llama_ubatch * ubatch) override; + ggml_tensor * get_kv_idxs() const { return self_kv_idxs; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + // TODO: should this be I64? + ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] @@ -273,9 +277,14 @@ public: void set_input(const llama_ubatch * ubatch) override; + ggml_tensor * get_kv_idxs() const { return self_kv_idxs; } + ggml_tensor * get_kv_idxs_swa() const { return self_kv_idxs_swa; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } + ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch] + ggml_tensor * self_kv_idxs_swa = nullptr; // I32 [n_batch] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index b506d32ed..5744e00a8 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -746,13 +746,17 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint 0); } -ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { +ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; const int64_t n_tokens = k_cur->ne[2]; + if (kv_idxs) { + return ggml_set_rows(ctx, k, ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens), kv_idxs); + } + ggml_tensor * k_view = ggml_view_1d(ctx, k, n_tokens*hparams.n_embd_k_gqa(il), ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); @@ -760,7 +764,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ return ggml_cpy(ctx, k_cur, k_view); } -ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { +ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; @@ -772,21 +776,48 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ ggml_tensor * v_view = nullptr; if (!v_trans) { + if (kv_idxs) { + return ggml_set_rows(ctx, v, v_cur, kv_idxs); + } + v_view = ggml_view_1d(ctx, v, n_tokens*hparams.n_embd_v_gqa(il), ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); } else { + v_cur = ggml_transpose(ctx, v_cur); + // note: the V cache is transposed when not using flash attention + if (kv_idxs) { + // the row becomes a single element and we repeat the KV indices d_head times + // TODO: this seems not very optimal - can we do something better? + v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]); + + v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]); + + kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1); + + return ggml_set_rows(ctx, v_view, v_cur, kv_idxs); + } + v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), (v->ne[1])*ggml_element_size(v), (head_cur)*ggml_element_size(v)); - - v_cur = ggml_transpose(ctx, v_cur); } return ggml_cpy(ctx, v_cur, v_view); } +void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const { + const uint32_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + int64_t * data = (int64_t *) dst->data; + + for (int64_t i = 0; i < n_tokens; ++i) { + data[i] = head_cur + i; + } +} + void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { const uint32_t n_tokens = ubatch->n_tokens; @@ -1789,18 +1820,22 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t return kv->get_v(ctx, il, n_kv); } -ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { - return kv->cpy_k(ctx, k_cur, il, head); +ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const { + return kv->cpy_k(ctx, k_cur, kv_idxs, il, head); } -ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { - return kv->cpy_v(ctx, v_cur, il, head); +ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const { + return kv->cpy_v(ctx, v_cur, kv_idxs, il, head); } void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const { kv->set_input_k_shift(dst); } +void llama_kv_cache_unified_context::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const { + kv->set_input_kv_idxs(dst, ubatch, head); +} + void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { kv->set_input_kq_mask(dst, ubatch, causal_attn); } diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 4c53f1273..5b5da3dc2 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -102,8 +102,8 @@ public: ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; // store k_cur and v_cur in the cache based on the provided head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const; // // preparation API @@ -126,6 +126,7 @@ public: // set_input API // + void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const; void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_k_shift (ggml_tensor * dst) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; @@ -257,11 +258,12 @@ public: ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location - ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; - ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const; void set_input_k_shift(ggml_tensor * dst) const; + void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;