From 1d72c841888b9450916bdd5a9b3274da380f5b36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 7 Aug 2025 10:53:21 +0200 Subject: [PATCH] CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16 (#15131) * CUDA: GEMM for FP32/FP16/BF16 and ne11 <= 16 --- ggml/src/ggml-cuda/common.cuh | 12 +- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 12 +- ggml/src/ggml-cuda/fattn.cu | 4 +- ggml/src/ggml-cuda/ggml-cuda.cu | 31 +- ggml/src/ggml-cuda/mma.cuh | 110 ++++-- ggml/src/ggml-cuda/mmf.cu | 431 +++++++++++++++++++++++ ggml/src/ggml-cuda/mmf.cuh | 5 + ggml/src/ggml-cuda/mmq.cu | 2 +- ggml/src/ggml-cuda/mmq.cuh | 264 +++++++------- ggml/src/ggml-cuda/{mmv.cu => mmvf.cu} | 94 ++--- ggml/src/ggml-cuda/{mmv.cuh => mmvf.cuh} | 6 +- ggml/src/ggml-cuda/vendors/hip.h | 1 + ggml/src/ggml-cuda/vendors/musa.h | 3 +- 13 files changed, 750 insertions(+), 225 deletions(-) create mode 100644 ggml/src/ggml-cuda/mmf.cu create mode 100644 ggml/src/ggml-cuda/mmf.cuh rename ggml/src/ggml-cuda/{mmv.cu => mmvf.cu} (86%) rename ggml/src/ggml-cuda/{mmv.cuh => mmvf.cuh} (55%) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8f2725547..2e5d48797 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -233,9 +233,13 @@ typedef float2 dfloat2; #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING -#define NEW_MMA_AVAILABLE +#define TURING_MMA_AVAILABLE #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define AMPERE_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #define CP_ASYNC_AVAILABLE #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE @@ -303,10 +307,14 @@ static bool amd_mfma_available(const int cc) { } // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. -static bool new_mma_available(const int cc) { +static bool turing_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; } +static bool ampere_mma_available(const int cc) { + return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; +} + static bool cp_async_available(const int cc) { return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index e7570f9d3..371253844 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -418,7 +418,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( float * const __restrict__ KQ_max, float * const __restrict__ KQ_rowsum, const int kb0) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE typedef fattn_mma_f16_config c; #ifdef CP_ASYNC_AVAILABLE @@ -776,7 +776,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); GGML_UNUSED(kb0); GGML_UNUSED(tile_Q); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } template @@ -800,7 +800,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jt, const int kb0_start, const int kb0_stop) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. typedef fattn_mma_f16_config c; @@ -1196,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } template @@ -1223,7 +1223,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) +#if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { @@ -1354,7 +1354,7 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) +#endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE) } template diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 656e04a47..8ddd0415b 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -327,7 +327,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192); - const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion && + const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion && (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000); const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { @@ -340,7 +340,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: - if (fp16_mma_available(cc) && !new_mma_available(cc)) { + if (fp16_mma_available(cc) && !turing_mma_available(cc)) { ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); return; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 60e481b95..ec7ab2551 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -22,8 +22,9 @@ #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" +#include "ggml-cuda/mmf.cuh" #include "ggml-cuda/mmq.cuh" -#include "ggml-cuda/mmv.cuh" +#include "ggml-cuda/mmvf.cuh" #include "ggml-cuda/mmvq.cuh" #include "ggml-cuda/norm.cuh" #include "ggml-cuda/opt-step-adamw.cuh" @@ -2008,7 +2009,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src; - bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) + bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + bool use_mul_mat_f = !ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 @@ -2028,14 +2031,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = ggml_cuda_info().devices[id].warp_size; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]); + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } } else { const int cc = ggml_cuda_info().devices[ctx.device].cc; + const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - use_mul_mat_vec = use_mul_mat_vec && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src1->ne[1]); + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } @@ -2048,15 +2055,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); //TODO update for generic tensor parallelism - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16); bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc); bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32; - if (!split && use_mul_mat_vec) { + if (!split && use_mul_mat_vec_f) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) - ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst); + ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst); + } else if (!split && use_mul_mat_f) { + ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst); } else if (!split && use_mul_mat_vec_q) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst); } else if (!split && use_mul_mat_q) { @@ -2065,8 +2074,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // general KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); - } else if (use_mul_mat_vec) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr); + } else if (use_mul_mat_vec_f) { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr); } else if (use_mul_mat_vec_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { @@ -2094,7 +2103,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * if (ggml_is_quantized(src0->type)) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); } else { - ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst); + ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); } return; } @@ -3516,7 +3525,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g #endif // FLASH_ATTN_AVAILABLE if (op->src[1]->ne[0] != op->src[2]->ne[0]) { const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; - if (!new_mma_available(cc)) { + if (!turing_mma_available(cc)) { return false; } const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index a86365c6a..83ee16b27 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -23,13 +23,13 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { int ret = 0; -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(ret) : "r"(x)); #else GGML_UNUSED(x); NO_DEVICE_CODE; -#endif // defined(NEW_MMA_AVAILABLE) +#endif // defined(TURING_MMA_AVAILABLE) return ret; } @@ -167,6 +167,38 @@ namespace ggml_cuda_mma { } }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr int ne = I * J / WARP_SIZE; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 8) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 4) { + return l * 8 + threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return (l % 2) * 8 + threadIdx.x / 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 8) { + return l * 4 + threadIdx.x % 4; + } else if constexpr (I == 16 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return (l / 2) * 4 + threadIdx.x % 4; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + }; + template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { tile ret; @@ -209,7 +241,7 @@ namespace ggml_cuda_mma { template static __device__ __forceinline__ void load_ldmatrix( tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J; asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" @@ -217,13 +249,13 @@ namespace ggml_cuda_mma { : "l"(xs)); #else load_generic(t, xs0, stride); -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } template static __device__ __forceinline__ void load_ldmatrix( tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" @@ -232,13 +264,13 @@ namespace ggml_cuda_mma { #else load_generic(xs0, stride); GGML_UNUSED(t); -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } template static __device__ __forceinline__ void load_ldmatrix( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { -#if defined(NEW_MMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) int * xi = (int * ) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" @@ -246,13 +278,13 @@ namespace ggml_cuda_mma { : "l"(xs)); #else load_generic(t, xs0, stride); -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } template static __device__ __forceinline__ void load_ldmatrix_trans( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE int * xi = (int * ) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" @@ -263,12 +295,12 @@ namespace ggml_cuda_mma { GGML_UNUSED(xs0); GGML_UNUSED(stride); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } static __device__ __forceinline__ void mma( tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) @@ -287,12 +319,12 @@ namespace ggml_cuda_mma { GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } static __device__ __forceinline__ void mma( tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) @@ -317,12 +349,12 @@ namespace ggml_cuda_mma { GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } static __device__ __forceinline__ void mma( tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -344,12 +376,12 @@ namespace ggml_cuda_mma { GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } static __device__ __forceinline__ void mma( tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -380,12 +412,29 @@ namespace ggml_cuda_mma { GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) { +#ifdef AMPERE_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // AMPERE_MMA_AVAILABLE } static __device__ __forceinline__ void mma( tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -407,12 +456,29 @@ namespace ggml_cuda_mma { GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) { +#ifdef AMPERE_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // AMPERE_MMA_AVAILABLE } static __device__ __forceinline__ void mma( tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { -#ifdef NEW_MMA_AVAILABLE +#ifdef TURING_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -443,7 +509,7 @@ namespace ggml_cuda_mma { GGML_UNUSED(A); GGML_UNUSED(B); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // TURING_MMA_AVAILABLE } static __device__ __forceinline__ void mma( diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu new file mode 100644 index 000000000..1437367e8 --- /dev/null +++ b/ggml/src/ggml-cuda/mmf.cu @@ -0,0 +1,431 @@ +#include "ggml.h" +#include "common.cuh" +#include "mma.cuh" +#include "mmf.cuh" + +using namespace ggml_cuda_mma; + +#define MMF_ROWS_PER_BLOCK 32 + +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f( + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int tile_k_padded = warp_size + 4; + constexpr int ntA = rows_per_block / tile_A::I; + constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; + + const int row0 = blockIdx.x * rows_per_block; + const int channel_dst = blockIdx.y; + const int channel_x = channel_dst / channel_ratio; + const int channel_y = channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; + y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; + dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + + tile_C C[ntA][ntB]; + + T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded); + + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { + tile_A A[ntA][warp_size / tile_A::J]; +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int i = 0; i < tile_A::I; ++i) { + tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { + load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); + } + } + +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; + } + } else if constexpr (std::is_same_v || std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + } + } + + float * buf_iw = (float *) data_mmv; + constexpr int kiw = nwarps*rows_per_block + 4; + + if (nwarps > 1) { + __syncthreads(); + } +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); + const int j = itB*tile_C::J + tile_C::get_j(l); + buf_iw[j*kiw + i] = C[itA][itB].x[l]; + } + } + } + + if (nwarps > 1) { + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > cols_per_block && j >= cols_per_block) { + return; + } + + float sum = 0.0f; + static_assert(rows_per_block == warp_size, "need loop/check"); +#pragma unroll + for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { + const int i = i0 + threadIdx.x; + + sum += buf_iw[j*kiw + i]; + } + dst[j*stride_col_dst + row0 + threadIdx.x] = sum; + } +#else + NO_DEVICE_CODE; + GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(ids); GGML_UNUSED(dst); + GGML_UNUSED(ncols); GGML_UNUSED(nchannels_y); GGML_UNUSED(stride_row); GGML_UNUSED(stride_col_y); GGML_UNUSED(stride_col_dst); + GGML_UNUSED(channel_ratio); GGML_UNUSED(stride_channel_x); GGML_UNUSED(stride_channel_y); GGML_UNUSED(stride_channel_dst); + GGML_UNUSED(sample_ratio); GGML_UNUSED(stride_sample_x); GGML_UNUSED(stride_sample_y); GGML_UNUSED(stride_sample_dst); +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +} + +template +static void mul_mat_f_cuda( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; + + GGML_ASSERT(!ids && "mul_mat_id not implemented"); + + GGML_ASSERT(ncols_x % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(stride_col_y % 2 == 0); + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const int64_t channel_ratio = nchannels_dst / nchannels_x; + const int64_t sample_ratio = nsamples_dst / nsamples_x; + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t nwarps_best = 1; + int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); + int64_t max_block_size = 256; + for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { + const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); + if (niter < niter_best) { + niter_best = niter; + nwarps_best = nwarps; + } + } + + constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; + const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); + const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst); + const dim3 block_dims(warp_size, nwarps_best, 1); + switch (nwarps_best) { + case 1: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 2: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 3: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 4: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 5: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 6: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 7: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 8: { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +template +static void mul_mat_f_switch_cols_per_block( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + switch (ncols_dst) { + case 1: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 2: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 3: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 4: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 5: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 6: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 7: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 8: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 9: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 10: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 11: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 12: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 13: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 14: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 15: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 16: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT(ne13 == ne3); + + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + GGML_ASSERT( nb0 == ts_dst); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + const float * src1_d = (const float *) src1->data; + const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; + float * dst_d = (float *) dst->data; + + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s13 = src1->nb[3] / ts_src1; + const int64_t s3 = dst->nb[3] / ts_dst; + + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: + const int64_t ncols_dst = ids ? ne2 : ne1; + const int64_t nchannels_y = ids ? ne11 : ne12; + const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_channel_dst = ids ? s1 : s2; + const int64_t stride_channel_y = ids ? s11 : s12; + + GGML_ASSERT(!ids || ncols_dst == 1); + + switch (src0->type) { + case GGML_TYPE_F32: { + const float * src0_d = (const float *) src0->data; + constexpr int vals_per_T = 1; + mul_mat_f_switch_cols_per_block( + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, + ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + } break; + case GGML_TYPE_F16: { + const half2 * src0_d = (const half2 *) src0->data; + constexpr int vals_per_T = 2; + mul_mat_f_switch_cols_per_block( + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, + ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + } break; + case GGML_TYPE_BF16: { + const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; + constexpr int vals_per_T = 2; + mul_mat_f_switch_cols_per_block( + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, + ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + } break; + default: + GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); + } +} + +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) { + if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) { + return false; + } + if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { + return false; + } + if (ne11 > 16) { + return false; + } + switch (type) { + case GGML_TYPE_F32: + return ampere_mma_available(cc); + case GGML_TYPE_F16: + return turing_mma_available(cc); + case GGML_TYPE_BF16: + return ampere_mma_available(cc); + default: + return false; + } +} diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh new file mode 100644 index 000000000..785f9f211 --- /dev/null +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11); diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 8954a3831..384ee7615 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -310,7 +310,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - if (new_mma_available(cc)) { + if (turing_mma_available(cc)) { return true; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 1634725c2..96129bd83 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -92,7 +92,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || new_mma_available(cc)) ? 128 : + return (amd_mfma_available(cc) || turing_mma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -102,9 +102,9 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) return 128; -#else // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) #if defined(GGML_USE_HIP) return 64; @@ -121,7 +121,7 @@ static constexpr __device__ int get_mmq_x_max_device() { #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } static int get_mmq_y_host(const int cc) { @@ -233,7 +233,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { static int mmq_get_granularity_host(const int mmq_x, const int cc) { if (amd_mfma_available(cc)) { return mmq_x >= 128 ? 32 : 16; - } else if (new_mma_available(cc) && mmq_x >= 48) { + } else if (turing_mma_available(cc) && mmq_x >= 48) { return 16; } else { return 8; @@ -244,7 +244,7 @@ static int mmq_get_granularity_host(const int mmq_x, const int cc) { static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 128 ? 32 : 16; } -#elif defined(NEW_MMA_AVAILABLE) +#elif defined(TURING_MMA_AVAILABLE) static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 48 ? 16 : 8; } @@ -279,14 +279,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); constexpr int nrows = warp_size / threads_per_row; @@ -305,12 +305,12 @@ template static __device__ __forceinline__ void loa const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0; @@ -327,11 +327,11 @@ template static __device__ __forceinline__ void loa const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -382,14 +382,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); constexpr int nrows = warp_size / threads_per_row; @@ -408,12 +408,12 @@ template static __device__ __forceinline__ void loa const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; @@ -430,11 +430,11 @@ template static __device__ __forceinline__ void loa const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -485,14 +485,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); constexpr int nrows = warp_size / threads_per_row; @@ -527,13 +527,13 @@ template static __device__ __forceinline__ void loa qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; @@ -550,11 +550,11 @@ template static __device__ __forceinline__ void loa const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -563,14 +563,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); constexpr int nrows = warp_size / threads_per_row; @@ -603,13 +603,13 @@ template static __device__ __forceinline__ void loa qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; @@ -626,11 +626,11 @@ template static __device__ __forceinline__ void loa const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -639,14 +639,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp constexpr int threads_per_row = 32; @@ -665,13 +665,13 @@ template static __device__ __forceinline__ void loa const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; @@ -688,11 +688,11 @@ template static __device__ __forceinline__ void loa const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -701,14 +701,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4); constexpr int nrows = warp_size / threads_per_row; @@ -730,13 +730,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); const int k0 = kbx * (2 * QI_MXFP4) + kqsx; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4; @@ -753,11 +753,11 @@ template static __device__ __forceinline__ void loa const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; #else x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -1178,7 +1178,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } } } -#elif defined(NEW_MMA_AVAILABLE) +#elif defined(TURING_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile<16, 8, int> tile_A_8; @@ -1264,14 +1264,14 @@ template static __device__ __forceinline__ void loa const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; @@ -1295,11 +1295,11 @@ template static __device__ __forceinline__ void loa const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int sc_m = bxi->scales[kqsx]; @@ -1310,11 +1310,11 @@ template static __device__ __forceinline__ void loa const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -1452,7 +1452,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } } } -#elif defined(NEW_MMA_AVAILABLE) +#elif defined(TURING_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile<16, 8, int> tile_A_8; @@ -1582,7 +1582,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else @@ -1590,7 +1590,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K); constexpr int nrows = warp_size / threads_per_row; @@ -1618,11 +1618,11 @@ template static __device__ __forceinline__ void loa const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -1649,7 +1649,7 @@ template static __device__ __forceinline__ void loa const int sc = __vsubss4(sc_low | sc_high, 0x20202020); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) const int8_t * sc8 = (const int8_t *) ≻ const float d = bxi->d; @@ -1659,10 +1659,10 @@ template static __device__ __forceinline__ void loa } #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -#if !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)) +#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; @@ -1675,7 +1675,7 @@ template static __device__ __forceinline__ void loa x_df[i] = bxi->d; } -#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)) +#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) } template @@ -1728,7 +1728,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else @@ -1736,7 +1736,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); constexpr int nrows = warp_size / threads_per_row; @@ -1753,15 +1753,15 @@ template static __device__ __forceinline__ void loa const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; const int qs0 = get_int_b4(bxi->qs, txi); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { @@ -1829,7 +1829,7 @@ template static __device__ __forceinline__ void loa x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } template @@ -1872,7 +1872,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else @@ -1880,7 +1880,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K); constexpr int nrows = warp_size / threads_per_row; @@ -1908,16 +1908,16 @@ template static __device__ __forceinline__ void loa const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int rows_per_warp = warp_size / 2; #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { @@ -1986,7 +1986,7 @@ template static __device__ __forceinline__ void loa x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } template @@ -2029,7 +2029,7 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); @@ -2038,7 +2038,7 @@ template static __device__ __forceinline__ void loa int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); constexpr int nrows = warp_size / threads_per_row; @@ -2065,13 +2065,13 @@ template static __device__ __forceinline__ void loa const int kq0 = 2*txi - txi % (QI6_K/2) + 0; const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } #pragma unroll @@ -2084,11 +2084,11 @@ template static __device__ __forceinline__ void loa const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; #else x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 4; @@ -2102,11 +2102,11 @@ template static __device__ __forceinline__ void loa const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); #else x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2199,7 +2199,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } } } -#elif defined(NEW_MMA_AVAILABLE) +#elif defined(TURING_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile< 8, 4, int> tile_B; @@ -2311,14 +2311,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); constexpr int nrows = warp_size / threads_per_row; @@ -2340,13 +2340,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = kbx * (2 * QI4_NL) + kqsx; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; @@ -2363,11 +2363,11 @@ template static __device__ __forceinline__ void loa const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2376,14 +2376,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2414,22 +2414,22 @@ template static __device__ __forceinline__ void loa const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2438,14 +2438,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2472,24 +2472,24 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2498,14 +2498,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2539,24 +2539,24 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2565,14 +2565,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2601,22 +2601,22 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2625,14 +2625,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; constexpr int nrows = warp_size / threads_per_row; @@ -2668,22 +2668,22 @@ template static __device__ __forceinline__ void loa const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); const float d = bxi->d; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2692,14 +2692,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); constexpr int nrows = warp_size / threads_per_row; @@ -2727,23 +2727,23 @@ template static __device__ __forceinline__ void loa const int grid0 = (grid >> 0) & 0x0F0F0F0F; const int grid1 = (grid >> 4) & 0x0F0F0F0F; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); #else x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2752,14 +2752,14 @@ template static __device__ __forceinline__ void loa constexpr int nwarps = mmq_get_nwarps_device(); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); constexpr int nrows = warp_size / threads_per_row; @@ -2779,13 +2779,13 @@ template static __device__ __forceinline__ void loa const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); const int k0 = 8 * (kqsx / 4) + kqsx % 4; -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } constexpr int rows_per_warp = warp_size / 8; @@ -2804,11 +2804,11 @@ template static __device__ __forceinline__ void loa const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) } } @@ -2859,9 +2859,9 @@ static __device__ __forceinline__ void mmq_write_back_mma( constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); -#if defined(NEW_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { @@ -3061,13 +3061,13 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( int * tile_y = data_mul_mat_q + mmq_x; int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); -#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; constexpr mmq_write_back_t write_back = mmq_write_back_mma; #else constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -3534,7 +3534,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const size_t nbs_ids = mmq_x*sizeof(int); - const size_t nbs_x = (new_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); } diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmvf.cu similarity index 86% rename from ggml/src/ggml-cuda/mmv.cu rename to ggml/src/ggml-cuda/mmvf.cu index e14c93516..1ad4bc75b 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -1,9 +1,9 @@ #include "ggml.h" #include "common.cuh" -#include "mmv.cuh" +#include "mmvf.cuh" template -static __global__ void mul_mat_vec( +static __global__ void mul_mat_vec_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, @@ -37,7 +37,7 @@ static __global__ void mul_mat_vec( float sumf[ncols_dst] = {0.0f}; - if constexpr (std::is_same::value) { + if constexpr (std::is_same_v) { const float2 * x2 = (const float2 *) x; for (int col2 = tid; col2 < ncols2; col2 += block_size) { @@ -50,10 +50,10 @@ static __global__ void mul_mat_vec( sumf[j] += tmpx.y*tmpy.y; } } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same_v) { const half2 * x2 = (const half2 *) x; - if (std::is_same::value) { + if (std::is_same_v) { for (int col2 = tid; col2 < ncols2; col2 += block_size) { const float2 tmpx = __half22float2(x2[col2]); @@ -86,7 +86,7 @@ static __global__ void mul_mat_vec( NO_DEVICE_CODE; #endif // FP16_AVAILABLE } - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same_v) { const int * x2 = (const int *) x; for (int col2 = tid; col2 < ncols2; col2 += block_size) { const int tmpx = x2[col2]; @@ -98,7 +98,7 @@ static __global__ void mul_mat_vec( } } } else { - static_assert(std::is_same::value, "unsupported type"); + static_assert(std::is_same_v, "unsupported type"); } #pragma unroll @@ -126,7 +126,7 @@ static __global__ void mul_mat_vec( } template -static void launch_mul_mat_vec_cuda( +static void launch_mul_mat_vec_f_cuda( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, @@ -141,11 +141,9 @@ static void launch_mul_mat_vec_cuda( GGML_ASSERT( nsamples_dst % nsamples_x == 0); const int64_t channel_ratio = nchannels_dst / nchannels_x; const int64_t sample_ratio = nsamples_dst / nsamples_x; - int device; - int warp_size; - CUDA_CHECK(cudaGetDevice(&device)); - warp_size = ggml_cuda_info().devices[device].warp_size; + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; int64_t block_size_best = warp_size; int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size); @@ -161,54 +159,54 @@ static void launch_mul_mat_vec_cuda( } } - const int smem = warp_size*sizeof(float); + const int nbytes_shared = warp_size*sizeof(float); const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 64: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 96: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 128: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 160: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 192: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 224: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 256: { - mul_mat_vec<<>> + mul_mat_vec_f<<>> (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); @@ -220,7 +218,7 @@ static void launch_mul_mat_vec_cuda( } template -static void mul_mat_vec_cuda_switch_ncols_dst( +static void mul_mat_vec_f_cuda_switch_ncols_dst( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, @@ -230,49 +228,49 @@ static void mul_mat_vec_cuda_switch_ncols_dst( cudaStream_t stream) { switch (ncols_dst) { case 1: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 2: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 3: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 4: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 5: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 6: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 7: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); break; case 8: - launch_mul_mat_vec_cuda + launch_mul_mat_vec_f_cuda (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); @@ -284,7 +282,7 @@ static void mul_mat_vec_cuda_switch_ncols_dst( } template -static void mul_mat_vec_cuda( +static void mul_mat_vec_f_cuda( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, @@ -292,22 +290,22 @@ static void mul_mat_vec_cuda( const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, enum ggml_prec prec, cudaStream_t stream) { - if constexpr(std::is_same::value) { + if constexpr(std::is_same_v) { if (prec == GGML_PREC_DEFAULT) { - mul_mat_vec_cuda_switch_ncols_dst + mul_mat_vec_f_cuda_switch_ncols_dst (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); return; } } - mul_mat_vec_cuda_switch_ncols_dst + mul_mat_vec_f_cuda_switch_ncols_dst (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } -void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { +void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -355,19 +353,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; @@ -376,7 +374,7 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * } } -void ggml_cuda_op_mul_mat_vec( +void ggml_cuda_op_mul_mat_vec_f( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, @@ -414,19 +412,19 @@ void ggml_cuda_op_mul_mat_vec( switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; @@ -442,15 +440,15 @@ void ggml_cuda_op_mul_mat_vec( GGML_UNUSED(src1_padded_row_size); } -bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { +bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { if (src0_ne[0] % 2 != 0) { return false; } switch (type) { case GGML_TYPE_F32: if (GGML_CUDA_CC_IS_NVIDIA(cc)) { - if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { - return ne11 <= 8; + if (ampere_mma_available(cc)) { + return ne11 <= 3; } if (cc >= GGML_CUDA_CC_TURING) { return ne11 <= 4; @@ -466,6 +464,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ case GGML_TYPE_F16: if (GGML_CUDA_CC_IS_NVIDIA(cc)) { const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); + if (ampere_mma_available(cc)) { + return src0_small && ne11 == 1; + } if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { return src0_small && ne11 <= 4; } @@ -486,6 +487,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ case GGML_TYPE_BF16: if (GGML_CUDA_CC_IS_NVIDIA(cc)) { const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); + if (ampere_mma_available(cc)) { + return src0_small && ne11 == 1; + } if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { return src0_small && ne11 <= 4; } diff --git a/ggml/src/ggml-cuda/mmv.cuh b/ggml/src/ggml-cuda/mmvf.cuh similarity index 55% rename from ggml/src/ggml-cuda/mmv.cuh rename to ggml/src/ggml-cuda/mmvf.cuh index 1330bcb6a..1da460992 100644 --- a/ggml/src/ggml-cuda/mmv.cuh +++ b/ggml/src/ggml-cuda/mmvf.cuh @@ -1,11 +1,11 @@ #include "common.cuh" -void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); +void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); -void ggml_cuda_op_mul_mat_vec( +void ggml_cuda_op_mul_mat_vec_f( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); -bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11); +bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11); diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 8b172e60f..c31f31923 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -200,6 +200,7 @@ #endif typedef hip_bfloat16 nv_bfloat16; +typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 198963202..8c55a2e4e 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -137,4 +137,5 @@ #define cudaStreamEndCapture musaStreamEndCapture #define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor -typedef mt_bfloat16 nv_bfloat16; +typedef __mt_bfloat16 nv_bfloat16; +typedef __mt_bfloat162 nv_bfloat162;