From 5ba36f61033d2819be27e2abb70e9a3bd20c0fde Mon Sep 17 00:00:00 2001 From: uvos Date: Thu, 14 Aug 2025 16:23:56 +0200 Subject: [PATCH] HIP: Cleanup hipification header (#15285) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add expicit conversion operator to support older versions of rocm Switch over to hip_bf16 from legacy hip_bfloat16 Simplify RDNA3 define Reduce swap over of new hipblas api to rocm 6.5 as this version is used for rocm 7.0 previews --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/convert.cu | 6 +++--- ggml/src/ggml-cuda/convert.cuh | 13 +++++++++++++ ggml/src/ggml-cuda/cpy-utils.cuh | 12 ++---------- ggml/src/ggml-cuda/getrows.cu | 7 ++++--- ggml/src/ggml-cuda/mmvf.cu | 5 +++-- ggml/src/ggml-cuda/set-rows.cu | 9 +-------- ggml/src/ggml-cuda/vendors/hip.h | 13 ++++++------- 7 files changed, 32 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index e3beddbc1..8f0efdcc1 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -31,8 +31,8 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ dequantize_kernel(vx, ib, iqs, v); const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs; - y[iy0 + 0] = float(v.x); - y[iy0 + y_offset] = float(v.y); + y[iy0 + 0] = ggml_cuda_cast(v.x); + y[iy0 + y_offset] = ggml_cuda_cast(v.y); } template @@ -630,7 +630,7 @@ static __global__ void convert_unary( const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00; - y[iy] = float(x[ix]); + y[iy] = ggml_cuda_cast(x[ix]); } template diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index f04214be1..c62e8a1b1 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -29,3 +29,16 @@ typedef to_t_nc_cuda_t to_bf16_nc_cuda_t; to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type); to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type); to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type); + +template + __host__ __device__ inline dst_t ggml_cuda_cast(src_t x) { + if constexpr (std::is_same_v) { + return x; + } else if constexpr(std::is_same_v) { + return __float2bfloat16(float(x)); + } else if constexpr(std::is_same_v) { + return __bfloat162float(x); + } else { + return float(x); + } +} diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh index 410c12b7b..e621cb981 100644 --- a/ggml/src/ggml-cuda/cpy-utils.cuh +++ b/ggml/src/ggml-cuda/cpy-utils.cuh @@ -1,15 +1,7 @@ #pragma once #include "ggml-common.h" - -template -static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) { - if constexpr (std::is_same_v) { - *dst = *src; - } else { - *dst = float(*src); - } -} +#include "convert.cuh" static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { if (x <= val[0]) return 0; @@ -221,5 +213,5 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { template static __device__ void cpy_1_flt(const char * cxi, char * cdsti) { - convert_flt((const src_t *)cxi, (dst_t *)cdsti); + *(dst_t *) cdsti = ggml_cuda_cast(*(const src_t *) cxi); } diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index f77b2629a..68d3254fb 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -1,5 +1,6 @@ #include "getrows.cuh" #include "dequantize.cuh" +#include "convert.cuh" template static __global__ void k_get_rows( @@ -34,8 +35,8 @@ static __global__ void k_get_rows( dfloat2 v; dequantize_kernel(src0_row, ib, iqs, v); - dst_row[iybs + iqs + 0] = float(v.x); - dst_row[iybs + iqs + y_offset] = float(v.y); + dst_row[iybs + iqs + 0] = ggml_cuda_cast(v.x); + dst_row[iybs + iqs + y_offset] = ggml_cuda_cast(v.y); } template @@ -62,7 +63,7 @@ static __global__ void k_get_rows_float( dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); - dst_row[i00] = float(src0_row[i00]); + dst_row[i00] = ggml_cuda_cast(src0_row[i00]); } template diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 1ad4bc75b..16100b680 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -1,5 +1,6 @@ #include "ggml.h" #include "common.cuh" +#include "convert.cuh" #include "mmvf.cuh" template @@ -93,8 +94,8 @@ static __global__ void mul_mat_vec_f( #pragma unroll for (int j = 0; j < ncols_dst; ++j) { const float2 tmpy = y2[j*stride_col_y2 + col2]; - sumf[j] += float(reinterpret_cast(&tmpx)[0]) * tmpy.x; - sumf[j] += float(reinterpret_cast(&tmpx)[1]) * tmpy.y; + sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[0]) * tmpy.x; + sumf[j] += ggml_cuda_cast(reinterpret_cast(&tmpx)[1]) * tmpy.y; } } } else { diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 079834364..b4115a43c 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -3,11 +3,6 @@ typedef void (*set_rows_kernel_t)(const char * src, char * dst); -template -__device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) { - convert_flt(src_f, dst_f); -} - // Generic quantized set_rows kernel template template static __global__ void k_set_rows_quant( @@ -117,9 +112,7 @@ static __global__ void k_set_rows( const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; - const src_t* src_elem = src0_row + i00; - dst_t* dst_elem = dst_row_ptr + i00; - set_rows_1(src_elem, dst_elem); + dst_row_ptr[i00] = ggml_cuda_cast(src0_row[i00]); GGML_UNUSED(ne10); GGML_UNUSED(ne13); diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index ec1b59caa..6e9c67aca 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -4,7 +4,7 @@ #include #include #include -#include +#include #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT @@ -135,7 +135,7 @@ #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED -#if HIP_VERSION >= 70000000 +#if HIP_VERSION >= 60500000 #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F @@ -147,7 +147,7 @@ #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F #define cublasComputeType_t hipblasDatatype_t #define cudaDataType_t hipblasDatatype_t -#endif // HIP_VERSION >= 7000000 +#endif // HIP_VERSION >= 6050000 #if !defined(__HIP_PLATFORM_AMD__) #error "The HIP backend supports only AMD targets" @@ -179,8 +179,7 @@ #define RDNA4 #endif -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ - defined(__gfx1150__) || defined(__gfx1151__) +#if defined(__GFX11__) #define RDNA3 #endif @@ -197,8 +196,8 @@ #define __has_builtin(x) 0 #endif -typedef hip_bfloat16 nv_bfloat16; -typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));