ggml : add GGML_PAD_REFLECT_1D operation (ggml/1034)

* ggml_pad_reflect_1d defined in header

* implemented on CPU

* called the forward pass

* impl Metal kernel

* added Metal kernel

* added OP_PAD_REFLECT_1D in test-backend-ops.cpp

* add test-pad-reflect-1d test case

* test case support multiple backend
This commit is contained in:
PAB
2024-12-03 20:20:04 +01:00
committed by Georgi Gerganov
parent d405804be8
commit c2082d93a8
6 changed files with 192 additions and 2 deletions

View File

@ -310,6 +310,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
GGML_METAL_KERNEL_TYPE_PAD_F32,
GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@ -877,6 +878,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
@ -1099,6 +1101,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_POOL_2D:
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
@ -3258,6 +3261,38 @@ static void ggml_metal_encode_node(
const int nth = MIN(1024, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_PAD_REFLECT_1D:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
const int32_t p0 = ((const int32_t *)(dst->op_params))[0];
const int32_t p1 = ((const int32_t *)(dst->op_params))[1];
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
[encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
[encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
const int nth = MIN(1024, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ARANGE:

View File

@ -2897,6 +2897,53 @@ kernel void kernel_pad_f32(
}
}
kernel void kernel_pad_reflect_1d_f32(
device const char * src0,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant int64_t & ne0,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
constant int32_t & p0,
constant int32_t & p1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3;
const int64_t i02 = i2;
const int64_t i01 = i1;
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
if (i0 < p0) {
dst_ptr[i0] = src0_ptr[p0 - i0];
} else if (i0 < ne0 - p1) {
dst_ptr[i0] = src0_ptr[i0 - p0];
} else {
dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
}
}
}
}
kernel void kernel_arange_f32(
device char * dst,
constant int64_t & ne0,