From b123d89445f2a982839dbdf08038f3eaa1244ab0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 3 Jul 2025 15:05:27 +0300 Subject: [PATCH] kv-cache : prepare K/V buffers for separation ggml-ci --- src/llama-hparams.cpp | 40 ++++++++++++ src/llama-hparams.h | 8 +++ src/llama-kv-cache-unified.cpp | 115 ++++++++++++++++++++++----------- 3 files changed, 127 insertions(+), 36 deletions(-) diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 86c814d51..81fb74d19 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } +bool llama_hparams::is_n_embd_k_gqa_variable() const { + const uint32_t val = n_embd_k_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + if (val != n_embd_k_gqa(il)) { + return true; + } + } + + return false; +} + +bool llama_hparams::is_n_embd_v_gqa_variable() const { + const uint32_t val = n_embd_v_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + if (val != n_embd_v_gqa(il)) { + return true; + } + } + + return false; +} + +uint32_t llama_hparams::n_embd_k_gqa_max() const { + uint32_t val = n_embd_k_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + val = std::max(val, n_embd_k_gqa(il)); + } + + return val; +} + +uint32_t llama_hparams::n_embd_v_gqa_max() const { + uint32_t val = n_embd_v_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + val = std::max(val, n_embd_v_gqa(il)); + } + + return val; +} + uint32_t llama_hparams::n_embd_r() const { if (wkv_head_size != 0) { // for RWKV models diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 476d0a5ea..12bff2eb1 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -189,6 +189,14 @@ struct llama_hparams { // dimension of value embeddings across all k-v heads uint32_t n_embd_v_gqa(uint32_t il = 0) const; + // true if any layer has a different n_embd_k_gqa/n_embd_v_gqa + bool is_n_embd_k_gqa_variable() const; + bool is_n_embd_v_gqa_variable() const; + + // return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers + uint32_t n_embd_k_gqa_max() const; + uint32_t n_embd_v_gqa_max() const; + // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size uint32_t n_embd_r() const; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index d3129cc53..075e46255 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -68,14 +68,21 @@ llama_kv_cache_unified::llama_kv_cache_unified( cells.resize(kv_size); + // [TAG_V_CACHE_VARIABLE] + if (v_trans && hparams.is_n_embd_v_gqa_variable()) { + LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n", + __func__, hparams.n_embd_v_gqa_max()); + } + for (uint32_t il = 0; il < n_layer_cache; il++) { if (filter && !filter(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + // [TAG_V_CACHE_VARIABLE] + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max(); const char * dev_name = "CPU"; @@ -98,8 +105,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_tensor * k; ggml_tensor * v; - k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); - v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, 1); + v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, 1); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); @@ -785,11 +792,17 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint auto * k = layers[ikv].k; - return ggml_view_3d(ctx, k, - hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, + const uint64_t kv_size = get_size(); + const uint64_t n_embd_k_gqa = k->ne[0]; + + assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il)); + + return ggml_view_4d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, 1, ggml_row_size(k->type, hparams.n_embd_head_k), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), - 0); + ggml_row_size(k->type, n_embd_k_gqa), + ggml_row_size(k->type, n_embd_k_gqa*kv_size), + ggml_row_size(k->type, n_embd_k_gqa*kv_size)*0); } ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { @@ -797,21 +810,29 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint auto * v = layers[ikv].v; + const uint64_t kv_size = get_size(); + const uint64_t n_embd_v_gqa = v->ne[0]; + + // [TAG_V_CACHE_VARIABLE] + assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il)); + if (!v_trans) { // note: v->nb[1] <= v->nb[2] - return ggml_view_3d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] - 0); + return ggml_view_4d(ctx, v, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, 1, + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] + ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] + ggml_row_size(v->type, n_embd_v_gqa*kv_size)*0); } // note: v->nb[1] > v->nb[2] - return ggml_view_3d(ctx, v, - n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, - ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, v->ne[1]), // v->nb[2] - 0); + return ggml_view_4d(ctx, v, + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, 1, + ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, kv_size), // v->nb[2] + ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] + ggml_row_size(v->type, kv_size*n_embd_v_gqa)*0); } ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { @@ -825,6 +846,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); if (k_idxs && supports_set_rows) { + if (k->ne[2] > 1) { + k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); + } + return ggml_set_rows(ctx, k, k_cur, k_idxs); } @@ -843,31 +868,30 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ auto * v = layers[ikv].v; - const int64_t n_embd_v_gqa = v->ne[0]; - const int64_t n_tokens = v_cur->ne[2]; + const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1]; + const int64_t n_tokens = v_cur->ne[2]; v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); if (v_idxs && supports_set_rows) { if (!v_trans) { + if (v->ne[2] > 1) { + v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); + } + return ggml_set_rows(ctx, v, v_cur, v_idxs); } + // [TAG_V_CACHE_VARIABLE] + if (n_embd_v_gqa < v->ne[0]) { + v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0); + } + // the row becomes a single element - ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]); + ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]); - // note: the V cache is transposed when not using flash attention - v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3); + v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]); - // note: we can be more explicit here at the cost of extra cont - // however, above we take advantage that a row of single element is always continuous regardless of the row stride - //v_cur = ggml_transpose(ctx, v_cur); - //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]); - - // we broadcast the KV indices n_embd_v_gqa times - // v [1, n_kv, n_embd_v_gqa] - // v_cur [1, n_tokens, n_embd_v_gqa] - // v_idxs [n_tokens, 1, 1] return ggml_set_rows(ctx, v_view, v_cur, v_idxs); } @@ -904,7 +928,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; - ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + ggml_tensor * v_idxs; + + if (!v_trans) { + v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + } else { + v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max()); + } ggml_set_input(v_idxs); @@ -921,7 +951,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int64_t * data = (int64_t *) dst->data; - for (int64_t i = 0; i < n_tokens; ++i) { + for (uint32_t i = 0; i < n_tokens; ++i) { data[i] = sinfo.idxs.at(i); } } @@ -936,8 +966,21 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int64_t * data = (int64_t *) dst->data; - for (int64_t i = 0; i < n_tokens; ++i) { - data[i] = sinfo.idxs.at(i); + if (!v_trans) { + for (uint32_t i = 0; i < n_tokens; ++i) { + data[i] = sinfo.idxs.at(i); + } + } else { + // note: the V cache is transposed when not using flash attention + const int64_t kv_size = get_size(); + + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max(); + + for (uint32_t i = 0; i < n_tokens; ++i) { + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + data[i*n_embd_v_gqa + j] = j*kv_size + sinfo.idxs.at(i); + } + } } }