metal : extend ggml_soft_max_ext() to support n_seq dim

This commit is contained in:
Georgi Gerganov
2025-06-24 20:00:40 +03:00
parent 401c13e3c3
commit 7c6487b22f
3 changed files with 7 additions and 6 deletions

View File

@ -454,6 +454,8 @@ typedef struct {
int64_t ne00; int64_t ne00;
int64_t ne01; int64_t ne01;
int64_t ne02; int64_t ne02;
uint64_t nb11;
uint64_t nb12;
float scale; float scale;
float max_bias; float max_bias;
float m0; float m0;

View File

@ -2562,10 +2562,7 @@ static bool ggml_metal_encode_node(
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale)); memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias)); memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
const int64_t nrows_x = ggml_nrows(src0); const uint32_t n_head = src0->ne[2];
const int64_t nrows_y = src0->ne[1];
const uint32_t n_head = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@ -2625,6 +2622,8 @@ static bool ggml_metal_encode_node(
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
/*.ne01 =*/ ne01, /*.ne01 =*/ ne01,
/*.ne02 =*/ ne02, /*.ne02 =*/ ne02,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.scale =*/ scale, /*.scale =*/ scale,
/*.max_bias =*/ max_bias, /*.max_bias =*/ max_bias,
/*.m0 =*/ m0, /*.m0 =*/ m0,

View File

@ -1263,7 +1263,7 @@ kernel void kernel_soft_max(
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr; device const T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr;
device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00); device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
float slope = 1.0f; float slope = 1.0f;
@ -1359,7 +1359,7 @@ kernel void kernel_soft_max_4(
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01); const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr; device const T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr;
device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4; device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
float slope = 1.0f; float slope = 1.0f;