mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 12:35:16 +00:00
kv-cache : simplify set_rows logic
ggml-ci
This commit is contained in:
@ -937,17 +937,17 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
|
|||||||
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
|
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
|
||||||
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
|
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
|
||||||
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
|
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
|
||||||
size_virt,
|
size_virt, // v->nb[3]
|
||||||
size_virt*sinfo.s0);
|
size_virt*sinfo.s0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// note: v->nb[1] > v->nb[2]
|
// note: v->nb[1] > v->nb[2]
|
||||||
return ggml_view_4d(ctx, v,
|
return ggml_view_4d(ctx, v,
|
||||||
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
|
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
|
||||||
ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
|
ggml_row_size(v->type, v->ne[1]*n_seq_virt*hparams.n_embd_head_v), // v->nb[1]
|
||||||
ggml_row_size(v->type, v->ne[1]), // v->nb[2]
|
ggml_row_size(v->type, v->ne[1]*n_seq_virt), // v->nb[2]
|
||||||
size_virt,
|
ggml_row_size(v->type, v->ne[1]), // v->nb[3]
|
||||||
size_virt*sinfo.s0);
|
ggml_row_size(v->type, v->ne[1]*sinfo.s0));
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
|
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
|
||||||
@ -961,20 +961,9 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
|
|||||||
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
||||||
|
|
||||||
if (kv_idxs && supports_set_rows) {
|
if (kv_idxs && supports_set_rows) {
|
||||||
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
|
||||||
|
|
||||||
const uint64_t size_virt = ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*get_size());
|
return ggml_set_rows(ctx, k, k_cur, kv_idxs);
|
||||||
|
|
||||||
ggml_tensor * k_view = ggml_view_3d(ctx, k, k->ne[0], k->ne[1], ns,
|
|
||||||
ggml_row_size(k->type, k->ne[0]),
|
|
||||||
size_virt,
|
|
||||||
size_virt*sinfo.s0);
|
|
||||||
|
|
||||||
k_cur = ggml_reshape_3d(ctx, k_cur, k_cur->ne[0], k_cur->ne[1]/ns, ns);
|
|
||||||
|
|
||||||
kv_idxs = ggml_reshape_2d(ctx, kv_idxs, n_tokens/ns, ns);
|
|
||||||
|
|
||||||
return ggml_set_rows(ctx, k_view, k_cur, kv_idxs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||||
@ -1000,45 +989,27 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|||||||
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
|
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
|
||||||
|
|
||||||
if (kv_idxs && supports_set_rows) {
|
if (kv_idxs && supports_set_rows) {
|
||||||
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
|
||||||
|
|
||||||
const uint64_t size_virt = ggml_row_size(v->type, hparams.n_embd_v_gqa(il)*get_size());
|
|
||||||
|
|
||||||
if (!v_trans) {
|
if (!v_trans) {
|
||||||
ggml_tensor * v_view = ggml_view_3d(ctx, v, v->ne[0], v->ne[1], ns,
|
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
|
||||||
ggml_row_size(v->type, v->ne[0]),
|
|
||||||
size_virt,
|
|
||||||
size_virt*sinfo.s0);
|
|
||||||
|
|
||||||
v_cur = ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], v_cur->ne[1]/ns, ns);
|
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
|
||||||
|
|
||||||
kv_idxs = ggml_reshape_2d(ctx, kv_idxs, n_tokens/ns, ns);
|
|
||||||
|
|
||||||
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// the row becomes a single element
|
// the row becomes a single element
|
||||||
ggml_tensor * v_view = ggml_view_4d(ctx, v, 1, v->ne[1], v->ne[0], ns,
|
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1]*v->ne[2], v->ne[0]);
|
||||||
ggml_row_size(v->type, 1),
|
|
||||||
ggml_row_size(v->type, v->ne[1]),
|
|
||||||
size_virt,
|
|
||||||
size_virt*sinfo.s0);
|
|
||||||
|
|
||||||
// note: the V cache is transposed when not using flash attention
|
// note: the V cache is transposed when not using flash attention
|
||||||
v_cur = ggml_permute(ctx, ggml_reshape_4d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]/ns, ns), 2, 0, 1, 3);
|
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
|
||||||
|
|
||||||
// note: we can be more explicit here at the cost of extra cont
|
// note: we can be more explicit here at the cost of extra cont
|
||||||
// however, above we take advantage that a row of single element is always contiguous regardless of the row stride
|
// however, above we take advantage that a row of single element is always contiguous regardless of the row stride
|
||||||
//v_cur = ggml_reshape_3d(ctx, v_cur, n_embd_v_gqa, v_cur->ne[1]/ns, ns);
|
|
||||||
//v_cur = ggml_transpose(ctx, v_cur);
|
//v_cur = ggml_transpose(ctx, v_cur);
|
||||||
//v_cur = ggml_cont_4d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1], v_cur->ne[2]);
|
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
|
||||||
|
|
||||||
// we broadcast the KV indices n_embd_v_gqa times
|
// we broadcast the KV indices n_embd_v_gqa times
|
||||||
// v [1, n_kv, n_embd_v_gqa, ns]
|
// v [1, n_kv*n_seq_virt, n_embd_v_gqa]
|
||||||
// v_cur [1, n_tokens/ns, n_embd_v_gqa, ns]
|
// v_cur [1, n_tokens, n_embd_v_gqa]
|
||||||
// kv_idxs [n_tokens/ns, 1, ns]
|
// kv_idxs [n_tokens, 1, 1]
|
||||||
|
|
||||||
kv_idxs = ggml_reshape_3d(ctx, kv_idxs, n_tokens/ns, 1, ns);
|
|
||||||
|
|
||||||
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
|
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
|
||||||
}
|
}
|
||||||
@ -1077,8 +1048,10 @@ void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ub
|
|||||||
int64_t * data = (int64_t *) dst->data;
|
int64_t * data = (int64_t *) dst->data;
|
||||||
|
|
||||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
||||||
|
const int64_t offs = sinfo.seq_id_virt[s]*get_size();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
||||||
data[s*sinfo.size() + i] = sinfo.idxs[s][i];
|
data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user