diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 364efcaec..2f2fce067 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -362,6 +362,26 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { #endif // FP16_AVAILABLE } +// Row reduction kernel template - compute sum (norm=false) or mean (norm=true) +template +static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) { + const int row = blockIdx.x; + const int col = threadIdx.x; + + float sum = 0.0f; + for (int i = col; i < ncols; i += blockDim.x) { + sum += x[row * ncols + i]; + } + + sum = warp_reduce_sum(sum); + + if (col != 0) { + return; + } + + dst[row] = norm ? sum / ncols : sum; +} + template static __device__ __forceinline__ float warp_reduce_max(float x) { #pragma unroll diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5bab92e34..c6bdd4fb3 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -37,6 +37,7 @@ #include "ggml-cuda/ssm-scan.cuh" #include "ggml-cuda/sum.cuh" #include "ggml-cuda/sumrows.cuh" +#include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/unary.cuh" #include "ggml-cuda/upscale.cuh" @@ -2357,6 +2358,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SUM_ROWS: ggml_cuda_op_sum_rows(ctx, dst); break; + case GGML_OP_MEAN: + ggml_cuda_op_mean(ctx, dst); + break; case GGML_OP_SSM_CONV: ggml_cuda_op_ssm_conv(ctx, dst); break; @@ -3260,6 +3264,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_POOL_2D: case GGML_OP_SUM: case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: case GGML_OP_ARGSORT: case GGML_OP_ACC: return true; diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu new file mode 100644 index 000000000..4b238a399 --- /dev/null +++ b/ggml/src/ggml-cuda/mean.cu @@ -0,0 +1,19 @@ +#include "mean.cuh" + +void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + reduce_rows_f32<<>>(src0_d, dst_d, ncols); +} diff --git a/ggml/src/ggml-cuda/mean.cuh b/ggml/src/ggml-cuda/mean.cuh new file mode 100644 index 000000000..2b9b10433 --- /dev/null +++ b/ggml/src/ggml-cuda/mean.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 38dbf1b5e..2eee08fa0 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -1,25 +1,9 @@ #include "sumrows.cuh" -static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) { - const int row = blockIdx.x; - const int col = threadIdx.x; - - float sum = 0.0f; - for (int i = col; i < ncols; i += blockDim.x) { - sum += x[row * ncols + i]; - } - - sum = warp_reduce_sum(sum); - - if (col == 0) { - dst[row] = sum; - } -} - void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_nums(nrows, 1, 1); - k_sum_rows_f32<<>>(x, dst, ncols); + reduce_rows_f32<<>>(x, dst, ncols); } void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -35,5 +19,8 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t ncols = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); - sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream); + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + + reduce_rows_f32<<>>(src0_d, dst_d, ncols); } diff --git a/ggml/src/ggml-cuda/sumrows.cuh b/ggml/src/ggml-cuda/sumrows.cuh index 191db1c13..3431c599b 100644 --- a/ggml/src/ggml-cuda/sumrows.cuh +++ b/ggml/src/ggml-cuda/sumrows.cuh @@ -1,5 +1,4 @@ #include "common.cuh" void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream); - void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 772bee346..7be7f2205 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4652,6 +4652,8 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_conv_transpose_2d({256, 256, 256, 1}, {3, 3, 16, 256}, 1)); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1})); + return test_cases; }