llama : fix FA when KV cache is not used (i.e. embeddings) (#12825)

* ggml : FA supports F32 V

* graph : cast KV to F16 when the KV cache is not used

ggml-ci

* server : add test that exercises embeddings with FA enabled

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-04-08 19:54:51 +03:00
committed by GitHub
parent 78a1ba0a4f
commit a19b5cef16
6 changed files with 59 additions and 6 deletions

View File

@ -1215,6 +1215,15 @@ ggml_tensor * llm_graph_context::build_attn_mha(
v = ggml_transpose(ctx0, v);
}
// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
if (k->type == GGML_TYPE_F32) {
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
}
if (v->type == GGML_TYPE_F32) {
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
}
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);