mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-07 01:16:35 -04:00
ggml : add SSM Metal kernels (#8546)
* ggml : add ggml_ssm_conv metal impl * ggml : add ssm_scan metal impl ggml-ci
This commit is contained in:
@@ -82,6 +82,8 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
||||
@@ -542,6 +544,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
||||
@@ -803,6 +807,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||||
return false;
|
||||
}
|
||||
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
return true;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return ctx->support_simdgroup_reduction &&
|
||||
@@ -1538,6 +1545,121 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
{
|
||||
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
|
||||
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
|
||||
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
|
||||
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SSM_SCAN:
|
||||
{
|
||||
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
||||
struct ggml_tensor * src4 = gf->nodes[i]->src[4];
|
||||
struct ggml_tensor * src5 = gf->nodes[i]->src[5];
|
||||
|
||||
GGML_ASSERT(src3);
|
||||
GGML_ASSERT(src4);
|
||||
GGML_ASSERT(src5);
|
||||
|
||||
size_t offs_src3 = 0;
|
||||
size_t offs_src4 = 0;
|
||||
size_t offs_src5 = 0;
|
||||
|
||||
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
||||
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
|
||||
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
|
||||
|
||||
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
|
||||
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
|
||||
|
||||
const uint64_t nb30 = src3->nb[0];
|
||||
const uint64_t nb31 = src3->nb[1];
|
||||
|
||||
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
|
||||
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
|
||||
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
|
||||
|
||||
const uint64_t nb40 = src4->nb[0];
|
||||
const uint64_t nb41 = src4->nb[1];
|
||||
const uint64_t nb42 = src4->nb[2];
|
||||
|
||||
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
|
||||
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
|
||||
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
|
||||
|
||||
const uint64_t nb50 = src5->nb[0];
|
||||
const uint64_t nb51 = src5->nb[1];
|
||||
const uint64_t nb52 = src5->nb[2];
|
||||
|
||||
const int64_t d_state = ne00;
|
||||
const int64_t d_inner = ne01;
|
||||
const int64_t n_seq_tokens = ne11;
|
||||
const int64_t n_seqs = ne02;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
|
||||
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
|
||||
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
|
||||
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
|
||||
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
|
||||
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
|
||||
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
||||
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
|
||||
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
|
||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
|
||||
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
|
||||
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
|
||||
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
|
||||
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
|
||||
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
|
||||
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
|
||||
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
|
||||
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
GGML_ASSERT(ne00 == ne10);
|
||||
|
Reference in New Issue
Block a user