metal : add glu kernels

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-13 16:12:25 +03:00
committed by Akarshan
parent a341aa3c2b
commit d9ddeb9dfd
3 changed files with 116 additions and 0 deletions

View File

@ -422,6 +422,12 @@ typedef struct {
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
} ggml_metal_kargs_im2col; } ggml_metal_kargs_im2col;
typedef struct{
int32_t ne00;
uint64_t nb01;
uint64_t nb1;
} ggml_metal_kargs_glu;
typedef struct { typedef struct {
int64_t ne00; int64_t ne00;
int64_t ne01; int64_t ne01;

View File

@ -514,6 +514,9 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_SIN,
GGML_METAL_KERNEL_TYPE_COS, GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_NEG, GGML_METAL_KERNEL_TYPE_NEG,
GGML_METAL_KERNEL_TYPE_REGLU,
GGML_METAL_KERNEL_TYPE_GEGLU,
GGML_METAL_KERNEL_TYPE_SWIGLU,
GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_MEAN, GGML_METAL_KERNEL_TYPE_MEAN,
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
@ -1478,6 +1481,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
@ -1652,6 +1658,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
default: default:
return false; return false;
} }
case GGML_OP_GLU:
switch (ggml_get_glu_op(op)) {
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
default:
return false;
}
case GGML_OP_NONE: case GGML_OP_NONE:
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
case GGML_OP_VIEW: case GGML_OP_VIEW:
@ -2370,6 +2385,43 @@ static bool ggml_metal_encode_node(
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} break; } break;
case GGML_OP_GLU:
{
GGML_ASSERT(ggml_is_contiguous_1(src0));
id<MTLComputePipelineState> pipeline = nil;
switch (ggml_get_glu_op(node)) {
case GGML_GLU_OP_REGLU:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
break;
case GGML_GLU_OP_GEGLU:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
break;
case GGML_GLU_OP_SWIGLU:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
break;
default:
GGML_ABORT("fatal error");
}
ggml_metal_kargs_glu args = {
/*.ne00 =*/ ne00,
/*.nb01 =*/ nb01,
/*.nb1 =*/ nb1,
};
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&args length:sizeof(args) atIndex:2];
const int64_t nrows = ggml_nrows(src0);
const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_SQR: case GGML_OP_SQR:
{ {
GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src0));

View File

@ -993,6 +993,64 @@ kernel void kernel_neg(
dst[tpig] = -src0[tpig]; dst[tpig] = -src0[tpig];
} }
kernel void kernel_reglu(
device const char * src0,
device char * dst,
constant ggml_metal_kargs_glu & args,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01);
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) {
const float x0 = src_row[i00];
const float x1 = src_row[i00 + args.ne00/2];
dst_row[i00] = x0*x1*(x0 > 0.0f);
}
}
kernel void kernel_geglu(
device const char * src0,
device char * dst,
constant ggml_metal_kargs_glu & args,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01);
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) {
const float x0 = src_row[i00];
const float x1 = src_row[i00 + args.ne00/2];
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
dst_row[i00] = gelu*x1;
}
}
kernel void kernel_swiglu(
device const char * src0,
device char * dst,
constant ggml_metal_kargs_glu & args,
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint ntg[[threads_per_threadgroup]]) {
device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01);
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) {
const float x0 = src_row[i00];
const float x1 = src_row[i00 + args.ne00/2];
const float silu = x0 / (1.0f + exp(-x0));
dst_row[i00] = silu*x1;
}
}
template <bool norm> template <bool norm>
kernel void kernel_sum_rows( kernel void kernel_sum_rows(
constant ggml_metal_kargs_sum_rows & args, constant ggml_metal_kargs_sum_rows & args,