mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 19:55:04 +00:00
metal : extend ggml_soft_max_ext() to support n_seq dim
This commit is contained in:
@ -454,6 +454,8 @@ typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
|
@ -2562,10 +2562,7 @@ static bool ggml_metal_encode_node(
|
||||
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
|
||||
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
||||
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
|
||||
const uint32_t n_head = nrows_x/nrows_y;
|
||||
const uint32_t n_head = src0->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
@ -2625,6 +2622,8 @@ static bool ggml_metal_encode_node(
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.scale =*/ scale,
|
||||
/*.max_bias =*/ max_bias,
|
||||
/*.m0 =*/ m0,
|
||||
|
@ -1263,7 +1263,7 @@ kernel void kernel_soft_max(
|
||||
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
||||
|
||||
device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr;
|
||||
device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
||||
|
||||
float slope = 1.0f;
|
||||
@ -1359,7 +1359,7 @@ kernel void kernel_soft_max_4(
|
||||
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) (src1 + i01*args.nb11 + i03*args.nb12) : nullptr;
|
||||
device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
||||
|
||||
float slope = 1.0f;
|
||||
|
Reference in New Issue
Block a user