feat: Add layer filter to recurrent cache

Branch: HybridCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-05-20 13:43:16 -06:00
parent fb26e95ae7
commit 40e9187892
3 changed files with 26 additions and 14 deletions

View File

@ -16,12 +16,13 @@
//
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
const int32_t n_layer = hparams.n_layer;
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
@ -63,6 +64,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
v_l.reserve(n_layer);
for (int i = 0; i < n_layer; i++) {
if (filter && !filter(i)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
continue;
}
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
@ -88,8 +94,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i);
k_l.push_back(k);
v_l.push_back(v);
k_l[i] = k;
v_l[i] = v;
}
// allocate tensors and initialize the buffers to avoid NaNs in the padding

View File

@ -15,13 +15,18 @@
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
class llama_kv_cache_recurrent : public llama_memory_i {
public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_kv_cache_recurrent(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max);
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size,
uint32_t n_seq_max);
~llama_kv_cache_recurrent() = default;

View File

@ -13759,6 +13759,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
{
res = new llama_kv_cache_recurrent(
*this,
nullptr,
GGML_TYPE_F32,
GGML_TYPE_F32,
cparams.offload_kqv,