mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-29 13:43:38 -04:00
HIP : Add HIP 7.0+ compatibility for hipBLAS compute types (#14634)
This commit is contained in:
19
ggml/src/ggml-cuda/vendors/hip.h
vendored
19
ggml/src/ggml-cuda/vendors/hip.h
vendored
@@ -10,9 +10,6 @@
|
|||||||
#include "rocblas/rocblas.h"
|
#include "rocblas/rocblas.h"
|
||||||
#endif // __HIP_PLATFORM_AMD__
|
#endif // __HIP_PLATFORM_AMD__
|
||||||
|
|
||||||
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
|
||||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
|
||||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
|
||||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
||||||
#define CUBLAS_OP_N HIPBLAS_OP_N
|
#define CUBLAS_OP_N HIPBLAS_OP_N
|
||||||
@@ -30,7 +27,6 @@
|
|||||||
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
||||||
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
||||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||||
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
|
||||||
#define cublasCreate hipblasCreate
|
#define cublasCreate hipblasCreate
|
||||||
#define cublasDestroy hipblasDestroy
|
#define cublasDestroy hipblasDestroy
|
||||||
#define cublasGemmEx hipblasGemmEx
|
#define cublasGemmEx hipblasGemmEx
|
||||||
@@ -42,7 +38,6 @@
|
|||||||
#define cublasSgemm hipblasSgemm
|
#define cublasSgemm hipblasSgemm
|
||||||
#define cublasStatus_t hipblasStatus_t
|
#define cublasStatus_t hipblasStatus_t
|
||||||
#define cublasOperation_t hipblasOperation_t
|
#define cublasOperation_t hipblasOperation_t
|
||||||
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
|
|
||||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||||
@@ -144,6 +139,20 @@
|
|||||||
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
||||||
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
||||||
|
|
||||||
|
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000
|
||||||
|
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
|
||||||
|
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
|
||||||
|
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
|
||||||
|
#define cublasComputeType_t hipblasComputeType_t
|
||||||
|
#define cudaDataType_t hipDataType
|
||||||
|
#else
|
||||||
|
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
||||||
|
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||||
|
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||||
|
#define cublasComputeType_t hipblasDatatype_t
|
||||||
|
#define cudaDataType_t hipblasDatatype_t
|
||||||
|
#endif
|
||||||
|
|
||||||
#define __CUDA_ARCH__ 1300
|
#define __CUDA_ARCH__ 1300
|
||||||
|
|
||||||
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
|
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
|
||||||
|
Reference in New Issue
Block a user