From 36f8e20d08bfd5b712eff3a407f38dd8a86e46d1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 22 Jun 2025 10:28:22 +0300 Subject: [PATCH] kv-cache : utilize ggml_set_rows broadcast ggml-ci --- src/llama-kv-cache-unified.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 8fac081d1..f2bc03597 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -821,17 +821,21 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ return ggml_set_rows(ctx, v, v_cur, kv_idxs); } - // note: the V cache is transposed when not using flash attention - v_cur = ggml_transpose(ctx, v_cur); - - // the row becomes a single element and we repeat the KV indices d_head times + // the row becomes a single element ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]); - v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]); + // note: the V cache is transposed when not using flash attention + v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3); - // TODO: this repeat can be avoided if ggml_set_rows() supports broadcast - kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1); + // note: we can be more explicit here at the cost of extra cont + // however, above we take advantage that a row of single element is always contiguous regardless of the row stride + //v_cur = ggml_transpose(ctx, v_cur); + //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]); + // we broadcast the KV indices n_embd_v_gqa times + // v [1, n_kv, n_embd_v_gqa] + // v_cur [1, n_tokens, n_embd_v_gqa] + // kv_idxs [n_tokens, 1, 1] return ggml_set_rows(ctx, v_view, v_cur, kv_idxs); }