kv-cache : prepare K/V buffers for separation

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-07-03 15:05:27 +03:00
parent 67d1ef23c6
commit b123d89445
3 changed files with 127 additions and 36 deletions

View File

@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
return n_embd_head_v * n_head_kv; return n_embd_head_v * n_head_kv;
} }
bool llama_hparams::is_n_embd_k_gqa_variable() const {
const uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_k_gqa(il)) {
return true;
}
}
return false;
}
bool llama_hparams::is_n_embd_v_gqa_variable() const {
const uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
if (val != n_embd_v_gqa(il)) {
return true;
}
}
return false;
}
uint32_t llama_hparams::n_embd_k_gqa_max() const {
uint32_t val = n_embd_k_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_k_gqa(il));
}
return val;
}
uint32_t llama_hparams::n_embd_v_gqa_max() const {
uint32_t val = n_embd_v_gqa();
for (uint32_t il = 0; il < n_layer; ++il) {
val = std::max(val, n_embd_v_gqa(il));
}
return val;
}
uint32_t llama_hparams::n_embd_r() const { uint32_t llama_hparams::n_embd_r() const {
if (wkv_head_size != 0) { if (wkv_head_size != 0) {
// for RWKV models // for RWKV models

View File

@ -189,6 +189,14 @@ struct llama_hparams {
// dimension of value embeddings across all k-v heads // dimension of value embeddings across all k-v heads
uint32_t n_embd_v_gqa(uint32_t il = 0) const; uint32_t n_embd_v_gqa(uint32_t il = 0) const;
// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
bool is_n_embd_k_gqa_variable() const;
bool is_n_embd_v_gqa_variable() const;
// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
uint32_t n_embd_k_gqa_max() const;
uint32_t n_embd_v_gqa_max() const;
// dimension of the rolling state embeddings // dimension of the rolling state embeddings
// corresponds to Mamba's conv_states size or RWKV's token_shift states size // corresponds to Mamba's conv_states size or RWKV's token_shift states size
uint32_t n_embd_r() const; uint32_t n_embd_r() const;

View File

@ -68,14 +68,21 @@ llama_kv_cache_unified::llama_kv_cache_unified(
cells.resize(kv_size); cells.resize(kv_size);
// [TAG_V_CACHE_VARIABLE]
if (v_trans && hparams.is_n_embd_v_gqa_variable()) {
LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n",
__func__, hparams.n_embd_v_gqa_max());
}
for (uint32_t il = 0; il < n_layer_cache; il++) { for (uint32_t il = 0; il < n_layer_cache; il++) {
if (filter && !filter(il)) { if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
continue; continue;
} }
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // [TAG_V_CACHE_VARIABLE]
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
const char * dev_name = "CPU"; const char * dev_name = "CPU";
@ -98,8 +105,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
ggml_tensor * k; ggml_tensor * k;
ggml_tensor * v; ggml_tensor * v;
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, 1);
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, 1);
ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(k, "cache_k_l%d", il);
ggml_format_name(v, "cache_v_l%d", il); ggml_format_name(v, "cache_v_l%d", il);
@ -785,11 +792,17 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
auto * k = layers[ikv].k; auto * k = layers[ikv].k;
return ggml_view_3d(ctx, k, const uint64_t kv_size = get_size();
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, const uint64_t n_embd_k_gqa = k->ne[0];
assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il));
return ggml_view_4d(ctx, k,
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, 1,
ggml_row_size(k->type, hparams.n_embd_head_k), ggml_row_size(k->type, hparams.n_embd_head_k),
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), ggml_row_size(k->type, n_embd_k_gqa),
0); ggml_row_size(k->type, n_embd_k_gqa*kv_size),
ggml_row_size(k->type, n_embd_k_gqa*kv_size)*0);
} }
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
@ -797,21 +810,29 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
auto * v = layers[ikv].v; auto * v = layers[ikv].v;
const uint64_t kv_size = get_size();
const uint64_t n_embd_v_gqa = v->ne[0];
// [TAG_V_CACHE_VARIABLE]
assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il));
if (!v_trans) { if (!v_trans) {
// note: v->nb[1] <= v->nb[2] // note: v->nb[1] <= v->nb[2]
return ggml_view_3d(ctx, v, return ggml_view_4d(ctx, v,
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, 1,
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2]
0); ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3]
ggml_row_size(v->type, n_embd_v_gqa*kv_size)*0);
} }
// note: v->nb[1] > v->nb[2] // note: v->nb[1] > v->nb[2]
return ggml_view_3d(ctx, v, return ggml_view_4d(ctx, v,
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, 1,
ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
ggml_row_size(v->type, v->ne[1]), // v->nb[2] ggml_row_size(v->type, kv_size), // v->nb[2]
0); ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3]
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*0);
} }
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
@ -825,6 +846,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
if (k_idxs && supports_set_rows) { if (k_idxs && supports_set_rows) {
if (k->ne[2] > 1) {
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
}
return ggml_set_rows(ctx, k, k_cur, k_idxs); return ggml_set_rows(ctx, k, k_cur, k_idxs);
} }
@ -843,31 +868,30 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
auto * v = layers[ikv].v; auto * v = layers[ikv].v;
const int64_t n_embd_v_gqa = v->ne[0]; const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1];
const int64_t n_tokens = v_cur->ne[2]; const int64_t n_tokens = v_cur->ne[2];
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
if (v_idxs && supports_set_rows) { if (v_idxs && supports_set_rows) {
if (!v_trans) { if (!v_trans) {
if (v->ne[2] > 1) {
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
}
return ggml_set_rows(ctx, v, v_cur, v_idxs); return ggml_set_rows(ctx, v, v_cur, v_idxs);
} }
// [TAG_V_CACHE_VARIABLE]
if (n_embd_v_gqa < v->ne[0]) {
v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0);
}
// the row becomes a single element // the row becomes a single element
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]); ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
// note: the V cache is transposed when not using flash attention v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
// note: we can be more explicit here at the cost of extra cont
// however, above we take advantage that a row of single element is always continuous regardless of the row stride
//v_cur = ggml_transpose(ctx, v_cur);
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
// we broadcast the KV indices n_embd_v_gqa times
// v [1, n_kv, n_embd_v_gqa]
// v_cur [1, n_tokens, n_embd_v_gqa]
// v_idxs [n_tokens, 1, 1]
return ggml_set_rows(ctx, v_view, v_cur, v_idxs); return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
} }
@ -904,7 +928,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_tokens = ubatch.n_tokens;
ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); ggml_tensor * v_idxs;
if (!v_trans) {
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
} else {
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max());
}
ggml_set_input(v_idxs); ggml_set_input(v_idxs);
@ -921,7 +951,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
int64_t * data = (int64_t *) dst->data; int64_t * data = (int64_t *) dst->data;
for (int64_t i = 0; i < n_tokens; ++i) { for (uint32_t i = 0; i < n_tokens; ++i) {
data[i] = sinfo.idxs.at(i); data[i] = sinfo.idxs.at(i);
} }
} }
@ -936,8 +966,21 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
int64_t * data = (int64_t *) dst->data; int64_t * data = (int64_t *) dst->data;
for (int64_t i = 0; i < n_tokens; ++i) { if (!v_trans) {
data[i] = sinfo.idxs.at(i); for (uint32_t i = 0; i < n_tokens; ++i) {
data[i] = sinfo.idxs.at(i);
}
} else {
// note: the V cache is transposed when not using flash attention
const int64_t kv_size = get_size();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max();
for (uint32_t i = 0; i < n_tokens; ++i) {
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
data[i*n_embd_v_gqa + j] = j*kv_size + sinfo.idxs.at(i);
}
}
} }
} }