mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-05 21:43:45 +00:00
graph : support iSWA virtual sequences
ggml-ci
This commit is contained in:
@ -1002,9 +1002,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|||||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
||||||
|
|
||||||
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
||||||
|
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
@ -1213,7 +1213,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|||||||
ggml_set_input(inp->self_kv_idxs);
|
ggml_set_input(inp->self_kv_idxs);
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
@ -1440,14 +1439,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||||
|
|
||||||
|
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||||
|
|
||||||
{
|
{
|
||||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||||
|
|
||||||
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
|
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
|
||||||
ggml_set_input(inp->self_kv_idxs);
|
ggml_set_input(inp->self_kv_idxs);
|
||||||
|
|
||||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
||||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask);
|
ggml_set_input(inp->self_kq_mask);
|
||||||
|
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
@ -1461,8 +1461,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||||||
inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
|
inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
|
||||||
ggml_set_input(inp->self_kv_idxs_swa);
|
ggml_set_input(inp->self_kv_idxs_swa);
|
||||||
|
|
||||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
inp->self_kq_mask_swa = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
|
||||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
||||||
ggml_set_input(inp->self_kq_mask_swa);
|
ggml_set_input(inp->self_kq_mask_swa);
|
||||||
|
|
||||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||||
|
Reference in New Issue
Block a user