diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 8f6f120f6..917d2a60c 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -16,12 +16,13 @@ // llama_kv_cache_recurrent::llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", @@ -63,6 +64,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( v_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { + if (filter && !filter(i)) { + LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i); + continue; + } + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); @@ -88,8 +94,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); - k_l.push_back(k); - v_l.push_back(v); + k_l[i] = k; + v_l[i] = v; } // allocate tensors and initialize the buffers to avoid NaNs in the padding diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h index f9b01a651..b89590b7b 100644 --- a/src/llama-kv-cache-recurrent.h +++ b/src/llama-kv-cache-recurrent.h @@ -15,13 +15,18 @@ // see the implementation of llama_kv_cache_unified_state_i for an example how to do it class llama_kv_cache_recurrent : public llama_memory_i { public: + + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max); + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max); ~llama_kv_cache_recurrent() = default; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a49a2c09f..c82cdc627 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13759,6 +13759,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = new llama_kv_cache_recurrent( *this, + nullptr, GGML_TYPE_F32, GGML_TYPE_F32, cparams.offload_kqv,