kv-cache : utilize ggml_set_rows broadcast

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-22 10:28:22 +03:00
parent 332f073589
commit 36f8e20d08

View File

@ -821,17 +821,21 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
}
// note: the V cache is transposed when not using flash attention
v_cur = ggml_transpose(ctx, v_cur);
// the row becomes a single element and we repeat the KV indices d_head times
// the row becomes a single element
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
// note: the V cache is transposed when not using flash attention
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
// TODO: this repeat can be avoided if ggml_set_rows() supports broadcast
kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1);
// note: we can be more explicit here at the cost of extra cont
// however, above we take advantage that a row of single element is always contiguous regardless of the row stride
//v_cur = ggml_transpose(ctx, v_cur);
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
// we broadcast the KV indices n_embd_v_gqa times
// v [1, n_kv, n_embd_v_gqa]
// v_cur [1, n_tokens, n_embd_v_gqa]
// kv_idxs [n_tokens, 1, 1]
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
}