diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index d83e02811..bb3e6c17c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -404,13 +404,6 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } -llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent( - const llama_hparams & hparams, - const llama_cparams & cparams, - const llama_kv_cache_hybrid_recurrent_state * kv_state) : - llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) { -} - // // llm_graph_context // @@ -1269,7 +1262,9 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto * kv_state = static_cast(mstate); + // NOTE: For hybrid caches, this may be a child of mstate, so we use the one + // encapsulated in inp + const auto * kv_state = inp->kv_state; // store to KV cache { @@ -1301,10 +1296,10 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { +llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, cparams, kv_state); + auto inp = std::make_unique(hparams, cparams, kv_state->get_state_attn()); { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); @@ -1318,25 +1313,7 @@ llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } - return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); -} - -ggml_tensor * llm_graph_context::build_attn( - llm_graph_input_attn_kv_hybrid_recurrent * inp, - ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - ggml_tensor * kq_b, - ggml_tensor * v_mla, - float kq_scale, - int il) const { - return build_attn( - static_cast(inp), - gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il - ); + return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); } llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { @@ -1479,13 +1456,17 @@ ggml_tensor * llm_graph_context::build_attn( } ggml_tensor * llm_graph_context::build_recurrent_state( - ggml_cgraph * gf, - ggml_tensor * s, - ggml_tensor * state_copy, - int32_t state_size, - int32_t n_seqs, - bool avoid_copies) const { - const auto * kv_state = static_cast(mstate); + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies, + const llama_kv_cache_recurrent_state * kv_state) const { + + if (kv_state == nullptr) { + kv_state = static_cast(mstate); + } const auto n_kv = kv_state->get_n_kv(); const auto kv_head = kv_state->get_head(); diff --git a/src/llama-graph.h b/src/llama-graph.h index 5abdfde24..5f5846ab7 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -286,16 +286,6 @@ public: const llama_kv_cache_unified_iswa_state * kv_state; }; -class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified { -public: - llm_graph_input_attn_kv_hybrid_recurrent( - const llama_hparams & hparams, - const llama_cparams & cparams, - const llama_kv_cache_hybrid_recurrent_state * kv_state); - - virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default; -}; - class llm_graph_input_attn_cross : public llm_graph_input_i { public: llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} @@ -585,20 +575,7 @@ struct llm_graph_context { float kq_scale, int il) const; - llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const; - - ggml_tensor * build_attn( - llm_graph_input_attn_kv_hybrid_recurrent * inp, - ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] - ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] - ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] - ggml_tensor * kq_b, - ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] - float kq_scale, - int il) const; + llm_graph_input_attn_kv_unified * build_attn_inp_kv_hybrid_recurrent() const; llm_graph_input_attn_cross * build_attn_inp_cross() const; @@ -620,12 +597,13 @@ struct llm_graph_context { // ggml_tensor * build_recurrent_state( - ggml_cgraph * gf, - ggml_tensor * s, - ggml_tensor * state_copy, - int32_t state_size, - int32_t n_seqs, - bool avoid_copies = false) const; + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies = false, + const llama_kv_cache_recurrent_state * kv_state = nullptr) const; ggml_tensor * build_rwkv_token_shift_load( ggml_cgraph * gf,