metal : improve FA + improve MoE (#12612)

* ggml : FA with different K, V head sizes (CPU)

ggml-ci

* metal : add FA with HS=192

* metal : extend FA to support different K and V head sizes

ggml-ci

* metal : add FA vector kernels for heads K 192 and V 128

ggml-ci

* ggml : restrict op on other backends to equal head sizes

ggml-ci

* metal : optimize FA-vec kernel

ggml-ci

* metal : FA remove mq registers

* metal : improve MoE mul_mat_id condition

ggml-ci

* metal : fix comments + remove unnecessary addition

ggml-ci

* metal : avoid too much shared memory usage with mul_mat_id

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-03-28 20:21:59 +02:00
committed by GitHub
parent b86f600723
commit b4ae50810e
10 changed files with 913 additions and 706 deletions

View File

@ -12238,10 +12238,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int ith = params->ith;
const int nth = params->nth;
const int64_t D = neq0;
const int64_t N = neq1;
const int64_t DK = nek0;
const int64_t DV = nev0;
const int64_t N = neq1;
GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne0 == DV);
GGML_ASSERT(ne2 == N);
// input tensor rows must be contiguous
@ -12249,12 +12250,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
GGML_ASSERT(neq0 == D);
GGML_ASSERT(nek0 == D);
GGML_ASSERT(nev0 == D);
GGML_ASSERT(neq0 == DK);
GGML_ASSERT(nek0 == DK);
GGML_ASSERT(nev0 == DV);
GGML_ASSERT(neq1 == N);
GGML_ASSERT(nev0 == D);
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
@ -12320,15 +12320,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
if (v->type == GGML_TYPE_F16) {
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
} else {
memset(VKQ32, 0, D*sizeof(float));
memset(VKQ32, 0, DV*sizeof(float));
}
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
@ -12342,7 +12342,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int iv2 = iq2 / rv2;
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
q_to_vec_dot(pq, Q_q, D);
q_to_vec_dot(pq, Q_q, DK);
// online softmax / attention
// loop over n_kv and n_head_kv
@ -12356,7 +12356,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
float s; // KQ value
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
s = s*scale; // scale KQ value
@ -12380,14 +12380,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ms = expf(Mold - M);
// V = V*expf(Mold - M)
ggml_vec_scale_f16(D, VKQ16, ms);
ggml_vec_scale_f16(DV, VKQ16, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}
// V += v*expf(s - M)
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
} else {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
@ -12395,30 +12395,30 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ms = expf(Mold - M);
// V = V*expf(Mold - M)
ggml_vec_scale_f32(D, VKQ32, ms);
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}
v_to_float(v_data, V32, D);
v_to_float(v_data, V32, DV);
// V += v*expf(s - M)
ggml_vec_mad_f32(D, VKQ32, V32, vs);
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
}
S = S*ms + vs; // scale and increment sum with partial sum
}
if (v->type == GGML_TYPE_F16) {
for (int64_t d = 0; d < D; ++d) {
for (int64_t d = 0; d < DV; ++d) {
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
}
}
// V /= S
const float S_inv = 1.0f/S;
ggml_vec_scale_f32(D, VKQ32, S_inv);
ggml_vec_scale_f32(DV, VKQ32, S_inv);
// dst indices
const int i1 = iq1;
@ -15277,7 +15277,6 @@ struct ggml_cplan ggml_graph_plan(
size_t cur = 0;
if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
switch (node->op) {
case GGML_OP_CPY:
case GGML_OP_DUP:
@ -15386,9 +15385,10 @@ struct ggml_cplan ggml_graph_plan(
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
const int64_t ne00 = node->src[0]->ne[0]; // D
const int64_t ne10 = node->src[1]->ne[0]; // DK
const int64_t ne20 = node->src[2]->ne[0]; // DV
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
} break;
case GGML_OP_FLASH_ATTN_BACK:
{