diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 94d6d4f90..55f1c0382 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -246,31 +246,48 @@ void llama_context::init() { uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + int n_splits_pp = -1; + int n_nodes_pp = -1; + + int n_splits_tg = -1; + int n_nodes_tg = -1; + // reserve pp graph first so that buffers are only allocated once - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - ggml_cgraph * gf_pp = build_graph(ubatch_pp, true); - if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__); - throw std::runtime_error("failed to allocate compute buffers"); + { + llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + auto res_pp = graph_build(ubatch_pp, true); + auto & gf_pp = res_pp.gf; + if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__); + throw std::runtime_error("failed to allocate compute buffers"); + } + + n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_pp = ggml_graph_n_nodes(gf_pp); } - int n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); - int n_nodes_pp = ggml_graph_n_nodes(gf_pp); // reserve with tg graph to get the number of splits and nodes - llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - ggml_cgraph * gf_tg = build_graph(ubatch_tg, true); - if (!ggml_backend_sched_reserve(sched.get(), gf_tg)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute tg buffers\n", __func__); - throw std::runtime_error("failed to allocate compute buffers"); + { + llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + auto res_tg = graph_build(ubatch_tg, true); + auto & gf_tg = res_tg.gf; + if (!ggml_backend_sched_reserve(sched.get(), gf_tg)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute tg buffers\n", __func__); + throw std::runtime_error("failed to allocate compute buffers"); + } + n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_tg = ggml_graph_n_nodes(gf_tg); } - int n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); - int n_nodes_tg = ggml_graph_n_nodes(gf_tg); // reserve again with pp graph to avoid ggml-alloc reallocations during inference - gf_pp = build_graph(ubatch_pp, true); - if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__); - throw std::runtime_error("failed to allocate compute buffers"); + { + llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + auto res_pp = graph_build(ubatch_pp, true); + auto & gf_pp = res_pp.gf; + if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__); + throw std::runtime_error("failed to allocate compute buffers"); + } } for (size_t i = 0; i < backend_ptrs.size(); ++i) { @@ -890,7 +907,7 @@ void llama_context::build_cb( } } -ggml_cgraph * llama_context::build_graph(const llama_ubatch & ubatch, bool worst_case) { +llama_graph_result llama_context::graph_build(const llama_ubatch & ubatch, bool worst_case) { return model.build_graph(*this, cparams, ubatch, graph_init(), worst_case); } @@ -1814,11 +1831,11 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - ggml_cgraph * gf = build_graph(ubatch, true); + auto res = graph_build(ubatch, true); // initialize scheduler with the worst-case graph ggml_backend_sched_reset(sched.get()); - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + if (!ggml_backend_sched_reserve(sched.get(), res.gf)) { LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); } @@ -1828,7 +1845,9 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - ggml_cgraph * gf = build_graph(ubatch, false); + auto res = graph_build(ubatch, false); + + auto & gf = res.gf; // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -2073,7 +2092,9 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - ggml_cgraph * gf = build_graph(ubatch, false); + auto res = graph_build(ubatch, false); + + auto & gf = res.gf; ggml_backend_sched_alloc_graph(sched.get(), gf); diff --git a/src/llama-context.h b/src/llama-context.h index 7a10f84bd..981afcc00 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -95,6 +95,9 @@ struct llama_context : public llama_graph_i { // zero-out inputs and create ggml_context virtual ggml_context_ptr graph_init(); + // TODO: add encode/decode graphs + virtual llama_graph_result graph_build(const llama_ubatch & ubatch, bool worst_case); + // returns the result of ggml_backend_sched_graph_compute_async execution virtual enum ggml_status graph_compute( ggml_cgraph * graph, @@ -145,9 +148,6 @@ struct llama_context : public llama_graph_i { const llama_ubatch & ubatch, int il); - // TODO: add encode/decode graphs - virtual ggml_cgraph * build_graph(const llama_ubatch & ubatch, bool worst_case); - // apply control vector for layer il virtual ggml_tensor * build_cvec( ggml_context * ctx0, diff --git a/src/llama-graph.h b/src/llama-graph.h index d60b57491..de3cd2f04 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -10,6 +10,13 @@ struct ggml_context; struct ggml_tensor; struct llama_ubatch; +struct llama_graph_result { + ggml_cgraph * gf = nullptr; + + ggml_tensor * t_logits = nullptr; + ggml_tensor * t_embd = nullptr; +}; + // TODO: can become more granular in the future class llama_graph_i { public: diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 543e78d2b..4950af59b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4251,22 +4251,6 @@ struct llm_build_context { return cur; } - struct ggml_cgraph * build_kv_self_shift() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - - lgf.build_kv_self_shift(ctx0, gf); - - return gf; - } - - struct ggml_cgraph * build_kv_self_defrag() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - - lgf.build_kv_self_defrag(ctx0, gf); - - return gf; - } - struct ggml_tensor * build_inp_pos() { ggml_tensor * cur = lgf.build_inp_pos(ctx0, n_tokens); cb(cur, "inp_pos", -1); @@ -4295,7 +4279,7 @@ struct llm_build_context { return cur; } - struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { + void append_pooling(struct ggml_cgraph * gf) { // find result_norm tensor for input struct ggml_tensor * inp = nullptr; for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) { @@ -4356,8 +4340,6 @@ struct llm_build_context { cb(cur, "result_embd_pooled", -1); ggml_build_forward_expand(gf, cur); - - return gf; } //struct ggml_tensor * build_pos_bucket(bool causal) { @@ -4406,9 +4388,7 @@ struct llm_build_context { return cur; } - struct ggml_cgraph * build_llama() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_llama(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -4563,13 +4543,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_deci() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_deci(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -4719,13 +4695,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_baichuan() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_baichuan(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -4834,13 +4806,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_xverse() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_xverse(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -4937,13 +4905,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_falcon() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_falcon(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5057,13 +5021,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_grok() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_grok(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -5211,13 +5171,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_dbrx() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_dbrx(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5334,13 +5290,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_starcoder() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_starcoder(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5438,13 +5390,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_refact() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_refact(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5532,13 +5480,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_bert() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_bert(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -5726,13 +5670,9 @@ struct llm_build_context { cb(cur, "result_embd", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_bloom() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_bloom(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5827,13 +5767,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_mpt() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_mpt(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5967,13 +5903,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_stablelm() { - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - + void build_stablelm(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6117,13 +6049,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_qwen() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_qwen(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6229,13 +6157,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_qwen2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_qwen2(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -6341,12 +6265,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_qwen2vl() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + void build_qwen2vl(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -6457,13 +6378,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_qwen2moe() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_qwen2moe(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -6601,13 +6518,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_phi2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_phi2(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6722,13 +6635,11 @@ struct llm_build_context { cur = ggml_add(ctx0, cur, model.output_b); cb(cur, "result_output", -1); + ggml_build_forward_expand(gf, cur); - return gf; } - struct ggml_cgraph * build_phi3() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_phi3(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6866,14 +6777,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - - struct ggml_cgraph * build_plamo() { - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - + void build_plamo(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -6971,13 +6877,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_gpt2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_gpt2(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7076,13 +6978,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_codeshell() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_codeshell(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7187,13 +7085,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_orion() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_orion(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -7305,13 +7199,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_internlm2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_internlm2(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -7423,13 +7313,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_minicpm3() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_minicpm3(ggml_cgraph * gf) { //TODO: if the model varies, these parameters need to be read from the model const int64_t n_embd_base = 256; const float scale_embd = 12.0f; @@ -7633,13 +7519,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_gemma() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_gemma(ggml_cgraph * gf) { const int64_t n_embd_head_k = hparams.n_embd_head_k; struct ggml_tensor * cur; @@ -7741,13 +7623,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_gemma2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_gemma2(ggml_cgraph * gf) { const int64_t n_embd_head_k = hparams.n_embd_head_k; struct ggml_tensor * cur; @@ -7871,14 +7749,10 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - - struct ggml_cgraph * build_starcoder2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + // TODO: move up next to build_starcoder + void build_starcoder2(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -7991,13 +7865,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_mamba() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_mamba(ggml_cgraph * gf) { struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -8045,14 +7915,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_command_r() { - - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_command_r(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); const float f_logit_scale = hparams.f_logit_scale; @@ -8193,14 +8058,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; - } - struct ggml_cgraph * build_cohere2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_cohere2(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); const float f_logit_scale = hparams.f_logit_scale; @@ -8322,8 +8182,6 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } // ref: https://allenai.org/olmo @@ -8332,9 +8190,7 @@ struct llm_build_context { // * clamp qkv // * removed bias // * removed MoE - struct ggml_cgraph * build_olmo() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_olmo(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -8447,13 +8303,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_olmo2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_olmo2(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -8566,17 +8418,13 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } // based on the build_qwen2moe() function, changes: // * removed shared experts // * removed bias // * added q, k norm - struct ggml_cgraph * build_olmoe() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_olmoe(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -8692,13 +8540,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_openelm() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_openelm(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -8817,13 +8661,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_gptneox() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_gptneox(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -8960,13 +8800,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_arctic() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_arctic(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -9089,13 +8925,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_deepseek() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_deepseek(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -9244,13 +9076,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_deepseek2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_deepseek2(ggml_cgraph * gf) { bool is_lite = (hparams.n_layer == 27); // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. @@ -9471,13 +9299,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_bitnet() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_bitnet(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -9622,12 +9446,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - return gf; } - //struct ggml_cgraph * build_t5_enc() { - // struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + //void build_t5_enc(ggml_cgraph * gf) { // const int64_t n_embd_head = hparams.n_embd_head_v; // const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); // GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -9749,13 +9570,9 @@ struct llm_build_context { // cb(cur, "result_norm", -1); // ggml_build_forward_expand(gf, cur); - - // return gf; //} - //struct ggml_cgraph * build_t5_dec() { - // struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + //void build_t5_dec(ggml_cgraph * gf) { // const int64_t n_embd_head = hparams.n_embd_head_v; // const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); // GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -9954,9 +9771,7 @@ struct llm_build_context { // return gf; //} - struct ggml_cgraph * build_jais() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_jais(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -10041,13 +9856,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_chatglm() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_chatglm(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -10170,13 +9981,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_nemotron() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_nemotron(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); //GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -10290,13 +10097,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_exaone() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_exaone(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -10412,13 +10215,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - ggml_cgraph * build_rwkv6() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_rwkv6(ggml_cgraph * gf) { GGML_ASSERT(hparams.token_shift_count == 2); struct ggml_tensor * cur; @@ -10502,14 +10301,10 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py - ggml_cgraph * build_rwkv6qwen2() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_rwkv6qwen2(ggml_cgraph * gf) { GGML_ASSERT(n_embd == hparams.n_embd_k_s()); struct ggml_tensor * cur; @@ -10586,8 +10381,6 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } // ref: https://github.com/facebookresearch/chameleon @@ -10596,9 +10389,7 @@ struct llm_build_context { // * swin-norm // * removed bias // * removed MoE - struct ggml_cgraph * build_chameleon() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_chameleon(ggml_cgraph * gf) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -10759,13 +10550,9 @@ struct llm_build_context { cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); - - return gf; } - struct ggml_cgraph * build_wavtokenizer_dec() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); - + void build_wavtokenizer_dec(ggml_cgraph * gf) { struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -10911,231 +10698,233 @@ struct llm_build_context { cb(cur, "result_embd", -1); ggml_build_forward_expand(gf, cur); - - return gf; } }; -ggml_cgraph * llama_model::build_graph( +llama_graph_result llama_model::build_graph( llama_graph_i & lgf, const llama_cparams & cparams, const llama_ubatch & ubatch, ggml_context_ptr && ctx, bool worst_case) const { - struct ggml_cgraph * result = NULL; + llama_graph_result result = {}; struct llm_build_context llm(lgf, *this, cparams, ubatch, std::move(ctx), worst_case); + auto & gf = result.gf; + + gf = ggml_new_graph_custom(llm.ctx0, max_nodes(), false); + switch (arch) { case LLM_ARCH_LLAMA: case LLM_ARCH_MINICPM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: { - result = llm.build_llama(); + llm.build_llama(gf); } break; case LLM_ARCH_DECI: { - result = llm.build_deci(); + llm.build_deci(gf); } break; case LLM_ARCH_BAICHUAN: { - result = llm.build_baichuan(); + llm.build_baichuan(gf); } break; case LLM_ARCH_FALCON: { - result = llm.build_falcon(); + llm.build_falcon(gf); } break; case LLM_ARCH_GROK: { - result = llm.build_grok(); + llm.build_grok(gf); } break; case LLM_ARCH_STARCODER: { - result = llm.build_starcoder(); + llm.build_starcoder(gf); } break; case LLM_ARCH_REFACT: { - result = llm.build_refact(); + llm.build_refact(gf); } break; case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: { - result = llm.build_bert(); + llm.build_bert(gf); } break; case LLM_ARCH_BLOOM: { - result = llm.build_bloom(); + llm.build_bloom(gf); } break; case LLM_ARCH_MPT: { - result = llm.build_mpt(); + llm.build_mpt(gf); } break; case LLM_ARCH_STABLELM: { - result = llm.build_stablelm(); + llm.build_stablelm(gf); } break; case LLM_ARCH_QWEN: { - result = llm.build_qwen(); + llm.build_qwen(gf); } break; case LLM_ARCH_QWEN2: { - result = llm.build_qwen2(); + llm.build_qwen2(gf);; } break; case LLM_ARCH_QWEN2VL: { - result = llm.build_qwen2vl(); + llm.build_qwen2vl(gf); } break; case LLM_ARCH_QWEN2MOE: { - result = llm.build_qwen2moe(); + llm.build_qwen2moe(gf); } break; case LLM_ARCH_PHI2: { - result = llm.build_phi2(); + llm.build_phi2(gf); } break; case LLM_ARCH_PHI3: case LLM_ARCH_PHIMOE: { - result = llm.build_phi3(); + llm.build_phi3(gf); } break; case LLM_ARCH_PLAMO: { - result = llm.build_plamo(); + llm.build_plamo(gf); } break; case LLM_ARCH_GPT2: { - result = llm.build_gpt2(); + llm.build_gpt2(gf); } break; case LLM_ARCH_CODESHELL: { - result = llm.build_codeshell(); + llm.build_codeshell(gf); } break; case LLM_ARCH_ORION: { - result = llm.build_orion(); + llm.build_orion(gf); } break; case LLM_ARCH_INTERNLM2: { - result = llm.build_internlm2(); + llm.build_internlm2(gf); } break; case LLM_ARCH_MINICPM3: { - result = llm.build_minicpm3(); + llm.build_minicpm3(gf); } break; case LLM_ARCH_GEMMA: { - result = llm.build_gemma(); + llm.build_gemma(gf); } break; case LLM_ARCH_GEMMA2: { - result = llm.build_gemma2(); + llm.build_gemma2(gf); } break; case LLM_ARCH_STARCODER2: { - result = llm.build_starcoder2(); + llm.build_starcoder2(gf); } break; case LLM_ARCH_MAMBA: { - result = llm.build_mamba(); + llm.build_mamba(gf); } break; case LLM_ARCH_XVERSE: { - result = llm.build_xverse(); + llm.build_xverse(gf); } break; case LLM_ARCH_COMMAND_R: { - result = llm.build_command_r(); + llm.build_command_r(gf); } break; case LLM_ARCH_COHERE2: { - result = llm.build_cohere2(); + llm.build_cohere2(gf); } break; case LLM_ARCH_DBRX: { - result = llm.build_dbrx(); + llm.build_dbrx(gf); } break; case LLM_ARCH_OLMO: { - result = llm.build_olmo(); + llm.build_olmo(gf); } break; case LLM_ARCH_OLMO2: { - result = llm.build_olmo2(); + llm.build_olmo2(gf); } break; case LLM_ARCH_OLMOE: { - result = llm.build_olmoe(); + llm.build_olmoe(gf); } break; case LLM_ARCH_OPENELM: { - result = llm.build_openelm(); + llm.build_openelm(gf); } break; case LLM_ARCH_GPTNEOX: { - result = llm.build_gptneox(); + llm.build_gptneox(gf); } break; case LLM_ARCH_ARCTIC: { - result = llm.build_arctic(); + llm.build_arctic(gf); } break; case LLM_ARCH_DEEPSEEK: { - result = llm.build_deepseek(); + llm.build_deepseek(gf); } break; case LLM_ARCH_DEEPSEEK2: { - result = llm.build_deepseek2(); + llm.build_deepseek2(gf); } break; case LLM_ARCH_CHATGLM: { - result = llm.build_chatglm(); + llm.build_chatglm(gf); } break; case LLM_ARCH_BITNET: { - result = llm.build_bitnet(); + llm.build_bitnet(gf); } break; //case LLM_ARCH_T5: // { // if (lctx.is_encoding) { - // result = llm.build_t5_enc(); + // llm.build_t5_enc(gf); // } else { - // result = llm.build_t5_dec(); + // llm.build_t5_dec(gf); // } // } break; //case LLM_ARCH_T5ENCODER: // { - // result = llm.build_t5_enc(); + // llm.build_t5_enc(gf); // } break; case LLM_ARCH_JAIS: { - result = llm.build_jais(); + llm.build_jais(gf); } break; case LLM_ARCH_NEMOTRON: { - result = llm.build_nemotron(); + llm.build_nemotron(gf); } break; case LLM_ARCH_EXAONE: { - result = llm.build_exaone(); + llm.build_exaone(gf); } break; case LLM_ARCH_RWKV6: { - result = llm.build_rwkv6(); + llm.build_rwkv6(gf); } break; case LLM_ARCH_RWKV6QWEN2: { - result = llm.build_rwkv6qwen2(); + llm.build_rwkv6qwen2(gf); } break; case LLM_ARCH_CHAMELEON: { - result = llm.build_chameleon(); + llm.build_chameleon(gf); } break; case LLM_ARCH_WAVTOKENIZER_DEC: { - result = llm.build_wavtokenizer_dec(); + llm.build_wavtokenizer_dec(gf); } break; default: GGML_ABORT("fatal error"); @@ -11143,7 +10932,7 @@ ggml_cgraph * llama_model::build_graph( // add on pooling layer if (cparams.embeddings) { - result = llm.append_pooling(result); + llm.append_pooling(gf); } return result; diff --git a/src/llama-model.h b/src/llama-model.h index 0374b484b..a3267bbbb 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -16,6 +16,7 @@ class llama_graph_i; struct llama_cparams; struct llama_ubatch; struct llama_model_loader; +struct llama_graph_result; // available models enum llm_type { @@ -368,8 +369,7 @@ struct llama_model { const struct ggml_tensor * get_tensor(const char * name) const; // TODO: add encode/decode graphs - // TODO: return a struct containing the graph and the output tensors, such as logits, embeddings, etc. - ggml_cgraph * build_graph( + llama_graph_result build_graph( llama_graph_i & lgf, const llama_cparams & cparams, const llama_ubatch & ubatch,