From d9ddeb9dfd6fc070bfac019cce879ff67056aea3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 13 Jun 2025 16:12:25 +0300 Subject: [PATCH] metal : add glu kernels ggml-ci --- ggml/src/ggml-metal/ggml-metal-impl.h | 6 +++ ggml/src/ggml-metal/ggml-metal.m | 52 ++++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal.metal | 58 +++++++++++++++++++++++++++ 3 files changed, 116 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 17eab976f..ec9069c52 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -422,6 +422,12 @@ typedef struct { int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources } ggml_metal_kargs_im2col; +typedef struct{ + int32_t ne00; + uint64_t nb01; + uint64_t nb1; +} ggml_metal_kargs_glu; + typedef struct { int64_t ne00; int64_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 19f4d59e5..cd1ff2844 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -514,6 +514,9 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_NEG, + GGML_METAL_KERNEL_TYPE_REGLU, + GGML_METAL_KERNEL_TYPE_GEGLU, + GGML_METAL_KERNEL_TYPE_SWIGLU, GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_MEAN, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, @@ -1478,6 +1481,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); @@ -1652,6 +1658,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex default: return false; } + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + default: + return false; + } case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -2370,6 +2385,43 @@ static bool ggml_metal_encode_node( GGML_ABORT("fatal error"); } } break; + case GGML_OP_GLU: + { + GGML_ASSERT(ggml_is_contiguous_1(src0)); + + id pipeline = nil; + + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_REGLU: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline; + break; + case GGML_GLU_OP_GEGLU: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline; + break; + case GGML_GLU_OP_SWIGLU: + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline; + break; + default: + GGML_ABORT("fatal error"); + } + + ggml_metal_kargs_glu args = { + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.nb1 =*/ nb1, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:2]; + + const int64_t nrows = ggml_nrows(src0); + + const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2); + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case GGML_OP_SQR: { GGML_ASSERT(ggml_is_contiguous(src0)); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3da19879b..4154e5054 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -993,6 +993,64 @@ kernel void kernel_neg( dst[tpig] = -src0[tpig]; } +kernel void kernel_reglu( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01); + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) { + const float x0 = src_row[i00]; + const float x1 = src_row[i00 + args.ne00/2]; + + dst_row[i00] = x0*x1*(x0 > 0.0f); + } +} + +kernel void kernel_geglu( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01); + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) { + const float x0 = src_row[i00]; + const float x1 = src_row[i00 + args.ne00/2]; + + const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0))); + + dst_row[i00] = gelu*x1; + } +} + +kernel void kernel_swiglu( + device const char * src0, + device char * dst, + constant ggml_metal_kargs_glu & args, + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint ntg[[threads_per_threadgroup]]) { + device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01); + device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + + for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) { + const float x0 = src_row[i00]; + const float x1 = src_row[i00 + args.ne00/2]; + + const float silu = x0 / (1.0f + exp(-x0)); + + dst_row[i00] = silu*x1; + } +} + template kernel void kernel_sum_rows( constant ggml_metal_kargs_sum_rows & args,