CUDA: use async data loading for FlashAttention (#11894)

* CUDA: use async data loading for FlashAttention

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
Johannes Gäßler
2025-02-17 14:03:24 +01:00
committed by GitHub
parent f7b1116af1
commit 73e2ed3ce3
6 changed files with 744 additions and 739 deletions

View File

@@ -7,6 +7,8 @@
#include <climits>
#include <cstdint>
using namespace ggml_cuda_mma;
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
#define MMQ_ITER_K 256
#define MMQ_NWARPS 8
@@ -647,15 +649,15 @@ template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
typedef mma_A_I16K8<int> mma_A;
typedef mma_B_J8K8<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 8, int> tile_A;
typedef tile< 8, 8, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
@@ -663,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const float * y_df = (const float *) y;
const half2 * y_ds = (const half2 *) y;
mma_A A[ntx][WARP_SIZE/QI8_0];
float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
tile_A A[ntx][WARP_SIZE/QI8_0];
float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
@@ -674,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
const int k0 = k00 + k01;
A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
@@ -691,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
mma_B B;
float dB[mma_C::ne/2];
tile_B B;
float dB[tile_C::ne/2];
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
@@ -712,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C;
C.mma(A[n][k01/QI8_0], B);
tile_C C;
mma(C, A[n][k01/QI8_0], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
}
}
}
@@ -758,23 +760,23 @@ template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
typedef mma_A_I16K8<int> mma_A;
typedef mma_B_J8K8<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 8, int> tile_A;
typedef tile< 8, 8, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
const int * y_qs = (const int *) y + 4;
const half2 * y_dm = (const half2 *) y;
mma_A A[ntx][WARP_SIZE/QI8_1];
float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
tile_A A[ntx][WARP_SIZE/QI8_1];
float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
@@ -784,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
const int k0 = k00 + k01;
A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
@@ -801,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
mma_B B;
float2 dsB[mma_C::ne/2];
tile_B B;
float2 dsB[tile_C::ne/2];
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C;
C.mma(A[n][k01/QI8_1], B);
tile_C C;
mma(C, A[n][k01/QI8_1], B);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
}
}
}
@@ -868,26 +870,26 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef NEW_MMA_AVAILABLE
typedef mma_A_I16K4<int> mma_A;
typedef mma_A_I16K8<int> mma_A_K8;
typedef mma_B_J8K4<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 4, int> tile_A;
typedef tile<16, 8, int> tile_A_8;
typedef tile< 8, 4, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
mma_A A[ntx][8];
float dA[ntx][mma_C::ne/2][8];
tile_A A[ntx][8];
float dA[ntx][tile_C::ne/2][8];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
@@ -895,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
const int k0 = k00 + k01;
((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
@@ -912,32 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
mma_B B[2];
float dB[mma_C::ne/2];
tile_B B[2];
float dB[tile_C::ne/2];
// Here load_generic is faster than load_ldmatrix.
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C[2];
C[0].mma(A[n][k01/4 + 0], B[0]);
C[1].mma(A[n][k01/4 + 1], B[1]);
tile_C C[2];
mma(C[0], A[n][k01/4 + 0], B[0]);
mma(C[1], A[n][k01/4 + 1], B[1]);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
}
}
}
@@ -1056,27 +1058,27 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef NEW_MMA_AVAILABLE
typedef mma_A_I16K4<int> mma_A;
typedef mma_A_I16K8<int> mma_A_K8;
typedef mma_B_J8K4<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 4, int> tile_A;
typedef tile<16, 8, int> tile_A_8;
typedef tile< 8, 4, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
const int * y_qs = (const int *) y + 4;
const half2 * y_ds = (const half2 *) y;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
mma_A A[ntx][8];
float dA[ntx][mma_C::ne/2][8];
float mA[ntx][mma_C::ne/2][8];
tile_A A[ntx][8];
float dA[ntx][tile_C::ne/2][8];
float mA[ntx][tile_C::ne/2][8];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
@@ -1084,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
const int k0 = k00 + k01;
((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
}
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
@@ -1107,58 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
float2 dB[mma_C::ne/2];
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
float2 dB[tile_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
}
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
mma_B B[2];
tile_B B[2];
// Here load_generic is faster than load_ldmatrix.
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
mma_C Cm[2];
tile_C Cm[2];
if (k01 >= WARP_SIZE * 3/4) {
mma_A A1;
tile_A A1;
A1.x[0] = 0x01010101;
A1.x[1] = 0x01010101;
Cm[0].mma(A1, B[0]);
Cm[1].mma(A1, B[1]);
mma(Cm[0], A1, B[0]);
mma(Cm[1], A1, B[1]);
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C Cd[2];
tile_C Cd[2];
Cd[0].mma(A[n][k01/4 + 0], B[0]);
Cd[1].mma(A[n][k01/4 + 1], B[1]);
mma(Cd[0], A[n][k01/4 + 0], B[0]);
mma(Cd[1], A[n][k01/4 + 1], B[1]);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
for (int l = 0; l < tile_C::ne; ++l) {
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
if (k01 >= WARP_SIZE * 3/4) {
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
}
sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
}
}
}
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
float2 sB[mma_C::ne/2];
float2 sB[tile_C::ne/2];
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
}
@@ -1166,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
}
}
}
@@ -1708,15 +1710,15 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
#ifdef NEW_MMA_AVAILABLE
typedef mma_A_I16K4<int> mma_A;
typedef mma_B_J8K4<int> mma_B;
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 4, int> tile_A;
typedef tile< 8, 4, int> tile_B;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
const int * x_qs = (const int *) x;
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
@@ -1724,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * y_qs = (const int *) y + 4;
const float * y_df = (const float *) y;
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
mma_A A[ntx][8];
int scA[ntx][mma_C::ne/2][8];
float dA[ntx][mma_C::ne/2];
tile_A A[ntx][8];
int scA[ntx][tile_C::ne/2][8];
float dA[ntx][tile_C::ne/2];
#pragma unroll
for (int n = 0; n < ntx; ++n) {
@@ -1736,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
const int k0 = k00 + k01;
A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
}
#pragma unroll
@@ -1745,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int k0 = k00 + k01;
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
const int8_t * sc = (const int8_t *) &sc_packed;
@@ -1759,41 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
}
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
}
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
float tmp[ntx][mma_C::ne] = {{0.0f}};
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
float tmp[ntx][tile_C::ne] = {{0.0f}};
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
mma_B B[2];
float dB[mma_C::ne/2];
tile_B B[2];
float dB[tile_C::ne/2];
// Here load_generic is faster than load_ldmatrix.
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
#pragma unroll
for (int l = 0; l < mma_C::ne/2; ++l) {
const int j = j0 + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne/2; ++l) {
const int j = j0 + tile_C::get_j(l);
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
}
#pragma unroll
for (int n = 0; n < ntx; ++n) {
mma_C C[2];
C[0].mma(A[n][k01/4 + 0], B[0]);
C[1].mma(A[n][k01/4 + 1], B[1]);
tile_C C[2];
mma(C[0], A[n][k01/4 + 0], B[0]);
mma(C[1], A[n][k01/4 + 1], B[1]);
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
for (int l = 0; l < tile_C::ne; ++l) {
tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
}
}
@@ -1802,8 +1804,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
for (int l = 0; l < tile_C::ne; ++l) {
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
}
}
}
@@ -2312,36 +2314,36 @@ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_mma(
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
typedef mma_C_I16J8<int> mma_C;
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = 2 * granularity;
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
#ifdef NEW_MMA_AVAILABLE
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
#endif // NEW_MMA_AVAILABLE
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
#pragma unroll
for (int n = 0; n < ntx; ++n) {
#pragma unroll
for (int l = 0; l < mma_C::ne; ++l) {
const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
for (int l = 0; l < tile_C::ne; ++l) {
const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
if (j > j_max) {
continue;
}
const int i = i0 + n*mma_C::I + mma_C::get_i(l);
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
if (need_check && i > i_max) {
continue;
}
dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
}
}
}