From 9e50456e19ac5c24c40387e6b4a2b3072f7a9d8e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 18 Feb 2025 14:53:02 +0200 Subject: [PATCH] context : minor simplify ggml-ci --- src/llama-context.cpp | 24 +++++++++++------------- src/llama-context.h | 2 +- src/llama-model.cpp | 20 +++++++++----------- src/llama-model.h | 2 +- 4 files changed, 22 insertions(+), 26 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 0e0af806d..d9735cfaa 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -256,7 +256,7 @@ void llama_context::init() { { llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; auto ctx = graph_init(); - auto res_pp = graph_build(ctx, ubatch_pp, true); + auto res_pp = graph_build(ctx.get(), ubatch_pp, true); auto & gf_pp = res_pp.gf; if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) { LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__); @@ -271,7 +271,7 @@ void llama_context::init() { { llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; auto ctx = graph_init(); - auto res_tg = graph_build(ctx, ubatch_tg, true); + auto res_tg = graph_build(ctx.get(), ubatch_tg, true); auto & gf_tg = res_tg.gf; if (!ggml_backend_sched_reserve(sched.get(), gf_tg)) { LLAMA_LOG_ERROR("%s: failed to allocate compute tg buffers\n", __func__); @@ -285,7 +285,7 @@ void llama_context::init() { { llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; auto ctx = graph_init(); - auto res_pp = graph_build(ctx, ubatch_pp, true); + auto res_pp = graph_build(ctx.get(), ubatch_pp, true); auto & gf_pp = res_pp.gf; if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) { LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__); @@ -573,7 +573,7 @@ ggml_context_ptr llama_context::graph_init() { } llama_graph_result llama_context::graph_build( - ggml_context_ptr & ctx, + ggml_context * ctx, const llama_ubatch & ubatch, bool worst_case) { return model.build_graph(ctx, *this, cparams, ubatch, worst_case); @@ -1720,7 +1720,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); auto ctx = graph_init(); - auto res = graph_build(ctx, ubatch, false); + auto res = graph_build(ctx.get(), ubatch, false); auto * gf = res.gf; @@ -2000,7 +2000,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; auto ctx = graph_init(); - auto res = graph_build(ctx, ubatch, true); + auto res = graph_build(ctx.get(), ubatch, true); // initialize scheduler with the worst-case graph ggml_backend_sched_reset(sched.get()); @@ -2015,7 +2015,7 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); auto ctx = graph_init(); - auto res = graph_build(ctx, ubatch, false); + auto res = graph_build(ctx.get(), ubatch, false); auto * gf = res.gf; @@ -2483,11 +2483,10 @@ void llama_context_kv_self::kv_self_update() { ggml_backend_sched_reset(sched.get()); auto ctx = graph_init(); - auto * ctx0 = ctx.get(); - ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), model.max_nodes(), false); - build_kv_self_shift(ctx0, gf); + build_kv_self_shift(ctx.get(), gf); ggml_backend_sched_alloc_graph(sched.get(), gf); @@ -2512,11 +2511,10 @@ void llama_context_kv_self::kv_self_update() { ggml_backend_sched_reset(sched.get()); auto ctx = graph_init(); - auto * ctx0 = ctx.get(); - ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + ggml_cgraph * gf = ggml_new_graph_custom(ctx.get(), model.max_nodes(), false); - build_kv_self_defrag(ctx0, gf); + build_kv_self_defrag(ctx.get(), gf); ggml_backend_sched_alloc_graph(sched.get(), gf); diff --git a/src/llama-context.h b/src/llama-context.h index 9f6abfc82..4bf8244e6 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -97,7 +97,7 @@ struct llama_context : public llama_graph_i { // TODO: add encode/decode graphs virtual llama_graph_result graph_build( - ggml_context_ptr & ctx, + ggml_context * ctx, const llama_ubatch & ubatch, bool worst_case); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index ecfd6f185..289c3422e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3841,19 +3841,18 @@ struct llm_build_context { const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; - ggml_context_ptr & ctx; - ggml_context * ctx0 = nullptr; + ggml_context * ctx0 = nullptr; llama_graph_result res; // TODO: consider making the entire interface noexcept llm_build_context( - ggml_context_ptr & ctx, - llama_graph_i & lgf, - const llama_model & model, - const llama_cparams & cparams, - const llama_ubatch & ubatch, - bool worst_case) : + ggml_context * ctx, + llama_graph_i & lgf, + const llama_model & model, + const llama_cparams & cparams, + const llama_ubatch & ubatch, + bool worst_case) : lgf (lgf), model (model), hparams (model.hparams), @@ -3885,8 +3884,7 @@ struct llm_build_context { flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), - ctx (ctx), - ctx0 (this->ctx.get()) { + ctx0 (ctx) { } // TODO: tmp @@ -10937,7 +10935,7 @@ struct llm_build_context { }; llama_graph_result llama_model::build_graph( - ggml_context_ptr & ctx, + ggml_context * ctx, llama_graph_i & lgf, const llama_cparams & cparams, const llama_ubatch & ubatch, diff --git a/src/llama-model.h b/src/llama-model.h index f5d1f7b79..a7c53bdbd 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -370,7 +370,7 @@ struct llama_model { // TODO: add encode/decode graphs llama_graph_result build_graph( - ggml_context_ptr & ctx, + ggml_context * ctx, llama_graph_i & lgf, const llama_cparams & cparams, const llama_ubatch & ubatch,