From d67341dc18fc5cc63362880ab2f8f9ecfc7932e7 Mon Sep 17 00:00:00 2001 From: aa956 Date: Thu, 19 Jun 2025 16:01:03 +0300 Subject: [PATCH] server : add server parameters for draft model cache type (#13782) Co-authored-by: aa956 <27946957+aa956@users.noreply.github.com> --- common/arg.cpp | 26 ++++++++++++++++++++++++++ common/common.h | 3 +++ tools/server/README.md | 2 ++ tools/server/server.cpp | 6 ++---- 4 files changed, 33 insertions(+), 4 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 231de227a..3dfaa71ef 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3210,6 +3210,32 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.model.path = value; } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT")); + add_opt(common_arg( + {"-ctkd", "--cache-type-k-draft"}, "TYPE", + string_format( + "KV cache data type for K for the draft model\n" + "allowed values: %s\n" + "(default: %s)", + get_all_kv_cache_types().c_str(), + ggml_type_name(params.speculative.cache_type_k) + ), + [](common_params & params, const std::string & value) { + params.speculative.cache_type_k = kv_cache_type_from_str(value); + } + ).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT")); + add_opt(common_arg( + {"-ctvd", "--cache-type-v-draft"}, "TYPE", + string_format( + "KV cache data type for V for the draft model\n" + "allowed values: %s\n" + "(default: %s)", + get_all_kv_cache_types().c_str(), + ggml_type_name(params.speculative.cache_type_v) + ), + [](common_params & params, const std::string & value) { + params.speculative.cache_type_v = kv_cache_type_from_str(value); + } + ).set_env("LLAMA_ARG_CACHE_TYPE_V_DRAFT")); add_opt(common_arg( {"-mv", "--model-vocoder"}, "FNAME", diff --git a/common/common.h b/common/common.h index 00b6ca03a..5710c4e97 100644 --- a/common/common.h +++ b/common/common.h @@ -199,6 +199,9 @@ struct common_params_speculative { float p_split = 0.1f; // speculative decoding split probability float p_min = 0.75f; // minimum speculative decoding probability (greedy) + ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K + ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V + struct cpu_params cpuparams; struct cpu_params cpuparams_batch; diff --git a/tools/server/README.md b/tools/server/README.md index 06533c172..43aa65d50 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -187,6 +187,8 @@ The project is under active development, and we are [looking for feedback and co | `-devd, --device-draft ` | comma-separated list of devices to use for offloading the draft model (none = don't offload)
use --list-devices to see a list of available devices | | `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_MODEL_DRAFT) | +| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for speculative decoding model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) | +| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for speculative decoding model
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) | | `-mv, --model-vocoder FNAME` | vocoder model for audio generation (default: unused) | | `--tts-use-guide-tokens` | Use guide tokens to improve TTS word recall | | `--embd-bge-small-en-default` | use default bge-small-en-v1.5 model (note: can download weights from the internet) | diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 721d09182..9d55b3338 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1969,10 +1969,8 @@ struct server_context { params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx; params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_parallel = 1; - - // force F16 KV cache for the draft model for extra performance - params_dft.cache_type_k = GGML_TYPE_F16; - params_dft.cache_type_v = GGML_TYPE_F16; + params_dft.cache_type_k = params_base.speculative.cache_type_k; + params_dft.cache_type_v = params_base.speculative.cache_type_v; llama_init_dft = common_init_from_params(params_dft);