mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 11:45:21 +00:00
kv-cache : utilize ggml_set_rows broadcast
ggml-ci
This commit is contained in:
@ -821,17 +821,21 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
||||
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
|
||||
}
|
||||
|
||||
// note: the V cache is transposed when not using flash attention
|
||||
v_cur = ggml_transpose(ctx, v_cur);
|
||||
|
||||
// the row becomes a single element and we repeat the KV indices d_head times
|
||||
// the row becomes a single element
|
||||
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
|
||||
|
||||
v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
|
||||
// note: the V cache is transposed when not using flash attention
|
||||
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
|
||||
|
||||
// TODO: this repeat can be avoided if ggml_set_rows() supports broadcast
|
||||
kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1);
|
||||
// note: we can be more explicit here at the cost of extra cont
|
||||
// however, above we take advantage that a row of single element is always contiguous regardless of the row stride
|
||||
//v_cur = ggml_transpose(ctx, v_cur);
|
||||
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
|
||||
|
||||
// we broadcast the KV indices n_embd_v_gqa times
|
||||
// v [1, n_kv, n_embd_v_gqa]
|
||||
// v_cur [1, n_tokens, n_embd_v_gqa]
|
||||
// kv_idxs [n_tokens, 1, 1]
|
||||
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user