From 15e92fd33791e60a4ddb5970b47242a855c27117 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 Aug 2025 17:13:05 +0300 Subject: [PATCH] cuda, sycl : fix batched gemm when ne02 == 1 && ne03 > 1 (#15038) * cuda, sycl : fix batched gemm when ne02 == 1 && ne03 > 1 ggml-ci * cont : fix cont types ggml-ci * cont : adopt variable names and comment from the other branch --- ggml/src/ggml-cuda/ggml-cuda.cu | 17 +++++++++++++---- ggml/src/ggml-sycl/ggml-sycl.cpp | 15 ++++++++++++--- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 517927946..8885fb7fb 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1852,6 +1852,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct ggml_cuda_pool_alloc src0_alloc(ctx.pool()); ggml_cuda_pool_alloc src1_alloc(ctx.pool()); + bool is_src0_cont_2 = ggml_is_contiguous_2(src0); + bool is_src1_cont_2 = ggml_is_contiguous_2(src1); + // Handle src0 src0_ptr = (const cuda_t *) src0->data; @@ -1870,6 +1873,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct s11 = ne10; s12 = ne11*s11; s13 = ne12*s12; + + is_src1_cont_2 = true; } // Setup destination buffer @@ -1918,15 +1923,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct const int64_t r2 = ne12/ne02; const int64_t r3 = ne13/ne03; - if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { + if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) { + // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3: + const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00; + const int64_t smb = ne12 == 1 ? s13 : s12; + // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx CUBLAS_CHECK( cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA - src1_ptr, cu_data_type_b, s11, s12, // strideB - beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC + alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA + src1_ptr, cu_data_type_b, s11, smb, // strideB + beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC ne12*ne13, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 2acdef98a..f68f1739a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2688,6 +2688,9 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons const size_t type_size_src0 = ggml_type_size(src0->type); const size_t type_size_src1 = ggml_type_size(src1->type); + bool is_src0_cont_2 = ggml_is_contiguous_2(src0); + bool is_src1_cont_2 = ggml_is_contiguous_2(src1); + // SRC1 strides int64_t s11 = nb11 / type_size_src1; int64_t s12 = nb12 / type_size_src1; @@ -2737,6 +2740,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons s11 = ne10; s12 = ne11 * s11; s13 = ne12 * s12; + + is_src1_cont_2 = true; } ggml_sycl_pool_alloc dst_f16(ctx.pool()); @@ -2852,12 +2857,16 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons else #endif { - if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { + if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) { + // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3: + const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00; + const int64_t smb = ne12 == 1 ? s13 : s12; + // there is no broadcast and src0, src1 are contiguous across dims 2, 3 SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, - src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, - src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf, + src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma, + src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf, mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); } else { const int ne23 = ne12 * ne13;