mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 20:45:04 +00:00
graph : make FA compatible with MLA + add initial Metal kernels (#12953)
* graph : make mla compatible with FA * metal : add exp FA kernels for DeepSeek models ggml-ci * llama : minor naming updates ggml-ci * ggml : disable FA for DS head sizes * tests : add FA tests for MLA shapes ggml-ci
This commit is contained in:
@ -484,7 +484,7 @@ ggml_tensor * llama_context::build_rope_shift(
|
||||
|
||||
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
|
||||
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
||||
const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
|
||||
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
|
||||
|
||||
ggml_tensor * tmp;
|
||||
|
||||
@ -504,14 +504,14 @@ ggml_tensor * llama_context::build_rope_shift(
|
||||
|
||||
tmp = ggml_rope_ext_inplace(ctx0, tmp,
|
||||
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
|
||||
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
||||
|
||||
tmp = ggml_cpy(ctx0, tmp, cur);
|
||||
} else {
|
||||
// we rotate only the first n_rot dimensions
|
||||
tmp = ggml_rope_ext_inplace(ctx0, cur,
|
||||
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
|
||||
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
||||
}
|
||||
|
||||
return tmp;
|
||||
@ -2278,11 +2278,6 @@ llama_context * llama_init_from_model(
|
||||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
if (params.flash_attn && model->arch == LLM_ARCH_DEEPSEEK2) {
|
||||
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Deepseek2 - forcing off\n", __func__);
|
||||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
|
||||
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
||||
return nullptr;
|
||||
|
Reference in New Issue
Block a user