From 8fcb563613e20a04dd9791f0a9b8a41086428c09 Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Fri, 14 Mar 2025 13:47:05 +0100 Subject: [PATCH] Load all MoE experts during warmup (#11571) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * llama : introduce llama_set_warmup() API call that controls warmup mode; use all MoE experts during warmup * common : use new API to enable warmup mode during model warmup --------- Co-authored-by: Stanisław Szymczyk --- common/common.cpp | 3 +++ include/llama.h | 4 ++++ src/llama-context.cpp | 13 ++++++++++++- src/llama-context.h | 1 + src/llama-cparams.h | 1 + src/llama-graph.cpp | 2 +- 6 files changed, 22 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 8487e3834..18ffb4e73 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1033,6 +1033,8 @@ struct common_init_result common_init_from_params(common_params & params) { if (params.warmup) { LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); + llama_set_warmup(lctx, true); + std::vector tmp; llama_token bos = llama_vocab_bos(vocab); llama_token eos = llama_vocab_eos(vocab); @@ -1063,6 +1065,7 @@ struct common_init_result common_init_from_params(common_params & params) { llama_kv_self_clear(lctx); llama_synchronize(lctx); llama_perf_context_reset(lctx); + llama_set_warmup(lctx, false); } iparams.model.reset(model); diff --git a/include/llama.h b/include/llama.h index e5286f061..6a44be404 100644 --- a/include/llama.h +++ b/include/llama.h @@ -945,6 +945,10 @@ extern "C" { // If set to true, the model will only attend to the past tokens LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn); + // Set whether the model is in warmup mode or not + // If true, all model tensors are activated during llama_decode() to load and cache their weights. + LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); + // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4df6b18ec..c2fcce42a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -39,6 +39,7 @@ llama_context::llama_context( cparams.flash_attn = params.flash_attn; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; + cparams.warmup = false; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; @@ -948,6 +949,12 @@ void llama_context::set_causal_attn(bool value) { cparams.causal_attn = value; } +void llama_context::set_warmup(bool value) { + LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + + cparams.warmup = value; +} + void llama_context::set_adapter_lora( llama_adapter_lora * adapter, float scale) { @@ -1594,7 +1601,7 @@ void llama_context::output_reorder() { // int32_t llama_context::graph_max_nodes() const { - return std::max(8192, 5*model.n_tensors()); + return std::max(65536, 5*model.n_tensors()); } ggml_cgraph * llama_context::graph_init() { @@ -2372,6 +2379,10 @@ void llama_set_causal_attn(llama_context * ctx, bool causal_attn) { ctx->set_causal_attn(causal_attn); } +void llama_set_warmup(llama_context * ctx, bool warmup) { + ctx->set_warmup(warmup); +} + void llama_synchronize(llama_context * ctx) { ctx->synchronize(); } diff --git a/src/llama-context.h b/src/llama-context.h index 88df8950e..04facb544 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -64,6 +64,7 @@ struct llama_context { void set_embeddings (bool value); void set_causal_attn(bool value); + void set_warmup(bool value); void set_adapter_lora( llama_adapter_lora * adapter, diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 252012f3d..30e550f02 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -29,6 +29,7 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; bool no_perf; + bool warmup; enum llama_pooling_type pooling_type; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index e4af50778..4e9087339 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -577,7 +577,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : n_embd_head_v (hparams.n_embd_head_v), n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), - n_expert_used (hparams.n_expert_used), + n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used), freq_base (cparams.rope_freq_base), freq_scale (cparams.rope_freq_scale), ext_factor (cparams.yarn_ext_factor),