mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-13 22:39:06 +00:00
graph : separate k and v indices
ggml-ci
This commit is contained in:
@ -281,8 +281,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kv_idxs) {
|
||||
mctx->set_input_kv_idxs(self_kv_idxs, ubatch);
|
||||
if (self_k_idxs) {
|
||||
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
}
|
||||
|
||||
if (self_v_idxs) {
|
||||
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
}
|
||||
|
||||
if (self_kq_mask) {
|
||||
@ -291,12 +295,20 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kv_idxs) {
|
||||
mctx->get_base()->set_input_kv_idxs(self_kv_idxs, ubatch);
|
||||
if (self_k_idxs) {
|
||||
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
}
|
||||
|
||||
if (self_kv_idxs_swa) {
|
||||
mctx->get_swa()->set_input_kv_idxs(self_kv_idxs_swa, ubatch);
|
||||
if (self_v_idxs) {
|
||||
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
}
|
||||
|
||||
if (self_k_idxs_swa) {
|
||||
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
||||
}
|
||||
|
||||
if (self_v_idxs_swa) {
|
||||
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||
}
|
||||
|
||||
if (self_kq_mask) {
|
||||
@ -345,6 +357,14 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
|
||||
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_k_idxs) {
|
||||
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
}
|
||||
|
||||
if (self_v_idxs) {
|
||||
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
}
|
||||
|
||||
if (self_kq_mask) {
|
||||
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
@ -362,7 +382,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_one::set_input(const llama_ubatch *) {
|
||||
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
GGML_ASSERT(one && ggml_nelements(one) == 1);
|
||||
float f_one = 1.0f;
|
||||
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
|
||||
@ -1009,6 +1030,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||
|
||||
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
||||
|
||||
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
@ -1210,11 +1234,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
|
||||
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
|
||||
ggml_set_input(inp->self_kv_idxs);
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
@ -1245,10 +1268,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
const auto & kv_idxs = inp->get_kv_idxs();
|
||||
const auto & k_idxs = inp->get_k_idxs();
|
||||
const auto & v_idxs = inp->get_v_idxs();
|
||||
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
@ -1307,15 +1331,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
// optionally store to KV cache
|
||||
if (k_cur) {
|
||||
const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs();
|
||||
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
|
||||
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, kv_idxs, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||
}
|
||||
|
||||
if (v_cur) {
|
||||
const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa() : inp->get_kv_idxs();
|
||||
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
|
||||
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, kv_idxs, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
||||
@ -1419,8 +1443,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, nullptr, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, nullptr, il));
|
||||
const auto & k_idxs = inp->get_k_idxs();
|
||||
const auto & v_idxs = inp->get_v_idxs();
|
||||
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
@ -1455,11 +1482,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||
{
|
||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||
|
||||
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
|
||||
ggml_set_input(inp->self_kv_idxs);
|
||||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
@ -1470,11 +1496,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
||||
|
||||
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||
|
||||
inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
|
||||
ggml_set_input(inp->self_kv_idxs_swa);
|
||||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
|
@ -248,10 +248,13 @@ public:
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
|
||||
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_kv_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
@ -276,13 +279,18 @@ public:
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * get_kv_idxs() const { return self_kv_idxs; }
|
||||
ggml_tensor * get_kv_idxs_swa() const { return self_kv_idxs_swa; }
|
||||
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
|
||||
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
ggml_tensor * self_kv_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_kv_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
@ -326,8 +334,14 @@ public:
|
||||
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
|
||||
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
|
||||
@ -343,7 +357,7 @@ public:
|
||||
llm_graph_input_one() {}
|
||||
virtual ~llm_graph_input_one() = default;
|
||||
|
||||
void set_input(const llama_ubatch *) override;
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * one = nullptr; // F32
|
||||
};
|
||||
|
@ -808,7 +808,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
|
||||
0);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * k = layers[ikv].k;
|
||||
@ -818,8 +818,8 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
|
||||
|
||||
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
||||
|
||||
if (kv_idxs && supports_set_rows) {
|
||||
return ggml_set_rows(ctx, k, k_cur, kv_idxs);
|
||||
if (k_idxs && supports_set_rows) {
|
||||
return ggml_set_rows(ctx, k, k_cur, k_idxs);
|
||||
}
|
||||
|
||||
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||
@ -832,7 +832,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
|
||||
return ggml_cpy(ctx, k_cur, k_view);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * v = layers[ikv].v;
|
||||
@ -842,9 +842,9 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
||||
|
||||
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
|
||||
|
||||
if (kv_idxs && supports_set_rows) {
|
||||
if (v_idxs && supports_set_rows) {
|
||||
if (!v_trans) {
|
||||
return ggml_set_rows(ctx, v, v_cur, kv_idxs);
|
||||
return ggml_set_rows(ctx, v, v_cur, v_idxs);
|
||||
}
|
||||
|
||||
// the row becomes a single element
|
||||
@ -859,10 +859,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
||||
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
|
||||
|
||||
// we broadcast the KV indices n_embd_v_gqa times
|
||||
// v [1, n_kv, n_embd_v_gqa]
|
||||
// v_cur [1, n_tokens, n_embd_v_gqa]
|
||||
// kv_idxs [n_tokens, 1, 1]
|
||||
return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
|
||||
// v [1, n_kv, n_embd_v_gqa]
|
||||
// v_cur [1, n_tokens, n_embd_v_gqa]
|
||||
// v_idxs [n_tokens, 1, 1]
|
||||
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
|
||||
}
|
||||
|
||||
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||
@ -885,7 +885,42 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
||||
return ggml_cpy(ctx, v_cur, v_view);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||
ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
||||
|
||||
ggml_set_input(k_idxs);
|
||||
|
||||
return k_idxs;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
||||
|
||||
ggml_set_input(v_idxs);
|
||||
|
||||
return v_idxs;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||
if (!supports_set_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
int64_t * data = (int64_t *) dst->data;
|
||||
|
||||
for (int64_t i = 0; i < n_tokens; ++i) {
|
||||
data[i] = sinfo.idxs[i];
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||
if (!supports_set_rows) {
|
||||
return;
|
||||
}
|
||||
@ -1906,20 +1941,32 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
|
||||
return kv->get_v(ctx, il, n_kv);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const {
|
||||
return kv->cpy_k(ctx, k_cur, kv_idxs, il, sinfos[i_cur]);
|
||||
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
||||
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const {
|
||||
return kv->cpy_v(ctx, v_cur, kv_idxs, il, sinfos[i_cur]);
|
||||
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
|
||||
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
return kv->build_input_k_idxs(ctx, ubatch);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
return kv->build_input_v_idxs(ctx, ubatch);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||
kv->set_input_k_shift(dst);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_kv_idxs(dst, ubatch, sinfos[i_cur]);
|
||||
void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
||||
|
@ -124,8 +124,8 @@ public:
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
|
||||
//
|
||||
// preparation API
|
||||
@ -146,10 +146,15 @@ public:
|
||||
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
|
||||
|
||||
//
|
||||
// set_input API
|
||||
// input API
|
||||
//
|
||||
|
||||
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
|
||||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
@ -286,12 +291,16 @@ public:
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * kv_idxs, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * kv_idxs, int32_t il) const;
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
|
||||
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
|
||||
void set_input_kv_idxs (ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
|
Reference in New Issue
Block a user