diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 87d6642da..13beb097c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -33,14 +33,68 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t return relative_bucket; } -llama_context::llama_context( +// llama_context + +llama_context::llama_context(const llama_model & model) : + model (model), + t_start_us(model.t_start_us), + t_load_us (model.t_load_us) { +} + +llama_context::~llama_context() = default; + +void llama_context::synchronize() { + ggml_backend_sched_synchronize(sched.get()); + + // FIXME: if multiple single tokens are evaluated without a synchronization, + // the stats will be added to the prompt evaluation stats + // this should only happen when using batch size 1 to evaluate a batch + + // add the evaluation to the stats + if (n_queued_tokens == 1) { + if (!cparams.no_perf) { + t_eval_us += ggml_time_us() - t_compute_start_us; + } + n_eval++; + } else if (n_queued_tokens > 1) { + if (!cparams.no_perf) { + t_p_eval_us += ggml_time_us() - t_compute_start_us; + } + n_p_eval += n_queued_tokens; + } + + // get a more accurate load time, upon first eval + if (n_queued_tokens > 0 && !has_evaluated_once) { + t_load_us = ggml_time_us() - t_start_us; + has_evaluated_once = true; + } + + n_queued_tokens = 0; + t_compute_start_us = 0; +} + +int64_t llama_context::n_pos_per_token() const { + return model.arch == LLM_ARCH_QWEN2VL ? 4 : 1; +} + +ggml_context_ptr llama_context::init() { + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + return ggml_context_ptr { ggml_init(params) }; +} + +// llama_context_unified + +llama_context_unified::llama_context_unified( const llama_model & model, const llama_context_params & params, build_graph_callback && cb_build_graph) : - model(model), - cb_build_graph(std::move(cb_build_graph)), - t_start_us(model.t_start_us), - t_load_us (model.t_load_us) { + llama_context(model), + cb_build_graph(std::move(cb_build_graph)){ const auto & hparams = model.hparams; @@ -252,6 +306,7 @@ llama_context::llama_context( const size_t max_nodes = model.max_nodes(); // buffer used to store the computation graph and the tensor meta data + // TODO: move to base class buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); // TODO: move these checks to ggml_backend_sched @@ -337,25 +392,161 @@ llama_context::llama_context( } } } - } -struct llama_batch_manager_i { - virtual ~llama_batch_manager_i() = default; +llama_context_unified::~llama_context_unified() = default; - virtual bool is_done() const = 0; - virtual llama_ubatch next() = 0; - virtual bool prepare(const llama_ubatch & ubatch) = 0; - virtual void restore() = 0; - virtual void update(const llama_ubatch & ubatch) = 0; - virtual void finalize() = 0; +uint32_t llama_context_unified::n_ctx() const { + return cparams.n_ctx; +} - // TODO: might be temporary - int64_t n_outputs_all = 0; -}; +uint32_t llama_context_unified::n_batch() const { + return cparams.n_batch; +} -struct llama_batch_manager : public llama_batch_manager_i { - llama_batch_manager(llama_context & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) { +uint32_t llama_context_unified::n_ubatch() const { + return cparams.n_ubatch; +} + +uint32_t llama_context_unified::n_seq_max() const { + // TODO: add notion of n_seq_max to llama_kv_cache and use it here + return kv_self.size; +} + +llama_kv_cache * llama_context_unified::get_kv_self() { + return &kv_self; +} + +const llama_kv_cache * llama_context_unified::get_kv_self() const { + return &kv_self; +} + +enum llama_pooling_type llama_context_unified::pooling_type() const { + return cparams.pooling_type; +} + +float * llama_context_unified::get_logits() { + // reorder logits for backward compatibility + reorder_outputs(); + + return logits; +} + +float * llama_context_unified::get_logits_ith(int32_t i) { + int32_t j = -1; + + try { + if (logits == nullptr) { + throw std::runtime_error("no logits"); + } + + if (i < 0) { + j = n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); + } + } else if ((size_t) i >= output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); + } else { + j = output_ids[i]; + } + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if (j >= n_outputs) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs)); + } + + return logits + j*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + +float * llama_context_unified::get_embeddings() { + // reorder embeddings for backward compatibility + reorder_outputs(); + + return embd; +} + +float * llama_context_unified::get_embeddings_ith(int32_t i) { + int32_t j = -1; + + try { + if (embd == nullptr) { + throw std::runtime_error("no embeddings"); + } + + if (i < 0) { + j = n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); + } + } else if ((size_t) i >= output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); + } else { + j = output_ids[i]; + } + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if (j >= n_outputs) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs)); + } + + return embd + j*model.hparams.n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + +float * llama_context_unified::get_embeddings_seq(llama_seq_id seq_id) { + auto it = embd_seq.find(seq_id); + if (it == embd_seq.end()) { + return nullptr; + } + + return it->second.data(); +} + +ggml_context_ptr llama_context_unified::init() { + inp_tokens = nullptr; + inp_embd = nullptr; + inp_pos = nullptr; + inp_out_ids = nullptr; + inp_mean = nullptr; + inp_cls = nullptr; + inp_embd_enc = nullptr; + inp_pos_bucket = nullptr; + inp_KQ_mask = nullptr; + inp_KQ_mask_cnv = nullptr; + inp_KQ_mask_swa = nullptr; + inp_KQ_mask_swa_cnv = nullptr; + inp_KQ_mask_cross = nullptr; + inp_K_shift = nullptr; + inp_s_copy = nullptr; + inp_s_mask = nullptr; + + return llama_context::init(); +} + +struct llama_context_unified::batch_manager { + batch_manager(llama_context_unified & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) { const auto & model = lctx.model; const auto & cparams = lctx.cparams; const auto & hparams = lctx.model.hparams; @@ -409,14 +600,14 @@ struct llama_batch_manager : public llama_batch_manager_i { /* logits_all */ logits_all); } - ~llama_batch_manager() override { + ~batch_manager() { } - virtual bool is_done() const override { + bool is_done() const { return lctx.sbatch.n_tokens == 0; } - virtual llama_ubatch next() override { + llama_ubatch next() { llama_ubatch ubatch = llama_ubatch(); const auto & cparams = lctx.cparams; @@ -442,7 +633,7 @@ struct llama_batch_manager : public llama_batch_manager_i { return ubatch; } - virtual bool prepare(const llama_ubatch & ubatch) override { + bool prepare(const llama_ubatch & ubatch) { const auto & cparams = lctx.cparams; const auto & hparams = lctx.model.hparams; const auto & batch = lctx.sbatch.batch; @@ -525,11 +716,11 @@ struct llama_batch_manager : public llama_batch_manager_i { return true; } - virtual void restore() override { + void restore() { kv_slot_restorer.restore(lctx.kv_self); } - virtual void update(const llama_ubatch & ubatch) override { + void update(const llama_ubatch & ubatch) { auto & kv_self = lctx.kv_self; // update the kv ring buffer @@ -543,7 +734,7 @@ struct llama_batch_manager : public llama_batch_manager_i { } } - virtual void finalize() override { + void finalize() { const auto & cparams = lctx.cparams; auto & kv_self = lctx.kv_self; @@ -563,18 +754,20 @@ struct llama_batch_manager : public llama_batch_manager_i { } } - llama_context & lctx; + int64_t n_outputs_all = 0; + + llama_context_unified & lctx; const llama_batch & batch; llama_kv_slot_restorer kv_slot_restorer; }; -std::unique_ptr llama_context::prepare_batch(const llama_batch & batch) { - return std::make_unique(*this, batch); +std::unique_ptr llama_context_unified::prepare_batch(const llama_batch & batch) { + return std::make_unique(*this, batch); } -int llama_context::decode(llama_batch & inp_batch) { +int llama_context_unified::decode(llama_batch & inp_batch) { is_encoding = false; if (inp_batch.n_tokens == 0) { @@ -679,12 +872,11 @@ int llama_context::decode(llama_batch & inp_batch) { GGML_ASSERT(logits != nullptr); float * logits_out = logits + n_outputs_prev*n_vocab; - const int32_t n_outputs_new = n_outputs; - if (n_outputs_new) { - GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) logits_size); - ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs_new*n_vocab*sizeof(float)); + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); } } @@ -699,12 +891,11 @@ int llama_context::decode(llama_batch & inp_batch) { // extract token embeddings GGML_ASSERT(embd != nullptr); float * embd_out = embd + n_outputs_prev*n_embd; - const int32_t n_outputs_new = n_outputs; - if (n_outputs_new) { - GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float)); + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_MEAN: @@ -770,7 +961,7 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs = n_outputs_all; // wait for the computation to finish (automatically done when obtaining the model output) - //llama_synchronize(&; + //synchronize(); bman->finalize(); @@ -781,7 +972,7 @@ int llama_context::decode(llama_batch & inp_batch) { return 0; } -int llama_context::encode(llama_batch & inp_batch) { +int llama_context_unified::encode(llama_batch & inp_batch) { is_encoding = true; if (inp_batch.n_tokens == 0) { @@ -958,7 +1149,7 @@ int llama_context::encode(llama_batch & inp_batch) { return 0; } -enum ggml_status llama_context::compute_graph( +enum ggml_status llama_context_unified::compute_graph( ggml_cgraph * graph, bool batched) { int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads; @@ -985,43 +1176,23 @@ enum ggml_status llama_context::compute_graph( return status; } -llama_pos llama_context::pos_max() const { +llama_pos llama_context_unified::pos_max() const { return kv_self.pos_max(); } -uint32_t llama_context::get_ctx_padding(const llama_cparams & cparams) const { +uint32_t llama_context_unified::get_ctx_padding(const llama_cparams & cparams) const { return kv_self.get_padding(cparams); } -// TODO: improve -void llama_context::reset() { - inp_tokens = nullptr; - inp_embd = nullptr; - inp_pos = nullptr; - inp_out_ids = nullptr; - inp_mean = nullptr; - inp_cls = nullptr; - inp_embd_enc = nullptr; - inp_pos_bucket = nullptr; - inp_KQ_mask = nullptr; - inp_KQ_mask_cnv = nullptr; - inp_KQ_mask_swa = nullptr; - inp_KQ_mask_swa_cnv = nullptr; - inp_KQ_mask_cross = nullptr; - inp_K_shift = nullptr; - inp_s_copy = nullptr; - inp_s_mask = nullptr; +void llama_context_unified::prepare_k_shift() { } -void llama_context::prepare_k_shift() { -} - -void llama_context::prepare_defrag() { +void llama_context_unified::prepare_defrag() { } // llama input -void llama_context::set_inputs(const llama_ubatch & ubatch) { +void llama_context_unified::set_inputs(const llama_ubatch & ubatch) { const llama_hparams & hparams = model.hparams; // @@ -1056,8 +1227,8 @@ void llama_context::set_inputs(const llama_ubatch & ubatch) { if (ubatch.pos && inp_pos) { const int64_t n_tokens = ubatch.n_tokens; - auto n_pos = n_pos_per_token; - ggml_backend_tensor_set(inp_pos, ubatch.pos, 0, n_tokens*n_pos*ggml_element_size(inp_pos)); + + ggml_backend_tensor_set(inp_pos, ubatch.pos, 0, n_tokens*n_pos_per_token()*ggml_element_size(inp_pos)); } if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { @@ -1440,7 +1611,7 @@ void llama_context::set_inputs(const llama_ubatch & ubatch) { } } -void llama_context::reorder_outputs() { +void llama_context_unified::reorder_outputs() { std::vector & out_ids = sbatch.out_ids; if (!out_ids.empty()) { const uint32_t n_vocab = model.vocab.n_tokens(); @@ -1478,7 +1649,7 @@ void llama_context::reorder_outputs() { } } -size_t llama_context::reserve_outputs(size_t n_outputs) { +size_t llama_context_unified::reserve_outputs(size_t n_outputs) { const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1605,7 +1776,7 @@ ggml_tensor * llama_context::build_lora_mm_id( return res; } -void llama_context::kv_self_update() { +void llama_context_unified::kv_self_update() { auto & kv = kv_self; if (kv.has_shift) { @@ -1619,15 +1790,8 @@ void llama_context::kv_self_update() { ggml_backend_sched_reset(sched.get()); - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute_meta.size(), - /*.mem_buffer =*/ buf_compute_meta.data(), - /*.no_alloc =*/ true, - }; - - ggml_context * ctx0 = ggml_init(params); - - reset(); + auto ctx = init(); + auto ctx0 = ctx.get(); ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); @@ -1639,8 +1803,6 @@ void llama_context::kv_self_update() { compute_graph(gf, false); - ggml_free(ctx0); - need_reserve = true; } @@ -1659,15 +1821,8 @@ void llama_context::kv_self_update() { ggml_backend_sched_reset(sched.get()); - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute_meta.size(), - /*.mem_buffer =*/ buf_compute_meta.data(), - /*.no_alloc =*/ true, - }; - - ggml_context * ctx0 = ggml_init(params); - - reset(); + auto ctx = init(); + auto ctx0 = ctx.get(); ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); @@ -1680,19 +1835,13 @@ void llama_context::kv_self_update() { compute_graph(gf, false); - ggml_free(ctx0); - kv.do_defrag = false; need_reserve = true; } } -void llama_kv_self_update(llama_context * ctx) { - ctx->kv_self_update(); -} - -void llama_context::build_attn_inp( +void llama_context_unified::build_attn_inp( ggml_context * ctx0, int32_t n_tokens, bool causal, @@ -1723,7 +1872,7 @@ void llama_context::build_attn_inp( } } -void llama_context::build_attn_kv_store( +void llama_context_unified::build_attn_kv_store( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * k_cur, @@ -1767,7 +1916,7 @@ void llama_context::build_attn_kv_store( ggml_build_forward_expand(graph, ggml_cpy(ctx0, v_cur, v_cache_view)); } -ggml_tensor * llama_context::build_attn_qkv( +ggml_tensor * llama_context_unified::build_attn_qkv( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * wo, @@ -1919,7 +2068,7 @@ ggml_tensor * llama_context::build_attn_qkv( return cur; } -ggml_tensor * llama_context::build_soft_max_ext( +ggml_tensor * llama_context_unified::build_soft_max_ext( ggml_context * ctx0, ggml_tensor * kq, float kq_scale) { @@ -1928,7 +2077,7 @@ ggml_tensor * llama_context::build_soft_max_ext( return ggml_soft_max_ext(ctx0, kq, inp_KQ_mask_cnv, kq_scale, hparams.f_max_alibi_bias); } -ggml_tensor * llama_context::get_rope_factors(int il) { +ggml_tensor * llama_context_unified::get_rope_factors(int il) { const auto & hparams = model.hparams; // choose long/short freq factors based on the context size @@ -1945,7 +2094,96 @@ ggml_tensor * llama_context::get_rope_factors(int il) { return model.layers[il].rope_short; } -void llama_context::build_k_shift( +ggml_tensor * llama_context_unified::build_inp_embd( + ggml_context * ctx0, + ggml_tensor * tok_embd, + const llama_ubatch & ubatch) { + const auto & hparams = model.hparams; + + const int64_t n_embd = hparams.n_embd; + + struct ggml_tensor * inpL; + + if (ubatch.token) { + inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + //cb(inp_tokens, "inp_tokens", -1); + ggml_set_input(inp_tokens); + + inpL = ggml_get_rows(ctx0, tok_embd, inp_tokens); + + // apply lora for embedding tokens if needed + for (const auto & lora : loras) { + struct llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd); + if (lw == nullptr) { + continue; + } + + const float adapter_scale = lora.second; + const float scale = lw->get_scale(lora.first->alpha, adapter_scale); + + struct ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat( + ctx0, lw->b, // non-transposed lora_b + ggml_get_rows(ctx0, lw->a, inp_tokens) + ), scale); + + inpL = ggml_add(ctx0, inpL, inpL_delta); + } + } else { + inp_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); + inpL = inp_embd; + ggml_set_input(inp_embd); + } + + // For Granite architecture + if (hparams.f_embedding_scale != 0.0f) { + inpL = ggml_scale(ctx0, inpL, hparams.f_embedding_scale); + } + + //cb(inpL, "inp_embd", -1); + + return inpL; +} + +ggml_tensor * llama_context_unified::build_inp_pos( + ggml_context * ctx0, + int32_t n_tokens) { + inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token()); + ggml_set_input(inp_pos); + + return inp_pos; +} + +ggml_tensor * llama_context_unified::build_inp_out_ids( + ggml_context * ctx0, + int32_t n_tokens, + bool worst_case) { + const int32_t n_out_ids = worst_case ? n_tokens : n_outputs; + + inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_out_ids); + ggml_set_input(inp_out_ids); + + return inp_out_ids; +} + +ggml_tensor * llama_context_unified::build_inp_mean( + ggml_context * ctx0, + int32_t n_tokens) { + inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); + ggml_set_input(inp_mean); + + return inp_mean; +} + +ggml_tensor * llama_context_unified::build_inp_cls( + ggml_context * ctx0, + int32_t n_tokens) { + inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp_cls); + + return inp_cls; +} + +void llama_context_unified::build_k_shift( ggml_context * ctx0, ggml_cgraph * graph) { const auto & n_ctx = cparams.n_ctx; @@ -2017,7 +2255,7 @@ void llama_context::build_k_shift( } } -void llama_context::build_defrag( +void llama_context_unified::build_defrag( ggml_context * ctx0, ggml_cgraph * graph) { const auto & hparams = model.hparams; @@ -2287,7 +2525,39 @@ void llama_context::build_defrag( #endif } -ggml_tensor * llama_context::build_inp_s_copy( +ggml_tensor * llama_context_unified::build_inp_embd_enc( + ggml_context * ctx0, + int32_t n_tokens, + bool worst_case) { + const auto & hparams = model.hparams; + const int64_t n_embd = hparams.n_embd; + + // TODO: not sure if this is correct + const int32_t n_outputs_enc = worst_case ? n_tokens : embd_enc.size() / n_embd; + + inp_embd_enc = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc); + ggml_set_input(inp_embd_enc); + + return inp_embd_enc; +} + +ggml_tensor * llama_context_unified::build_inp_KQ_mask_cross( + ggml_context * ctx0, + int32_t n_tokens, + bool worst_case) { + const auto & hparams = model.hparams; + const int64_t n_embd = hparams.n_embd; + + // TODO: not sure if this is correct + const int32_t n_outputs_enc = worst_case ? n_tokens : embd_enc.size() / n_embd; + + inp_KQ_mask_cross = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + ggml_set_input(inp_KQ_mask_cross); + + return inp_KQ_mask_cross; +} + +ggml_tensor * llama_context_unified::build_inp_s_copy( ggml_context * ctx0, bool worst_case) { const auto n_kv = worst_case ? kv_self.size : kv_self.n; @@ -2298,7 +2568,7 @@ ggml_tensor * llama_context::build_inp_s_copy( return inp_s_copy; } -ggml_tensor * llama_context::build_inp_s_mask( +ggml_tensor * llama_context_unified::build_inp_s_mask( ggml_context * ctx0, bool worst_case) { const auto n_kv = worst_case ? kv_self.size : kv_self.n; @@ -2308,7 +2578,7 @@ ggml_tensor * llama_context::build_inp_s_mask( return inp_s_mask; } -ggml_tensor * llama_context::build_copy_mask_state( +ggml_tensor * llama_context_unified::build_copy_mask_state( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * s, @@ -2343,7 +2613,7 @@ ggml_tensor * llama_context::build_copy_mask_state( } // TODO: split -ggml_tensor * llama_context::build_mamba_layer( +ggml_tensor * llama_context_unified::build_mamba_layer( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * cur, @@ -2479,7 +2749,7 @@ ggml_tensor * llama_context::build_mamba_layer( } -ggml_tensor * llama_context::build_rwkv_token_shift_load( +ggml_tensor * llama_context_unified::build_rwkv_token_shift_load( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * state_copy, @@ -2506,7 +2776,7 @@ ggml_tensor * llama_context::build_rwkv_token_shift_load( } -ggml_tensor * llama_context::build_rwkv_token_shift_store( +ggml_tensor * llama_context_unified::build_rwkv_token_shift_store( ggml_context * ctx0, ggml_tensor * token_shift, const llama_ubatch & ubatch, @@ -2530,7 +2800,7 @@ ggml_tensor * llama_context::build_rwkv_token_shift_store( } -ggml_tensor * llama_context::build_rwkv6_time_mix( +ggml_tensor * llama_context_unified::build_rwkv6_time_mix( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * cur, @@ -2701,6 +2971,608 @@ ggml_tensor * llama_context::build_rwkv6_time_mix( return cur; } +// +// state +// + +// TODO: this needs a big rework + +// TODO: replace all non-fatal assertions with returned errors or exceptions +struct llama_data_write { + llama_data_write(llama_context_unified * ctx) : ctx(ctx) {} + virtual ~llama_data_write() = default; + + virtual void write(const void * src, size_t size) = 0; + virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0; + virtual size_t get_size_written() = 0; + + void write_string(const std::string & str) { + uint32_t str_size = str.size(); + + write(&str_size, sizeof(str_size)); + write(str.data(), str_size); + } + + void write_model_info() { + const std::string arch_str = llm_arch_name(ctx->model.arch); + write_string(arch_str); + // TODO: add more model-specific info which should prevent loading the session file if not identical + } + + //void write_rng(const std::mt19937 & rng) { + // std::ostringstream rng_ss; + // rng_ss << rng; + + // const std::string & rng_str = rng_ss.str(); + + // write_string(rng_str); + //} + + void write_output_ids() { + ctx->reorder_outputs(); + + const uint32_t n_outputs = ctx->n_outputs; + + std::vector output_pos; + + const size_t n_batch = ctx->cparams.n_batch; + const auto & output_ids = ctx->output_ids; + + GGML_ASSERT(n_outputs <= ctx->output_size); + + output_pos.resize(n_outputs); + + // build a more compact representation of the output ids + for (size_t i = 0; i < n_batch; ++i) { + // map an output id to a position in the batch + int32_t pos = output_ids[i]; + if (pos >= 0) { + GGML_ASSERT((uint32_t) pos < n_outputs); + output_pos[pos] = i; + } + } + + write(&n_outputs, sizeof(n_outputs)); + + if (n_outputs) { + write(output_pos.data(), n_outputs * sizeof(int32_t)); + } + } + + void write_logits() { + const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens()); + + write(&logits_size, sizeof(logits_size)); + + if (logits_size) { + write(ctx->logits, logits_size * sizeof(float)); + } + } + + void write_embeddings() { + const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd); + + write(&embeddings_size, sizeof(embeddings_size)); + + if (embeddings_size) { + write(ctx->embd, embeddings_size * sizeof(float)); + } + } + + llama_context_unified * ctx; +}; + +struct llama_data_read { + llama_data_read(llama_context_unified * ctx) : ctx(ctx) {} + virtual ~llama_data_read() = default; + + virtual const uint8_t * read(size_t size) = 0; + virtual void read_to(void * dst, size_t size) = 0; + virtual size_t get_size_read() = 0; + + void read_string(std::string & str) { + uint32_t str_size; + read_to(&str_size, sizeof(str_size)); + + str.assign((const char *) read(str_size), str_size); + } + + // validate model information + void read_model_info() { + const std::string cur_arch_str = llm_arch_name(ctx->model.arch); + + std::string arch_str; + read_string(arch_str); + if (cur_arch_str != arch_str) { + throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); + } + // TODO: add more info which needs to be identical but which is not verified otherwise + } + + //void read_rng(std::mt19937 & rng) { + // std::string rng_str; + // read_string(rng_str); + + // std::istringstream rng_ss(rng_str); + // rng_ss >> rng; + + // if (rng_ss.fail()) { + // throw std::runtime_error("failed to load RNG state"); + // } + //} + + void read_output_ids() { + std::vector output_pos; + + uint32_t n_outputs; + read_to(&n_outputs, sizeof(n_outputs)); + + if (n_outputs > ctx->reserve_outputs(n_outputs)) { + throw std::runtime_error("could not reserve outputs"); + } + + if (n_outputs) { + output_pos.resize(n_outputs); + read_to(output_pos.data(), n_outputs * sizeof(int32_t)); + + for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { + int32_t id = output_pos[i]; + if ((uint32_t) id >= ctx->cparams.n_batch) { + throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch)); + } + ctx->output_ids[id] = i; + } + + ctx->n_outputs = n_outputs; + } + } + + void read_logits() { + uint64_t logits_size; + read_to(&logits_size, sizeof(logits_size)); + + if (ctx->logits_size < logits_size) { + throw std::runtime_error("logits buffer too small"); + } + + if (logits_size) { + read_to(ctx->logits, logits_size * sizeof(float)); + } + } + + void read_embeddings() { + uint64_t embeddings_size; + read_to(&embeddings_size, sizeof(embeddings_size)); + + if (ctx->embd_size < embeddings_size) { + throw std::runtime_error("embeddings buffer too small"); + } + + if (embeddings_size) { + read_to(ctx->embd, embeddings_size * sizeof(float)); + } + } + + llama_context_unified * ctx; +}; + +struct llama_data_write_dummy : llama_data_write { + llama_data_write_dummy(llama_context_unified * ctx) : llama_data_write(ctx) {} + + void write(const void * /* src */, size_t size) override { + size_written += size; + } + + void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { + size_written += size; + } + + size_t get_size_written() override { + return size_written; + } + + size_t size_written = 0; +}; + +struct llama_data_write_buffer : llama_data_write { + llama_data_write_buffer( + llama_context_unified * ctx, + uint8_t * p, size_t len) : llama_data_write(ctx), ptr(p), buf_size(len) {} + + void write(const void * src, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + memcpy(ptr, src, size); + ptr += size; + size_written += size; + buf_size -= size; + } + + void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + ggml_backend_tensor_get(tensor, ptr, offset, size); + ptr += size; + size_written += size; + buf_size -= size; + } + + size_t get_size_written() override { + return size_written; + } + + uint8_t * ptr; + size_t buf_size = 0; + size_t size_written = 0; +}; + +struct llama_data_read_buffer : llama_data_read { + llama_data_read_buffer( + llama_context_unified * ctx, + const uint8_t * p, size_t len) : llama_data_read(ctx), ptr(p), buf_size(len) {} + + const uint8_t * read(size_t size) override { + const uint8_t * base_ptr = ptr; + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + ptr += size; + size_read += size; + buf_size -= size; + return base_ptr; + } + + void read_to(void * dst, size_t size) override { + memcpy(dst, read(size), size); + } + + size_t get_size_read() override { + return size_read; + } + + const uint8_t * ptr; + size_t buf_size = 0; + size_t size_read = 0; +}; + +struct llama_data_write_file : llama_data_write { + llama_data_write_file( + llama_context_unified * ctx, + llama_file * f) : llama_data_write(ctx), file(f) {} + + void write(const void * src, size_t size) override { + file->write_raw(src, size); + size_written += size; + } + + void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { + temp_buffer.resize(size); + ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); + write(temp_buffer.data(), temp_buffer.size()); + } + + size_t get_size_written() override { + return size_written; + } + + llama_file * file; + size_t size_written = 0; + std::vector temp_buffer; +}; + +struct llama_data_read_file : llama_data_read { + llama_data_read_file( + llama_context_unified * ctx, + llama_file * f) : llama_data_read(ctx), file(f) {} + + void read_to(void * dst, size_t size) override { + file->read_raw(dst, size); + size_read += size; + } + + const uint8_t * read(size_t size) override { + temp_buffer.resize(size); + read_to(temp_buffer.data(), size); + return temp_buffer.data(); + } + + size_t get_size_read() override { + return size_read; + } + + llama_file * file; + size_t size_read = 0; + std::vector temp_buffer; +}; + +size_t llama_context_unified::state_get_size() { + llama_data_write_dummy data_ctx(this); + try { + return state_get_data(data_ctx); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context_unified::state_get_data(uint8_t * dst, size_t size) { + llama_data_write_buffer data_ctx(this, dst, size); + try { + return state_get_data(data_ctx); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context_unified::state_set_data(const uint8_t * src, size_t size) { + llama_data_read_buffer data_ctx(this, src, size); + try { + return state_set_data(data_ctx); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context_unified::state_seq_get_size(llama_seq_id seq_id) { + llama_data_write_dummy data_ctx(this); + try { + return state_seq_get_data(data_ctx, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context_unified::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { + llama_data_write_buffer data_ctx(this, dst, size); + try { + return state_seq_get_data(data_ctx, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context_unified::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { + llama_data_read_buffer data_ctx(this, src, size); + try { + return state_seq_set_data(data_ctx, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); + return 0; + } +} + +bool llama_context_unified::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(filepath, "rb"); + + // sanity checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { + LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); + return false; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return false; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t n_state_size_cur = file.size() - file.tell(); + + llama_data_read_file data_ctx(this, &file); + const size_t n_read = state_set_data(data_ctx); + + if (n_read != n_state_size_cur) { + LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); + return false; + } + } + + return true; +} + +bool llama_context_unified::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { + llama_file file(filepath, "wb"); + + file.write_u32(LLAMA_SESSION_MAGIC); + file.write_u32(LLAMA_SESSION_VERSION); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_data_write_file data_ctx(this, &file); + state_get_data(data_ctx); + + return true; +} + +size_t llama_context_unified::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(filepath, "rb"); + + // version checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { + LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); + return 0; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return 0; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t state_size = file.size() - file.tell(); + llama_data_read_file data_ctx(this, &file); + const size_t nread = state_seq_set_data(data_ctx, seq_id); + if (!nread) { + LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); + return 0; + } + GGML_ASSERT(nread <= state_size); + GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); + } + + return file.tell(); +} + +size_t llama_context_unified::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { + llama_file file(filepath, "wb"); + + file.write_u32(LLAMA_STATE_SEQ_MAGIC); + file.write_u32(LLAMA_STATE_SEQ_VERSION); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_data_write_file data_ctx(this, &file); + state_seq_get_data(data_ctx, seq_id); + + const size_t res = file.tell(); + GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written()); + + return res; +} + +/** copy state data into either a buffer or file depending on the passed in context + * + * file context: + * llama_file file("/path", "wb"); + * llama_data_write_file data_ctx(&file); + * llama_state_get_data_internal(ctx, data_ctx); + * + * buffer context: + * std::vector buf(max_size, 0); + * llama_data_write_buffer data_ctx(buf.data(), max_size); + * llama_state_get_data_internal(ctx, data_ctx); + * +*/ +size_t llama_context_unified::state_get_data(llama_data_write & data_ctx) { + synchronize(); + + data_ctx.write_model_info(); + + // copy outputs + data_ctx.write_output_ids(); + data_ctx.write_logits(); + data_ctx.write_embeddings(); + + llama_kv_cache::io io = { + /* .write = */ [&](const void * src, size_t size) { + data_ctx.write(src, size); + }, + /* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) { + data_ctx.write_tensor_data(tensor, offset, size); + }, + /* .read = */ nullptr, + /* .read_to = */ nullptr, + }; + + kv_self.state_write(io, model.hparams); + + return data_ctx.get_size_written(); +} + +size_t llama_context_unified::state_set_data(llama_data_read & data_ctx) { + synchronize(); + + data_ctx.read_model_info(); + + // set outputs + data_ctx.read_output_ids(); + data_ctx.read_logits(); + data_ctx.read_embeddings(); + + llama_kv_cache::io io = { + /* .write = */ nullptr, + /* .write_tensor_data = */ nullptr, + /* .read = */ [&](size_t size) { + return data_ctx.read(size); + }, + /* .read_to = */ [&](void * dst, size_t size) { + data_ctx.read_to(dst, size); + }, + }; + + kv_self.state_read(io, model.hparams); + + return data_ctx.get_size_read(); +} + +size_t llama_context_unified::state_seq_get_data(llama_data_write & data_ctx, llama_seq_id seq_id) { + synchronize(); + + llama_kv_cache::io io = { + /* .write = */ [&](const void * src, size_t size) { + data_ctx.write(src, size); + }, + /* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) { + data_ctx.write_tensor_data(tensor, offset, size); + }, + /* .read = */ nullptr, + /* .read_to = */ nullptr, + }; + + kv_self.state_write(io, model.hparams, seq_id); + + return data_ctx.get_size_written(); +} + +size_t llama_context_unified::state_seq_set_data(llama_data_read & data_ctx, llama_seq_id seq_id) { + synchronize(); + + llama_kv_cache::io io = { + /* .write = */ nullptr, + /* .write_tensor_data = */ nullptr, + /* .read = */ [&](size_t size) { + return data_ctx.read(size); + }, + /* .read_to = */ [&](void * dst, size_t size) { + data_ctx.read_to(dst, size); + }, + }; + + kv_self.state_read(io, model.hparams, seq_id); + + return data_ctx.get_size_read(); +} + // // interface implementation // @@ -2710,20 +3582,19 @@ void llama_free(struct llama_context * ctx) { } uint32_t llama_n_ctx(const struct llama_context * ctx) { - return ctx->cparams.n_ctx; + return ctx->n_ctx(); } uint32_t llama_n_batch(const struct llama_context * ctx) { - return ctx->cparams.n_batch; + return ctx->n_batch(); } uint32_t llama_n_ubatch(const struct llama_context * ctx) { - return ctx->cparams.n_ubatch; + return ctx->n_ubatch(); } uint32_t llama_n_seq_max(const struct llama_context * ctx) { - // TODO: add notion of n_seq_max to llama_kv_cache and use it here - return ctx->kv_self.size; + return ctx->n_seq_max(); } const llama_model * llama_get_model(const llama_context * ctx) { @@ -2731,11 +3602,15 @@ const llama_model * llama_get_model(const llama_context * ctx) { } llama_kv_cache * llama_get_kv_self(llama_context * ctx) { - return &ctx->kv_self; + return ctx->get_kv_self(); +} + +void llama_kv_self_update(llama_context * ctx) { + ctx->kv_self_update(); } enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { - return ctx->cparams.pooling_type; + return ctx->pooling_type(); } void llama_attach_threadpool( @@ -2786,142 +3661,37 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { } void llama_synchronize(struct llama_context * ctx) { - ggml_backend_sched_synchronize(ctx->sched.get()); - - // FIXME: if multiple single tokens are evaluated without a synchronization, - // the stats will be added to the prompt evaluation stats - // this should only happen when using batch size 1 to evaluate a batch - - // add the evaluation to the stats - if (ctx->n_queued_tokens == 1) { - if (!ctx->cparams.no_perf) { - ctx->t_eval_us += ggml_time_us() - ctx->t_compute_start_us; - } - ctx->n_eval++; - } else if (ctx->n_queued_tokens > 1) { - if (!ctx->cparams.no_perf) { - ctx->t_p_eval_us += ggml_time_us() - ctx->t_compute_start_us; - } - ctx->n_p_eval += ctx->n_queued_tokens; - } - - // get a more accurate load time, upon first eval - if (ctx->n_queued_tokens > 0 && !ctx->has_evaluated_once) { - ctx->t_load_us = ggml_time_us() - ctx->t_start_us; - ctx->has_evaluated_once = true; - } - - ctx->n_queued_tokens = 0; - ctx->t_compute_start_us = 0; + ctx->synchronize(); } float * llama_get_logits(struct llama_context * ctx) { - llama_synchronize(ctx); + ctx->synchronize(); - // reorder logits for backward compatibility - ctx->reorder_outputs(); - - return ctx->logits; + return ctx->get_logits(); } float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { - int32_t j = -1; + ctx->synchronize(); - llama_synchronize(ctx); - - try { - if (ctx->logits == nullptr) { - throw std::runtime_error("no logits"); - } - - if (i < 0) { - j = ctx->n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); - } - } else if ((size_t) i >= ctx->output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size())); - } else { - j = ctx->output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= ctx->n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); - } - - return ctx->logits + j*ctx->model.vocab.n_tokens(); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); -#ifndef NDEBUG - GGML_ABORT("fatal error"); -#else - return nullptr; -#endif - } + return ctx->get_logits_ith(i); } float * llama_get_embeddings(struct llama_context * ctx) { - llama_synchronize(ctx); + ctx->synchronize(); - // reorder embeddings for backward compatibility - ctx->reorder_outputs(); - - return ctx->embd; + return ctx->get_embeddings(); } float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { - int32_t j = -1; + ctx->synchronize(); - llama_synchronize(ctx); - - try { - if (ctx->embd == nullptr) { - throw std::runtime_error("no embeddings"); - } - - if (i < 0) { - j = ctx->n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); - } - } else if ((size_t) i >= ctx->output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size())); - } else { - j = ctx->output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= ctx->n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); - } - - return ctx->embd + j*ctx->model.hparams.n_embd; - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); -#ifndef NDEBUG - GGML_ABORT("fatal error"); -#else - return nullptr; -#endif - } + return ctx->get_embeddings_ith(i); } float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) { - llama_synchronize(ctx); + ctx->synchronize(); - auto it = ctx->embd_seq.find(seq_id); - if (it == ctx->embd_seq.end()) { - return nullptr; - } - - return it->second.data(); + return ctx->get_embeddings_seq(seq_id); } // llama adapter API @@ -2965,11 +3735,11 @@ int32_t llama_apply_adapter_cvec( // struct llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) { - return llama_kv_cache_view_init(ctx->kv_self, n_seq_max); + return llama_kv_cache_view_init(*ctx->get_kv_self(), n_seq_max); } void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) { - llama_kv_cache_view_update(view, ctx->kv_self); + llama_kv_cache_view_update(view, *ctx->get_kv_self()); } // @@ -2982,7 +3752,7 @@ int32_t llama_get_kv_cache_token_count(const llama_context * ctx) { } int32_t llama_kv_self_n_tokens(const llama_context * ctx) { - return llama_kv_cache_n_tokens(&ctx->kv_self); + return llama_kv_cache_n_tokens(ctx->get_kv_self()); } // deprecated @@ -2991,7 +3761,7 @@ int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) { } int32_t llama_kv_self_used_cells(const llama_context * ctx) { - return llama_kv_cache_used_cells(&ctx->kv_self); + return llama_kv_cache_used_cells(ctx->get_kv_self()); } // deprecated @@ -3000,7 +3770,7 @@ void llama_kv_cache_clear(llama_context * ctx) { } void llama_kv_self_clear(llama_context * ctx) { - llama_kv_cache_clear(&ctx->kv_self); + llama_kv_cache_clear(ctx->get_kv_self()); } // deprecated @@ -3017,7 +3787,7 @@ bool llama_kv_self_seq_rm( llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - return llama_kv_cache_seq_rm(&ctx->kv_self, seq_id, p0, p1); + return llama_kv_cache_seq_rm(ctx->get_kv_self(), seq_id, p0, p1); } // deprecated @@ -3036,7 +3806,7 @@ void llama_kv_self_seq_cp( llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - return llama_kv_cache_seq_cp(&ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); + return llama_kv_cache_seq_cp(ctx->get_kv_self(), seq_id_src, seq_id_dst, p0, p1); } // deprecated @@ -3047,7 +3817,7 @@ void llama_kv_cache_seq_keep( } void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_cache_seq_keep(&ctx->kv_self, seq_id); + return llama_kv_cache_seq_keep(ctx->get_kv_self(), seq_id); } // deprecated @@ -3066,7 +3836,7 @@ void llama_kv_self_seq_add( llama_pos p0, llama_pos p1, llama_pos delta) { - return llama_kv_cache_seq_add(&ctx->kv_self, seq_id, p0, p1, delta); + return llama_kv_cache_seq_add(ctx->get_kv_self(), seq_id, p0, p1, delta); } // deprecated @@ -3085,7 +3855,7 @@ void llama_kv_self_seq_div( llama_pos p0, llama_pos p1, int d) { - return llama_kv_cache_seq_div(&ctx->kv_self, seq_id, p0, p1, d); + return llama_kv_cache_seq_div(ctx->get_kv_self(), seq_id, p0, p1, d); } // deprecated @@ -3094,7 +3864,7 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { } llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_cache_seq_pos_max(&ctx->kv_self, seq_id); + return llama_kv_cache_seq_pos_max(ctx->get_kv_self(), seq_id); } // deprecated @@ -3103,7 +3873,7 @@ void llama_kv_cache_defrag(llama_context * ctx) { } void llama_kv_self_defrag(llama_context * ctx) { - return llama_kv_cache_defrag(&ctx->kv_self); + return llama_kv_cache_defrag(ctx->get_kv_self()); } // deprecated @@ -3112,7 +3882,7 @@ bool llama_kv_cache_can_shift(const llama_context * ctx) { } bool llama_kv_self_can_shift(const llama_context * ctx) { - return llama_kv_cache_can_shift(&ctx->kv_self); + return llama_kv_cache_can_shift(ctx->get_kv_self()); } // deprecated @@ -3147,603 +3917,54 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi return llama_state_save_file(ctx, path_session, tokens, n_token_count); } -// TODO: replace all non-fatal assertions with returned errors or exceptions -struct llama_data_write { - virtual void write(const void * src, size_t size) = 0; - virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0; - virtual size_t get_size_written() = 0; - virtual ~llama_data_write() = default; - - void write_string(const std::string & str) { - uint32_t str_size = str.size(); - - write(&str_size, sizeof(str_size)); - write(str.data(), str_size); - } - - void write_model_info(const struct llama_context * ctx) { - const std::string arch_str = llm_arch_name(ctx->model.arch); - write_string(arch_str); - // TODO: add more model-specific info which should prevent loading the session file if not identical - } - - //void write_rng(const std::mt19937 & rng) { - // std::ostringstream rng_ss; - // rng_ss << rng; - - // const std::string & rng_str = rng_ss.str(); - - // write_string(rng_str); - //} - - void write_output_ids(struct llama_context * ctx) { - ctx->reorder_outputs(); - - const uint32_t n_outputs = ctx->n_outputs; - - std::vector output_pos; - - const size_t n_batch = ctx->cparams.n_batch; - const auto & output_ids = ctx->output_ids; - - GGML_ASSERT(n_outputs <= ctx->output_size); - - output_pos.resize(n_outputs); - - // build a more compact representation of the output ids - for (size_t i = 0; i < n_batch; ++i) { - // map an output id to a position in the batch - int32_t pos = output_ids[i]; - if (pos >= 0) { - GGML_ASSERT((uint32_t) pos < n_outputs); - output_pos[pos] = i; - } - } - - write(&n_outputs, sizeof(n_outputs)); - - if (n_outputs) { - write(output_pos.data(), n_outputs * sizeof(int32_t)); - } - } - - void write_logits(const struct llama_context * ctx) { - const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens()); - - write(&logits_size, sizeof(logits_size)); - - if (logits_size) { - write(ctx->logits, logits_size * sizeof(float)); - } - } - - void write_embeddings(const struct llama_context * ctx) { - const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_embd); - - write(&embeddings_size, sizeof(embeddings_size)); - - if (embeddings_size) { - write(ctx->embd, embeddings_size * sizeof(float)); - } - } -}; - -struct llama_data_read { - virtual const uint8_t * read(size_t size) = 0; - virtual void read_to(void * dst, size_t size) = 0; - virtual size_t get_size_read() = 0; - virtual ~llama_data_read() = default; - - void read_string(std::string & str) { - uint32_t str_size; - read_to(&str_size, sizeof(str_size)); - - str.assign((const char *) read(str_size), str_size); - } - - // validate model information - void read_model_info(const struct llama_context * ctx) { - const std::string cur_arch_str = llm_arch_name(ctx->model.arch); - - std::string arch_str; - read_string(arch_str); - if (cur_arch_str != arch_str) { - throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); - } - // TODO: add more info which needs to be identical but which is not verified otherwise - } - - //void read_rng(std::mt19937 & rng) { - // std::string rng_str; - // read_string(rng_str); - - // std::istringstream rng_ss(rng_str); - // rng_ss >> rng; - - // if (rng_ss.fail()) { - // throw std::runtime_error("failed to load RNG state"); - // } - //} - - void read_output_ids(struct llama_context * ctx) { - std::vector output_pos; - - uint32_t n_outputs; - read_to(&n_outputs, sizeof(n_outputs)); - - if (n_outputs > ctx->reserve_outputs(n_outputs)) { - throw std::runtime_error("could not reserve outputs"); - } - - if (n_outputs) { - output_pos.resize(n_outputs); - read_to(output_pos.data(), n_outputs * sizeof(int32_t)); - - for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { - int32_t id = output_pos[i]; - if ((uint32_t) id >= ctx->cparams.n_batch) { - throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch)); - } - ctx->output_ids[id] = i; - } - - ctx->n_outputs = n_outputs; - } - } - - void read_logits(struct llama_context * ctx) { - uint64_t logits_size; - read_to(&logits_size, sizeof(logits_size)); - - if (ctx->logits_size < logits_size) { - throw std::runtime_error("logits buffer too small"); - } - - if (logits_size) { - read_to(ctx->logits, logits_size * sizeof(float)); - } - } - - void read_embeddings(struct llama_context * ctx) { - uint64_t embeddings_size; - read_to(&embeddings_size, sizeof(embeddings_size)); - - if (ctx->embd_size < embeddings_size) { - throw std::runtime_error("embeddings buffer too small"); - } - - if (embeddings_size) { - read_to(ctx->embd, embeddings_size * sizeof(float)); - } - } -}; - -struct llama_data_write_dummy : llama_data_write { - size_t size_written = 0; - - llama_data_write_dummy() {} - - void write(const void * /* src */, size_t size) override { - size_written += size; - } - - void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { - size_written += size; - } - - size_t get_size_written() override { - return size_written; - } -}; - -struct llama_data_write_buffer : llama_data_write { - uint8_t * ptr; - size_t buf_size = 0; - size_t size_written = 0; - - llama_data_write_buffer(uint8_t * p, size_t len) : ptr(p), buf_size(len) {} - - void write(const void * src, size_t size) override { - if (size > buf_size) { - throw std::runtime_error("unexpectedly reached end of buffer"); - } - memcpy(ptr, src, size); - ptr += size; - size_written += size; - buf_size -= size; - } - - void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { - if (size > buf_size) { - throw std::runtime_error("unexpectedly reached end of buffer"); - } - ggml_backend_tensor_get(tensor, ptr, offset, size); - ptr += size; - size_written += size; - buf_size -= size; - } - - size_t get_size_written() override { - return size_written; - } -}; - -struct llama_data_read_buffer : llama_data_read { - const uint8_t * ptr; - size_t buf_size = 0; - size_t size_read = 0; - - llama_data_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} - - const uint8_t * read(size_t size) override { - const uint8_t * base_ptr = ptr; - if (size > buf_size) { - throw std::runtime_error("unexpectedly reached end of buffer"); - } - ptr += size; - size_read += size; - buf_size -= size; - return base_ptr; - } - - void read_to(void * dst, size_t size) override { - memcpy(dst, read(size), size); - } - - size_t get_size_read() override { - return size_read; - } -}; - -struct llama_data_write_file : llama_data_write { - llama_file * file; - size_t size_written = 0; - std::vector temp_buffer; - - llama_data_write_file(llama_file * f) : file(f) {} - - void write(const void * src, size_t size) override { - file->write_raw(src, size); - size_written += size; - } - - void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { - temp_buffer.resize(size); - ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); - write(temp_buffer.data(), temp_buffer.size()); - } - - size_t get_size_written() override { - return size_written; - } -}; - -struct llama_data_read_file : llama_data_read { - llama_file * file; - size_t size_read = 0; - std::vector temp_buffer; - - llama_data_read_file(llama_file * f) : file(f) {} - - void read_to(void * dst, size_t size) override { - file->read_raw(dst, size); - size_read += size; - } - - const uint8_t * read(size_t size) override { - temp_buffer.resize(size); - read_to(temp_buffer.data(), size); - return temp_buffer.data(); - } - - size_t get_size_read() override { - return size_read; - } -}; - -/** copy state data into either a buffer or file depending on the passed in context - * - * file context: - * llama_file file("/path", "wb"); - * llama_data_write_file data_ctx(&file); - * llama_state_get_data_internal(ctx, data_ctx); - * - * buffer context: - * std::vector buf(max_size, 0); - * llama_data_write_buffer data_ctx(buf.data(), max_size); - * llama_state_get_data_internal(ctx, data_ctx); - * -*/ -static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx) { - llama_synchronize(ctx); - - data_ctx.write_model_info(ctx); - - // copy outputs - data_ctx.write_output_ids(ctx); - data_ctx.write_logits(ctx); - data_ctx.write_embeddings(ctx); - - llama_kv_cache::io io = { - /* .write = */ [&](const void * src, size_t size) { - data_ctx.write(src, size); - }, - /* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) { - data_ctx.write_tensor_data(tensor, offset, size); - }, - /* .read = */ nullptr, - /* .read_to = */ nullptr, - }; - - ctx->kv_self.state_write(io, ctx->model.hparams); - - return data_ctx.get_size_written(); -} - -size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) { - llama_data_write_buffer data_ctx(dst, size); - try { - return llama_state_get_data_internal(ctx, data_ctx); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); - return 0; - } -} - // Returns the *actual* size of the state. // Intended to be used when saving to state to a buffer. size_t llama_state_get_size(struct llama_context * ctx) { - llama_data_write_dummy data_ctx; - try { - return llama_state_get_data_internal(ctx, data_ctx); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); - return 0; - } + return ctx->state_get_size(); } -static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) { - llama_synchronize(ctx); - - data_ctx.read_model_info(ctx); - - // set outputs - data_ctx.read_output_ids(ctx); - data_ctx.read_logits(ctx); - data_ctx.read_embeddings(ctx); - - llama_kv_cache::io io = { - /* .write = */ nullptr, - /* .write_tensor_data = */ nullptr, - /* .read = */ [&](size_t size) { - return data_ctx.read(size); - }, - /* .read_to = */ [&](void * dst, size_t size) { - data_ctx.read_to(dst, size); - }, - }; - - ctx->kv_self.state_read(io, ctx->model.hparams); - - return data_ctx.get_size_read(); +size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) { + return ctx->state_get_data(dst, size); } // Sets the state reading from the specified source address size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) { - llama_data_read_buffer data_ctx(src, size); - try { - return llama_state_set_data_internal(ctx, data_ctx); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); - return 0; - } -} - -static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { - llama_file file(path_session, "rb"); - - // sanity checks - { - const uint32_t magic = file.read_u32(); - const uint32_t version = file.read_u32(); - - if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { - LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); - return false; - } - } - - // load the prompt - { - const uint32_t n_token_count = file.read_u32(); - - if (n_token_count > n_token_capacity) { - LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); - return false; - } - - file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); - *n_token_count_out = n_token_count; - } - - // restore the context state - { - const size_t n_state_size_cur = file.size() - file.tell(); - - llama_data_read_file data_ctx(&file); - const size_t n_read = llama_state_set_data_internal(ctx, data_ctx); - - if (n_read != n_state_size_cur) { - LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); - return false; - } - } - return true; + return ctx->state_set_data(src, size); } bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { try { - return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); + return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what()); return false; } } -static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { - llama_file file(path_session, "wb"); - - file.write_u32(LLAMA_SESSION_MAGIC); - file.write_u32(LLAMA_SESSION_VERSION); - - // save the prompt - file.write_u32((uint32_t) n_token_count); - file.write_raw(tokens, sizeof(llama_token) * n_token_count); - - // save the context state using stream saving - llama_data_write_file data_ctx(&file); - llama_state_get_data_internal(ctx, data_ctx); - - return true; -} - bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { try { - return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count); + return ctx->state_save_file(path_session, tokens, n_token_count); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what()); return false; } } -static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) { - llama_synchronize(ctx); - - llama_kv_cache::io io = { - /* .write = */ [&](const void * src, size_t size) { - data_ctx.write(src, size); - }, - /* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) { - data_ctx.write_tensor_data(tensor, offset, size); - }, - /* .read = */ nullptr, - /* .read_to = */ nullptr, - }; - - ctx->kv_self.state_write(io, ctx->model.hparams, seq_id); - - return data_ctx.get_size_written(); -} - size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) { - llama_data_write_dummy data_ctx; - return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); + return ctx->state_seq_get_size(seq_id); } size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { - llama_data_write_buffer data_ctx(dst, size); - try { - return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what()); - return 0; - } + return ctx->state_seq_get_data(seq_id, dst, size); } -static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) { - llama_synchronize(ctx); - - llama_kv_cache::io io = { - /* .write = */ nullptr, - /* .write_tensor_data = */ nullptr, - /* .read = */ [&](size_t size) { - return data_ctx.read(size); - }, - /* .read_to = */ [&](void * dst, size_t size) { - data_ctx.read_to(dst, size); - }, - }; - - ctx->kv_self.state_read(io, ctx->model.hparams, dest_seq_id); - - return data_ctx.get_size_read(); -} - -size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id dest_seq_id) { - llama_data_read_buffer data_ctx(src, size); - try { - return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what()); - return 0; - } -} - -static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { - llama_file file(filepath, "wb"); - - file.write_u32(LLAMA_STATE_SEQ_MAGIC); - file.write_u32(LLAMA_STATE_SEQ_VERSION); - - // save the prompt - file.write_u32((uint32_t) n_token_count); - file.write_raw(tokens, sizeof(llama_token) * n_token_count); - - // save the context state using stream saving - llama_data_write_file data_ctx(&file); - llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); - - const size_t res = file.tell(); - GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written()); - return res; -} - -static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { - llama_file file(filepath, "rb"); - - // version checks - { - const uint32_t magic = file.read_u32(); - const uint32_t version = file.read_u32(); - - if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { - LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); - return 0; - } - } - - // load the prompt - { - const uint32_t n_token_count = file.read_u32(); - - if (n_token_count > n_token_capacity) { - LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); - return 0; - } - - file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); - *n_token_count_out = n_token_count; - } - - // restore the context state - { - const size_t state_size = file.size() - file.tell(); - llama_data_read_file data_ctx(&file); - const size_t nread = llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); - if (!nread) { - LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); - return 0; - } - GGML_ASSERT(nread <= state_size); - GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); - } - - return file.tell(); +size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { + return ctx->state_seq_set_data(seq_id, src, size); } size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { try { - return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count); + return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what()); return 0; @@ -3752,7 +3973,7 @@ size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepa size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { try { - return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out); + return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what()); return 0; diff --git a/src/llama-context.h b/src/llama-context.h index 8f22fd3b1..f7e007f32 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -16,38 +16,245 @@ using llama_loras = std::unordered_map; -struct llama_batch_manager_i; - -// TODO: make implementation details private -// TODO: become abstract base class, split the current implementation into different child classes struct llama_context { - // TODO: tmp until llama-model starts implementing the graph build function - typedef std::function build_graph_callback; + llama_context(const llama_model & model); + virtual ~llama_context(); - llama_context( - const llama_model & model, - const llama_context_params & params, - build_graph_callback && cb_build_graph); + virtual void synchronize(); - virtual ~llama_context() = default; + virtual uint32_t n_ctx() const = 0; + virtual uint32_t n_batch() const = 0; + virtual uint32_t n_ubatch() const = 0; + virtual uint32_t n_seq_max() const = 0; - const struct llama_model & model; + virtual llama_kv_cache * get_kv_self() = 0; + virtual const llama_kv_cache * get_kv_self() const = 0; + + virtual void kv_self_update() = 0; + + virtual enum llama_pooling_type pooling_type() const = 0; + + virtual float * get_logits() = 0; + virtual float * get_logits_ith(int32_t i) = 0; + + virtual float * get_embeddings() = 0; + virtual float * get_embeddings_ith(int32_t i) = 0; + virtual float * get_embeddings_seq(llama_seq_id seq_id) = 0; + + int64_t n_pos_per_token() const; // vision + + virtual ggml_context_ptr init(); + + virtual int decode(llama_batch & inp_batch) = 0; + virtual int encode(llama_batch & inp_batch) = 0; + + // graph build API (generic) + + // do mat_mul, while optionally apply lora + virtual ggml_tensor * build_lora_mm( + ggml_context * ctx0, + ggml_tensor * w, + ggml_tensor * cur); + + // do mat_mul_id, while optionally apply lora + virtual ggml_tensor * build_lora_mm_id( + ggml_context * ctx0, + ggml_tensor * w, // struct ggml_tensor * as + ggml_tensor * cur, // struct ggml_tensor * b + ggml_tensor * ids); + + // graph build API (context-specific) + + virtual ggml_tensor * build_inp_embd( + ggml_context * ctx0, + ggml_tensor * tok_embd, + const llama_ubatch & ubatch) = 0; + + virtual ggml_tensor * build_inp_pos( + ggml_context * ctx0, + int32_t n_tokens) = 0; + + virtual ggml_tensor * build_inp_out_ids( + ggml_context * ctx0, + int32_t n_tokens, + bool worst_case) = 0; + + virtual ggml_tensor * build_inp_mean( + ggml_context * ctx0, + int32_t n_tokens) = 0; + + virtual ggml_tensor * build_inp_cls( + ggml_context * ctx0, + int32_t n_tokens) = 0; + + virtual void build_attn_inp( + ggml_context * ctx0, + int32_t n_tokens, + bool causal, + bool swa, + bool worst_case) = 0; + + virtual void build_attn_kv_store( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + int32_t n_tokens, + int64_t il, + bool worst_case) = 0; + + virtual ggml_tensor * build_attn_qkv( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + int32_t n_tokens, + float kq_scale, + int il, + bool worst_case) = 0; + + virtual ggml_tensor * build_soft_max_ext( + ggml_context * ctx0, + ggml_tensor * kq, + float kq_scale) = 0; + + virtual ggml_tensor * get_rope_factors(int il) = 0; + + virtual void build_k_shift( + ggml_context * ctx0, + ggml_cgraph * graph) = 0; + + // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache + virtual void build_defrag( + ggml_context * ctx0, + ggml_cgraph * graph) = 0; + + virtual ggml_tensor * build_inp_embd_enc( + ggml_context * ctx0, + int32_t n_tokens, + bool worst_case) = 0; + + virtual ggml_tensor * build_inp_KQ_mask_cross( + ggml_context * ctx0, + int32_t n_tokens, + bool worst_case) = 0; + + virtual ggml_tensor * build_inp_s_copy( + ggml_context * ctx0, + bool worst_case) = 0; + + virtual ggml_tensor * build_inp_s_mask( + ggml_context * ctx0, + bool worst_case) = 0; + + virtual ggml_tensor * build_copy_mask_state( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * s, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + int32_t n_tokens, + int32_t n_state, + int32_t n_seqs, + bool worst_case) = 0; + + virtual ggml_tensor * build_mamba_layer( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * cur, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il, + bool worst_case) = 0; + + virtual ggml_tensor * build_rwkv_token_shift_load( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il, + bool worst_case) = 0; + + virtual ggml_tensor * build_rwkv_token_shift_store( + ggml_context * ctx0, + ggml_tensor * token_shift, + const llama_ubatch & ubatch, + int il, + bool worst_case) = 0; + + virtual ggml_tensor * build_rwkv6_time_mix( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * cur, + ggml_tensor * x_prev, + ggml_tensor * state_copy, + ggml_tensor * state_mask, + const llama_ubatch & ubatch, + int il, + bool worst_case) = 0; + + // state save/load + + virtual size_t state_get_size() = 0; + virtual size_t state_get_data( uint8_t * dst, size_t size) = 0; + virtual size_t state_set_data(const uint8_t * src, size_t size) = 0; + + virtual size_t state_seq_get_size(llama_seq_id seq_id) = 0; + virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) = 0; + virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) = 0; + + virtual bool state_load_file( + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out) = 0; + + virtual bool state_save_file( + const char * filepath, + const llama_token * tokens, + size_t n_token_count) = 0; + + virtual size_t state_seq_load_file( + llama_seq_id seq_id, + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out) = 0; + + virtual size_t state_seq_save_file( + llama_seq_id seq_id, + const char * filepath, + const llama_token * tokens, + size_t n_token_count) = 0; + + // members + + const llama_model & model; llama_cparams cparams; - llama_sbatch sbatch; // TODO: revisit if needed llama_adapter_cvec cvec; llama_loras loras; - build_graph_callback cb_build_graph; + ggml_threadpool_t threadpool = nullptr; + ggml_threadpool_t threadpool_batch = nullptr; + + ggml_abort_callback abort_callback = nullptr; + void * abort_callback_data = nullptr; std::vector backends; std::vector> set_n_threads_fns; ggml_backend_t backend_cpu = nullptr; - ggml_threadpool_t threadpool = nullptr; - ggml_threadpool_t threadpool_batch = nullptr; + ggml_backend_sched_ptr sched; + // memory buffers used to evaluate the model + std::vector buf_compute_meta; + + // perf bool has_evaluated_once = false; mutable int64_t t_start_us; @@ -60,6 +267,49 @@ struct llama_context { mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) mutable int32_t n_eval = 0; // number of eval calls +}; + +// TODO: make implementation details private +struct llama_context_unified : public llama_context { + struct batch_manager; + + // TODO: tmp until llama-model starts implementing the graph build function + typedef std::function build_graph_callback; + + llama_context_unified( + const llama_model & model, + const llama_context_params & params, + build_graph_callback && cb_build_graph); + + virtual ~llama_context_unified(); + + virtual uint32_t n_ctx() const override; + virtual uint32_t n_batch() const override; + virtual uint32_t n_ubatch() const override; + virtual uint32_t n_seq_max() const override; + + virtual llama_kv_cache * get_kv_self() override; + virtual const llama_kv_cache * get_kv_self() const override; + + virtual void kv_self_update() override; + + virtual enum llama_pooling_type pooling_type() const override; + + virtual float * get_logits() override; + virtual float * get_logits_ith(int32_t i) override; + + virtual float * get_embeddings() override; + virtual float * get_embeddings_ith(int32_t i) override; + virtual float * get_embeddings_seq(llama_seq_id seq_id) override; + + virtual ggml_context_ptr init() override; + + virtual int decode(llama_batch & inp_batch) override; + virtual int encode(llama_batch & inp_batch) override; + + llama_sbatch sbatch; + + build_graph_callback cb_build_graph; // host buffer for the model output (logits and embeddings) ggml_backend_buffer_ptr buf_output; @@ -72,7 +322,7 @@ struct llama_context { size_t output_size = 0; // capacity (of tokens positions) for the output buffers int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch - bool logits_all = false; + bool logits_all = false; bool need_reserve = false; // embeddings output (2-dimensional array: [n_outputs][n_embd]) @@ -84,17 +334,7 @@ struct llama_context { // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; - // memory buffers used to evaluate the model - std::vector buf_compute_meta; - ggml_backend_sched_ptr sched; - - ggml_abort_callback abort_callback = nullptr; - void * abort_callback_data = nullptr; - - virtual std::unique_ptr prepare_batch(const llama_batch & batch); - - virtual int decode(llama_batch & inp_batch); - virtual int encode(llama_batch & inp_batch); + virtual std::unique_ptr prepare_batch(const llama_batch & batch); // returns the result of ggml_backend_sched_graph_compute_async execution enum ggml_status compute_graph( @@ -107,32 +347,19 @@ struct llama_context { // certain implementations could require a padding for the context size uint32_t get_ctx_padding(const llama_cparams & cparams) const; - void reset(); - void prepare_k_shift(); void prepare_defrag(); void set_inputs(const llama_ubatch & ubatch); // make the outputs have the same order they had in the user-provided batch - // TODO: maybe deprecate this + // TODO: maybe remove this void reorder_outputs(); // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. size_t reserve_outputs(size_t n_outputs); - ggml_tensor * build_lora_mm( - ggml_context * ctx0, - ggml_tensor * w, - ggml_tensor * cur); - - ggml_tensor * build_lora_mm_id( - ggml_context * ctx0, - ggml_tensor * w, // struct ggml_tensor * as - ggml_tensor * cur, // struct ggml_tensor * b - ggml_tensor * ids); - // input tensors struct ggml_tensor * inp_tokens; // I32 [n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] @@ -141,6 +368,81 @@ struct llama_context { struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] + // === unified KV cache === + + llama_kv_cache kv_self; + + struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] + struct ggml_tensor * inp_KQ_mask_cnv; // [kv_size, n_batch] + struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch] + struct ggml_tensor * inp_KQ_mask_swa_cnv; // [kv_size, n_batch] + struct ggml_tensor * inp_K_shift; // I32 [kv_size] + + virtual ggml_tensor * build_inp_embd( + ggml_context * ctx0, + ggml_tensor * tok_embd, + const llama_ubatch & ubatch) override; + + virtual ggml_tensor * build_inp_pos( + ggml_context * ctx0, + int32_t n_tokens) override; + + virtual ggml_tensor * build_inp_out_ids( + ggml_context * ctx0, + int32_t n_tokens, + bool worst_case) override; + + virtual ggml_tensor * build_inp_mean( + ggml_context * ctx0, + int32_t n_tokens) override; + + virtual ggml_tensor * build_inp_cls( + ggml_context * ctx0, + int32_t n_tokens) override; + + virtual void build_attn_inp( + ggml_context * ctx0, + int32_t n_tokens, + bool causal, + bool swa, + bool worst_case) override; + + virtual void build_attn_kv_store( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + int32_t n_tokens, + int64_t il, + bool worst_case) override; + + virtual ggml_tensor * build_attn_qkv( + ggml_context * ctx0, + ggml_cgraph * graph, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + int32_t n_tokens, + float kq_scale, + int il, + bool worst_case) override; + + virtual ggml_tensor * build_soft_max_ext( + ggml_context * ctx0, + ggml_tensor * kq, + float kq_scale) override; + + virtual ggml_tensor * get_rope_factors(int il) override; + + virtual void build_k_shift( + ggml_context * ctx0, + ggml_cgraph * graph) override; + + // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache + virtual void build_defrag( + ggml_context * ctx0, + ggml_cgraph * graph) override; + // === encoder-decoder === // whether we are computing encoder output or decoder output @@ -152,79 +454,36 @@ struct llama_context { struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] + struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] - // === unified KV cache === - - llama_kv_cache kv_self; - - struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch] - struct ggml_tensor * inp_KQ_mask_cnv; // [kv_size, n_batch] - struct ggml_tensor * inp_KQ_mask_swa; // F32 [kv_size, n_batch] - struct ggml_tensor * inp_KQ_mask_swa_cnv; // [kv_size, n_batch] - struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] - struct ggml_tensor * inp_K_shift; // I32 [kv_size] - - // return true if need to reserve new worst-case graph - void kv_self_update(); - - void build_attn_inp( + virtual ggml_tensor * build_inp_embd_enc( ggml_context * ctx0, int32_t n_tokens, - bool causal, - bool swa, - bool worst_case); + bool worst_case) override; - void build_attn_kv_store( + virtual ggml_tensor * build_inp_KQ_mask_cross( ggml_context * ctx0, - ggml_cgraph * graph, - ggml_tensor * k_cur, - ggml_tensor * v_cur, int32_t n_tokens, - int64_t il, - bool worst_case); - - ggml_tensor * build_attn_qkv( - ggml_context * ctx0, - ggml_cgraph * graph, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, - int32_t n_tokens, - float kq_scale, - int il, - bool worst_case); - - ggml_tensor * build_soft_max_ext( - ggml_context * ctx0, - ggml_tensor * kq, - float kq_scale); - - ggml_tensor * get_rope_factors(int il); - - void build_k_shift( - ggml_context * ctx0, - ggml_cgraph * graph); - - // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache - void build_defrag( - ggml_context * ctx0, - ggml_cgraph * graph); + bool worst_case) override; // === recurrent === + struct ggml_tensor * inp_s_copy; // I32 [kv_size] + struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] + // TODO: add recurrent cache // TODO: add mamba-specific llama_context // TODO: change these to build_mamba_inp and hide `state_copy` and `state_mask` inside the llama_context impl - ggml_tensor * build_inp_s_copy( + virtual ggml_tensor * build_inp_s_copy( ggml_context * ctx0, - bool worst_case); + bool worst_case) override; - ggml_tensor * build_inp_s_mask( + virtual ggml_tensor * build_inp_s_mask( ggml_context * ctx0, - bool worst_case); + bool worst_case) override; - ggml_tensor * build_copy_mask_state( + virtual ggml_tensor * build_copy_mask_state( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * s, @@ -233,9 +492,9 @@ struct llama_context { int32_t n_tokens, int32_t n_state, int32_t n_seqs, - bool worst_case); + bool worst_case) override; - ggml_tensor * build_mamba_layer( + virtual ggml_tensor * build_mamba_layer( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * cur, @@ -243,25 +502,25 @@ struct llama_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il, - bool worst_case); + bool worst_case) override; - ggml_tensor * build_rwkv_token_shift_load( + virtual ggml_tensor * build_rwkv_token_shift_load( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * state_copy, ggml_tensor * state_mask, const llama_ubatch & ubatch, int il, - bool worst_case); + bool worst_case) override; - ggml_tensor * build_rwkv_token_shift_store( + virtual ggml_tensor * build_rwkv_token_shift_store( ggml_context * ctx0, ggml_tensor * token_shift, const llama_ubatch & ubatch, int il, - bool worst_case); + bool worst_case) override; - ggml_tensor * build_rwkv6_time_mix( + virtual ggml_tensor * build_rwkv6_time_mix( ggml_context * ctx0, ggml_cgraph * graph, ggml_tensor * cur, @@ -270,17 +529,48 @@ struct llama_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il, - bool worst_case); + bool worst_case) override; - struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] + // state save/load - // === vision === + virtual size_t state_get_size() override; + virtual size_t state_get_data( uint8_t * dst, size_t size) override; + virtual size_t state_set_data(const uint8_t * src, size_t size) override; - // TODO: find a better way to accommodate mutli-dimension position encoding methods - // number of position id each token get, 1 for each token in most cases. - // when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate. - int n_pos_per_token = 1; + virtual size_t state_seq_get_size(llama_seq_id seq_id) override; + virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override; + virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override; + + virtual bool state_load_file( + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out) override; + + virtual bool state_save_file( + const char * filepath, + const llama_token * tokens, + size_t n_token_count) override; + + virtual size_t state_seq_load_file( + llama_seq_id seq_id, + const char * filepath, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out) override; + + virtual size_t state_seq_save_file( + llama_seq_id seq_id, + const char * filepath, + const llama_token * tokens, + size_t n_token_count) override; + +private: + size_t state_get_data(struct llama_data_write & data_ctx); + size_t state_set_data(struct llama_data_read & data_ctx); + + size_t state_seq_get_data(struct llama_data_write & data_ctx, llama_seq_id seq_id); + size_t state_seq_set_data(struct llama_data_read & data_ctx, llama_seq_id seq_id); }; // For internal test use diff --git a/src/llama.cpp b/src/llama.cpp index ed5e1e525..7c002f9bf 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8,7 +8,6 @@ #include "llama-model.h" #include "ggml.h" -#include "ggml-alloc.h" #include "ggml-backend.h" #include "ggml-cpp.h" @@ -86,8 +85,6 @@ struct llm_build_context { const float norm_rms_eps; const int32_t n_tokens; - const int32_t n_outputs; - const int32_t n_outputs_enc; const int32_t n_ctx_orig; const bool worst_case; @@ -98,9 +95,8 @@ struct llm_build_context { const llm_build_cb & cb; - std::vector & buf_compute_meta; - - struct ggml_context * ctx0 = nullptr; + const ggml_context_ptr ctx = nullptr; + ggml_context * ctx0 = nullptr; // TODO: consider making the entire interface noexcept llm_build_context( @@ -136,132 +132,37 @@ struct llm_build_context { norm_eps (hparams.f_norm_eps), norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (ubatch.n_tokens), - n_outputs (worst_case ? n_tokens : lctx.n_outputs), - n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), n_ctx_orig (cparams.n_ctx_orig_yarn), worst_case (worst_case), flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), - buf_compute_meta (lctx.buf_compute_meta) { - // all initializations should be done in init() + ctx (lctx.init()), + ctx0 (ctx.get()) { } - void init() { - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute_meta.size(), - /*.mem_buffer =*/ buf_compute_meta.data(), - /*.no_alloc =*/ true, - }; - - ctx0 = ggml_init(params); - - lctx.reset(); - } - - void free() { - ggml_free(ctx0); - ctx0 = nullptr; - } - + // TODO: tmp struct ggml_tensor * build_inp_embd(struct ggml_tensor * tok_embd) { - struct ggml_tensor * inpL; - - if (ubatch.token) { - lctx.inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); - cb(lctx.inp_tokens, "inp_tokens", -1); - ggml_set_input(lctx.inp_tokens); - - inpL = ggml_get_rows(ctx0, tok_embd, lctx.inp_tokens); - - // apply lora for embedding tokens if needed - for (const auto & lora : loras) { - struct llama_adapter_lora_weight * lw = lora.first->get_weight(tok_embd); - if (lw == nullptr) { - continue; - } - - const float adapter_scale = lora.second; - const float scale = lw->get_scale(lora.first->alpha, adapter_scale); - - struct ggml_tensor * inpL_delta = ggml_scale(ctx0, ggml_mul_mat( - ctx0, lw->b, // non-transposed lora_b - ggml_get_rows(ctx0, lw->a, lctx.inp_tokens) - ), scale); - - inpL = ggml_add(ctx0, inpL, inpL_delta); - } - } else { - lctx.inp_embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); - inpL = lctx.inp_embd; - ggml_set_input(lctx.inp_embd); - } - - // For Granite architecture - if (hparams.f_embedding_scale != 0.0f) { - inpL = ggml_scale(ctx0, inpL, hparams.f_embedding_scale); - } - + struct ggml_tensor * inpL = lctx.build_inp_embd(ctx0, tok_embd, ubatch); cb(inpL, "inp_embd", -1); return inpL; } - // do mat_mul, while optionally apply lora + // TODO: tmp struct ggml_tensor * build_lora_mm( struct ggml_tensor * w, struct ggml_tensor * cur) { - struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur); - - for (const auto & lora : loras) { - struct llama_adapter_lora_weight * lw = lora.first->get_weight(w); - if (lw == nullptr) { - continue; - } - - const float adapter_scale = lora.second; - const float scale = lw->get_scale(lora.first->alpha, adapter_scale); - - struct ggml_tensor * ab_cur = ggml_mul_mat( - ctx0, lw->b, - ggml_mul_mat(ctx0, lw->a, cur) - ); - - ab_cur = ggml_scale(ctx0, ab_cur, scale); - res = ggml_add(ctx0, res, ab_cur); - } - - return res; + return lctx.build_lora_mm(ctx0, w, cur); } - // do mat_mul_id, while optionally apply lora + // TODO: tmp struct ggml_tensor * build_lora_mm_id( struct ggml_tensor * w, // struct ggml_tensor * as struct ggml_tensor * cur, // struct ggml_tensor * b struct ggml_tensor * ids) { - struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids); - for (const auto & lora : loras) { - struct llama_adapter_lora_weight * lw = lora.first->get_weight(w); - if (lw == nullptr) { - continue; - } - - const float alpha = lora.first->alpha; - const float rank = (float) lw->b->ne[0]; - const float scale = alpha ? lora.second * alpha / rank : lora.second; - - struct ggml_tensor * ab_cur = ggml_mul_mat_id( - ctx0, lw->b, - ggml_mul_mat_id(ctx0, lw->a, cur, ids), - ids - ); - - ab_cur = ggml_scale(ctx0, ab_cur, scale); - res = ggml_add(ctx0, res, ab_cur); - } - - return res; + return lctx.build_lora_mm_id(ctx0, w, cur, ids); } struct ggml_tensor * build_norm( @@ -620,31 +521,31 @@ struct llm_build_context { } struct ggml_tensor * build_inp_pos() { - lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - cb(lctx.inp_pos, "inp_pos", -1); - ggml_set_input(lctx.inp_pos); - return lctx.inp_pos; + ggml_tensor * cur = lctx.build_inp_pos(ctx0, n_tokens); + cb(cur, "inp_pos", -1); + + return cur; } struct ggml_tensor * build_inp_out_ids() { - lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs); - cb(lctx.inp_out_ids, "inp_out_ids", -1); - ggml_set_input(lctx.inp_out_ids); - return lctx.inp_out_ids; + ggml_tensor * cur = lctx.build_inp_out_ids(ctx0, n_tokens, worst_case); + cb(cur, "inp_out_ids", -1); + + return cur; } struct ggml_tensor * build_inp_mean() { - lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); - cb(lctx.inp_mean, "inp_mean", -1); - ggml_set_input(lctx.inp_mean); - return lctx.inp_mean; + ggml_tensor * cur = lctx.build_inp_mean(ctx0, n_tokens); + cb(cur, "inp_mean", -1); + + return cur; } struct ggml_tensor * build_inp_cls() { - lctx.inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - cb(lctx.inp_cls, "inp_cls", -1); - ggml_set_input(lctx.inp_cls); - return lctx.inp_cls; + ggml_tensor * cur = lctx.build_inp_cls(ctx0, n_tokens); + cb(cur, "inp_cls", -1); + + return cur; } struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { @@ -745,26 +646,22 @@ struct llm_build_context { //} struct ggml_tensor * build_inp_embd_enc() { - const int64_t n_embd = hparams.n_embd; - lctx.inp_embd_enc = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_outputs_enc); - ggml_set_input(lctx.inp_embd_enc); - cb(lctx.inp_embd_enc, "embd_enc", -1); - return lctx.inp_embd_enc; + ggml_tensor * cur = lctx.build_inp_embd_enc(ctx0, n_tokens, worst_case); + cb(cur, "embd_enc", -1); + + return cur; } struct ggml_tensor * build_inp_KQ_mask_cross() { - lctx.inp_KQ_mask_cross = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_outputs_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - ggml_set_input(lctx.inp_KQ_mask_cross); - cb(lctx.inp_KQ_mask_cross, "KQ_mask_cross", -1); - return lctx.inp_KQ_mask_cross; + ggml_tensor * cur = lctx.build_inp_KQ_mask_cross(ctx0, n_tokens, worst_case); + cb(cur, "KQ_mask_cross", -1); + + return cur; } struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -838,7 +735,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -927,9 +823,6 @@ struct llm_build_context { struct ggml_cgraph * build_deci() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -1014,7 +907,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -1422,9 +1314,6 @@ struct llm_build_context { struct ggml_cgraph * build_grok() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -1498,7 +1387,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -1580,9 +1468,6 @@ struct llm_build_context { struct ggml_cgraph * build_dbrx() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -1649,7 +1534,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -2716,10 +2600,7 @@ struct llm_build_context { inpL = build_inp_embd(model.tok_embd); // inp_pos - contains the positions - lctx.inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens * 4); - cb(lctx.inp_pos, "inp_pos", -1); - ggml_set_input(lctx.inp_pos); - struct ggml_tensor * inp_pos = lctx.inp_pos; + struct ggml_tensor * inp_pos = build_inp_pos(); lctx.build_attn_inp(ctx0, n_tokens, true, false, worst_case); @@ -2825,9 +2706,6 @@ struct llm_build_context { struct ggml_cgraph * build_qwen2moe() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -2891,7 +2769,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -4685,9 +4562,6 @@ struct llm_build_context { struct ggml_cgraph * build_olmo() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -4757,7 +4631,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -4808,9 +4681,6 @@ struct llm_build_context { struct ggml_cgraph * build_olmo2() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -4880,7 +4750,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -4935,9 +4804,6 @@ struct llm_build_context { struct ggml_cgraph * build_olmoe() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -5006,7 +4872,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -5325,9 +5190,6 @@ struct llm_build_context { struct ggml_cgraph * build_arctic() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -5385,7 +5247,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -5458,9 +5319,6 @@ struct llm_build_context { struct ggml_cgraph * build_deepseek() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -5535,7 +5393,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -5616,9 +5473,6 @@ struct llm_build_context { struct ggml_cgraph * build_deepseek2() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - bool is_lite = (hparams.n_layer == 27); // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. @@ -5767,7 +5621,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -5996,9 +5849,6 @@ struct llm_build_context { //struct ggml_cgraph * build_t5_enc() { // struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // // mutable variable, needed during the last layer of the computation to skip unused tokens - // int32_t n_tokens = this->n_tokens; - // const int64_t n_embd_head = hparams.n_embd_head_v; // const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); // GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6072,7 +5922,6 @@ struct llm_build_context { // if (il == n_layer - 1) { // // skip computing output for unused tokens // struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - // n_tokens = n_outputs; // cur = ggml_get_rows(ctx0, cur, inp_out_ids); // inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); // } @@ -6128,9 +5977,6 @@ struct llm_build_context { //struct ggml_cgraph * build_t5_dec() { // struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // // mutable variable, needed during the last layer of the computation to skip unused tokens - // int32_t n_tokens = this->n_tokens; - // const int64_t n_embd_head = hparams.n_embd_head_v; // const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); // GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6272,7 +6118,6 @@ struct llm_build_context { // if (il == n_layer - 1) { // // skip computing output for unused tokens // struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - // n_tokens = n_outputs; // cur = ggml_get_rows(ctx0, cur, inp_out_ids); // inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); // inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); @@ -6673,9 +6518,6 @@ struct llm_build_context { struct ggml_cgraph * build_exaone() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -6748,7 +6590,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -6978,9 +6819,6 @@ struct llm_build_context { struct ggml_cgraph * build_chameleon() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -7076,7 +6914,6 @@ struct llm_build_context { if (il == n_layer - 1) { // skip computing output for unused tokens struct ggml_tensor * inp_out_ids = build_inp_out_ids(); - n_tokens = n_outputs; cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -7341,8 +7178,6 @@ static struct ggml_cgraph * llama_build_graph( struct llm_build_context llm(lctx, ubatch, cb, worst_case); - llm.init(); - switch (model.arch) { case LLM_ARCH_LLAMA: case LLM_ARCH_MINICPM: @@ -7403,7 +7238,6 @@ static struct ggml_cgraph * llama_build_graph( } break; case LLM_ARCH_QWEN2VL: { - lctx.n_pos_per_token = 4; result = llm.build_qwen2vl(); } break; case LLM_ARCH_QWEN2MOE: @@ -7564,8 +7398,6 @@ static struct ggml_cgraph * llama_build_graph( result = llm.append_pooling(result); } - llm.free(); - return result; } @@ -7908,7 +7740,7 @@ struct llama_context * llama_init_from_model( try { // TODO: add logic which llama_context implementation to construct - ctx = new llama_context(*model, params, + ctx = new llama_context_unified(*model, params, [](llama_context & lctx, const llama_ubatch & ubatch, bool worst_case) { return llama_build_graph(lctx, ubatch, worst_case); });