diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 337fb5cb0..d83e02811 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -7,6 +7,7 @@ #include "llama-kv-cache-unified.h" #include "llama-kv-cache-unified-iswa.h" #include "llama-kv-cache-recurrent.h" +#include "llama-kv-cache-hybrid-recurrent.h" #include #include @@ -403,6 +404,13 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_hybrid_recurrent_state * kv_state) : + llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) { +} + // // llm_graph_context // @@ -961,8 +969,10 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { return cur; } -ggml_tensor * llm_graph_context::build_inp_s_copy() const { - const auto * kv_state = static_cast(mstate); +ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const { + if (kv_state == nullptr) { + kv_state = static_cast(mstate); + } auto inp = std::make_unique(kv_state); @@ -1291,6 +1301,44 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { + const auto * kv_state = static_cast(mstate); + + auto inp = std::make_unique(hparams, cparams, kv_state); + + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); + + const auto n_kv = kv_state->get_state_attn()->get_n_kv(); + + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_kv_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + return build_attn( + static_cast(inp), + gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il + ); +} + llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { const auto * kv_state = static_cast(mstate); diff --git a/src/llama-graph.h b/src/llama-graph.h index 87813119b..5abdfde24 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -22,6 +22,7 @@ struct llama_memory_state_i; class llama_kv_cache_unified_state; class llama_kv_cache_unified_iswa_state; class llama_kv_cache_recurrent_state; +class llama_kv_cache_hybrid_recurrent_state; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -242,7 +243,7 @@ public: cparams(cparams), kv_state(kv_state) { } - ~llm_graph_input_attn_kv_unified() = default; + virtual ~llm_graph_input_attn_kv_unified() = default; void set_input(const llama_ubatch * ubatch) override; @@ -285,6 +286,16 @@ public: const llama_kv_cache_unified_iswa_state * kv_state; }; +class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified { +public: + llm_graph_input_attn_kv_hybrid_recurrent( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_hybrid_recurrent_state * kv_state); + + virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default; +}; + class llm_graph_input_attn_cross : public llm_graph_input_i { public: llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} @@ -508,7 +519,7 @@ struct llm_graph_context { ggml_tensor * build_inp_out_ids() const; ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; - ggml_tensor * build_inp_s_copy() const; + ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const; ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; @@ -574,6 +585,21 @@ struct llm_graph_context { float kq_scale, int il) const; + llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_kv_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; + llm_graph_input_attn_cross * build_attn_inp_cross() const; ggml_tensor * build_attn(