feat: Support hybrid recurrent in llama-graph

NOTE: I intentionally did not add support for s_mask since it will be going
away soon

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-06-04 08:47:55 -06:00
parent cf03d4ae5c
commit e3c1631556
2 changed files with 78 additions and 4 deletions

View File

@@ -7,6 +7,7 @@
#include "llama-kv-cache-unified.h"
#include "llama-kv-cache-unified-iswa.h"
#include "llama-kv-cache-recurrent.h"
#include "llama-kv-cache-hybrid-recurrent.h"
#include <cassert>
#include <cmath>
@@ -403,6 +404,13 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
}
}
llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) {
}
//
// llm_graph_context
//
@@ -961,8 +969,10 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
return cur;
}
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const {
if (kv_state == nullptr) {
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
}
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
@@ -1291,6 +1301,44 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(hparams, cparams, kv_state);
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
const auto n_kv = kv_state->get_state_attn()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_hybrid_recurrent * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
return build_attn(
static_cast<llm_graph_input_attn_kv_unified *>(inp),
gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il
);
}
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);

View File

@@ -22,6 +22,7 @@ struct llama_memory_state_i;
class llama_kv_cache_unified_state;
class llama_kv_cache_unified_iswa_state;
class llama_kv_cache_recurrent_state;
class llama_kv_cache_hybrid_recurrent_state;
// certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type {
@@ -242,7 +243,7 @@ public:
cparams(cparams),
kv_state(kv_state) {
}
~llm_graph_input_attn_kv_unified() = default;
virtual ~llm_graph_input_attn_kv_unified() = default;
void set_input(const llama_ubatch * ubatch) override;
@@ -285,6 +286,16 @@ public:
const llama_kv_cache_unified_iswa_state * kv_state;
};
class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified {
public:
llm_graph_input_attn_kv_hybrid_recurrent(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_hybrid_recurrent_state * kv_state);
virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default;
};
class llm_graph_input_attn_cross : public llm_graph_input_i {
public:
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
@@ -508,7 +519,7 @@ struct llm_graph_context {
ggml_tensor * build_inp_out_ids() const;
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const;
ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const;
ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const;
@@ -574,6 +585,21 @@ struct llm_graph_context {
float kq_scale,
int il) const;
llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const;
ggml_tensor * build_attn(
llm_graph_input_attn_kv_hybrid_recurrent * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
llm_graph_input_attn_cross * build_attn_inp_cross() const;
ggml_tensor * build_attn(