kv-cache : simplify + fix warning for recurrent models (#12756)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-04-04 21:48:10 +03:00
committed by GitHub
parent 1be76e4620
commit 3e1d29348b
4 changed files with 80 additions and 173 deletions

View File

@@ -20,8 +20,8 @@ struct llama_kv_cache : public llama_memory_i {
virtual void restore() = 0; // call if batch processing fails - restores the cache state
virtual void commit() = 0; // call after successful batch processing - clears any pending state
virtual int32_t get_n_tokens() const = 0;
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual int32_t get_n_tokens() const = 0;
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual bool get_can_shift() const = 0;
@@ -89,8 +89,8 @@ public:
uint32_t kv_size,
bool offload);
int32_t get_n_tokens() const override;
uint32_t get_used_cells() const override;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;
size_t total_size() const;
@@ -109,7 +109,7 @@ public:
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_max(llama_seq_id seq_id) override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
bool get_can_shift() const override;
@@ -204,48 +204,6 @@ private:
// using llama_kv_cache_unified::llama_kv_cache_unified;
//};
// TODO: maybe become part of the public llama_kv_cache in the future
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
void llama_kv_cache_clear(llama_kv_cache * kv);
bool llama_kv_cache_seq_rm(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_cp(
llama_kv_cache * kv,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
void llama_kv_cache_seq_add(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
void llama_kv_cache_seq_div(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
void llama_kv_cache_defrag(llama_kv_cache * kv);
bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
//
// kv cache view
//