From 756aa1020aabc70b71b9a8efcd35fbf42b3bb9ab Mon Sep 17 00:00:00 2001 From: Slobodan Josic <127323561+slojosic-amd@users.noreply.github.com> Date: Fri, 11 Jul 2025 18:55:00 +0200 Subject: [PATCH] HIP : Add HIP 7.0+ compatibility for hipBLAS compute types (#14634) --- ggml/src/ggml-cuda/vendors/hip.h | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 1a28831b7..184d445f5 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -10,9 +10,6 @@ #include "rocblas/rocblas.h" #endif // __HIP_PLATFORM_AMD__ -#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F -#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F -#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N @@ -30,7 +27,6 @@ #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) -#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 #define cublasCreate hipblasCreate #define cublasDestroy hipblasDestroy #define cublasGemmEx hipblasGemmEx @@ -42,7 +38,6 @@ #define cublasSgemm hipblasSgemm #define cublasStatus_t hipblasStatus_t #define cublasOperation_t hipblasOperation_t -#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess @@ -144,6 +139,20 @@ #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000 +#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F +#define cublasComputeType_t hipblasComputeType_t +#define cudaDataType_t hipDataType +#else +#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F +#define cublasComputeType_t hipblasDatatype_t +#define cudaDataType_t hipblasDatatype_t +#endif + #define __CUDA_ARCH__ 1300 #if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)