mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-14 14:53:55 +00:00
610 lines
18 KiB
C++
610 lines
18 KiB
C++
#pragma once
|
|
|
|
#include "llama.h"
|
|
#include "llama-batch.h"
|
|
#include "llama-cparams.h"
|
|
#include "llama-graph.h"
|
|
#include "llama-model.h"
|
|
#include "llama-kv-cache.h"
|
|
#include "llama-adapter.h"
|
|
|
|
#include "ggml-cpp.h"
|
|
|
|
#include <map>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
#include <set>
|
|
|
|
class llama_io_read_i;
|
|
class llama_io_write_i;
|
|
|
|
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
|
|
|
|
// basic transformer without KV cache
|
|
struct llama_context : public llama_graph_i {
|
|
public:
|
|
llama_context(
|
|
const llama_model & model,
|
|
const llama_context_params & params);
|
|
|
|
virtual ~llama_context();
|
|
|
|
// init scheduler and compute buffers, reserve worst-case graphs
|
|
// call once after the context is constructed
|
|
virtual void init();
|
|
|
|
virtual void synchronize();
|
|
|
|
protected:
|
|
// called by init() to reserve the worst-case graphs
|
|
// override in child classes
|
|
virtual void reserve();
|
|
|
|
public:
|
|
const llama_model & get_model() const;
|
|
const llama_cparams & get_cparams() const;
|
|
|
|
virtual uint32_t n_ctx() const;
|
|
virtual uint32_t n_ctx_per_seq() const;
|
|
virtual uint32_t n_batch() const;
|
|
virtual uint32_t n_ubatch() const;
|
|
virtual uint32_t n_seq_max() const;
|
|
|
|
virtual uint32_t n_threads() const;
|
|
virtual uint32_t n_threads_batch() const;
|
|
|
|
virtual int32_t max_nodes() const;
|
|
|
|
// self-attention:
|
|
|
|
// if the context does not have a KV cache, return nullptr
|
|
virtual llama_kv_cache * get_kv_self();
|
|
virtual const llama_kv_cache * get_kv_self() const;
|
|
|
|
// if the context does not have a KV cache, noop
|
|
virtual void kv_self_update();
|
|
|
|
virtual enum llama_pooling_type pooling_type() const;
|
|
|
|
virtual float * get_logits();
|
|
virtual float * get_logits_ith(int32_t i);
|
|
|
|
virtual float * get_embeddings();
|
|
virtual float * get_embeddings_ith(int32_t i);
|
|
virtual float * get_embeddings_seq(llama_seq_id seq_id);
|
|
|
|
virtual int64_t n_pos_per_token() const; // vision
|
|
|
|
virtual void attach_threadpool(
|
|
ggml_threadpool_t threadpool,
|
|
ggml_threadpool_t threadpool_batch);
|
|
|
|
virtual void detach_threadpool();
|
|
|
|
virtual void set_n_threads(int32_t n_threads, int32_t n_threads_batch);
|
|
|
|
virtual void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data);
|
|
|
|
virtual void set_embeddings (bool value);
|
|
virtual void set_causal_attn(bool value);
|
|
|
|
virtual void set_adapter_lora(
|
|
llama_adapter_lora * adapter,
|
|
float scale);
|
|
|
|
virtual bool rm_adapter_lora(
|
|
llama_adapter_lora * adapter);
|
|
|
|
virtual void clear_adapter_lora();
|
|
|
|
virtual bool apply_adapter_cvec(
|
|
const float * data,
|
|
size_t len,
|
|
int32_t n_embd,
|
|
int32_t il_start,
|
|
int32_t il_end);
|
|
|
|
// encode a batch of tokens by evaluating the encoder part of the transformer
|
|
//
|
|
// - lctx: llama context
|
|
// - batch: batch to evaluate
|
|
//
|
|
// return 0 on success
|
|
// return positive int on warning
|
|
// return negative int on error
|
|
//
|
|
virtual int encode(llama_batch & inp_batch);
|
|
|
|
// decode a batch of tokens by evaluating the transformer
|
|
// in case of unsuccessful decoding (error or warning),
|
|
// the kv_cache state will be returned to its original state
|
|
// (for non-recurrent models) or cleaned (for recurrent models)
|
|
//
|
|
// - lctx: llama context
|
|
// - inp_batch: batch to evaluate
|
|
//
|
|
// return 0 on success
|
|
// return positive int on warning
|
|
// return negative int on error
|
|
//
|
|
virtual int decode(llama_batch & inp_batch);
|
|
|
|
protected:
|
|
//
|
|
// input
|
|
//
|
|
|
|
// when the compute graph is built, it creates the input tensors that it needs
|
|
// the contents of the input tensors are set by the input_set() function
|
|
|
|
virtual void input_set(const llama_ubatch & ubatch);
|
|
|
|
struct {
|
|
// base input tensors
|
|
ggml_tensor * tokens; // I32 [n_batch]
|
|
ggml_tensor * embd; // F32 [n_embd, n_batch]
|
|
ggml_tensor * pos; // I32 [n_batch]
|
|
ggml_tensor * out_ids; // I32 [n_outputs]
|
|
ggml_tensor * mean; // F32 [n_batch, n_batch]
|
|
ggml_tensor * cls; // I32 [n_batch]
|
|
|
|
// KQ mask input tensors
|
|
ggml_tensor * kq_mask; // F32 [n_tokens, n_batch]
|
|
ggml_tensor * kq_mask_cnv; // [n_tokens, n_batch]
|
|
} inp;
|
|
|
|
//
|
|
// output
|
|
//
|
|
|
|
// Make sure enough space is available for outputs.
|
|
// Returns max number of outputs for which space was reserved.
|
|
virtual int32_t output_reserve(int32_t n_outputs);
|
|
|
|
// make the outputs have the same order they had in the user-provided batch
|
|
// TODO: maybe remove this
|
|
virtual void output_reorder();
|
|
|
|
//
|
|
// graph
|
|
//
|
|
|
|
// zero-out inputs and create the ctx_context for the compute graph
|
|
virtual ggml_cgraph * graph_init();
|
|
|
|
// TODO: add encode/decode graphs
|
|
virtual llama_graph_result graph_build(
|
|
ggml_context * ctx,
|
|
ggml_cgraph * gf,
|
|
const llama_ubatch & ubatch);
|
|
|
|
// returns the result of ggml_backend_sched_graph_compute_async execution
|
|
virtual enum ggml_status graph_compute(
|
|
ggml_cgraph * gf,
|
|
bool batched);
|
|
|
|
ggml_context_ptr ctx_compute;
|
|
|
|
//
|
|
// graph build API (generic)
|
|
//
|
|
|
|
virtual void build_cb(
|
|
ggml_tensor * cur,
|
|
const char * name,
|
|
const llama_ubatch & ubatch,
|
|
int il);
|
|
|
|
// apply control vector for layer il
|
|
virtual ggml_tensor * build_cvec(
|
|
ggml_context * ctx0,
|
|
ggml_tensor * cur,
|
|
int il);
|
|
|
|
// do mat_mul, while optionally apply lora
|
|
virtual ggml_tensor * build_lora_mm(
|
|
ggml_context * ctx0,
|
|
ggml_tensor * w,
|
|
ggml_tensor * cur);
|
|
|
|
// do mat_mul_id, while optionally apply lora
|
|
virtual ggml_tensor * build_lora_mm_id(
|
|
ggml_context * ctx0,
|
|
ggml_tensor * w, // struct ggml_tensor * as
|
|
ggml_tensor * cur, // struct ggml_tensor * b
|
|
ggml_tensor * ids);
|
|
|
|
virtual ggml_tensor * build_rope_factors(int il);
|
|
|
|
virtual ggml_tensor * build_rope_shift(
|
|
ggml_context * ctx0,
|
|
ggml_tensor * cur,
|
|
ggml_tensor * shift,
|
|
ggml_tensor * factors,
|
|
ggml_backend_buffer * bbuf);
|
|
|
|
virtual ggml_tensor * build_inp_embd(
|
|
ggml_context * ctx0,
|
|
ggml_tensor * tok_embd,
|
|
const llama_ubatch & ubatch);
|
|
|
|
virtual ggml_tensor * build_inp_pos(
|
|
ggml_context * ctx0,
|
|
int32_t n_tokens);
|
|
|
|
virtual ggml_tensor * build_inp_out_ids(
|
|
ggml_context * ctx0);
|
|
|
|
virtual ggml_tensor * build_inp_mean(
|
|
ggml_context * ctx0,
|
|
int32_t n_tokens);
|
|
|
|
virtual ggml_tensor * build_inp_cls(
|
|
ggml_context * ctx0,
|
|
int32_t n_tokens);
|
|
|
|
virtual void build_attn_inp(
|
|
ggml_context * ctx0,
|
|
int32_t n_tokens,
|
|
bool causal,
|
|
bool swa);
|
|
|
|
virtual ggml_tensor * build_attn(
|
|
ggml_context * ctx0,
|
|
ggml_cgraph * gf,
|
|
ggml_tensor * wo,
|
|
ggml_tensor * wo_b,
|
|
ggml_tensor * q_cur,
|
|
ggml_tensor * k_cur,
|
|
ggml_tensor * v_cur,
|
|
int32_t n_tokens,
|
|
float kq_scale,
|
|
int il);
|
|
|
|
public:
|
|
//
|
|
// perf
|
|
//
|
|
|
|
virtual llama_perf_context_data perf_get_data() const;
|
|
virtual void perf_reset();
|
|
|
|
protected:
|
|
mutable int64_t t_start_us = 0;
|
|
mutable int64_t t_load_us = 0;
|
|
mutable int64_t t_p_eval_us = 0;
|
|
mutable int64_t t_eval_us = 0;
|
|
|
|
mutable int64_t t_compute_start_us = 0;
|
|
mutable int64_t n_queued_tokens = 0;
|
|
|
|
mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
|
|
mutable int32_t n_eval = 0; // number of eval calls
|
|
|
|
public:
|
|
//
|
|
// state save/load
|
|
//
|
|
|
|
virtual size_t state_get_size();
|
|
virtual size_t state_get_data( uint8_t * dst, size_t size);
|
|
virtual size_t state_set_data(const uint8_t * src, size_t size);
|
|
|
|
virtual size_t state_seq_get_size(llama_seq_id seq_id);
|
|
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
|
|
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
|
|
|
|
virtual bool state_load_file(
|
|
const char * filepath,
|
|
llama_token * tokens_out,
|
|
size_t n_token_capacity,
|
|
size_t * n_token_count_out);
|
|
|
|
virtual bool state_save_file(
|
|
const char * filepath,
|
|
const llama_token * tokens,
|
|
size_t n_token_count);
|
|
|
|
virtual size_t state_seq_load_file(
|
|
llama_seq_id seq_id,
|
|
const char * filepath,
|
|
llama_token * tokens_out,
|
|
size_t n_token_capacity,
|
|
size_t * n_token_count_out);
|
|
|
|
virtual size_t state_seq_save_file(
|
|
llama_seq_id seq_id,
|
|
const char * filepath,
|
|
const llama_token * tokens,
|
|
size_t n_token_count);
|
|
|
|
protected:
|
|
virtual size_t state_get_data(llama_io_write_i & io);
|
|
virtual size_t state_set_data(llama_io_read_i & io);
|
|
|
|
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id);
|
|
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id);
|
|
|
|
//
|
|
// members
|
|
//
|
|
|
|
const llama_model & model;
|
|
|
|
llama_cparams cparams;
|
|
llama_adapter_cvec cvec;
|
|
llama_loras loras;
|
|
llama_sbatch sbatch;
|
|
|
|
ggml_threadpool_t threadpool = nullptr;
|
|
ggml_threadpool_t threadpool_batch = nullptr;
|
|
|
|
ggml_abort_callback abort_callback = nullptr;
|
|
void * abort_callback_data = nullptr;
|
|
|
|
ggml_backend_t backend_cpu = nullptr;
|
|
std::vector<ggml_backend_ptr> backends;
|
|
|
|
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;
|
|
|
|
ggml_backend_sched_ptr sched;
|
|
|
|
// buffer types used for the compute buffer of each backend
|
|
std::vector<ggml_backend_t> backend_ptrs;
|
|
std::vector<ggml_backend_buffer_type_t> backend_buft;
|
|
|
|
// memory buffers used to evaluate the model
|
|
std::vector<uint8_t> buf_compute_meta;
|
|
|
|
// host buffer for the model output (logits and embeddings)
|
|
ggml_backend_buffer_ptr buf_output;
|
|
|
|
// TODO: remove
|
|
bool logits_all = false;
|
|
|
|
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
|
size_t logits_size = 0; // capacity (of floats) for logits
|
|
float * logits = nullptr;
|
|
|
|
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
|
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
|
size_t embd_size = 0; // capacity (of floats) for embeddings
|
|
float * embd = nullptr;
|
|
|
|
// sequence embeddings output (map of [n_embd] vectors)
|
|
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
|
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
|
|
|
int32_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
|
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
|
|
|
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
|
|
|
bool has_evaluated_once = false;
|
|
};
|
|
|
|
// transformer with a self-attention KV cache
|
|
class llama_context_kv_self : public llama_context {
|
|
public:
|
|
llama_context_kv_self(
|
|
const llama_model & model,
|
|
const llama_context_params & params);
|
|
|
|
virtual ~llama_context_kv_self();
|
|
|
|
protected:
|
|
virtual void reserve() override;
|
|
|
|
public:
|
|
virtual llama_kv_cache * get_kv_self() override;
|
|
virtual const llama_kv_cache * get_kv_self() const override;
|
|
|
|
virtual void kv_self_update() override;
|
|
|
|
virtual int encode(llama_batch & inp_batch) override;
|
|
virtual int decode(llama_batch & inp_batch) override;
|
|
|
|
protected:
|
|
//
|
|
// input
|
|
//
|
|
|
|
virtual void input_set(const llama_ubatch & ubatch) override;
|
|
|
|
struct {
|
|
ggml_tensor * self_kq_mask; // F32 [kv_size, n_batch]
|
|
ggml_tensor * self_kq_mask_cnv; // [kv_size, n_batch]
|
|
ggml_tensor * self_kq_mask_swa; // F32 [kv_size, n_batch]
|
|
ggml_tensor * self_kq_mask_swa_cnv; // [kv_size, n_batch]
|
|
ggml_tensor * self_k_shift; // I32 [kv_size]
|
|
} inp;
|
|
|
|
//
|
|
// graph
|
|
//
|
|
|
|
virtual ggml_cgraph * graph_init() override;
|
|
|
|
//
|
|
// graph build
|
|
//
|
|
|
|
virtual ggml_tensor * build_inp_self_k_shift(ggml_context * ctx0) override;
|
|
|
|
virtual void build_attn_inp(
|
|
ggml_context * ctx0,
|
|
int32_t n_tokens,
|
|
bool causal,
|
|
bool swa) override;
|
|
|
|
virtual ggml_tensor * build_attn(
|
|
ggml_context * ctx0,
|
|
ggml_cgraph * gf,
|
|
ggml_tensor * wo,
|
|
ggml_tensor * wo_b,
|
|
ggml_tensor * q_cur,
|
|
ggml_tensor * k_cur,
|
|
ggml_tensor * v_cur,
|
|
int32_t n_tokens,
|
|
float kq_scale,
|
|
int il) override;
|
|
|
|
virtual void build_kv_self_shift(
|
|
ggml_context * ctx0,
|
|
ggml_cgraph * gf) override;
|
|
|
|
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
|
|
virtual void build_kv_self_defrag(
|
|
ggml_context * ctx0,
|
|
ggml_cgraph * gf) override;
|
|
|
|
// === encoder-decoder ===
|
|
|
|
// whether we are computing encoder output or decoder output
|
|
bool is_encoding = false;
|
|
|
|
// output of the encoder part of the encoder-decoder models
|
|
std::vector<float> embd_enc;
|
|
std::vector<std::set<llama_seq_id>> seq_ids_enc;
|
|
|
|
struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc]
|
|
struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch]
|
|
struct ggml_tensor * inp_kq_mask_cross; // F32 [n_outputs_enc, n_batch]
|
|
|
|
virtual ggml_tensor * build_inp_embd_enc(
|
|
ggml_context * ctx0) override;
|
|
|
|
virtual ggml_tensor * build_inp_kq_mask_cross(
|
|
ggml_context * ctx0,
|
|
int32_t n_tokens) override;
|
|
|
|
//
|
|
// state save/load
|
|
//
|
|
|
|
virtual size_t state_get_data(llama_io_write_i & io) override;
|
|
virtual size_t state_set_data(llama_io_read_i & io) override;
|
|
|
|
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
|
|
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
|
|
|
|
//
|
|
// members
|
|
//
|
|
|
|
llama_kv_cache kv_self;
|
|
};
|
|
|
|
// a recurrent transformer (ie.e RWKV, Mamba)
|
|
class llama_context_recurrent : public llama_context {
|
|
public:
|
|
llama_context_recurrent(
|
|
const llama_model & model,
|
|
const llama_context_params & params);
|
|
|
|
virtual ~llama_context_recurrent();
|
|
|
|
protected:
|
|
virtual void reserve() override;
|
|
|
|
public:
|
|
virtual llama_kv_cache * get_kv_self() override;
|
|
virtual const llama_kv_cache * get_kv_self() const override;
|
|
|
|
virtual void kv_self_update() override;
|
|
|
|
virtual int encode(llama_batch & inp_batch) override;
|
|
virtual int decode(llama_batch & inp_batch) override;
|
|
|
|
protected:
|
|
//
|
|
// input
|
|
//
|
|
|
|
virtual void input_set(const llama_ubatch & ubatch) override;
|
|
|
|
struct {
|
|
ggml_tensor * s_copy; // I32 [kv_size]
|
|
ggml_tensor * s_mask; // F32 [1, n_kv]
|
|
} inp;
|
|
|
|
//
|
|
// graph
|
|
//
|
|
|
|
virtual ggml_cgraph * graph_init() override;
|
|
|
|
//
|
|
// graph build
|
|
//
|
|
|
|
virtual ggml_tensor * build_inp_s_copy(
|
|
ggml_context * ctx0) override;
|
|
|
|
virtual ggml_tensor * build_inp_s_mask(
|
|
ggml_context * ctx0) override;
|
|
|
|
virtual ggml_tensor * build_copy_mask_state(
|
|
ggml_context * ctx0,
|
|
ggml_cgraph * gf,
|
|
ggml_tensor * s,
|
|
ggml_tensor * state_copy,
|
|
ggml_tensor * state_mask,
|
|
int32_t n_state,
|
|
int32_t n_seqs) override;
|
|
|
|
virtual ggml_tensor * build_mamba_layer(
|
|
ggml_context * ctx0,
|
|
ggml_cgraph * gf,
|
|
ggml_tensor * cur,
|
|
ggml_tensor * state_copy,
|
|
ggml_tensor * state_mask,
|
|
const llama_ubatch & ubatch,
|
|
int il) override;
|
|
|
|
virtual ggml_tensor * build_rwkv_token_shift_load(
|
|
ggml_context * ctx0,
|
|
ggml_cgraph * gf,
|
|
ggml_tensor * state_copy,
|
|
ggml_tensor * state_mask,
|
|
const llama_ubatch & ubatch,
|
|
int il) override;
|
|
|
|
virtual ggml_tensor * build_rwkv_token_shift_store(
|
|
ggml_context * ctx0,
|
|
ggml_tensor * token_shift,
|
|
const llama_ubatch & ubatch,
|
|
int il) override;
|
|
|
|
virtual ggml_tensor * build_rwkv6_time_mix(
|
|
ggml_context * ctx0,
|
|
ggml_cgraph * gf,
|
|
ggml_tensor * cur,
|
|
ggml_tensor * x_prev,
|
|
ggml_tensor * state_copy,
|
|
ggml_tensor * state_mask,
|
|
const llama_ubatch & ubatch,
|
|
int il) override;
|
|
|
|
//
|
|
// state save/load
|
|
//
|
|
|
|
virtual size_t state_get_data(llama_io_write_i & io) override;
|
|
virtual size_t state_set_data(llama_io_read_i & io) override;
|
|
|
|
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
|
|
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
|
|
|
|
//
|
|
// members
|
|
//
|
|
|
|
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?
|
|
llama_kv_cache_recurrent kv_self;
|
|
};
|
|
|
|
// For internal test use
|
|
// TODO: remove
|
|
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);
|