diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 445aac017..c27150b9a 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -454,6 +454,8 @@ typedef struct { int64_t ne00; int64_t ne01; int64_t ne02; + uint64_t nb11; + uint64_t nb12; float scale; float max_bias; float m0; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 946ebc509..ef70b35e7 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2562,10 +2562,7 @@ static bool ggml_metal_encode_node( memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); - const int64_t nrows_x = ggml_nrows(src0); - const int64_t nrows_y = src0->ne[1]; - - const uint32_t n_head = nrows_x/nrows_y; + const uint32_t n_head = src0->ne[2]; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); @@ -2625,6 +2622,8 @@ static bool ggml_metal_encode_node( /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, /*.scale =*/ scale, /*.max_bias =*/ max_bias, /*.m0 =*/ m0, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 6942df7cf..a633fb4b1 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1263,7 +1263,7 @@ kernel void kernel_soft_max( const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr; + device const T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr; device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); float slope = 1.0f; @@ -1359,7 +1359,7 @@ kernel void kernel_soft_max_4( const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; - device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr; + device const T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr; device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; float slope = 1.0f;