diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 9436e6ea9..2450100b4 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -5763,19 +5763,31 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0; - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_long nb01 = src0->nb[1]; + const cl_long nb02 = src0->nb[2]; + const cl_long nb03 = src0->nb[3]; + + const int ne12 = src1 ? src1->ne[2] : 0; + const int ne13 = src1 ? src1->ne[3] : 0; + + const cl_long nb11 = src1 ? src1->nb[1] : 0; + const cl_long nb12 = src1 ? src1->nb[2] : 0; + const cl_long nb13 = src1 ? src1->nb[3] : 0; + + const cl_long nb1 = dst->nb[1]; + const cl_long nb2 = dst->nb[2]; + const cl_long nb3 = dst->nb[3]; float scale, max_bias; memcpy(&scale, dst->op_params + 0, sizeof(float)); memcpy(&max_bias, dst->op_params + 1, sizeof(float)); - const int nrows_x = ggml_nrows(src0); - const int nrows_y = src0->ne[1]; - - const int n_head = nrows_x/nrows_y; + const int n_head = src0->ne[2]; const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); @@ -5820,13 +5832,22 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2)); size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl index 62c05369a..a6d8ede67 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl @@ -22,32 +22,45 @@ REQD_SUBGROUP_SIZE_64 #endif kernel void kernel_soft_max_4_f16( - global float * src0, + global char * src0, ulong offset0, - global half * src1, + global char * src1, ulong offset1, - global float * dst, + global char * dst, ulong offsetd, int ne00, - int ne01, - int ne02, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + int ne13, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb1, + ulong nb2, + ulong nb3, float scale, float max_bias, float m0, float m1, int n_head_log2 ) { - src0 = (global float *)((global char *)src0 + offset0); - src1 = (global half *)((global char *)src1 + offset1); - dst = (global float *)((global char *)dst + offsetd); + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; int i03 = get_group_id(2); int i02 = get_group_id(1); int i01 = get_group_id(0); - global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; - global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + int i13 = i03%ne13; + int i12 = i02%ne12; + int i11 = i01; + + global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); + global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl index d562774ea..35b5573b4 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl @@ -22,32 +22,45 @@ REQD_SUBGROUP_SIZE_64 #endif kernel void kernel_soft_max_4( - global float * src0, + global char * src0, ulong offset0, - global float * src1, + global char * src1, ulong offset1, - global float * dst, + global char * dst, ulong offsetd, int ne00, - int ne01, - int ne02, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + int ne13, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb1, + ulong nb2, + ulong nb3, float scale, float max_bias, float m0, float m1, int n_head_log2 ) { - src0 = (global float*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; int i03 = get_group_id(2); int i02 = get_group_id(1); int i01 = get_group_id(0); - global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; - global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + int i13 = i03%ne13; + int i12 = i02%ne12; + int i11 = i01; + + global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); + global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; diff --git a/ggml/src/ggml-opencl/kernels/softmax_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_f16.cl index d38d09967..9d292b574 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_f16.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_f16.cl @@ -22,32 +22,45 @@ REQD_SUBGROUP_SIZE_64 #endif kernel void kernel_soft_max_f16( - global float * src0, + global char * src0, ulong offset0, - global half * src1, + global char * src1, ulong offset1, - global float * dst, + global char * dst, ulong offsetd, int ne00, - int ne01, - int ne02, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + int ne13, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb1, + ulong nb2, + ulong nb3, float scale, float max_bias, float m0, float m1, int n_head_log2 ) { - src0 = (global float *)((global char *)src0 + offset0); - src1 = (global half *)((global char *)src1 + offset1); - dst = (global float *)((global char *)dst + offsetd); + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; int i03 = get_group_id(2); int i02 = get_group_id(1); int i01 = get_group_id(0); - global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; - global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + int i13 = i03%ne13; + int i12 = i02%ne12; + int i11 = i01; + + global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); + global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f; diff --git a/ggml/src/ggml-opencl/kernels/softmax_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_f32.cl index 001b587ab..7c53dfbe5 100644 --- a/ggml/src/ggml-opencl/kernels/softmax_f32.cl +++ b/ggml/src/ggml-opencl/kernels/softmax_f32.cl @@ -22,32 +22,45 @@ REQD_SUBGROUP_SIZE_64 #endif kernel void kernel_soft_max( - global float * src0, + global char * src0, ulong offset0, - global float * src1, + global char * src1, ulong offset1, - global float * dst, + global char * dst, ulong offsetd, int ne00, - int ne01, - int ne02, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + int ne13, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb1, + ulong nb2, + ulong nb3, float scale, float max_bias, float m0, float m1, int n_head_log2 ) { - src0 = (global float*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; int i03 = get_group_id(2); int i02 = get_group_id(1); int i01 = get_group_id(0); - global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; - global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + int i13 = i03%ne13; + int i12 = i02%ne12; + int i11 = i01; + + global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03); + global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0; + global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3); float slope = 1.0f;