llama : fix FA when KV cache is not used (i.e. embeddings) (#12825)

* ggml : FA supports F32 V

* graph : cast KV to F16 when the KV cache is not used

ggml-ci

* server : add test that exercises embeddings with FA enabled

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-04-08 19:54:51 +03:00
committed by GitHub
parent 78a1ba0a4f
commit a19b5cef16
6 changed files with 59 additions and 6 deletions

View File

@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
@ -6818,10 +6818,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
vs = expf(s - M);
}
v_to_float(v_data, V32, DV);
// V += v*expf(s - M)
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
if (v_to_float) {
v_to_float(v_data, V32, DV);
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
} else {
// V is F32
ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
}
}
S = S*ms + vs; // scale and increment sum with partial sum