mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-30 04:45:17 +00:00
kv-cache : use ggml_set_rows
ggml-ci
This commit is contained in:
@ -281,12 +281,24 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||||
|
if (self_kv_idxs) {
|
||||||
|
mctx->set_input_kv_idxs(self_kv_idxs, ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
if (self_kq_mask) {
|
if (self_kq_mask) {
|
||||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||||
|
if (self_kv_idxs) {
|
||||||
|
mctx->get_base()->set_input_kv_idxs(self_kv_idxs, ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (self_kv_idxs_swa) {
|
||||||
|
mctx->get_swa()->set_input_kv_idxs(self_kv_idxs_swa, ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
if (self_kq_mask) {
|
if (self_kq_mask) {
|
||||||
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
}
|
}
|
||||||
@ -1192,6 +1204,9 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|||||||
|
|
||||||
const auto n_kv = mctx_cur->get_n_kv();
|
const auto n_kv = mctx_cur->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
|
||||||
|
ggml_set_input(inp->self_kv_idxs);
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
@ -1224,8 +1239,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
const auto & kv_idxs = inp->get_kv_idxs();
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
@ -1278,8 +1295,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs();
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
|
||||||
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||||
@ -1383,8 +1402,8 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||||||
|
|
||||||
// store to KV cache
|
// store to KV cache
|
||||||
{
|
{
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, nullptr, il));
|
||||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, nullptr, il));
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
@ -1419,6 +1438,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||||||
{
|
{
|
||||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
|
||||||
|
ggml_set_input(inp->self_kv_idxs);
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
@ -1431,6 +1453,9 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||||||
|
|
||||||
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||||
|
|
||||||
|
inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
|
||||||
|
ggml_set_input(inp->self_kv_idxs_swa);
|
||||||
|
|
||||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||||
ggml_set_input(inp->self_kq_mask_swa);
|
ggml_set_input(inp->self_kq_mask_swa);
|
||||||
|
@ -248,8 +248,12 @@ public:
|
|||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
|
||||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
|
|
||||||
|
// TODO: should this be I64?
|
||||||
|
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
|
||||||
|
|
||||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||||
|
|
||||||
@ -273,9 +277,14 @@ public:
|
|||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
|
||||||
|
ggml_tensor * get_kv_idxs_swa() const { return self_kv_idxs_swa; }
|
||||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||||
|
|
||||||
|
ggml_tensor * self_kv_idxs = nullptr; // I32 [n_batch]
|
||||||
|
ggml_tensor * self_kv_idxs_swa = nullptr; // I32 [n_batch]
|
||||||
|
|
||||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
|
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
|
||||||
|
@ -746,13 +746,17 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
|
|||||||
0);
|
0);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
|
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
|
||||||
const int32_t ikv = map_layer_ids.at(il);
|
const int32_t ikv = map_layer_ids.at(il);
|
||||||
|
|
||||||
auto * k = layers[ikv].k;
|
auto * k = layers[ikv].k;
|
||||||
|
|
||||||
const int64_t n_tokens = k_cur->ne[2];
|
const int64_t n_tokens = k_cur->ne[2];
|
||||||
|
|
||||||
|
if (kv_idxs) {
|
||||||
|
return ggml_set_rows(ctx, k, ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens), kv_idxs);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
||||||
n_tokens*hparams.n_embd_k_gqa(il),
|
n_tokens*hparams.n_embd_k_gqa(il),
|
||||||
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
|
ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
|
||||||
@ -760,7 +764,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
|
|||||||
return ggml_cpy(ctx, k_cur, k_view);
|
return ggml_cpy(ctx, k_cur, k_view);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
|
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const {
|
||||||
const int32_t ikv = map_layer_ids.at(il);
|
const int32_t ikv = map_layer_ids.at(il);
|
||||||
|
|
||||||
auto * v = layers[ikv].v;
|
auto * v = layers[ikv].v;
|
||||||
@ -772,21 +776,48 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|||||||
ggml_tensor * v_view = nullptr;
|
ggml_tensor * v_view = nullptr;
|
||||||
|
|
||||||
if (!v_trans) {
|
if (!v_trans) {
|
||||||
|
if (kv_idxs) {
|
||||||
|
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
|
||||||
|
}
|
||||||
|
|
||||||
v_view = ggml_view_1d(ctx, v,
|
v_view = ggml_view_1d(ctx, v,
|
||||||
n_tokens*hparams.n_embd_v_gqa(il),
|
n_tokens*hparams.n_embd_v_gqa(il),
|
||||||
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
|
ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
|
||||||
} else {
|
} else {
|
||||||
|
v_cur = ggml_transpose(ctx, v_cur);
|
||||||
|
|
||||||
// note: the V cache is transposed when not using flash attention
|
// note: the V cache is transposed when not using flash attention
|
||||||
|
if (kv_idxs) {
|
||||||
|
// the row becomes a single element and we repeat the KV indices d_head times
|
||||||
|
// TODO: this seems not very optimal - can we do something better?
|
||||||
|
v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
|
||||||
|
|
||||||
|
v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
|
||||||
|
|
||||||
|
kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1);
|
||||||
|
|
||||||
|
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
|
||||||
|
}
|
||||||
|
|
||||||
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
|
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
|
||||||
(v->ne[1])*ggml_element_size(v),
|
(v->ne[1])*ggml_element_size(v),
|
||||||
(head_cur)*ggml_element_size(v));
|
(head_cur)*ggml_element_size(v));
|
||||||
|
|
||||||
v_cur = ggml_transpose(ctx, v_cur);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ggml_cpy(ctx, v_cur, v_view);
|
return ggml_cpy(ctx, v_cur, v_view);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
|
||||||
|
const uint32_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||||
|
int64_t * data = (int64_t *) dst->data;
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < n_tokens; ++i) {
|
||||||
|
data[i] = head_cur + i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||||
const uint32_t n_tokens = ubatch->n_tokens;
|
const uint32_t n_tokens = ubatch->n_tokens;
|
||||||
|
|
||||||
@ -1789,18 +1820,22 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
|
|||||||
return kv->get_v(ctx, il, n_kv);
|
return kv->get_v(ctx, il, n_kv);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
|
||||||
return kv->cpy_k(ctx, k_cur, il, head);
|
return kv->cpy_k(ctx, k_cur, kv_idxs, il, head);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
|
||||||
return kv->cpy_v(ctx, v_cur, il, head);
|
return kv->cpy_v(ctx, v_cur, kv_idxs, il, head);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||||
kv->set_input_k_shift(dst);
|
kv->set_input_k_shift(dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified_context::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||||
|
kv->set_input_kv_idxs(dst, ubatch, head);
|
||||||
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||||
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
||||||
}
|
}
|
||||||
|
@ -102,8 +102,8 @@ public:
|
|||||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||||
|
|
||||||
// store k_cur and v_cur in the cache based on the provided head location
|
// store k_cur and v_cur in the cache based on the provided head location
|
||||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
|
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const;
|
||||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
|
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, uint32_t head_cur) const;
|
||||||
|
|
||||||
//
|
//
|
||||||
// preparation API
|
// preparation API
|
||||||
@ -126,6 +126,7 @@ public:
|
|||||||
// set_input API
|
// set_input API
|
||||||
//
|
//
|
||||||
|
|
||||||
|
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const;
|
||||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||||
void set_input_k_shift (ggml_tensor * dst) const;
|
void set_input_k_shift (ggml_tensor * dst) const;
|
||||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
@ -257,11 +258,12 @@ public:
|
|||||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||||
|
|
||||||
// store k_cur and v_cur in the cache based on the provided head location
|
// store k_cur and v_cur in the cache based on the provided head location
|
||||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const;
|
||||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const;
|
||||||
|
|
||||||
void set_input_k_shift(ggml_tensor * dst) const;
|
void set_input_k_shift(ggml_tensor * dst) const;
|
||||||
|
|
||||||
|
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user