memory : migrate from llama_kv_cache to more generic llama_memory (#14006)

* memory : merge llama_kv_cache into llama_memory + new `llama_memory` API

ggml-ci

* context : fix casts

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-05 15:29:22 +03:00
committed by GitHub
parent 3a077146a4
commit 7f37b6cf1e
11 changed files with 324 additions and 220 deletions

View File

@ -2,8 +2,8 @@
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cache.h"
#include "llama-kv-cells.h"
#include "llama-memory.h"
#include <unordered_map>
#include <vector>
@ -17,7 +17,7 @@ struct llama_context;
// llama_kv_cache_unified
//
class llama_kv_cache_unified : public llama_kv_cache {
class llama_kv_cache_unified : public llama_memory_i {
public:
static uint32_t get_padding(const llama_cparams & cparams);
@ -56,21 +56,6 @@ public:
// llama_memory_i
//
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
@ -83,6 +68,17 @@ public:
bool get_can_shift() const override;
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;