fix: Fix logic for initializing inputs and attn layers for hybrid caches

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-06-04 15:02:14 -06:00
parent e3c1631556
commit a9b5fe98ad
2 changed files with 25 additions and 66 deletions

View File

@@ -404,13 +404,6 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
} }
} }
llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) {
}
// //
// llm_graph_context // llm_graph_context
// //
@@ -1269,7 +1262,9 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur); ggml_build_forward_expand(gf, v_cur);
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate); // NOTE: For hybrid caches, this may be a child of mstate, so we use the one
// encapsulated in inp
const auto * kv_state = inp->kv_state;
// store to KV cache // store to KV cache
{ {
@@ -1301,10 +1296,10 @@ ggml_tensor * llm_graph_context::build_attn(
return cur; return cur;
} }
llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate); const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(hparams, cparams, kv_state); auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state->get_state_attn());
{ {
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
@@ -1318,25 +1313,7 @@ llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
} }
return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_hybrid_recurrent * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
return build_attn(
static_cast<llm_graph_input_attn_kv_unified *>(inp),
gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il
);
} }
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
@@ -1484,8 +1461,12 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
ggml_tensor * state_copy, ggml_tensor * state_copy,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies) const { bool avoid_copies,
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate); const llama_kv_cache_recurrent_state * kv_state) const {
if (kv_state == nullptr) {
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
}
const auto n_kv = kv_state->get_n_kv(); const auto n_kv = kv_state->get_n_kv();
const auto kv_head = kv_state->get_head(); const auto kv_head = kv_state->get_head();

View File

@@ -286,16 +286,6 @@ public:
const llama_kv_cache_unified_iswa_state * kv_state; const llama_kv_cache_unified_iswa_state * kv_state;
}; };
class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified {
public:
llm_graph_input_attn_kv_hybrid_recurrent(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_hybrid_recurrent_state * kv_state);
virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default;
};
class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_attn_cross : public llm_graph_input_i {
public: public:
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
@@ -585,20 +575,7 @@ struct llm_graph_context {
float kq_scale, float kq_scale,
int il) const; int il) const;
llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const; llm_graph_input_attn_kv_unified * build_attn_inp_kv_hybrid_recurrent() const;
ggml_tensor * build_attn(
llm_graph_input_attn_kv_hybrid_recurrent * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
llm_graph_input_attn_cross * build_attn_inp_cross() const; llm_graph_input_attn_cross * build_attn_inp_cross() const;
@@ -625,7 +602,8 @@ struct llm_graph_context {
ggml_tensor * state_copy, ggml_tensor * state_copy,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies = false) const; bool avoid_copies = false,
const llama_kv_cache_recurrent_state * kv_state = nullptr) const;
ggml_tensor * build_rwkv_token_shift_load( ggml_tensor * build_rwkv_token_shift_load(
ggml_cgraph * gf, ggml_cgraph * gf,