#pragma once #include "llama.h" #include "llama-batch.h" #include "llama-cparams.h" #include "llama-graph.h" #include "llama-model.h" #include "llama-kv-cache.h" #include "llama-adapter.h" #include "ggml-cpp.h" #include #include #include #include class llama_io_read_i; class llama_io_write_i; using llama_loras = std::unordered_map; // basic transformer without KV cache struct llama_context : public llama_graph_i { public: llama_context( const llama_model & model, const llama_context_params & params); virtual ~llama_context(); // init scheduler and compute buffers, reserve worst-case graphs // call once after the context is constructed virtual void init(); virtual void synchronize(); protected: // called by init() to reserve the worst-case graphs // override in child classes virtual void reserve(); public: const llama_model & get_model() const; const llama_cparams & get_cparams() const; virtual uint32_t n_ctx() const; virtual uint32_t n_ctx_per_seq() const; virtual uint32_t n_batch() const; virtual uint32_t n_ubatch() const; virtual uint32_t n_seq_max() const; virtual uint32_t n_threads() const; virtual uint32_t n_threads_batch() const; virtual int32_t max_nodes() const; // self-attention: // if the context does not have a KV cache, return nullptr virtual llama_kv_cache * get_kv_self(); virtual const llama_kv_cache * get_kv_self() const; // if the context does not have a KV cache, noop virtual void kv_self_update(); virtual enum llama_pooling_type pooling_type() const; virtual float * get_logits(); virtual float * get_logits_ith(int32_t i); virtual float * get_embeddings(); virtual float * get_embeddings_ith(int32_t i); virtual float * get_embeddings_seq(llama_seq_id seq_id); virtual int64_t n_pos_per_token() const; // vision virtual void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); virtual void detach_threadpool(); virtual void set_n_threads(int32_t n_threads, int32_t n_threads_batch); virtual void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); virtual void set_embeddings (bool value); virtual void set_causal_attn(bool value); virtual void set_adapter_lora( llama_adapter_lora * adapter, float scale); virtual bool rm_adapter_lora( llama_adapter_lora * adapter); virtual void clear_adapter_lora(); virtual bool apply_adapter_cvec( const float * data, size_t len, int32_t n_embd, int32_t il_start, int32_t il_end); // encode a batch of tokens by evaluating the encoder part of the transformer // // - lctx: llama context // - batch: batch to evaluate // // return 0 on success // return positive int on warning // return negative int on error // virtual int encode(llama_batch & inp_batch); // decode a batch of tokens by evaluating the transformer // in case of unsuccessful decoding (error or warning), // the kv_cache state will be returned to its original state // (for non-recurrent models) or cleaned (for recurrent models) // // - lctx: llama context // - inp_batch: batch to evaluate // // return 0 on success // return positive int on warning // return negative int on error // virtual int decode(llama_batch & inp_batch); protected: // // input // // when the compute graph is built, it creates the input tensors that it needs // the contents of the input tensors are set by the input_set() function virtual void input_set(const llama_ubatch & ubatch); struct { // base input tensors ggml_tensor * tokens; // I32 [n_batch] ggml_tensor * embd; // F32 [n_embd, n_batch] ggml_tensor * pos; // I32 [n_batch] ggml_tensor * out_ids; // I32 [n_outputs] ggml_tensor * mean; // F32 [n_batch, n_batch] ggml_tensor * cls; // I32 [n_batch] // KQ mask input tensors ggml_tensor * kq_mask; // F32 [n_tokens, n_batch] ggml_tensor * kq_mask_cnv; // [n_tokens, n_batch] } inp; // // output // // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. virtual int32_t output_reserve(int32_t n_outputs); // make the outputs have the same order they had in the user-provided batch // TODO: maybe remove this virtual void output_reorder(); // // graph // // zero-out inputs and create the ctx_context for the compute graph virtual ggml_cgraph * graph_init(); // TODO: add encode/decode graphs virtual llama_graph_result graph_build( ggml_context * ctx, ggml_cgraph * gf, const llama_ubatch & ubatch); // returns the result of ggml_backend_sched_graph_compute_async execution virtual enum ggml_status graph_compute( ggml_cgraph * gf, bool batched); ggml_context_ptr ctx_compute; // // graph build API (generic) // virtual void build_cb( ggml_tensor * cur, const char * name, const llama_ubatch & ubatch, int il); // apply control vector for layer il virtual ggml_tensor * build_cvec( ggml_context * ctx0, ggml_tensor * cur, int il); // do mat_mul, while optionally apply lora virtual ggml_tensor * build_lora_mm( ggml_context * ctx0, ggml_tensor * w, ggml_tensor * cur); // do mat_mul_id, while optionally apply lora virtual ggml_tensor * build_lora_mm_id( ggml_context * ctx0, ggml_tensor * w, // struct ggml_tensor * as ggml_tensor * cur, // struct ggml_tensor * b ggml_tensor * ids); virtual ggml_tensor * build_rope_factors(int il); virtual ggml_tensor * build_rope_shift( ggml_context * ctx0, ggml_tensor * cur, ggml_tensor * shift, ggml_tensor * factors, ggml_backend_buffer * bbuf); virtual ggml_tensor * build_inp_embd( ggml_context * ctx0, ggml_tensor * tok_embd, const llama_ubatch & ubatch); virtual ggml_tensor * build_inp_pos( ggml_context * ctx0, int32_t n_tokens); virtual ggml_tensor * build_inp_out_ids( ggml_context * ctx0); virtual ggml_tensor * build_inp_mean( ggml_context * ctx0, int32_t n_tokens); virtual ggml_tensor * build_inp_cls( ggml_context * ctx0, int32_t n_tokens); virtual void build_attn_inp( ggml_context * ctx0, int32_t n_tokens, bool causal, bool swa); virtual ggml_tensor * build_attn( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, int32_t n_tokens, float kq_scale, int il); public: // // perf // virtual llama_perf_context_data perf_get_data() const; virtual void perf_reset(); protected: mutable int64_t t_start_us = 0; mutable int64_t t_load_us = 0; mutable int64_t t_p_eval_us = 0; mutable int64_t t_eval_us = 0; mutable int64_t t_compute_start_us = 0; mutable int64_t n_queued_tokens = 0; mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) mutable int32_t n_eval = 0; // number of eval calls public: // // state save/load // virtual size_t state_get_size(); virtual size_t state_get_data( uint8_t * dst, size_t size); virtual size_t state_set_data(const uint8_t * src, size_t size); virtual size_t state_seq_get_size(llama_seq_id seq_id); virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size); virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size); virtual bool state_load_file( const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); virtual bool state_save_file( const char * filepath, const llama_token * tokens, size_t n_token_count); virtual size_t state_seq_load_file( llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); virtual size_t state_seq_save_file( llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count); protected: virtual size_t state_get_data(llama_io_write_i & io); virtual size_t state_set_data(llama_io_read_i & io); virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id); virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id); // // members // const llama_model & model; llama_cparams cparams; llama_adapter_cvec cvec; llama_loras loras; llama_sbatch sbatch; ggml_threadpool_t threadpool = nullptr; ggml_threadpool_t threadpool_batch = nullptr; ggml_abort_callback abort_callback = nullptr; void * abort_callback_data = nullptr; ggml_backend_t backend_cpu = nullptr; std::vector backends; std::vector> set_n_threads_fns; ggml_backend_sched_ptr sched; // buffer types used for the compute buffer of each backend std::vector backend_ptrs; std::vector backend_buft; // memory buffers used to evaluate the model std::vector buf_compute_meta; // host buffer for the model output (logits and embeddings) ggml_backend_buffer_ptr buf_output; // TODO: remove bool logits_all = false; // decode output (2-dimensional array: [n_outputs][n_vocab]) size_t logits_size = 0; // capacity (of floats) for logits float * logits = nullptr; // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE size_t embd_size = 0; // capacity (of floats) for embeddings float * embd = nullptr; // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; int32_t output_size = 0; // capacity (of tokens positions) for the output buffers int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch std::vector output_ids; // map batch token positions to ids of the logits and embd buffers bool has_evaluated_once = false; }; // transformer with a self-attention KV cache class llama_context_kv_self : public llama_context { public: llama_context_kv_self( const llama_model & model, const llama_context_params & params); virtual ~llama_context_kv_self(); protected: virtual void reserve() override; public: virtual llama_kv_cache * get_kv_self() override; virtual const llama_kv_cache * get_kv_self() const override; virtual void kv_self_update() override; virtual int encode(llama_batch & inp_batch) override; virtual int decode(llama_batch & inp_batch) override; protected: // // input // virtual void input_set(const llama_ubatch & ubatch) override; struct { ggml_tensor * self_kq_mask; // F32 [kv_size, n_batch] ggml_tensor * self_kq_mask_cnv; // [kv_size, n_batch] ggml_tensor * self_kq_mask_swa; // F32 [kv_size, n_batch] ggml_tensor * self_kq_mask_swa_cnv; // [kv_size, n_batch] ggml_tensor * self_k_shift; // I32 [kv_size] } inp; // // graph // virtual ggml_cgraph * graph_init() override; // // graph build // virtual ggml_tensor * build_inp_self_k_shift(ggml_context * ctx0) override; virtual void build_attn_inp( ggml_context * ctx0, int32_t n_tokens, bool causal, bool swa) override; virtual ggml_tensor * build_attn( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, int32_t n_tokens, float kq_scale, int il) override; virtual void build_kv_self_shift( ggml_context * ctx0, ggml_cgraph * gf) override; // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache virtual void build_kv_self_defrag( ggml_context * ctx0, ggml_cgraph * gf) override; // === encoder-decoder === // whether we are computing encoder output or decoder output bool is_encoding = false; // output of the encoder part of the encoder-decoder models std::vector embd_enc; std::vector> seq_ids_enc; struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_kq_mask_cross; // F32 [n_outputs_enc, n_batch] virtual ggml_tensor * build_inp_embd_enc( ggml_context * ctx0) override; virtual ggml_tensor * build_inp_kq_mask_cross( ggml_context * ctx0, int32_t n_tokens) override; // // state save/load // virtual size_t state_get_data(llama_io_write_i & io) override; virtual size_t state_set_data(llama_io_read_i & io) override; virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override; virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override; // // members // llama_kv_cache kv_self; }; // a recurrent transformer (ie.e RWKV, Mamba) class llama_context_recurrent : public llama_context { public: llama_context_recurrent( const llama_model & model, const llama_context_params & params); virtual ~llama_context_recurrent(); protected: virtual void reserve() override; public: virtual llama_kv_cache * get_kv_self() override; virtual const llama_kv_cache * get_kv_self() const override; virtual void kv_self_update() override; virtual int encode(llama_batch & inp_batch) override; virtual int decode(llama_batch & inp_batch) override; protected: // // input // virtual void input_set(const llama_ubatch & ubatch) override; struct { ggml_tensor * s_copy; // I32 [kv_size] ggml_tensor * s_mask; // F32 [1, n_kv] } inp; // // graph // virtual ggml_cgraph * graph_init() override; // // graph build // virtual ggml_tensor * build_inp_s_copy( ggml_context * ctx0) override; virtual ggml_tensor * build_inp_s_mask( ggml_context * ctx0) override; virtual ggml_tensor * build_copy_mask_state( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, ggml_tensor * state_mask, int32_t n_state, int32_t n_seqs) override; virtual ggml_tensor * build_mamba_layer( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) override; virtual ggml_tensor * build_rwkv_token_shift_load( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) override; virtual ggml_tensor * build_rwkv_token_shift_store( ggml_context * ctx0, ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) override; virtual ggml_tensor * build_rwkv6_time_mix( ggml_context * ctx0, ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * x_prev, ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) override; // // state save/load // virtual size_t state_get_data(llama_io_write_i & io) override; virtual size_t state_set_data(llama_io_read_i & io) override; virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override; virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override; // // members // // TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models? llama_kv_cache_recurrent kv_self; }; // For internal test use // TODO: remove const std::vector> & llama_internal_get_tensor_map(struct llama_context * ctx);