cont : remove redundant ifs

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-30 14:44:09 +03:00
parent 3d930a9e4f
commit 82277da415

View File

@ -281,43 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
}
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_k_idxs) {
mctx->set_input_k_idxs(self_k_idxs, ubatch);
}
mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_v_idxs(self_v_idxs, ubatch);
if (self_v_idxs) {
mctx->set_input_v_idxs(self_v_idxs, ubatch);
}
if (self_kq_mask) {
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_k_idxs) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
}
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
if (self_v_idxs) {
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
}
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
if (self_k_idxs_swa) {
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
}
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
if (self_v_idxs_swa) {
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
}
if (self_kq_mask) {
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
if (self_kq_mask_swa) {
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@ -357,17 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
}
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
if (self_k_idxs) {
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
}
mctx->get_attn()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
if (self_v_idxs) {
mctx->get_attn()->set_input_v_idxs(self_v_idxs, ubatch);
}
if (self_kq_mask) {
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
const int64_t n_rs = mctx->get_recr()->get_n_rs();