From a19b5cef16d885c44c635da4a5c97113c1577de8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 8 Apr 2025 19:54:51 +0300 Subject: [PATCH] llama : fix FA when KV cache is not used (i.e. embeddings) (#12825) * ggml : FA supports F32 V * graph : cast KV to F16 when the KV cache is not used ggml-ci * server : add test that exercises embeddings with FA enabled ggml-ci --- examples/server/tests/unit/test_embedding.py | 20 ++++++++++++++++++++ examples/server/tests/utils.py | 15 +++++++++++++++ examples/server_embd.py | 2 +- ggml/src/ggml-cpu/ops.cpp | 14 +++++++++----- ggml/src/ggml-metal/ggml-metal.m | 5 +++++ src/llama-graph.cpp | 9 +++++++++ 6 files changed, 59 insertions(+), 6 deletions(-) diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index 8b0eb42b0..0feb452cc 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -49,6 +49,26 @@ def test_embedding_multiple(): assert len(d['embedding']) > 1 +def test_embedding_multiple_with_fa(): + server = ServerPreset.bert_bge_small_with_fa() + server.pooling = 'last' + server.start() + # one of these should trigger the FA branch (i.e. context size % 256 == 0) + res = server.make_request("POST", "/v1/embeddings", data={ + "input": [ + "a "*253, + "b "*254, + "c "*255, + "d "*256, + ], + }) + assert res.status_code == 200 + assert len(res.body['data']) == 4 + for d in res.body['data']: + assert 'embedding' in d + assert len(d['embedding']) > 1 + + @pytest.mark.parametrize( "input,is_multi_prompt", [ diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 30aa86609..4dc2062a8 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -323,6 +323,21 @@ class ServerPreset: server.server_embeddings = True return server + @staticmethod + def bert_bge_small_with_fa() -> ServerProcess: + server = ServerProcess() + server.model_hf_repo = "ggml-org/models" + server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" + server.model_alias = "bert-bge-small" + server.n_ctx = 1024 + server.n_batch = 300 + server.n_ubatch = 300 + server.n_slots = 2 + server.fa = True + server.seed = 42 + server.server_embeddings = True + return server + @staticmethod def tinyllama_infill() -> ServerProcess: server = ServerProcess() diff --git a/examples/server_embd.py b/examples/server_embd.py index 0e34c6cea..f8b0ffecd 100644 --- a/examples/server_embd.py +++ b/examples/server_embd.py @@ -15,7 +15,7 @@ async def main(): model_url = "http://127.0.0.1:6900" responses: list[requests.Response] = await asyncio.gather(*[requests_post_async( url= f"{model_url}/embedding", - json= {"content": str(0)*1024} + json= {"content": "a "*1022} ) for i in range(n)]) for response in responses: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 7a8d5ac6f..f63656be5 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16( ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; - GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type"); - GGML_ASSERT(v_to_float && "fattn: unsupported V-type"); + GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type"); + GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type"); // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { @@ -6818,10 +6818,14 @@ static void ggml_compute_forward_flash_attn_ext_f16( vs = expf(s - M); } - v_to_float(v_data, V32, DV); - // V += v*expf(s - M) - ggml_vec_mad_f32(DV, VKQ32, V32, vs); + if (v_to_float) { + v_to_float(v_data, V32, DV); + ggml_vec_mad_f32(DV, VKQ32, V32, vs); + } else { + // V is F32 + ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs); + } } S = S*ms + vs; // scale and increment sum with partial sum diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 456e1fd99..f22682602 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1345,6 +1345,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: + if (op->src[0]->ne[0] == 32) { + // head size == 32 (e.g. bert-bge-small) + // TODO: not sure if it is worth adding kernels for this size + return false; + } if (op->src[1]->type != op->src[2]->type) { return false; } diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index c3469177e..cd955d63b 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1215,6 +1215,15 @@ ggml_tensor * llm_graph_context::build_attn_mha( v = ggml_transpose(ctx0, v); } + // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn) + if (k->type == GGML_TYPE_F32) { + k = ggml_cast(ctx0, k, GGML_TYPE_F16); + } + + if (v->type == GGML_TYPE_F32) { + v = ggml_cast(ctx0, v, GGML_TYPE_F16); + } + cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias, hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);