mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
metal : improve FA + improve MoE (#12612)
* ggml : FA with different K, V head sizes (CPU) ggml-ci * metal : add FA with HS=192 * metal : extend FA to support different K and V head sizes ggml-ci * metal : add FA vector kernels for heads K 192 and V 128 ggml-ci * ggml : restrict op on other backends to equal head sizes ggml-ci * metal : optimize FA-vec kernel ggml-ci * metal : FA remove mq registers * metal : improve MoE mul_mat_id condition ggml-ci * metal : fix comments + remove unnecessary addition ggml-ci * metal : avoid too much shared memory usage with mul_mat_id ggml-ci
This commit is contained in:
@ -12238,10 +12238,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int64_t D = neq0;
|
||||
const int64_t N = neq1;
|
||||
const int64_t DK = nek0;
|
||||
const int64_t DV = nev0;
|
||||
const int64_t N = neq1;
|
||||
|
||||
GGML_ASSERT(ne0 == D);
|
||||
GGML_ASSERT(ne0 == DV);
|
||||
GGML_ASSERT(ne2 == N);
|
||||
|
||||
// input tensor rows must be contiguous
|
||||
@ -12249,12 +12250,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
||||
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
||||
|
||||
GGML_ASSERT(neq0 == D);
|
||||
GGML_ASSERT(nek0 == D);
|
||||
GGML_ASSERT(nev0 == D);
|
||||
GGML_ASSERT(neq0 == DK);
|
||||
GGML_ASSERT(nek0 == DK);
|
||||
GGML_ASSERT(nev0 == DV);
|
||||
|
||||
GGML_ASSERT(neq1 == N);
|
||||
GGML_ASSERT(nev0 == D);
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
@ -12320,15 +12320,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
float S = 0.0f; // sum
|
||||
float M = -INFINITY; // maximum KQ value
|
||||
|
||||
float * VKQ32 = (float *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
|
||||
float * V32 = (VKQ32 + 1*D); // (temporary) FP32 V buffer
|
||||
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
|
||||
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
|
||||
float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
|
||||
float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
|
||||
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
|
||||
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
|
||||
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
|
||||
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
|
||||
} else {
|
||||
memset(VKQ32, 0, D*sizeof(float));
|
||||
memset(VKQ32, 0, DV*sizeof(float));
|
||||
}
|
||||
|
||||
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
||||
@ -12342,7 +12342,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
const int iv2 = iq2 / rv2;
|
||||
|
||||
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||
q_to_vec_dot(pq, Q_q, D);
|
||||
q_to_vec_dot(pq, Q_q, DK);
|
||||
|
||||
// online softmax / attention
|
||||
// loop over n_kv and n_head_kv
|
||||
@ -12356,7 +12356,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
float s; // KQ value
|
||||
|
||||
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
|
||||
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
|
||||
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
|
||||
|
||||
s = s*scale; // scale KQ value
|
||||
|
||||
@ -12380,14 +12380,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
ms = expf(Mold - M);
|
||||
|
||||
// V = V*expf(Mold - M)
|
||||
ggml_vec_scale_f16(D, VKQ16, ms);
|
||||
ggml_vec_scale_f16(DV, VKQ16, ms);
|
||||
} else {
|
||||
// no new maximum, ms == 1.0f, vs != 1.0f
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
// V += v*expf(s - M)
|
||||
ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
|
||||
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
|
||||
} else {
|
||||
if (s > M) {
|
||||
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
|
||||
@ -12395,30 +12395,30 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
||||
ms = expf(Mold - M);
|
||||
|
||||
// V = V*expf(Mold - M)
|
||||
ggml_vec_scale_f32(D, VKQ32, ms);
|
||||
ggml_vec_scale_f32(DV, VKQ32, ms);
|
||||
} else {
|
||||
// no new maximum, ms == 1.0f, vs != 1.0f
|
||||
vs = expf(s - M);
|
||||
}
|
||||
|
||||
v_to_float(v_data, V32, D);
|
||||
v_to_float(v_data, V32, DV);
|
||||
|
||||
// V += v*expf(s - M)
|
||||
ggml_vec_mad_f32(D, VKQ32, V32, vs);
|
||||
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
|
||||
}
|
||||
|
||||
S = S*ms + vs; // scale and increment sum with partial sum
|
||||
}
|
||||
|
||||
if (v->type == GGML_TYPE_F16) {
|
||||
for (int64_t d = 0; d < D; ++d) {
|
||||
for (int64_t d = 0; d < DV; ++d) {
|
||||
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
|
||||
}
|
||||
}
|
||||
|
||||
// V /= S
|
||||
const float S_inv = 1.0f/S;
|
||||
ggml_vec_scale_f32(D, VKQ32, S_inv);
|
||||
ggml_vec_scale_f32(DV, VKQ32, S_inv);
|
||||
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
@ -15277,7 +15277,6 @@ struct ggml_cplan ggml_graph_plan(
|
||||
size_t cur = 0;
|
||||
|
||||
if (!ggml_cpu_extra_work_size(n_threads, node, &cur)) {
|
||||
|
||||
switch (node->op) {
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
@ -15386,9 +15385,10 @@ struct ggml_cplan ggml_graph_plan(
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
const int64_t ne00 = node->src[0]->ne[0]; // D
|
||||
const int64_t ne10 = node->src[1]->ne[0]; // DK
|
||||
const int64_t ne20 = node->src[2]->ne[0]; // DV
|
||||
|
||||
cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
|
||||
cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
{
|
||||
|
Reference in New Issue
Block a user