diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 25ab044dc..7131855f3 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -281,43 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { - if (self_k_idxs) { - mctx->set_input_k_idxs(self_k_idxs, ubatch); - } + mctx->set_input_k_idxs(self_k_idxs, ubatch); + mctx->set_input_v_idxs(self_v_idxs, ubatch); - if (self_v_idxs) { - mctx->set_input_v_idxs(self_v_idxs, ubatch); - } - - if (self_kq_mask) { - mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); - } + mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { - if (self_k_idxs) { - mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); - } + mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); + mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - if (self_v_idxs) { - mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - } + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); - if (self_k_idxs_swa) { - mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); - } + mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); + mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); - if (self_v_idxs_swa) { - mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); - } - - if (self_kq_mask) { - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); - } - - if (self_kq_mask_swa) { - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); - } + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { @@ -357,17 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { - if (self_k_idxs) { - mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch); - } + mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch); + mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch); - if (self_v_idxs) { - mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch); - } - - if (self_kq_mask) { - mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); - } + mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); const int64_t n_rs = mctx->get_recr()->get_n_rs();