diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 20467c54d..231250efc 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -174,6 +174,7 @@ option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF) +option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON) option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF) option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index cdc3bb5ae..19fcc5982 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -227,7 +227,7 @@ typedef float2 dfloat2; #define FP16_MMA_AVAILABLE #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) -#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3) +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3) && !defined(GGML_HIP_NO_MMQ_MFMA) #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3) @@ -295,7 +295,11 @@ static bool fp32_mma_hardware_available(const int cc) { // AMD CDNA3 matrix cores.. Will add support for other CDNA generations later. static bool amd_mfma_available(const int cc) { - return cc >= GGML_CUDA_CC_OFFSET_AMD && GGML_CUDA_CC_IS_CDNA3(cc); +#if !defined(GGML_HIP_NO_MMQ_MFMA) + return GGML_CUDA_CC_IS_CDNA3(cc); +#else + return false; +#endif //!defined(GGML_HIP_NO_MMQ_MFMA) } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index e29df9856..e92ec7faa 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -113,6 +113,10 @@ if (GGML_HIP_ROCWMMA_FATTN) add_compile_definitions(GGML_HIP_ROCWMMA_FATTN) endif() +if (NOT GGML_HIP_MMQ_MFMA) + add_compile_definitions(GGML_HIP_NO_MMQ_MFMA) +endif() + if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0) add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12) endif()