mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-20 06:36:48 -04:00
HIP: Enable Matrix cores for MMQ Kernels, Enable stream-K for CDNA 3 (#14624)
This commit adds support for MFMA instructions to MMQ. CDNA1/GFX908 CDNA2/GFX90a and CDNA3/GFX942 are supported by the MFMA-enabled code path added by this commit. The code path and stream-k is only enabled on CDNA3 for now as it fails to outperform blas in all cases on the other devices. Blas is currently only consistently outperformed on CDNA3 due to issues in the amd-provided blas libraries. This commit also improves the awareness of MMQ towards different warp sizes and as a side effect improves the performance of all quant formats besides q4_0 and q4_1, which regress slightly, on GCN gpus.
This commit is contained in:
@@ -12,7 +12,8 @@
|
||||
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
|
||||
// All matrix tiles have ne physical 32 bit elements per warp.
|
||||
//
|
||||
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
||||
// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
|
||||
// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.
|
||||
|
||||
#include "common.cuh"
|
||||
|
||||
@@ -66,7 +67,44 @@ namespace ggml_cuda_mma {
|
||||
struct tile {
|
||||
static constexpr int I = I_;
|
||||
static constexpr int J = J_;
|
||||
static constexpr int ne = I * J / WARP_SIZE;
|
||||
|
||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
static constexpr int ne = I * J / 64;
|
||||
T x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
||||
return threadIdx.x % 16;
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return threadIdx.x % 16;
|
||||
} else if constexpr (I == 32 && J == 4) {
|
||||
return threadIdx.x % 32;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return 4 * (threadIdx.x / 16) + l;
|
||||
} else if constexpr (I == 32 && J == 32) {
|
||||
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
||||
return (2 * ((threadIdx.x / 16) % 2) + l);
|
||||
} else if constexpr (I == 16 && J == 8) {
|
||||
return 2 * (threadIdx.x / 16) + l;
|
||||
} else if constexpr (I == 32 && J == 4) {
|
||||
return 2 * (threadIdx.x / 32) + l;
|
||||
} else if constexpr (I == 16 && J == 16) {
|
||||
return threadIdx.x % 16;
|
||||
} else if constexpr (I == 32 && J == 32) {
|
||||
return threadIdx.x % 32;
|
||||
} else {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
#else
|
||||
static constexpr int ne = I * J / 32;
|
||||
T x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
@@ -94,6 +132,7 @@ namespace ggml_cuda_mma {
|
||||
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
||||
}
|
||||
}
|
||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
||||
};
|
||||
|
||||
template <int I_, int J_>
|
||||
@@ -148,10 +187,23 @@ namespace ggml_cuda_mma {
|
||||
|
||||
template <int I, int J, typename T>
|
||||
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
|
||||
#pragma unroll
|
||||
for (int l = 0; l < t.ne; ++l) {
|
||||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
||||
}
|
||||
} else {
|
||||
int64_t * xi = (int64_t *) t.x;
|
||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||
xi[0] = xs[0];
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int l = 0; l < t.ne; ++l) {
|
||||
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
|
||||
}
|
||||
#endif // defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -186,7 +238,7 @@ namespace ggml_cuda_mma {
|
||||
template <typename T>
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
#if defined(NEW_MMA_AVAILABLE)
|
||||
int * xi = (int * ) t.x;
|
||||
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
|
||||
@@ -393,4 +445,60 @@ namespace ggml_cuda_mma {
|
||||
NO_DEVICE_CODE;
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||
int32x4_t * acc = (int32x4_t *) D.x;
|
||||
#if defined(CDNA3)
|
||||
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
|
||||
((int64_t *) B.x)[0],
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
#elif defined(CDNA2) || defined(CDNA)
|
||||
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
|
||||
B.x[0],
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
|
||||
B.x[1],
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
#endif // defined(CDNA3)
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
|
||||
int32x16_t * acc = (int32x16_t *) D.x;
|
||||
#if defined(CDNA3)
|
||||
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
|
||||
((int64_t *) B.x)[0],
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
#elif defined(CDNA2) || defined(CDNA)
|
||||
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
|
||||
B.x[0],
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
|
||||
B.x[1],
|
||||
acc[0],
|
||||
0, 0, 0);
|
||||
#endif // defined(CDNA3)
|
||||
#else
|
||||
GGML_UNUSED(D);
|
||||
GGML_UNUSED(A);
|
||||
GGML_UNUSED(B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user