mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 04:15:21 +00:00
kv-cache : refactor + add llama_memory_state_i (#13746)
* kv-cache : simplify the "struct llama_kv_cache" interface ggml-ci * kv-cache : revert the (n_swa + n_ubatch) change (for next PR) ggml-ci * kv-cache : some comments ggml-ci * context : fix graph reserve for multiple sequences ggml-ci * kv-cache : fix typo [no ci] * kv-cache : fix find_slot() logic for free slots ggml-ci * llama : add TODO for deprecating the defrag API in the future * kv-cache : improve find_slot() using min/max seq pos info ggml-ci * llama : handle aborts and compute errors ggml-ci * memory : extract state into llama_memory_state ggml-ci * kv-cache : add comments ggml-ci * server : update batching logic to reset n_batch on successful decode * server : upon full re-processing, remove the sequence from the cache * kv-cache : add TODO for doing split_equal when split_simple fails ggml-ci
This commit is contained in:
@ -17,10 +17,11 @@ struct ggml_tensor;
|
||||
struct llama_ubatch;
|
||||
struct llama_cparams;
|
||||
|
||||
class llama_memory_i;
|
||||
class llama_kv_cache_unified;
|
||||
class llama_kv_cache_unified_iswa;
|
||||
class llama_kv_cache_recurrent;
|
||||
class llama_memory_state_i;
|
||||
|
||||
class llama_kv_cache_unified_state;
|
||||
class llama_kv_cache_unified_iswa_state;
|
||||
class llama_kv_cache_recurrent_state;
|
||||
|
||||
// certain models (typically multi-modal) can produce different types of graphs
|
||||
enum llm_graph_type {
|
||||
@ -133,7 +134,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_pos_bucket_kv(
|
||||
const llama_hparams & hparams,
|
||||
const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {}
|
||||
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {}
|
||||
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
@ -141,7 +142,7 @@ public:
|
||||
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
const llama_kv_cache_unified_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_out_ids : public llm_graph_input_i {
|
||||
@ -188,26 +189,26 @@ public:
|
||||
|
||||
class llm_graph_input_s_copy : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
||||
llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
||||
virtual ~llm_graph_input_s_copy() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
|
||||
const llama_kv_cache_recurrent * kv_self;
|
||||
const llama_kv_cache_recurrent_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_s_mask : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {}
|
||||
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
||||
virtual ~llm_graph_input_s_mask() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_mask; // F32 [1, n_kv]
|
||||
|
||||
const llama_kv_cache_recurrent * kv_self;
|
||||
const llama_kv_cache_recurrent_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||
@ -247,10 +248,10 @@ public:
|
||||
llm_graph_input_attn_kv_unified(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified * kv_self) :
|
||||
const llama_kv_cache_unified_state * kv_state) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
kv_self(kv_self) {
|
||||
kv_state(kv_state) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified() = default;
|
||||
|
||||
@ -264,7 +265,7 @@ public:
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
const llama_kv_cache_unified_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
||||
@ -272,10 +273,10 @@ public:
|
||||
llm_graph_input_attn_kv_unified_iswa(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_unified_iswa * kv_self) :
|
||||
const llama_kv_cache_unified_iswa_state * kv_state) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
kv_self(kv_self) {
|
||||
kv_state(kv_state) {
|
||||
}
|
||||
~llm_graph_input_attn_kv_unified_iswa() = default;
|
||||
|
||||
@ -292,7 +293,7 @@ public:
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_kv_cache_unified_iswa * kv_self;
|
||||
const llama_kv_cache_unified_iswa_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
||||
@ -383,10 +384,10 @@ struct llm_graph_params {
|
||||
ggml_backend_sched_t sched;
|
||||
ggml_backend_t backend_cpu;
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_i * memory;
|
||||
const llama_cross * cross;
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_state_i * mstate;
|
||||
const llama_cross * cross;
|
||||
|
||||
int32_t n_outputs;
|
||||
|
||||
@ -435,10 +436,10 @@ struct llm_graph_context {
|
||||
|
||||
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
||||
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_i * memory;
|
||||
const llama_cross * cross;
|
||||
const llama_adapter_cvec * cvec;
|
||||
const llama_adapter_loras * loras;
|
||||
const llama_memory_state_i * mstate;
|
||||
const llama_cross * cross;
|
||||
|
||||
const llm_graph_cb & cb_func;
|
||||
|
||||
|
Reference in New Issue
Block a user