feat: Allow custom layer filters for hybrid recurrent

This should help support architectures like Falcon H1 where there is
overlap between layers that need attention and recurrent caches.

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140748922

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-06-11 13:41:52 -06:00
parent d5d7628b5f
commit d8c929ff5d
2 changed files with 46 additions and 32 deletions

View File

@@ -10,25 +10,30 @@
llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
const llama_model & model, const llama_model & model,
/* attn */ /* attn */
ggml_type attn_type_k, ggml_type attn_type_k,
ggml_type attn_type_v, ggml_type attn_type_v,
bool attn_v_trans, bool attn_v_trans,
uint32_t attn_kv_size, uint32_t attn_kv_size,
uint32_t attn_n_pad, uint32_t attn_n_pad,
uint32_t attn_n_swa, uint32_t attn_n_swa,
llama_swa_type attn_swa_type, llama_swa_type attn_swa_type,
/* recurrent */ /* recurrent */
ggml_type recurrent_type_k, ggml_type recurrent_type_k,
ggml_type recurrent_type_v, ggml_type recurrent_type_v,
uint32_t recurrent_kv_size, uint32_t recurrent_kv_size,
/* common */ /* common */
uint32_t n_seq_max, uint32_t n_seq_max,
bool offload) : bool offload,
/* layer filters */
layer_filter_cb && attn_filter,
layer_filter_cb && recurrent_filter) :
hparams(model.hparams), hparams(model.hparams),
kv_attn(new llama_kv_cache_unified( kv_attn(new llama_kv_cache_unified(
model, model,
[&](int32_t il) { return !model.hparams.recurrent_layer(il); }, attn_filter == nullptr ?
[&](int32_t il) { return !model.hparams.recurrent_layer(il); }
: attn_filter,
attn_type_k, attn_type_k,
attn_type_v, attn_type_v,
attn_v_trans, attn_v_trans,
@@ -41,7 +46,9 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
)), )),
kv_recurrent(new llama_kv_cache_recurrent( kv_recurrent(new llama_kv_cache_recurrent(
model, model,
[&](int32_t il) { return model.hparams.recurrent_layer(il); }, recurrent_filter == nullptr ?
[&](int32_t il) { return model.hparams.recurrent_layer(il); }
: recurrent_filter,
recurrent_type_k, recurrent_type_k,
recurrent_type_v, recurrent_type_v,
offload, offload,

View File

@@ -19,23 +19,30 @@
class llama_kv_cache_hybrid_recurrent : public llama_memory_i { class llama_kv_cache_hybrid_recurrent : public llama_memory_i {
public: public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_kv_cache_hybrid_recurrent( llama_kv_cache_hybrid_recurrent(
const llama_model & model, const llama_model & model,
/* attn */ /* attn */
ggml_type attn_type_k, ggml_type attn_type_k,
ggml_type attn_type_v, ggml_type attn_type_v,
bool attn_v_trans, bool attn_v_trans,
uint32_t attn_kv_size, uint32_t attn_kv_size,
uint32_t attn_n_pad, uint32_t attn_n_pad,
uint32_t attn_n_swa, uint32_t attn_n_swa,
llama_swa_type attn_swa_type, llama_swa_type attn_swa_type,
/* recurrent */ /* recurrent */
ggml_type recurrent_type_k, ggml_type recurrent_type_k,
ggml_type recurrent_type_v, ggml_type recurrent_type_v,
uint32_t recurrent_kv_size, uint32_t recurrent_kv_size,
/* common */ /* common */
uint32_t n_seq_max, uint32_t n_seq_max,
bool offload); bool offload,
/* layer filters */
layer_filter_cb && attn_filter = nullptr,
layer_filter_cb && recurrent_filter = nullptr);
~llama_kv_cache_hybrid_recurrent() = default; ~llama_kv_cache_hybrid_recurrent() = default;