mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-12 06:09:18 +00:00
CUDA: add softmax broadcast (#14475)
* CUDA: add softmax broadcast * Pass by const ref * Review: Use blockDims for indexing, remove designated initializers * Add TODO for noncontigous input/output
This commit is contained in:
committed by
Georgi Gerganov
parent
12a81af45f
commit
55a1c5a5fd
@ -3329,13 +3329,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
// TODO: support batching
|
return true;
|
||||||
if (op->src[0]->ne[3] != 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// TODO: support broadcast
|
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
|
||||||
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
|
||||||
case GGML_OP_SOFT_MAX_BACK: {
|
case GGML_OP_SOFT_MAX_BACK: {
|
||||||
float max_bias = 0.0f;
|
float max_bias = 0.0f;
|
||||||
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
|
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
|
||||||
|
@ -13,6 +13,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
|||||||
return __half2float(val);
|
return __half2float(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct soft_max_params {
|
||||||
|
|
||||||
|
int64_t nheads;
|
||||||
|
uint32_t n_head_log2;
|
||||||
|
int64_t ncols;
|
||||||
|
int64_t nrows_x;
|
||||||
|
int64_t nrows_y;
|
||||||
|
int64_t ne00;
|
||||||
|
int64_t ne01;
|
||||||
|
int64_t ne02;
|
||||||
|
int64_t ne03;
|
||||||
|
int64_t nb11;
|
||||||
|
int64_t nb12;
|
||||||
|
int64_t nb13;
|
||||||
|
|
||||||
|
int64_t ne12;
|
||||||
|
int64_t ne13;
|
||||||
|
float scale;
|
||||||
|
float max_bias;
|
||||||
|
float m0;
|
||||||
|
float m1;
|
||||||
|
};
|
||||||
|
|
||||||
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
|
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
|
||||||
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
|
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
|
||||||
#ifdef __clang__
|
#ifdef __clang__
|
||||||
@ -21,16 +44,24 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
|||||||
#endif // __clang__
|
#endif // __clang__
|
||||||
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
||||||
static __global__ void soft_max_f32(
|
static __global__ void soft_max_f32(
|
||||||
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
|
const float * x, const T * mask, float * dst, const soft_max_params p) {
|
||||||
const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
|
||||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
const int rowx = blockIdx.x;
|
|
||||||
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
|
const int64_t i03 = blockIdx.z;
|
||||||
|
const int64_t i02 = blockIdx.y;
|
||||||
|
const int64_t i01 = blockIdx.x;
|
||||||
|
|
||||||
|
//TODO: noncontigous inputs/outputs
|
||||||
|
const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
|
||||||
|
|
||||||
|
const int64_t i11 = i01;
|
||||||
|
const int64_t i12 = i02 % p.ne12;
|
||||||
|
const int64_t i13 = i03 % p.ne13;
|
||||||
|
|
||||||
x += int64_t(rowx)*ncols;
|
x += int64_t(rowx)*ncols;
|
||||||
mask += int64_t(rowy)*ncols * (mask != nullptr);
|
mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
|
||||||
dst += int64_t(rowx)*ncols;
|
dst += int64_t(rowx)*ncols;
|
||||||
|
|
||||||
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
||||||
@ -38,7 +69,7 @@ static __global__ void soft_max_f32(
|
|||||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
|
||||||
const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
|
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
|
||||||
|
|
||||||
extern __shared__ float data_soft_max_f32[];
|
extern __shared__ float data_soft_max_f32[];
|
||||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||||
@ -55,7 +86,7 @@ static __global__ void soft_max_f32(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
|
const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
|
||||||
|
|
||||||
vals[col] = val;
|
vals[col] = val;
|
||||||
max_val = max(max_val, val);
|
max_val = max(max_val, val);
|
||||||
@ -151,63 +182,60 @@ static __global__ void soft_max_back_f32(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
|
const int64_t ncols_x = params.ncols;
|
||||||
|
|
||||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
const dim3 block_dims(nth, 1, 1);
|
const dim3 block_dims(nth, 1, 1);
|
||||||
const dim3 block_nums(nrows_x, 1, 1);
|
const dim3 block_nums(params.ne01, params.ne02, params.ne03);
|
||||||
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
||||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||||
|
|
||||||
const uint32_t n_head = nrows_x/nrows_y;
|
|
||||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
||||||
|
|
||||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
||||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
||||||
|
|
||||||
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
|
||||||
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
||||||
switch (ncols_x) {
|
switch (ncols_x) {
|
||||||
case 32:
|
case 32:
|
||||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
(x, mask, dst, params);
|
||||||
break;
|
break;
|
||||||
case 64:
|
case 64:
|
||||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
(x, mask, dst, params);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
(x, mask, dst, params);
|
||||||
break;
|
break;
|
||||||
case 256:
|
case 256:
|
||||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
(x, mask, dst, params);
|
||||||
break;
|
break;
|
||||||
case 512:
|
case 512:
|
||||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
|
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
(x, mask, dst, params);
|
||||||
break;
|
break;
|
||||||
case 1024:
|
case 1024:
|
||||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
(x, mask, dst, params);
|
||||||
break;
|
break;
|
||||||
case 2048:
|
case 2048:
|
||||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
(x, mask, dst, params);
|
||||||
break;
|
break;
|
||||||
case 4096:
|
case 4096:
|
||||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
(x, mask, dst, params);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
|
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||||
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
(x, mask, dst, params);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
||||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,10 +263,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
|
|
||||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
const int64_t nrows_y = src0->ne[1];
|
const int64_t nrows_y = src0->ne[1];
|
||||||
|
|
||||||
|
const int64_t ne00 = src0->ne[0];
|
||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
float max_bias = 0.0f;
|
float max_bias = 0.0f;
|
||||||
|
|
||||||
@ -247,10 +276,44 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
|
|
||||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
|
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
||||||
|
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
||||||
|
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
||||||
|
|
||||||
|
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
||||||
|
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
||||||
|
|
||||||
|
const uint32_t n_head = src0->ne[2];
|
||||||
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||||
|
|
||||||
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||||
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||||
|
|
||||||
|
|
||||||
|
soft_max_params params = {};
|
||||||
|
params.nheads = src0->ne[2];
|
||||||
|
params.n_head_log2 = n_head_log2;
|
||||||
|
params.ncols = ne00;
|
||||||
|
params.nrows_x = nrows_x;
|
||||||
|
params.nrows_y = nrows_y;
|
||||||
|
params.ne00 = src0->ne[0];
|
||||||
|
params.ne01 = src0->ne[1];
|
||||||
|
params.ne02 = src0->ne[2];
|
||||||
|
params.ne03 = src0->ne[3];
|
||||||
|
params.nb11 = nb11;
|
||||||
|
params.nb12 = nb12;
|
||||||
|
params.nb13 = nb13;
|
||||||
|
params.ne12 = ne12;
|
||||||
|
params.ne13 = ne13;
|
||||||
|
params.scale = scale;
|
||||||
|
params.max_bias = max_bias;
|
||||||
|
params.m0 = m0;
|
||||||
|
params.m1 = m1;
|
||||||
|
|
||||||
if (use_f16) {
|
if (use_f16) {
|
||||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
|
||||||
} else {
|
} else {
|
||||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
|
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user