mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 19:55:04 +00:00
Load all MoE experts during warmup (#11571)
* llama : introduce llama_set_warmup() API call that controls warmup mode; use all MoE experts during warmup * common : use new API to enable warmup mode during model warmup --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This commit is contained in:
@ -1033,6 +1033,8 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
if (params.warmup) {
|
if (params.warmup) {
|
||||||
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||||
|
|
||||||
|
llama_set_warmup(lctx, true);
|
||||||
|
|
||||||
std::vector<llama_token> tmp;
|
std::vector<llama_token> tmp;
|
||||||
llama_token bos = llama_vocab_bos(vocab);
|
llama_token bos = llama_vocab_bos(vocab);
|
||||||
llama_token eos = llama_vocab_eos(vocab);
|
llama_token eos = llama_vocab_eos(vocab);
|
||||||
@ -1063,6 +1065,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||||||
llama_kv_self_clear(lctx);
|
llama_kv_self_clear(lctx);
|
||||||
llama_synchronize(lctx);
|
llama_synchronize(lctx);
|
||||||
llama_perf_context_reset(lctx);
|
llama_perf_context_reset(lctx);
|
||||||
|
llama_set_warmup(lctx, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
iparams.model.reset(model);
|
iparams.model.reset(model);
|
||||||
|
@ -945,6 +945,10 @@ extern "C" {
|
|||||||
// If set to true, the model will only attend to the past tokens
|
// If set to true, the model will only attend to the past tokens
|
||||||
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
|
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
|
||||||
|
|
||||||
|
// Set whether the model is in warmup mode or not
|
||||||
|
// If true, all model tensors are activated during llama_decode() to load and cache their weights.
|
||||||
|
LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
|
||||||
|
|
||||||
// Set abort callback
|
// Set abort callback
|
||||||
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
|
||||||
|
|
||||||
|
@ -39,6 +39,7 @@ llama_context::llama_context(
|
|||||||
cparams.flash_attn = params.flash_attn;
|
cparams.flash_attn = params.flash_attn;
|
||||||
cparams.no_perf = params.no_perf;
|
cparams.no_perf = params.no_perf;
|
||||||
cparams.pooling_type = params.pooling_type;
|
cparams.pooling_type = params.pooling_type;
|
||||||
|
cparams.warmup = false;
|
||||||
|
|
||||||
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
|
||||||
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
|
||||||
@ -948,6 +949,12 @@ void llama_context::set_causal_attn(bool value) {
|
|||||||
cparams.causal_attn = value;
|
cparams.causal_attn = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_context::set_warmup(bool value) {
|
||||||
|
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||||
|
|
||||||
|
cparams.warmup = value;
|
||||||
|
}
|
||||||
|
|
||||||
void llama_context::set_adapter_lora(
|
void llama_context::set_adapter_lora(
|
||||||
llama_adapter_lora * adapter,
|
llama_adapter_lora * adapter,
|
||||||
float scale) {
|
float scale) {
|
||||||
@ -1594,7 +1601,7 @@ void llama_context::output_reorder() {
|
|||||||
//
|
//
|
||||||
|
|
||||||
int32_t llama_context::graph_max_nodes() const {
|
int32_t llama_context::graph_max_nodes() const {
|
||||||
return std::max<int32_t>(8192, 5*model.n_tensors());
|
return std::max<int32_t>(65536, 5*model.n_tensors());
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cgraph * llama_context::graph_init() {
|
ggml_cgraph * llama_context::graph_init() {
|
||||||
@ -2372,6 +2379,10 @@ void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
|
|||||||
ctx->set_causal_attn(causal_attn);
|
ctx->set_causal_attn(causal_attn);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_set_warmup(llama_context * ctx, bool warmup) {
|
||||||
|
ctx->set_warmup(warmup);
|
||||||
|
}
|
||||||
|
|
||||||
void llama_synchronize(llama_context * ctx) {
|
void llama_synchronize(llama_context * ctx) {
|
||||||
ctx->synchronize();
|
ctx->synchronize();
|
||||||
}
|
}
|
||||||
|
@ -64,6 +64,7 @@ struct llama_context {
|
|||||||
|
|
||||||
void set_embeddings (bool value);
|
void set_embeddings (bool value);
|
||||||
void set_causal_attn(bool value);
|
void set_causal_attn(bool value);
|
||||||
|
void set_warmup(bool value);
|
||||||
|
|
||||||
void set_adapter_lora(
|
void set_adapter_lora(
|
||||||
llama_adapter_lora * adapter,
|
llama_adapter_lora * adapter,
|
||||||
|
@ -29,6 +29,7 @@ struct llama_cparams {
|
|||||||
bool offload_kqv;
|
bool offload_kqv;
|
||||||
bool flash_attn;
|
bool flash_attn;
|
||||||
bool no_perf;
|
bool no_perf;
|
||||||
|
bool warmup;
|
||||||
|
|
||||||
enum llama_pooling_type pooling_type;
|
enum llama_pooling_type pooling_type;
|
||||||
|
|
||||||
|
@ -577,7 +577,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|||||||
n_embd_head_v (hparams.n_embd_head_v),
|
n_embd_head_v (hparams.n_embd_head_v),
|
||||||
n_embd_v_gqa (hparams.n_embd_v_gqa()),
|
n_embd_v_gqa (hparams.n_embd_v_gqa()),
|
||||||
n_expert (hparams.n_expert),
|
n_expert (hparams.n_expert),
|
||||||
n_expert_used (hparams.n_expert_used),
|
n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
|
||||||
freq_base (cparams.rope_freq_base),
|
freq_base (cparams.rope_freq_base),
|
||||||
freq_scale (cparams.rope_freq_scale),
|
freq_scale (cparams.rope_freq_scale),
|
||||||
ext_factor (cparams.yarn_ext_factor),
|
ext_factor (cparams.yarn_ext_factor),
|
||||||
|
Reference in New Issue
Block a user