mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-08 18:04:54 -04:00
feat: Allow custom layer filters for hybrid recurrent
This should help support architectures like Falcon H1 where there is overlap between layers that need attention and recurrent caches. https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140748922 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
@@ -24,11 +24,16 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
|
|||||||
uint32_t recurrent_kv_size,
|
uint32_t recurrent_kv_size,
|
||||||
/* common */
|
/* common */
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
bool offload) :
|
bool offload,
|
||||||
|
/* layer filters */
|
||||||
|
layer_filter_cb && attn_filter,
|
||||||
|
layer_filter_cb && recurrent_filter) :
|
||||||
hparams(model.hparams),
|
hparams(model.hparams),
|
||||||
kv_attn(new llama_kv_cache_unified(
|
kv_attn(new llama_kv_cache_unified(
|
||||||
model,
|
model,
|
||||||
[&](int32_t il) { return !model.hparams.recurrent_layer(il); },
|
attn_filter == nullptr ?
|
||||||
|
[&](int32_t il) { return !model.hparams.recurrent_layer(il); }
|
||||||
|
: attn_filter,
|
||||||
attn_type_k,
|
attn_type_k,
|
||||||
attn_type_v,
|
attn_type_v,
|
||||||
attn_v_trans,
|
attn_v_trans,
|
||||||
@@ -41,7 +46,9 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
|
|||||||
)),
|
)),
|
||||||
kv_recurrent(new llama_kv_cache_recurrent(
|
kv_recurrent(new llama_kv_cache_recurrent(
|
||||||
model,
|
model,
|
||||||
[&](int32_t il) { return model.hparams.recurrent_layer(il); },
|
recurrent_filter == nullptr ?
|
||||||
|
[&](int32_t il) { return model.hparams.recurrent_layer(il); }
|
||||||
|
: recurrent_filter,
|
||||||
recurrent_type_k,
|
recurrent_type_k,
|
||||||
recurrent_type_v,
|
recurrent_type_v,
|
||||||
offload,
|
offload,
|
||||||
|
@@ -19,6 +19,10 @@
|
|||||||
|
|
||||||
class llama_kv_cache_hybrid_recurrent : public llama_memory_i {
|
class llama_kv_cache_hybrid_recurrent : public llama_memory_i {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
|
// this callback is used to filter out layers that should not be included in the cache
|
||||||
|
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||||
|
|
||||||
llama_kv_cache_hybrid_recurrent(
|
llama_kv_cache_hybrid_recurrent(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
/* attn */
|
/* attn */
|
||||||
@@ -35,7 +39,10 @@ public:
|
|||||||
uint32_t recurrent_kv_size,
|
uint32_t recurrent_kv_size,
|
||||||
/* common */
|
/* common */
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
bool offload);
|
bool offload,
|
||||||
|
/* layer filters */
|
||||||
|
layer_filter_cb && attn_filter = nullptr,
|
||||||
|
layer_filter_cb && recurrent_filter = nullptr);
|
||||||
|
|
||||||
~llama_kv_cache_hybrid_recurrent() = default;
|
~llama_kv_cache_hybrid_recurrent() = default;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user