feat: Allow custom layer filters for hybrid recurrent

This should help support architectures like Falcon H1 where there is
overlap between layers that need attention and recurrent caches.

https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140748922

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-06-11 13:41:52 -06:00
parent d5d7628b5f
commit d8c929ff5d
2 changed files with 46 additions and 32 deletions

View File

@@ -10,25 +10,30 @@
llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
const llama_model & model,
/* attn */
ggml_type attn_type_k,
ggml_type attn_type_v,
bool attn_v_trans,
uint32_t attn_kv_size,
uint32_t attn_n_pad,
uint32_t attn_n_swa,
llama_swa_type attn_swa_type,
/* recurrent */
ggml_type recurrent_type_k,
ggml_type recurrent_type_v,
uint32_t recurrent_kv_size,
/* common */
uint32_t n_seq_max,
bool offload) :
/* attn */
ggml_type attn_type_k,
ggml_type attn_type_v,
bool attn_v_trans,
uint32_t attn_kv_size,
uint32_t attn_n_pad,
uint32_t attn_n_swa,
llama_swa_type attn_swa_type,
/* recurrent */
ggml_type recurrent_type_k,
ggml_type recurrent_type_v,
uint32_t recurrent_kv_size,
/* common */
uint32_t n_seq_max,
bool offload,
/* layer filters */
layer_filter_cb && attn_filter,
layer_filter_cb && recurrent_filter) :
hparams(model.hparams),
kv_attn(new llama_kv_cache_unified(
model,
[&](int32_t il) { return !model.hparams.recurrent_layer(il); },
attn_filter == nullptr ?
[&](int32_t il) { return !model.hparams.recurrent_layer(il); }
: attn_filter,
attn_type_k,
attn_type_v,
attn_v_trans,
@@ -41,7 +46,9 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
)),
kv_recurrent(new llama_kv_cache_recurrent(
model,
[&](int32_t il) { return model.hparams.recurrent_layer(il); },
recurrent_filter == nullptr ?
[&](int32_t il) { return model.hparams.recurrent_layer(il); }
: recurrent_filter,
recurrent_type_k,
recurrent_type_v,
offload,

View File

@@ -19,23 +19,30 @@
class llama_kv_cache_hybrid_recurrent : public llama_memory_i {
public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_kv_cache_hybrid_recurrent(
const llama_model & model,
/* attn */
ggml_type attn_type_k,
ggml_type attn_type_v,
bool attn_v_trans,
uint32_t attn_kv_size,
uint32_t attn_n_pad,
uint32_t attn_n_swa,
llama_swa_type attn_swa_type,
/* recurrent */
ggml_type recurrent_type_k,
ggml_type recurrent_type_v,
uint32_t recurrent_kv_size,
/* common */
uint32_t n_seq_max,
bool offload);
/* attn */
ggml_type attn_type_k,
ggml_type attn_type_v,
bool attn_v_trans,
uint32_t attn_kv_size,
uint32_t attn_n_pad,
uint32_t attn_n_swa,
llama_swa_type attn_swa_type,
/* recurrent */
ggml_type recurrent_type_k,
ggml_type recurrent_type_v,
uint32_t recurrent_kv_size,
/* common */
uint32_t n_seq_max,
bool offload,
/* layer filters */
layer_filter_cb && attn_filter = nullptr,
layer_filter_cb && recurrent_filter = nullptr);
~llama_kv_cache_hybrid_recurrent() = default;