mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-06 09:10:11 -04:00
feat: Allow custom layer filters for hybrid recurrent
This should help support architectures like Falcon H1 where there is overlap between layers that need attention and recurrent caches. https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140748922 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
@@ -10,25 +10,30 @@
|
||||
|
||||
llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type attn_type_k,
|
||||
ggml_type attn_type_v,
|
||||
bool attn_v_trans,
|
||||
uint32_t attn_kv_size,
|
||||
uint32_t attn_n_pad,
|
||||
uint32_t attn_n_swa,
|
||||
llama_swa_type attn_swa_type,
|
||||
/* recurrent */
|
||||
ggml_type recurrent_type_k,
|
||||
ggml_type recurrent_type_v,
|
||||
uint32_t recurrent_kv_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload) :
|
||||
/* attn */
|
||||
ggml_type attn_type_k,
|
||||
ggml_type attn_type_v,
|
||||
bool attn_v_trans,
|
||||
uint32_t attn_kv_size,
|
||||
uint32_t attn_n_pad,
|
||||
uint32_t attn_n_swa,
|
||||
llama_swa_type attn_swa_type,
|
||||
/* recurrent */
|
||||
ggml_type recurrent_type_k,
|
||||
ggml_type recurrent_type_v,
|
||||
uint32_t recurrent_kv_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
/* layer filters */
|
||||
layer_filter_cb && attn_filter,
|
||||
layer_filter_cb && recurrent_filter) :
|
||||
hparams(model.hparams),
|
||||
kv_attn(new llama_kv_cache_unified(
|
||||
model,
|
||||
[&](int32_t il) { return !model.hparams.recurrent_layer(il); },
|
||||
attn_filter == nullptr ?
|
||||
[&](int32_t il) { return !model.hparams.recurrent_layer(il); }
|
||||
: attn_filter,
|
||||
attn_type_k,
|
||||
attn_type_v,
|
||||
attn_v_trans,
|
||||
@@ -41,7 +46,9 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
|
||||
)),
|
||||
kv_recurrent(new llama_kv_cache_recurrent(
|
||||
model,
|
||||
[&](int32_t il) { return model.hparams.recurrent_layer(il); },
|
||||
recurrent_filter == nullptr ?
|
||||
[&](int32_t il) { return model.hparams.recurrent_layer(il); }
|
||||
: recurrent_filter,
|
||||
recurrent_type_k,
|
||||
recurrent_type_v,
|
||||
offload,
|
||||
|
@@ -19,23 +19,30 @@
|
||||
|
||||
class llama_kv_cache_hybrid_recurrent : public llama_memory_i {
|
||||
public:
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_kv_cache_hybrid_recurrent(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type attn_type_k,
|
||||
ggml_type attn_type_v,
|
||||
bool attn_v_trans,
|
||||
uint32_t attn_kv_size,
|
||||
uint32_t attn_n_pad,
|
||||
uint32_t attn_n_swa,
|
||||
llama_swa_type attn_swa_type,
|
||||
/* recurrent */
|
||||
ggml_type recurrent_type_k,
|
||||
ggml_type recurrent_type_v,
|
||||
uint32_t recurrent_kv_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload);
|
||||
/* attn */
|
||||
ggml_type attn_type_k,
|
||||
ggml_type attn_type_v,
|
||||
bool attn_v_trans,
|
||||
uint32_t attn_kv_size,
|
||||
uint32_t attn_n_pad,
|
||||
uint32_t attn_n_swa,
|
||||
llama_swa_type attn_swa_type,
|
||||
/* recurrent */
|
||||
ggml_type recurrent_type_k,
|
||||
ggml_type recurrent_type_v,
|
||||
uint32_t recurrent_kv_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
/* layer filters */
|
||||
layer_filter_cb && attn_filter = nullptr,
|
||||
layer_filter_cb && recurrent_filter = nullptr);
|
||||
|
||||
~llama_kv_cache_hybrid_recurrent() = default;
|
||||
|
||||
|
Reference in New Issue
Block a user