diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index d1a92045d..9e4616e64 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1001,10 +1001,10 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); - const auto n_kv = inp->mctx->get_attn()->get_n_kv(); + const auto n_kv = inp->mctx->get_attn()->get_n_kv(); + const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1; - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask, "KQ_mask", -1); + inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1206,14 +1206,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); + const auto n_kv = mctx_cur->get_n_kv(); const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1; inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens); ggml_set_input(inp->self_kv_idxs); inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs); - //cb(inp->self_kq_mask, "KQ_mask", -1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1440,14 +1439,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif auto inp = std::make_unique(hparams, cparams, mctx_cur); + const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1; + { const auto n_kv = mctx_cur->get_base()->get_n_kv(); inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens); ggml_set_input(inp->self_kv_idxs); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask, "KQ_mask", -1); + inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1461,8 +1461,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens); ggml_set_input(inp->self_kv_idxs_swa); - inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); + inp->self_kq_mask_swa = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;