From c0cfc2f78b17c41103db8dc09c1346783b1783d4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 22 Jun 2025 18:45:52 +0300 Subject: [PATCH] metal : add ggml_set_rows implementation ggml-ci --- ggml/src/ggml-metal/ggml-metal-impl.h | 16 + ggml/src/ggml-metal/ggml-metal.m | 110 +++++- ggml/src/ggml-metal/ggml-metal.metal | 469 ++++++++++++++++---------- 3 files changed, 411 insertions(+), 184 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 17eab976f..260440aed 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -521,6 +521,22 @@ typedef struct { uint64_t nb2; } ggml_metal_kargs_get_rows; +typedef struct { + int32_t nk0; + int32_t ne01; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_set_rows; + typedef struct { int64_t ne00; int64_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 19f4d59e5..fd138ef0e 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -202,6 +202,15 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, + GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, + GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, + GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, + GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, + GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, + GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, + GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, + GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, + GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, GGML_METAL_KERNEL_TYPE_RMS_NORM, GGML_METAL_KERNEL_TYPE_L2_NORM, GGML_METAL_KERNEL_TYPE_GROUP_NORM, @@ -1166,6 +1175,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction); @@ -1630,7 +1648,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex if (!use_bfloat) { for (size_t i = 0, n = 3; i < n; ++i) { - if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { + if (op->src[i] != NULL && (op->src[i]->type == GGML_TYPE_BF16 || op->type == GGML_TYPE_BF16)) { return false; } } @@ -1798,6 +1816,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex { return op->ne[3] == 1; } + case GGML_OP_SET_ROWS: + { + if (op->src[0]->type != GGML_TYPE_F32) { + return false; + } + + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_IQ4_NL: + return true; + default: + return false; + }; + } default: return false; } @@ -3757,13 +3796,74 @@ static bool ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; + case GGML_OP_SET_ROWS: + { + id pipeline = nil; + + switch (dst->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break; + default: GGML_ABORT("not implemented"); + } + + const int32_t nk0 = ne0/ggml_blck_size(dst->type); + + int nth = 32; // SIMD width + + while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + int nrptg = 1; + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; + + if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) { + nrptg--; + } + } + + nth = MIN(nth, nk0); + + ggml_metal_kargs_set_rows args = { + /*.nk0 =*/ nk0, + /*.ne01 =*/ ne01, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)]; + } break; case GGML_OP_RMS_NORM: { GGML_ASSERT(ne00 % 4 == 0); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3da19879b..d2dec8d29 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = { -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f }; +static inline int best_index_int8(int n, constant float * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + // NOTE: this is not dequantizing - we are simply fitting the template template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { @@ -97,6 +108,173 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r } } +void quantize_q4_0(device const float * src, device block_q4_0 & dst) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + dst.d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_0/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + dst.qs[j] = xi0; + dst.qs[j] |= xi1 << 4; + } +} + +void quantize_q4_1(device const float * src, device block_q4_1 & dst) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < QK4_1; j++) { + const float v = src[j]; + if (min > v) min = v; + if (max < v) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst.d = d; + dst.m = min; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK4_1/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + dst.qs[j] = xi0; + dst.qs[j] |= xi1 << 4; + } +} + +void quantize_q5_0(device const float * src, device block_q5_0 & dst) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK5_0; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + dst.d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK5_0/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + + for (int j = 0; j < 4; ++j) { + dst.qh[j] = qh8[j]; + } +} + +void quantize_q5_1(device const float * src, device block_q5_1 & dst) { + float max = src[0]; + float min = src[0]; + + for (int j = 1; j < QK5_1; j++) { + const float v = src[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + dst.d = d; + dst.m = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (src[0 + j] - min)*id; + const float x1 = (src[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + + thread const uint8_t * qh8 = (thread const uint8_t *)&qh; + + for (int j = 0; j < 4; ++j) { + dst.qh[j] = qh8[j]; + } +} + +void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < QK4_NL; j++) { + const float v = src[j]; + if (amax < fabs(v)) { + amax = fabs(v); + max = v; + } + } + + const float d = max / kvalues_iq4nl_f[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = src[0 + j]*id; + const float x1 = src[QK4_NL/2 + j]*id; + + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); + + dst.qs[j] = xi0 | (xi1 << 4); + + const float v0 = kvalues_iq4nl_f[xi0]; + const float v1 = kvalues_iq4nl_f[xi1]; + const float w0 = src[0 + j]*src[0 + j]; + const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; + sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + + } + + dst.d = sumq2 > 0 ? sumqx/sumq2 : d; +} + template void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 2); @@ -279,6 +457,26 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re } } +void quantize_q8_0(device const float * src, device block_q8_0 & dst) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = src[j]; + amax = MAX(amax, fabs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + dst.d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = src[j]*id; + + dst.qs[j] = round(x0); + } +} + template void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { const float d = xb->d; @@ -4341,6 +4539,7 @@ template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; #endif +// TODO: templetify these kernels kernel void kernel_cpy_f32_q8_0( constant ggml_metal_kargs_cpy & args, device const char * src0, @@ -4364,23 +4563,7 @@ kernel void kernel_cpy_f32_q8_0( for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) { device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - const float v = src[j]; - amax = MAX(amax, fabs(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK8_0].d = d; - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = src[j]*id; - - dst_data[i00/QK8_0].qs[j] = round(x0); - } + quantize_q8_0(src, dst_data[i00/QK8_0]); } } @@ -4407,32 +4590,7 @@ kernel void kernel_cpy_f32_q4_0( for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) { device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK4_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / -8; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK4_0].d = d; - - for (int j = 0; j < QK4_0/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK4_0/2 + j]*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); - - dst_data[i00/QK4_0].qs[j] = xi0; - dst_data[i00/QK4_0].qs[j] |= xi1 << 4; - } + quantize_q4_0(src, dst_data[i00/QK4_0]); } } @@ -4459,31 +4617,7 @@ kernel void kernel_cpy_f32_q4_1( for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) { device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int j = 0; j < QK4_1; j++) { - const float v = src[j]; - if (min > v) min = v; - if (max < v) max = v; - } - - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK4_1].d = d; - dst_data[i00/QK4_1].m = min; - - for (int j = 0; j < QK4_1/2; ++j) { - const float x0 = (src[0 + j] - min)*id; - const float x1 = (src[QK4_1/2 + j] - min)*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); - - dst_data[i00/QK4_1].qs[j] = xi0; - dst_data[i00/QK4_1].qs[j] |= xi1 << 4; - } + quantize_q4_1(src, dst_data[i00/QK4_1]); } } @@ -4510,38 +4644,7 @@ kernel void kernel_cpy_f32_q5_0( for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) { device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK5_0; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / -16; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK5_0].d = d; - - uint32_t qh = 0; - for (int j = 0; j < QK5_0/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK5_0/2 + j]*id; - - const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); - const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); - - dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); - } - thread const uint8_t * qh8 = (thread const uint8_t *)&qh; - for (int j = 0; j < 4; ++j) { - dst_data[i00/QK5_0].qh[j] = qh8[j]; - } + quantize_q5_0(src, dst_data[i00/QK5_0]); } } @@ -4568,51 +4671,10 @@ kernel void kernel_cpy_f32_q5_1( for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) { device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - float max = src[0]; - float min = src[0]; - - for (int j = 1; j < QK5_1; j++) { - const float v = src[j]; - min = v < min ? v : min; - max = v > max ? v : max; - } - - const float d = (max - min) / 31; - const float id = d ? 1.0f/d : 0.0f; - - dst_data[i00/QK5_1].d = d; - dst_data[i00/QK5_1].m = min; - - uint32_t qh = 0; - for (int j = 0; j < QK5_1/2; ++j) { - const float x0 = (src[0 + j] - min)*id; - const float x1 = (src[QK5_1/2 + j] - min)*id; - - const uint8_t xi0 = (uint8_t)(x0 + 0.5f); - const uint8_t xi1 = (uint8_t)(x1 + 0.5f); - - dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); - qh |= ((xi0 & 0x10u) >> 4) << (j + 0); - qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); - } - thread const uint8_t * qh8 = (thread const uint8_t *)&qh; - for (int j = 0; j < 4; ++j) { - dst_data[i00/QK5_1].qh[j] = qh8[j]; - } + quantize_q5_1(src, dst_data[i00/QK5_1]); } } -static inline int best_index_int8(int n, constant float * val, float x) { - if (x <= val[0]) return 0; - if (x >= val[n-1]) return n-1; - int ml = 0, mu = n-1; - while (mu-ml > 1) { - int mav = (ml+mu)/2; - if (x < val[mav]) mu = mav; else ml = mav; - } - return x - val[mu-1] < val[mu] - x ? mu-1 : mu; -} - kernel void kernel_cpy_f32_iq4_nl( constant ggml_metal_kargs_cpy & args, device const char * src0, @@ -4636,40 +4698,7 @@ kernel void kernel_cpy_f32_iq4_nl( for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) { device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < QK4_NL; j++) { - const float v = src[j]; - if (amax < fabs(v)) { - amax = fabs(v); - max = v; - } - } - - const float d = max / kvalues_iq4nl_f[0]; - const float id = d ? 1.0f/d : 0.0f; - - float sumqx = 0, sumq2 = 0; - for (int j = 0; j < QK4_NL/2; ++j) { - const float x0 = src[0 + j]*id; - const float x1 = src[QK4_NL/2 + j]*id; - - const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); - const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); - - dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4); - - const float v0 = kvalues_iq4nl_f[xi0]; - const float v1 = kvalues_iq4nl_f[xi1]; - const float w0 = src[0 + j]*src[0 + j]; - const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; - sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; - sumq2 += w0*v0*v0 + w1*v1*v1; - - } - - dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d; + quantize_iq4_nl(src, dst_data[i00/QK4_NL]); } } @@ -6350,10 +6379,10 @@ kernel void kernel_mul_mv_iq4_xs_f32( template kernel void kernel_get_rows_q( + constant ggml_metal_kargs_get_rows & args, device const void * src0, device const void * src1, device float * dst, - constant ggml_metal_kargs_get_rows & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { @@ -6373,10 +6402,10 @@ kernel void kernel_get_rows_q( template kernel void kernel_get_rows_f( + constant ggml_metal_kargs_get_rows & args, device const void * src0, device const void * src1, device float * dst, - constant ggml_metal_kargs_get_rows & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { @@ -6394,10 +6423,10 @@ kernel void kernel_get_rows_f( } kernel void kernel_get_rows_i32( + constant ggml_metal_kargs_get_rows & args, device const void * src0, device const void * src1, device int32_t * dst, - constant ggml_metal_kargs_get_rows & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint3 tptg [[threads_per_threadgroup]]) { @@ -6414,6 +6443,67 @@ kernel void kernel_get_rows_i32( } } +template +kernel void kernel_set_rows_q32( + constant ggml_metal_kargs_set_rows & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + + const int32_t i12 = i03%args.ne12; + const int32_t i11 = i02%args.ne11; + + const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x; + if (i01 >= args.ne01) { + return; + } + + const int32_t i10 = i01; + const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; + + device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); + const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); + + for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) { + quantize_func(src_row + 32*ind, dst_row[ind]); + } +} + +template +kernel void kernel_set_rows_f( + constant ggml_metal_kargs_set_rows & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg [[threads_per_threadgroup]]) { + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + + const int32_t i12 = i03%args.ne12; + const int32_t i11 = i02%args.ne11; + + const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x; + if (i01 >= args.ne01) { + return; + } + + const int32_t i10 = i01; + const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; + + device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); + const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); + + for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) { + dst_row[ind] = (T) src_row[ind]; + } +} #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B @@ -6837,6 +6927,27 @@ template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q; +// +// set rows +// + +typedef decltype(kernel_set_rows_f) set_rows_f_t; + +template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f; +template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f; +#endif + +typedef decltype(kernel_set_rows_q32) set_rows_q32_t; + +template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32; +template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32; + // // matrix-matrix multiplication //