diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e1b02e4c0..f1f5d41d4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -15,6 +15,7 @@ add_library(llama llama-chat.cpp llama-context.cpp llama-grammar.cpp + llama-graph.cpp llama-hparams.cpp llama-impl.cpp llama-kv-cache.cpp diff --git a/src/llama-context.h b/src/llama-context.h index dd1030388..b446118ff 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -3,6 +3,7 @@ #include "llama.h" #include "llama-batch.h" #include "llama-cparams.h" +#include "llama-graph.h" #include "llama-model.h" #include "llama-kv-cache.h" #include "llama-adapter.h" @@ -16,7 +17,7 @@ using llama_loras = std::unordered_map; -struct llama_context { +struct llama_context : public llama_graph_i { llama_context(const llama_model & model); virtual ~llama_context(); @@ -129,137 +130,6 @@ struct llama_context { virtual ggml_tensor * build_rope_factors(int il); - // graph build API (context-specific) - - virtual ggml_tensor * build_inp_embd( - ggml_context * ctx0, - ggml_tensor * tok_embd, - const llama_ubatch & ubatch) = 0; - - virtual ggml_tensor * build_inp_pos( - ggml_context * ctx0, - int32_t n_tokens) = 0; - - virtual ggml_tensor * build_inp_out_ids( - ggml_context * ctx0, - int32_t n_tokens, - bool worst_case) = 0; - - virtual ggml_tensor * build_inp_mean( - ggml_context * ctx0, - int32_t n_tokens) = 0; - - virtual ggml_tensor * build_inp_cls( - ggml_context * ctx0, - int32_t n_tokens) = 0; - - virtual void build_attn_inp( - ggml_context * ctx0, - int32_t n_tokens, - bool causal, - bool swa, - bool worst_case) = 0; - - virtual void build_attn_kv_store( - ggml_context * ctx0, - ggml_cgraph * graph, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - int32_t n_tokens, - int64_t il, - bool worst_case) = 0; - - virtual ggml_tensor * build_attn_qkv( - ggml_context * ctx0, - ggml_cgraph * graph, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, - int32_t n_tokens, - float kq_scale, - int il, - bool worst_case) = 0; - - virtual ggml_tensor * build_soft_max_ext( - ggml_context * ctx0, - ggml_tensor * kq, - float kq_scale) = 0; - - virtual void build_k_shift( - ggml_context * ctx0, - ggml_cgraph * graph) = 0; - - // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache - virtual void build_defrag( - ggml_context * ctx0, - ggml_cgraph * graph) = 0; - - virtual ggml_tensor * build_inp_embd_enc( - ggml_context * ctx0, - int32_t n_tokens, - bool worst_case) = 0; - - virtual ggml_tensor * build_inp_KQ_mask_cross( - ggml_context * ctx0, - int32_t n_tokens, - bool worst_case) = 0; - - virtual ggml_tensor * build_inp_s_copy( - ggml_context * ctx0, - bool worst_case) = 0; - - virtual ggml_tensor * build_inp_s_mask( - ggml_context * ctx0, - bool worst_case) = 0; - - virtual ggml_tensor * build_copy_mask_state( - ggml_context * ctx0, - ggml_cgraph * graph, - ggml_tensor * s, - ggml_tensor * state_copy, - ggml_tensor * state_mask, - int32_t n_tokens, - int32_t n_state, - int32_t n_seqs, - bool worst_case) = 0; - - virtual ggml_tensor * build_mamba_layer( - ggml_context * ctx0, - ggml_cgraph * graph, - ggml_tensor * cur, - ggml_tensor * state_copy, - ggml_tensor * state_mask, - const llama_ubatch & ubatch, - int il, - bool worst_case) = 0; - - virtual ggml_tensor * build_rwkv_token_shift_load( - ggml_context * ctx0, - ggml_cgraph * graph, - ggml_tensor * state_copy, - ggml_tensor * state_mask, - const llama_ubatch & ubatch, - int il, - bool worst_case) = 0; - - virtual ggml_tensor * build_rwkv_token_shift_store( - ggml_context * ctx0, - ggml_tensor * token_shift, - const llama_ubatch & ubatch, - int il, - bool worst_case) = 0; - - virtual ggml_tensor * build_rwkv6_time_mix( - ggml_context * ctx0, - ggml_cgraph * graph, - ggml_tensor * cur, - ggml_tensor * x_prev, - ggml_tensor * state_copy, - ggml_tensor * state_mask, - const llama_ubatch & ubatch, - int il, - bool worst_case) = 0; - // state save/load virtual size_t state_get_size() = 0; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp new file mode 100644 index 000000000..20f2ee0bd --- /dev/null +++ b/src/llama-graph.cpp @@ -0,0 +1 @@ +#include "llama-graph.h" diff --git a/src/llama-graph.h b/src/llama-graph.h new file mode 100644 index 000000000..37dff8db4 --- /dev/null +++ b/src/llama-graph.h @@ -0,0 +1,164 @@ +#pragma once + +#include + +struct ggml_cgraph; +struct ggml_context; +struct ggml_tensor; +struct llama_ubatch; + +// TODO: pass to llama_model graph build +class llama_graph_i { +public: + // apply control vector for layer il + virtual ggml_tensor * build_cvec( + ggml_context * ctx0, + ggml_tensor * cur, + int il) = 0; + + // do mat_mul, while optionally apply lora + virtual ggml_tensor * build_lora_mm( + ggml_context * ctx0, + ggml_tensor * w, + ggml_tensor * cur) = 0; + + // do mat_mul_id, while optionally apply lora + virtual ggml_tensor * build_lora_mm_id( + ggml_context * ctx0, + ggml_tensor * w, // struct ggml_tensor * as + ggml_tensor * cur, // struct ggml_tensor * b + ggml_tensor * ids) = 0; + + virtual ggml_tensor * build_rope_factors(int il) = 0; + + // graph build API (context-specific) + + virtual ggml_tensor * build_inp_embd( + ggml_context * ctx0, + ggml_tensor * tok_embd, + const llama_ubatch & ubatch) = 0; + + virtual ggml_tensor * build_inp_pos( + ggml_context * ctx0, + int32_t n_tokens) = 0; + + virtual ggml_tensor * build_inp_out_ids( + ggml_context * ctx0, + int32_t n_tokens, + bool worst_case) = 0; + + virtual ggml_tensor * build_inp_mean( + ggml_context * ctx0, + int32_t n_tokens) = 0; + + virtual ggml_tensor * build_inp_cls( + ggml_context * ctx0, + int32_t n_tokens) = 0; + + virtual void build_attn_inp( + ggml_context * ctx0, + int32_t n_tokens, + bool causal, + bool swa, + bool worst_case) = 0; + + virtual void build_attn_kv_store( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + int32_t n_tokens, + int64_t il, + bool worst_case) = 0; + + virtual ggml_tensor * build_attn_qkv( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + int32_t n_tokens, + float kq_scale, + int il, + bool worst_case) = 0; + + virtual ggml_tensor * build_soft_max_ext( + ggml_context * ctx0, + ggml_tensor * kq, + float kq_scale) = 0; + + virtual void build_k_shift( + ggml_context * ctx0, + ggml_cgraph * graph) = 0; + + // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache + virtual void build_defrag( + ggml_context * ctx0, + ggml_cgraph * graph) = 0; + + virtual ggml_tensor * build_inp_embd_enc( + ggml_context * ctx0, + int32_t n_tokens, + bool worst_case) = 0; + + virtual ggml_tensor * build_inp_KQ_mask_cross( + ggml_context * ctx0, + int32_t n_tokens, + bool worst_case) = 0; + + virtual ggml_tensor * build_inp_s_copy( + ggml_context * ctx0, + bool worst_case) = 0; + + virtual ggml_tensor * build_inp_s_mask( + ggml_context * ctx0, + bool worst_case) = 0; + + virtual ggml_tensor * build_copy_mask_state( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * s, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + int32_t n_tokens, + int32_t n_state, + int32_t n_seqs, + bool worst_case) = 0; + + virtual ggml_tensor * build_mamba_layer( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * cur, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il, + bool worst_case) = 0; + + virtual ggml_tensor * build_rwkv_token_shift_load( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il, + bool worst_case) = 0; + + virtual ggml_tensor * build_rwkv_token_shift_store( + ggml_context * ctx0, + ggml_tensor * token_shift, + const llama_ubatch & ubatch, + int il, + bool worst_case) = 0; + + virtual ggml_tensor * build_rwkv6_time_mix( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * cur, + ggml_tensor * x_prev, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il, + bool worst_case) = 0; +};