mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-18 08:37:43 +00:00
kv-cache : prepare for abstraction
ggml-ci
This commit is contained in:
@ -8,11 +8,10 @@
|
||||
struct ggml_cgraph;
|
||||
struct ggml_context;
|
||||
struct ggml_tensor;
|
||||
struct ggml_backend_buffer;
|
||||
struct llama_ubatch;
|
||||
|
||||
struct llama_graph_result {
|
||||
ggml_cgraph * gf = nullptr;
|
||||
|
||||
// important graph nodes
|
||||
ggml_tensor * t_logits = nullptr;
|
||||
ggml_tensor * t_embd = nullptr;
|
||||
@ -50,6 +49,14 @@ public:
|
||||
|
||||
virtual ggml_tensor * build_rope_factors(int il) = 0;
|
||||
|
||||
// note: optionally set the backend to be the same as the bbuf's backend
|
||||
virtual ggml_tensor * build_rope_shift(
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
ggml_backend_buffer * bbuft) = 0;
|
||||
|
||||
// graph build API (context-specific)
|
||||
|
||||
virtual ggml_tensor * build_inp_embd(
|
||||
@ -83,7 +90,7 @@ public:
|
||||
|
||||
virtual void build_attn_kv_store(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * graph,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
int32_t n_tokens,
|
||||
@ -92,7 +99,7 @@ public:
|
||||
|
||||
virtual ggml_tensor * build_attn_qkv(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * graph,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
@ -106,14 +113,8 @@ public:
|
||||
ggml_tensor * kq,
|
||||
float kq_scale) = 0;
|
||||
|
||||
virtual void build_kv_self_shift(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * graph) = 0;
|
||||
|
||||
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
|
||||
virtual void build_kv_self_defrag(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * graph) = 0;
|
||||
virtual ggml_tensor * build_inp_k_shift(
|
||||
ggml_context * ctx0) = 0;
|
||||
|
||||
virtual ggml_tensor * build_inp_embd_enc(
|
||||
ggml_context * ctx0,
|
||||
@ -135,7 +136,7 @@ public:
|
||||
|
||||
virtual ggml_tensor * build_copy_mask_state(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * graph,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
@ -146,7 +147,7 @@ public:
|
||||
|
||||
virtual ggml_tensor * build_mamba_layer(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * graph,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
@ -156,7 +157,7 @@ public:
|
||||
|
||||
virtual ggml_tensor * build_rwkv_token_shift_load(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * graph,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
@ -172,7 +173,7 @@ public:
|
||||
|
||||
virtual ggml_tensor * build_rwkv6_time_mix(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * graph,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
ggml_tensor * state_copy,
|
||||
@ -181,3 +182,18 @@ public:
|
||||
int il,
|
||||
bool worst_case) = 0;
|
||||
};
|
||||
|
||||
class llama_graph_kv_cache_i {
|
||||
public:
|
||||
virtual void build_shift(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
llama_graph_i * lgf) = 0;
|
||||
|
||||
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
|
||||
virtual void build_defrag(
|
||||
ggml_context * ctx0,
|
||||
ggml_cgraph * gf,
|
||||
int32_t max_nodes,
|
||||
bool v_trans) = 0;
|
||||
};
|
||||
|
Reference in New Issue
Block a user