mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
metal : add glu kernels
ggml-ci
This commit is contained in:
committed by
Akarshan
parent
a341aa3c2b
commit
d9ddeb9dfd
@ -422,6 +422,12 @@ typedef struct {
|
|||||||
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
||||||
} ggml_metal_kargs_im2col;
|
} ggml_metal_kargs_im2col;
|
||||||
|
|
||||||
|
typedef struct{
|
||||||
|
int32_t ne00;
|
||||||
|
uint64_t nb01;
|
||||||
|
uint64_t nb1;
|
||||||
|
} ggml_metal_kargs_glu;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int64_t ne00;
|
int64_t ne00;
|
||||||
int64_t ne01;
|
int64_t ne01;
|
||||||
|
@ -514,6 +514,9 @@ enum ggml_metal_kernel_type {
|
|||||||
GGML_METAL_KERNEL_TYPE_SIN,
|
GGML_METAL_KERNEL_TYPE_SIN,
|
||||||
GGML_METAL_KERNEL_TYPE_COS,
|
GGML_METAL_KERNEL_TYPE_COS,
|
||||||
GGML_METAL_KERNEL_TYPE_NEG,
|
GGML_METAL_KERNEL_TYPE_NEG,
|
||||||
|
GGML_METAL_KERNEL_TYPE_REGLU,
|
||||||
|
GGML_METAL_KERNEL_TYPE_GEGLU,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
||||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||||
GGML_METAL_KERNEL_TYPE_MEAN,
|
GGML_METAL_KERNEL_TYPE_MEAN,
|
||||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||||
@ -1478,6 +1481,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||||
@ -1652,6 +1658,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
switch (ggml_get_glu_op(op)) {
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
case GGML_OP_VIEW:
|
case GGML_OP_VIEW:
|
||||||
@ -2370,6 +2385,43 @@ static bool ggml_metal_encode_node(
|
|||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
switch (ggml_get_glu_op(node)) {
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_kargs_glu args = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb1 =*/ nb1,
|
||||||
|
};
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
|
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||||
|
|
||||||
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
|
const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
|
} break;
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
@ -993,6 +993,64 @@ kernel void kernel_neg(
|
|||||||
dst[tpig] = -src0[tpig];
|
dst[tpig] = -src0[tpig];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_reglu(
|
||||||
|
device const char * src0,
|
||||||
|
device char * dst,
|
||||||
|
constant ggml_metal_kargs_glu & args,
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
|
device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01);
|
||||||
|
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
||||||
|
|
||||||
|
for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) {
|
||||||
|
const float x0 = src_row[i00];
|
||||||
|
const float x1 = src_row[i00 + args.ne00/2];
|
||||||
|
|
||||||
|
dst_row[i00] = x0*x1*(x0 > 0.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_geglu(
|
||||||
|
device const char * src0,
|
||||||
|
device char * dst,
|
||||||
|
constant ggml_metal_kargs_glu & args,
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
|
device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01);
|
||||||
|
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
||||||
|
|
||||||
|
for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) {
|
||||||
|
const float x0 = src_row[i00];
|
||||||
|
const float x1 = src_row[i00 + args.ne00/2];
|
||||||
|
|
||||||
|
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
||||||
|
|
||||||
|
dst_row[i00] = gelu*x1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_swiglu(
|
||||||
|
device const char * src0,
|
||||||
|
device char * dst,
|
||||||
|
constant ggml_metal_kargs_glu & args,
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
|
device const float * src_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01);
|
||||||
|
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
||||||
|
|
||||||
|
for (int i00 = tpitg; i00 < args.ne00/2; i00 += ntg) {
|
||||||
|
const float x0 = src_row[i00];
|
||||||
|
const float x1 = src_row[i00 + args.ne00/2];
|
||||||
|
|
||||||
|
const float silu = x0 / (1.0f + exp(-x0));
|
||||||
|
|
||||||
|
dst_row[i00] = silu*x1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <bool norm>
|
template <bool norm>
|
||||||
kernel void kernel_sum_rows(
|
kernel void kernel_sum_rows(
|
||||||
constant ggml_metal_kargs_sum_rows & args,
|
constant ggml_metal_kargs_sum_rows & args,
|
||||||
|
Reference in New Issue
Block a user