mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 12:35:16 +00:00
feat: Add layer filter to recurrent cache
Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
@ -16,12 +16,13 @@
|
||||
//
|
||||
|
||||
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
||||
const int32_t n_layer = hparams.n_layer;
|
||||
|
||||
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
|
||||
@ -63,6 +64,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
||||
v_l.reserve(n_layer);
|
||||
|
||||
for (int i = 0; i < n_layer; i++) {
|
||||
if (filter && !filter(i)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
||||
|
||||
@ -88,8 +94,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
||||
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
||||
ggml_format_name(k, "cache_k_l%d", i);
|
||||
ggml_format_name(v, "cache_v_l%d", i);
|
||||
k_l.push_back(k);
|
||||
v_l.push_back(v);
|
||||
k_l[i] = k;
|
||||
v_l[i] = v;
|
||||
}
|
||||
|
||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||
|
@ -15,13 +15,18 @@
|
||||
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
||||
class llama_kv_cache_recurrent : public llama_memory_i {
|
||||
public:
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_kv_cache_recurrent(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max);
|
||||
const llama_model & model,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max);
|
||||
|
||||
~llama_kv_cache_recurrent() = default;
|
||||
|
||||
|
@ -13759,6 +13759,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||
{
|
||||
res = new llama_kv_cache_recurrent(
|
||||
*this,
|
||||
nullptr,
|
||||
GGML_TYPE_F32,
|
||||
GGML_TYPE_F32,
|
||||
cparams.offload_kqv,
|
||||
|
Reference in New Issue
Block a user