mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-18 00:27:31 +00:00
context : introduce llama_graph_i
ggml-ci
This commit is contained in:
@ -15,6 +15,7 @@ add_library(llama
|
|||||||
llama-chat.cpp
|
llama-chat.cpp
|
||||||
llama-context.cpp
|
llama-context.cpp
|
||||||
llama-grammar.cpp
|
llama-grammar.cpp
|
||||||
|
llama-graph.cpp
|
||||||
llama-hparams.cpp
|
llama-hparams.cpp
|
||||||
llama-impl.cpp
|
llama-impl.cpp
|
||||||
llama-kv-cache.cpp
|
llama-kv-cache.cpp
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "llama-batch.h"
|
#include "llama-batch.h"
|
||||||
#include "llama-cparams.h"
|
#include "llama-cparams.h"
|
||||||
|
#include "llama-graph.h"
|
||||||
#include "llama-model.h"
|
#include "llama-model.h"
|
||||||
#include "llama-kv-cache.h"
|
#include "llama-kv-cache.h"
|
||||||
#include "llama-adapter.h"
|
#include "llama-adapter.h"
|
||||||
@ -16,7 +17,7 @@
|
|||||||
|
|
||||||
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
|
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
|
||||||
|
|
||||||
struct llama_context {
|
struct llama_context : public llama_graph_i {
|
||||||
llama_context(const llama_model & model);
|
llama_context(const llama_model & model);
|
||||||
virtual ~llama_context();
|
virtual ~llama_context();
|
||||||
|
|
||||||
@ -129,137 +130,6 @@ struct llama_context {
|
|||||||
|
|
||||||
virtual ggml_tensor * build_rope_factors(int il);
|
virtual ggml_tensor * build_rope_factors(int il);
|
||||||
|
|
||||||
// graph build API (context-specific)
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_inp_embd(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_tensor * tok_embd,
|
|
||||||
const llama_ubatch & ubatch) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_inp_pos(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
int32_t n_tokens) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_inp_out_ids(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
int32_t n_tokens,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_inp_mean(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
int32_t n_tokens) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_inp_cls(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
int32_t n_tokens) = 0;
|
|
||||||
|
|
||||||
virtual void build_attn_inp(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
int32_t n_tokens,
|
|
||||||
bool causal,
|
|
||||||
bool swa,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
virtual void build_attn_kv_store(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_cgraph * graph,
|
|
||||||
ggml_tensor * k_cur,
|
|
||||||
ggml_tensor * v_cur,
|
|
||||||
int32_t n_tokens,
|
|
||||||
int64_t il,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_attn_qkv(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_cgraph * graph,
|
|
||||||
ggml_tensor * wo,
|
|
||||||
ggml_tensor * wo_b,
|
|
||||||
ggml_tensor * q_cur,
|
|
||||||
int32_t n_tokens,
|
|
||||||
float kq_scale,
|
|
||||||
int il,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_soft_max_ext(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_tensor * kq,
|
|
||||||
float kq_scale) = 0;
|
|
||||||
|
|
||||||
virtual void build_k_shift(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_cgraph * graph) = 0;
|
|
||||||
|
|
||||||
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
|
|
||||||
virtual void build_defrag(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_cgraph * graph) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_inp_embd_enc(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
int32_t n_tokens,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_inp_KQ_mask_cross(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
int32_t n_tokens,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_inp_s_copy(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_inp_s_mask(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_copy_mask_state(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_cgraph * graph,
|
|
||||||
ggml_tensor * s,
|
|
||||||
ggml_tensor * state_copy,
|
|
||||||
ggml_tensor * state_mask,
|
|
||||||
int32_t n_tokens,
|
|
||||||
int32_t n_state,
|
|
||||||
int32_t n_seqs,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_mamba_layer(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_cgraph * graph,
|
|
||||||
ggml_tensor * cur,
|
|
||||||
ggml_tensor * state_copy,
|
|
||||||
ggml_tensor * state_mask,
|
|
||||||
const llama_ubatch & ubatch,
|
|
||||||
int il,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_rwkv_token_shift_load(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_cgraph * graph,
|
|
||||||
ggml_tensor * state_copy,
|
|
||||||
ggml_tensor * state_mask,
|
|
||||||
const llama_ubatch & ubatch,
|
|
||||||
int il,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_rwkv_token_shift_store(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_tensor * token_shift,
|
|
||||||
const llama_ubatch & ubatch,
|
|
||||||
int il,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
virtual ggml_tensor * build_rwkv6_time_mix(
|
|
||||||
ggml_context * ctx0,
|
|
||||||
ggml_cgraph * graph,
|
|
||||||
ggml_tensor * cur,
|
|
||||||
ggml_tensor * x_prev,
|
|
||||||
ggml_tensor * state_copy,
|
|
||||||
ggml_tensor * state_mask,
|
|
||||||
const llama_ubatch & ubatch,
|
|
||||||
int il,
|
|
||||||
bool worst_case) = 0;
|
|
||||||
|
|
||||||
// state save/load
|
// state save/load
|
||||||
|
|
||||||
virtual size_t state_get_size() = 0;
|
virtual size_t state_get_size() = 0;
|
||||||
|
1
src/llama-graph.cpp
Normal file
1
src/llama-graph.cpp
Normal file
@ -0,0 +1 @@
|
|||||||
|
#include "llama-graph.h"
|
164
src/llama-graph.h
Normal file
164
src/llama-graph.h
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
struct ggml_cgraph;
|
||||||
|
struct ggml_context;
|
||||||
|
struct ggml_tensor;
|
||||||
|
struct llama_ubatch;
|
||||||
|
|
||||||
|
// TODO: pass to llama_model graph build
|
||||||
|
class llama_graph_i {
|
||||||
|
public:
|
||||||
|
// apply control vector for layer il
|
||||||
|
virtual ggml_tensor * build_cvec(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_tensor * cur,
|
||||||
|
int il) = 0;
|
||||||
|
|
||||||
|
// do mat_mul, while optionally apply lora
|
||||||
|
virtual ggml_tensor * build_lora_mm(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_tensor * w,
|
||||||
|
ggml_tensor * cur) = 0;
|
||||||
|
|
||||||
|
// do mat_mul_id, while optionally apply lora
|
||||||
|
virtual ggml_tensor * build_lora_mm_id(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_tensor * w, // struct ggml_tensor * as
|
||||||
|
ggml_tensor * cur, // struct ggml_tensor * b
|
||||||
|
ggml_tensor * ids) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_rope_factors(int il) = 0;
|
||||||
|
|
||||||
|
// graph build API (context-specific)
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_inp_embd(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_tensor * tok_embd,
|
||||||
|
const llama_ubatch & ubatch) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_inp_pos(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
int32_t n_tokens) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_inp_out_ids(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
int32_t n_tokens,
|
||||||
|
bool worst_case) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_inp_mean(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
int32_t n_tokens) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_inp_cls(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
int32_t n_tokens) = 0;
|
||||||
|
|
||||||
|
virtual void build_attn_inp(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
int32_t n_tokens,
|
||||||
|
bool causal,
|
||||||
|
bool swa,
|
||||||
|
bool worst_case) = 0;
|
||||||
|
|
||||||
|
virtual void build_attn_kv_store(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_cgraph * graph,
|
||||||
|
ggml_tensor * k_cur,
|
||||||
|
ggml_tensor * v_cur,
|
||||||
|
int32_t n_tokens,
|
||||||
|
int64_t il,
|
||||||
|
bool worst_case) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_attn_qkv(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_cgraph * graph,
|
||||||
|
ggml_tensor * wo,
|
||||||
|
ggml_tensor * wo_b,
|
||||||
|
ggml_tensor * q_cur,
|
||||||
|
int32_t n_tokens,
|
||||||
|
float kq_scale,
|
||||||
|
int il,
|
||||||
|
bool worst_case) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_soft_max_ext(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_tensor * kq,
|
||||||
|
float kq_scale) = 0;
|
||||||
|
|
||||||
|
virtual void build_k_shift(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_cgraph * graph) = 0;
|
||||||
|
|
||||||
|
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
|
||||||
|
virtual void build_defrag(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_cgraph * graph) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_inp_embd_enc(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
int32_t n_tokens,
|
||||||
|
bool worst_case) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_inp_KQ_mask_cross(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
int32_t n_tokens,
|
||||||
|
bool worst_case) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_inp_s_copy(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
bool worst_case) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_inp_s_mask(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
bool worst_case) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_copy_mask_state(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_cgraph * graph,
|
||||||
|
ggml_tensor * s,
|
||||||
|
ggml_tensor * state_copy,
|
||||||
|
ggml_tensor * state_mask,
|
||||||
|
int32_t n_tokens,
|
||||||
|
int32_t n_state,
|
||||||
|
int32_t n_seqs,
|
||||||
|
bool worst_case) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_mamba_layer(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_cgraph * graph,
|
||||||
|
ggml_tensor * cur,
|
||||||
|
ggml_tensor * state_copy,
|
||||||
|
ggml_tensor * state_mask,
|
||||||
|
const llama_ubatch & ubatch,
|
||||||
|
int il,
|
||||||
|
bool worst_case) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_rwkv_token_shift_load(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_cgraph * graph,
|
||||||
|
ggml_tensor * state_copy,
|
||||||
|
ggml_tensor * state_mask,
|
||||||
|
const llama_ubatch & ubatch,
|
||||||
|
int il,
|
||||||
|
bool worst_case) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_rwkv_token_shift_store(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_tensor * token_shift,
|
||||||
|
const llama_ubatch & ubatch,
|
||||||
|
int il,
|
||||||
|
bool worst_case) = 0;
|
||||||
|
|
||||||
|
virtual ggml_tensor * build_rwkv6_time_mix(
|
||||||
|
ggml_context * ctx0,
|
||||||
|
ggml_cgraph * graph,
|
||||||
|
ggml_tensor * cur,
|
||||||
|
ggml_tensor * x_prev,
|
||||||
|
ggml_tensor * state_copy,
|
||||||
|
ggml_tensor * state_mask,
|
||||||
|
const llama_ubatch & ubatch,
|
||||||
|
int il,
|
||||||
|
bool worst_case) = 0;
|
||||||
|
};
|
Reference in New Issue
Block a user