From fbe6a07256c36264bfbb0749d2285f397edf38bb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 12 Feb 2025 17:16:44 +0200 Subject: [PATCH] context : rename to llama_context_kv_self --- src/llama-context.cpp | 140 +++++++++++++++++++++--------------------- src/llama-context.h | 54 ++++++++-------- src/llama-graph.h | 3 + src/llama-model.h | 1 + src/llama.cpp | 2 +- 5 files changed, 102 insertions(+), 98 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 62f76f48b..665a144d7 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -332,10 +332,10 @@ void llama_context::perf_reset() { } // -// llama_context_unified +// llama_context_kv_self // -llama_context_unified::llama_context_unified( +llama_context_kv_self::llama_context_kv_self( const llama_model & model, const llama_context_params & params) : llama_context(model) { const auto & hparams = model.hparams; @@ -636,29 +636,29 @@ llama_context_unified::llama_context_unified( } } -llama_context_unified::~llama_context_unified() = default; +llama_context_kv_self::~llama_context_kv_self() = default; -uint32_t llama_context_unified::n_seq_max() const { +uint32_t llama_context_kv_self::n_seq_max() const { // TODO: add notion of n_seq_max to llama_kv_cache and use it here return kv_self.size; } -llama_kv_cache * llama_context_unified::get_kv_self() { +llama_kv_cache * llama_context_kv_self::get_kv_self() { return &kv_self; } -const llama_kv_cache * llama_context_unified::get_kv_self() const { +const llama_kv_cache * llama_context_kv_self::get_kv_self() const { return &kv_self; } -float * llama_context_unified::get_logits() { +float * llama_context_kv_self::get_logits() { // reorder logits for backward compatibility reorder_outputs(); return logits; } -float * llama_context_unified::get_logits_ith(int32_t i) { +float * llama_context_kv_self::get_logits_ith(int32_t i) { int32_t j = -1; try { @@ -696,14 +696,14 @@ float * llama_context_unified::get_logits_ith(int32_t i) { } } -float * llama_context_unified::get_embeddings() { +float * llama_context_kv_self::get_embeddings() { // reorder embeddings for backward compatibility reorder_outputs(); return embd; } -float * llama_context_unified::get_embeddings_ith(int32_t i) { +float * llama_context_kv_self::get_embeddings_ith(int32_t i) { int32_t j = -1; try { @@ -741,7 +741,7 @@ float * llama_context_unified::get_embeddings_ith(int32_t i) { } } -float * llama_context_unified::get_embeddings_seq(llama_seq_id seq_id) { +float * llama_context_kv_self::get_embeddings_seq(llama_seq_id seq_id) { auto it = embd_seq.find(seq_id); if (it == embd_seq.end()) { return nullptr; @@ -750,7 +750,7 @@ float * llama_context_unified::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } -ggml_context_ptr llama_context_unified::init() { +ggml_context_ptr llama_context_kv_self::init() { inp_tokens = nullptr; inp_embd = nullptr; inp_pos = nullptr; @@ -771,8 +771,8 @@ ggml_context_ptr llama_context_unified::init() { return llama_context::init(); } -struct llama_context_unified::batch_manager { - batch_manager(llama_context_unified & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) { +struct llama_context_kv_self::batch_manager { + batch_manager(llama_context_kv_self & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) { const auto & model = lctx.model; const auto & cparams = lctx.cparams; const auto & hparams = lctx.model.hparams; @@ -982,18 +982,18 @@ struct llama_context_unified::batch_manager { int64_t n_outputs_all = 0; - llama_context_unified & lctx; + llama_context_kv_self & lctx; const llama_batch & batch; llama_kv_slot_restorer kv_slot_restorer; }; -std::unique_ptr llama_context_unified::prepare_batch(const llama_batch & batch) { +std::unique_ptr llama_context_kv_self::prepare_batch(const llama_batch & batch) { return std::make_unique(*this, batch); } -int llama_context_unified::decode(llama_batch & inp_batch) { +int llama_context_kv_self::decode(llama_batch & inp_batch) { is_encoding = false; if (inp_batch.n_tokens == 0) { @@ -1198,7 +1198,7 @@ int llama_context_unified::decode(llama_batch & inp_batch) { return 0; } -int llama_context_unified::encode(llama_batch & inp_batch) { +int llama_context_kv_self::encode(llama_batch & inp_batch) { is_encoding = true; if (inp_batch.n_tokens == 0) { @@ -1375,7 +1375,7 @@ int llama_context_unified::encode(llama_batch & inp_batch) { return 0; } -enum ggml_status llama_context_unified::compute_graph( +enum ggml_status llama_context_kv_self::compute_graph( ggml_cgraph * graph, bool batched) { int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; @@ -1402,23 +1402,23 @@ enum ggml_status llama_context_unified::compute_graph( return status; } -llama_pos llama_context_unified::pos_max() const { +llama_pos llama_context_kv_self::pos_max() const { return kv_self.pos_max(); } -uint32_t llama_context_unified::get_ctx_padding(const llama_cparams & cparams) const { +uint32_t llama_context_kv_self::get_ctx_padding(const llama_cparams & cparams) const { return kv_self.get_padding(cparams); } -void llama_context_unified::prepare_k_shift() { +void llama_context_kv_self::prepare_k_shift() { } -void llama_context_unified::prepare_defrag() { +void llama_context_kv_self::prepare_defrag() { } // llama input -void llama_context_unified::set_inputs(const llama_ubatch & ubatch) { +void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) { const llama_hparams & hparams = model.hparams; // @@ -1837,7 +1837,7 @@ void llama_context_unified::set_inputs(const llama_ubatch & ubatch) { } } -void llama_context_unified::reorder_outputs() { +void llama_context_kv_self::reorder_outputs() { std::vector & out_ids = sbatch.out_ids; if (!out_ids.empty()) { const uint32_t n_vocab = model.vocab.n_tokens(); @@ -1875,7 +1875,7 @@ void llama_context_unified::reorder_outputs() { } } -size_t llama_context_unified::reserve_outputs(size_t n_outputs) { +size_t llama_context_kv_self::reserve_outputs(size_t n_outputs) { const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1944,7 +1944,7 @@ size_t llama_context_unified::reserve_outputs(size_t n_outputs) { return n_outputs_max; } -void llama_context_unified::kv_self_update() { +void llama_context_kv_self::kv_self_update() { auto & kv = kv_self; if (kv.has_shift) { @@ -2009,7 +2009,7 @@ void llama_context_unified::kv_self_update() { } } -void llama_context_unified::build_attn_inp( +void llama_context_kv_self::build_attn_inp( ggml_context * ctx0, int32_t n_tokens, bool causal, @@ -2040,7 +2040,7 @@ void llama_context_unified::build_attn_inp( } } -void llama_context_unified::build_attn_kv_store( +void llama_context_kv_self::build_attn_kv_store( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * k_cur, @@ -2084,7 +2084,7 @@ void llama_context_unified::build_attn_kv_store( ggml_build_forward_expand(graph, ggml_cpy(ctx0, v_cur, v_cache_view)); } -ggml_tensor * llama_context_unified::build_attn_qkv( +ggml_tensor * llama_context_kv_self::build_attn_qkv( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * wo, @@ -2236,7 +2236,7 @@ ggml_tensor * llama_context_unified::build_attn_qkv( return cur; } -ggml_tensor * llama_context_unified::build_soft_max_ext( +ggml_tensor * llama_context_kv_self::build_soft_max_ext( ggml_context * ctx0, ggml_tensor * kq, float kq_scale) { @@ -2245,7 +2245,7 @@ ggml_tensor * llama_context_unified::build_soft_max_ext( return ggml_soft_max_ext(ctx0, kq, inp_KQ_mask_cnv, kq_scale, hparams.f_max_alibi_bias); } -ggml_tensor * llama_context_unified::build_inp_embd( +ggml_tensor * llama_context_kv_self::build_inp_embd( ggml_context * ctx0, ggml_tensor * tok_embd, const llama_ubatch & ubatch) { @@ -2295,7 +2295,7 @@ ggml_tensor * llama_context_unified::build_inp_embd( return inpL; } -ggml_tensor * llama_context_unified::build_inp_pos( +ggml_tensor * llama_context_kv_self::build_inp_pos( ggml_context * ctx0, int32_t n_tokens) { inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token()); @@ -2304,7 +2304,7 @@ ggml_tensor * llama_context_unified::build_inp_pos( return inp_pos; } -ggml_tensor * llama_context_unified::build_inp_out_ids( +ggml_tensor * llama_context_kv_self::build_inp_out_ids( ggml_context * ctx0, int32_t n_tokens, bool worst_case) { @@ -2316,7 +2316,7 @@ ggml_tensor * llama_context_unified::build_inp_out_ids( return inp_out_ids; } -ggml_tensor * llama_context_unified::build_inp_mean( +ggml_tensor * llama_context_kv_self::build_inp_mean( ggml_context * ctx0, int32_t n_tokens) { inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); @@ -2325,7 +2325,7 @@ ggml_tensor * llama_context_unified::build_inp_mean( return inp_mean; } -ggml_tensor * llama_context_unified::build_inp_cls( +ggml_tensor * llama_context_kv_self::build_inp_cls( ggml_context * ctx0, int32_t n_tokens) { inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); @@ -2334,7 +2334,7 @@ ggml_tensor * llama_context_unified::build_inp_cls( return inp_cls; } -void llama_context_unified::build_k_shift( +void llama_context_kv_self::build_k_shift( ggml_context * ctx0, ggml_cgraph * graph) { const auto & n_ctx = cparams.n_ctx; @@ -2406,7 +2406,7 @@ void llama_context_unified::build_k_shift( } } -void llama_context_unified::build_defrag( +void llama_context_kv_self::build_defrag( ggml_context * ctx0, ggml_cgraph * graph) { const auto & hparams = model.hparams; @@ -2676,7 +2676,7 @@ void llama_context_unified::build_defrag( #endif } -ggml_tensor * llama_context_unified::build_inp_embd_enc( +ggml_tensor * llama_context_kv_self::build_inp_embd_enc( ggml_context * ctx0, int32_t n_tokens, bool worst_case) { @@ -2692,7 +2692,7 @@ ggml_tensor * llama_context_unified::build_inp_embd_enc( return inp_embd_enc; } -ggml_tensor * llama_context_unified::build_inp_KQ_mask_cross( +ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross( ggml_context * ctx0, int32_t n_tokens, bool worst_case) { @@ -2708,7 +2708,7 @@ ggml_tensor * llama_context_unified::build_inp_KQ_mask_cross( return inp_KQ_mask_cross; } -ggml_tensor * llama_context_unified::build_inp_s_copy( +ggml_tensor * llama_context_kv_self::build_inp_s_copy( ggml_context * ctx0, bool worst_case) { const auto n_kv = worst_case ? kv_self.size : kv_self.n; @@ -2719,7 +2719,7 @@ ggml_tensor * llama_context_unified::build_inp_s_copy( return inp_s_copy; } -ggml_tensor * llama_context_unified::build_inp_s_mask( +ggml_tensor * llama_context_kv_self::build_inp_s_mask( ggml_context * ctx0, bool worst_case) { const auto n_kv = worst_case ? kv_self.size : kv_self.n; @@ -2729,7 +2729,7 @@ ggml_tensor * llama_context_unified::build_inp_s_mask( return inp_s_mask; } -ggml_tensor * llama_context_unified::build_copy_mask_state( +ggml_tensor * llama_context_kv_self::build_copy_mask_state( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * s, @@ -2764,7 +2764,7 @@ ggml_tensor * llama_context_unified::build_copy_mask_state( } // TODO: split -ggml_tensor * llama_context_unified::build_mamba_layer( +ggml_tensor * llama_context_kv_self::build_mamba_layer( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * cur, @@ -2900,7 +2900,7 @@ ggml_tensor * llama_context_unified::build_mamba_layer( } -ggml_tensor * llama_context_unified::build_rwkv_token_shift_load( +ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * state_copy, @@ -2927,7 +2927,7 @@ ggml_tensor * llama_context_unified::build_rwkv_token_shift_load( } -ggml_tensor * llama_context_unified::build_rwkv_token_shift_store( +ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store( ggml_context * ctx0, ggml_tensor * token_shift, const llama_ubatch & ubatch, @@ -2951,7 +2951,7 @@ ggml_tensor * llama_context_unified::build_rwkv_token_shift_store( } -ggml_tensor * llama_context_unified::build_rwkv6_time_mix( +ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * cur, @@ -3130,7 +3130,7 @@ ggml_tensor * llama_context_unified::build_rwkv6_time_mix( // TODO: replace all non-fatal assertions with returned errors or exceptions struct llama_data_write { - llama_data_write(llama_context_unified * ctx) : ctx(ctx) {} + llama_data_write(llama_context_kv_self * ctx) : ctx(ctx) {} virtual ~llama_data_write() = default; virtual void write(const void * src, size_t size) = 0; @@ -3215,11 +3215,11 @@ struct llama_data_write { } } - llama_context_unified * ctx; + llama_context_kv_self * ctx; }; struct llama_data_read { - llama_data_read(llama_context_unified * ctx) : ctx(ctx) {} + llama_data_read(llama_context_kv_self * ctx) : ctx(ctx) {} virtual ~llama_data_read() = default; virtual const uint8_t * read(size_t size) = 0; @@ -3311,11 +3311,11 @@ struct llama_data_read { } } - llama_context_unified * ctx; + llama_context_kv_self * ctx; }; struct llama_data_write_dummy : llama_data_write { - llama_data_write_dummy(llama_context_unified * ctx) : llama_data_write(ctx) {} + llama_data_write_dummy(llama_context_kv_self * ctx) : llama_data_write(ctx) {} void write(const void * /* src */, size_t size) override { size_written += size; @@ -3334,7 +3334,7 @@ struct llama_data_write_dummy : llama_data_write { struct llama_data_write_buffer : llama_data_write { llama_data_write_buffer( - llama_context_unified * ctx, + llama_context_kv_self * ctx, uint8_t * p, size_t len) : llama_data_write(ctx), ptr(p), buf_size(len) {} void write(const void * src, size_t size) override { @@ -3368,7 +3368,7 @@ struct llama_data_write_buffer : llama_data_write { struct llama_data_read_buffer : llama_data_read { llama_data_read_buffer( - llama_context_unified * ctx, + llama_context_kv_self * ctx, const uint8_t * p, size_t len) : llama_data_read(ctx), ptr(p), buf_size(len) {} const uint8_t * read(size_t size) override { @@ -3397,7 +3397,7 @@ struct llama_data_read_buffer : llama_data_read { struct llama_data_write_file : llama_data_write { llama_data_write_file( - llama_context_unified * ctx, + llama_context_kv_self * ctx, llama_file * f) : llama_data_write(ctx), file(f) {} void write(const void * src, size_t size) override { @@ -3422,7 +3422,7 @@ struct llama_data_write_file : llama_data_write { struct llama_data_read_file : llama_data_read { llama_data_read_file( - llama_context_unified * ctx, + llama_context_kv_self * ctx, llama_file * f) : llama_data_read(ctx), file(f) {} void read_to(void * dst, size_t size) override { @@ -3445,7 +3445,7 @@ struct llama_data_read_file : llama_data_read { std::vector temp_buffer; }; -size_t llama_context_unified::state_get_size() { +size_t llama_context_kv_self::state_get_size() { llama_data_write_dummy data_ctx(this); try { return state_get_data(data_ctx); @@ -3455,7 +3455,7 @@ size_t llama_context_unified::state_get_size() { } } -size_t llama_context_unified::state_get_data(uint8_t * dst, size_t size) { +size_t llama_context_kv_self::state_get_data(uint8_t * dst, size_t size) { llama_data_write_buffer data_ctx(this, dst, size); try { return state_get_data(data_ctx); @@ -3465,7 +3465,7 @@ size_t llama_context_unified::state_get_data(uint8_t * dst, size_t size) { } } -size_t llama_context_unified::state_set_data(const uint8_t * src, size_t size) { +size_t llama_context_kv_self::state_set_data(const uint8_t * src, size_t size) { llama_data_read_buffer data_ctx(this, src, size); try { return state_set_data(data_ctx); @@ -3475,7 +3475,7 @@ size_t llama_context_unified::state_set_data(const uint8_t * src, size_t size) { } } -size_t llama_context_unified::state_seq_get_size(llama_seq_id seq_id) { +size_t llama_context_kv_self::state_seq_get_size(llama_seq_id seq_id) { llama_data_write_dummy data_ctx(this); try { return state_seq_get_data(data_ctx, seq_id); @@ -3485,7 +3485,7 @@ size_t llama_context_unified::state_seq_get_size(llama_seq_id seq_id) { } } -size_t llama_context_unified::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { +size_t llama_context_kv_self::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { llama_data_write_buffer data_ctx(this, dst, size); try { return state_seq_get_data(data_ctx, seq_id); @@ -3495,7 +3495,7 @@ size_t llama_context_unified::state_seq_get_data(llama_seq_id seq_id, uint8_t * } } -size_t llama_context_unified::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { +size_t llama_context_kv_self::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { llama_data_read_buffer data_ctx(this, src, size); try { return state_seq_set_data(data_ctx, seq_id); @@ -3505,7 +3505,7 @@ size_t llama_context_unified::state_seq_set_data(llama_seq_id seq_id, const uint } } -bool llama_context_unified::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +bool llama_context_kv_self::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(filepath, "rb"); // sanity checks @@ -3548,7 +3548,7 @@ bool llama_context_unified::state_load_file(const char * filepath, llama_token * return true; } -bool llama_context_unified::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { +bool llama_context_kv_self::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { llama_file file(filepath, "wb"); file.write_u32(LLAMA_SESSION_MAGIC); @@ -3565,7 +3565,7 @@ bool llama_context_unified::state_save_file(const char * filepath, const llama_t return true; } -size_t llama_context_unified::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +size_t llama_context_kv_self::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(filepath, "rb"); // version checks @@ -3608,7 +3608,7 @@ size_t llama_context_unified::state_seq_load_file(llama_seq_id seq_id, const cha return file.tell(); } -size_t llama_context_unified::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { +size_t llama_context_kv_self::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { llama_file file(filepath, "wb"); file.write_u32(LLAMA_STATE_SEQ_MAGIC); @@ -3641,7 +3641,7 @@ size_t llama_context_unified::state_seq_save_file(llama_seq_id seq_id, const cha * llama_state_get_data_internal(ctx, data_ctx); * */ -size_t llama_context_unified::state_get_data(llama_data_write & data_ctx) { +size_t llama_context_kv_self::state_get_data(llama_data_write & data_ctx) { synchronize(); data_ctx.write_model_info(); @@ -3667,7 +3667,7 @@ size_t llama_context_unified::state_get_data(llama_data_write & data_ctx) { return data_ctx.get_size_written(); } -size_t llama_context_unified::state_set_data(llama_data_read & data_ctx) { +size_t llama_context_kv_self::state_set_data(llama_data_read & data_ctx) { synchronize(); data_ctx.read_model_info(); @@ -3693,7 +3693,7 @@ size_t llama_context_unified::state_set_data(llama_data_read & data_ctx) { return data_ctx.get_size_read(); } -size_t llama_context_unified::state_seq_get_data(llama_data_write & data_ctx, llama_seq_id seq_id) { +size_t llama_context_kv_self::state_seq_get_data(llama_data_write & data_ctx, llama_seq_id seq_id) { synchronize(); llama_kv_cache::io io = { @@ -3712,7 +3712,7 @@ size_t llama_context_unified::state_seq_get_data(llama_data_write & data_ctx, ll return data_ctx.get_size_written(); } -size_t llama_context_unified::state_seq_set_data(llama_data_read & data_ctx, llama_seq_id seq_id) { +size_t llama_context_kv_self::state_seq_set_data(llama_data_read & data_ctx, llama_seq_id seq_id) { synchronize(); llama_kv_cache::io io = { diff --git a/src/llama-context.h b/src/llama-context.h index dc85c7971..648a41045 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -82,6 +82,8 @@ struct llama_context : public llama_graph_i { int32_t il_start, int32_t il_end); + // graph build API (generic) + virtual void build_cb( ggml_tensor * cur, const char * name, @@ -91,6 +93,27 @@ struct llama_context : public llama_graph_i { // TODO: add encode/decode graphs virtual ggml_cgraph * build_graph(const llama_ubatch & ubatch, bool worst_case); + // apply control vector for layer il + virtual ggml_tensor * build_cvec( + ggml_context * ctx0, + ggml_tensor * cur, + int il); + + // do mat_mul, while optionally apply lora + virtual ggml_tensor * build_lora_mm( + ggml_context * ctx0, + ggml_tensor * w, + ggml_tensor * cur); + + // do mat_mul_id, while optionally apply lora + virtual ggml_tensor * build_lora_mm_id( + ggml_context * ctx0, + ggml_tensor * w, // struct ggml_tensor * as + ggml_tensor * cur, // struct ggml_tensor * b + ggml_tensor * ids); + + virtual ggml_tensor * build_rope_factors(int il); + // decode a batch of tokens by evaluating the transformer // in case of unsuccessful decoding (error or warning), // the kv_cache state will be returned to its original state @@ -116,29 +139,6 @@ struct llama_context : public llama_graph_i { // virtual int encode(llama_batch & inp_batch) = 0; - // graph build API (generic) - - // apply control vector for layer il - virtual ggml_tensor * build_cvec( - ggml_context * ctx0, - ggml_tensor * cur, - int il); - - // do mat_mul, while optionally apply lora - virtual ggml_tensor * build_lora_mm( - ggml_context * ctx0, - ggml_tensor * w, - ggml_tensor * cur); - - // do mat_mul_id, while optionally apply lora - virtual ggml_tensor * build_lora_mm_id( - ggml_context * ctx0, - ggml_tensor * w, // struct ggml_tensor * as - ggml_tensor * cur, // struct ggml_tensor * b - ggml_tensor * ids); - - virtual ggml_tensor * build_rope_factors(int il); - // state save/load virtual size_t state_get_size() = 0; @@ -217,16 +217,16 @@ protected: mutable int32_t n_eval = 0; // number of eval calls }; -// TODO: make implementation details private -class llama_context_unified : public llama_context { +// transformer with a self-attention KV cache +class llama_context_kv_self : public llama_context { public: struct batch_manager; - llama_context_unified( + llama_context_kv_self( const llama_model & model, const llama_context_params & params); - virtual ~llama_context_unified(); + virtual ~llama_context_kv_self(); virtual uint32_t n_seq_max() const override; diff --git a/src/llama-graph.h b/src/llama-graph.h index d111d76e9..5267d53da 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -2,6 +2,9 @@ #include +// note: do not add high-level objects here, such as llama_context, llama_kv_cache, etc. +// not sure about llama_batch/llama_sbatch yet + struct ggml_cgraph; struct ggml_context; struct ggml_tensor; diff --git a/src/llama-model.h b/src/llama-model.h index 5d2a07abc..0374b484b 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -368,6 +368,7 @@ struct llama_model { const struct ggml_tensor * get_tensor(const char * name) const; // TODO: add encode/decode graphs + // TODO: return a struct containing the graph and the output tensors, such as logits, embeddings, etc. ggml_cgraph * build_graph( llama_graph_i & lgf, const llama_cparams & cparams, diff --git a/src/llama.cpp b/src/llama.cpp index 83b66035f..d20a2a6d5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -327,7 +327,7 @@ struct llama_context * llama_init_from_model( try { // TODO: add logic which llama_context implementation to construct - ctx = new llama_context_unified(*model, params); + ctx = new llama_context_kv_self(*model, params); } catch (const std::exception & e) { LLAMA_LOG_ERROR("%s: failed to initialize context: %s\n", __func__, e.what()); return nullptr;