From d8c929ff5d8879c3faaf0a7aac9b8955bb9dfa26 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 11 Jun 2025 13:41:52 -0600 Subject: [PATCH] feat: Allow custom layer filters for hybrid recurrent This should help support architectures like Falcon H1 where there is overlap between layers that need attention and recurrent caches. https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140748922 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 41 +++++++++++++++---------- src/llama-kv-cache-hybrid-recurrent.h | 37 +++++++++++++--------- 2 files changed, 46 insertions(+), 32 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index 8871dbf63..889f43025 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -10,25 +10,30 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( const llama_model & model, - /* attn */ - ggml_type attn_type_k, - ggml_type attn_type_v, - bool attn_v_trans, - uint32_t attn_kv_size, - uint32_t attn_n_pad, - uint32_t attn_n_swa, - llama_swa_type attn_swa_type, - /* recurrent */ - ggml_type recurrent_type_k, - ggml_type recurrent_type_v, - uint32_t recurrent_kv_size, - /* common */ - uint32_t n_seq_max, - bool offload) : + /* attn */ + ggml_type attn_type_k, + ggml_type attn_type_v, + bool attn_v_trans, + uint32_t attn_kv_size, + uint32_t attn_n_pad, + uint32_t attn_n_swa, + llama_swa_type attn_swa_type, + /* recurrent */ + ggml_type recurrent_type_k, + ggml_type recurrent_type_v, + uint32_t recurrent_kv_size, + /* common */ + uint32_t n_seq_max, + bool offload, + /* layer filters */ + layer_filter_cb && attn_filter, + layer_filter_cb && recurrent_filter) : hparams(model.hparams), kv_attn(new llama_kv_cache_unified( model, - [&](int32_t il) { return !model.hparams.recurrent_layer(il); }, + attn_filter == nullptr ? + [&](int32_t il) { return !model.hparams.recurrent_layer(il); } + : attn_filter, attn_type_k, attn_type_v, attn_v_trans, @@ -41,7 +46,9 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( )), kv_recurrent(new llama_kv_cache_recurrent( model, - [&](int32_t il) { return model.hparams.recurrent_layer(il); }, + recurrent_filter == nullptr ? + [&](int32_t il) { return model.hparams.recurrent_layer(il); } + : recurrent_filter, recurrent_type_k, recurrent_type_v, offload, diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index 8728fd733..444e87e10 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -19,23 +19,30 @@ class llama_kv_cache_hybrid_recurrent : public llama_memory_i { public: + + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + llama_kv_cache_hybrid_recurrent( const llama_model & model, - /* attn */ - ggml_type attn_type_k, - ggml_type attn_type_v, - bool attn_v_trans, - uint32_t attn_kv_size, - uint32_t attn_n_pad, - uint32_t attn_n_swa, - llama_swa_type attn_swa_type, - /* recurrent */ - ggml_type recurrent_type_k, - ggml_type recurrent_type_v, - uint32_t recurrent_kv_size, - /* common */ - uint32_t n_seq_max, - bool offload); + /* attn */ + ggml_type attn_type_k, + ggml_type attn_type_v, + bool attn_v_trans, + uint32_t attn_kv_size, + uint32_t attn_n_pad, + uint32_t attn_n_swa, + llama_swa_type attn_swa_type, + /* recurrent */ + ggml_type recurrent_type_k, + ggml_type recurrent_type_v, + uint32_t recurrent_kv_size, + /* common */ + uint32_t n_seq_max, + bool offload, + /* layer filters */ + layer_filter_cb && attn_filter = nullptr, + layer_filter_cb && recurrent_filter = nullptr); ~llama_kv_cache_hybrid_recurrent() = default;