From ad4a700117d1746799d2d6599e526e2c3a7938d2 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 30 Jul 2025 17:38:06 +0200 Subject: [PATCH] HIP: enable mfma mmq on gfx908 and gfx90a for select datatypes and shapes (#14949) --- ggml/src/ggml-cuda/common.cuh | 7 +++---- ggml/src/ggml-cuda/mmq.cu | 24 ++++++++++++++++++++---- ggml/src/ggml-cuda/mmq.cuh | 4 ++-- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 4560d85c4..7fb04d51b 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -227,9 +227,9 @@ typedef float2 dfloat2; #define FP16_MMA_AVAILABLE #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) -#if defined(GGML_USE_HIP) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA) +#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) #define AMD_MFMA_AVAILABLE -#endif // defined(GGML_USE_HIP) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA) +#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define NEW_MMA_AVAILABLE @@ -293,10 +293,9 @@ static bool fp32_mma_hardware_available(const int cc) { return GGML_CUDA_CC_IS_CDNA(cc); } -// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later. static bool amd_mfma_available(const int cc) { #if !defined(GGML_HIP_NO_MMQ_MFMA) - return GGML_CUDA_CC_IS_CDNA3(cc); + return GGML_CUDA_CC_IS_CDNA(cc); #else return false; #endif //!defined(GGML_HIP_NO_MMQ_MFMA) diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index e2fd0c1c2..d4954fbe6 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -109,8 +109,8 @@ void ggml_cuda_mul_mat_q( const int64_t s03 = src0->nb[3] / ts_src0; const int64_t s3 = dst->nb[3] / ts_dst; - const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) - || (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc))); + const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || GGML_CUDA_CC_IS_CDNA(cc); if (!ids) { const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + @@ -252,7 +252,7 @@ void ggml_cuda_op_mul_mat_q( // Also its fixup needs to allocate a temporary buffer in the memory pool. // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) - || (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc))) + || GGML_CUDA_CC_IS_CDNA(cc)) && src1_ncols == ne11; const mmq_args args = { src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i, @@ -306,7 +306,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - if (new_mma_available(cc) || amd_mfma_available(cc)) { + if (new_mma_available(cc)) { return true; } @@ -322,5 +322,21 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } + if (amd_mfma_available(cc)) { + // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT) + // performs better but is currently suffering from a crash on this architecture. + // TODO: Revisit when hipblaslt is fixed on CDNA3 + if (GGML_CUDA_CC_IS_CDNA3(cc)) { + return true; + } + if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) { + return true; + } + if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) { + return true; + } + return false; + } + return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index d8650dd70..dd60529fa 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3096,8 +3096,8 @@ static __global__ void mul_mat_q( } __syncthreads(); - // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: -#if (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA + // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: +#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { const int wt = blockIdx.z / nchannels_y; const int zt = blockIdx.z - wt*nchannels_y;