mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-05 16:38:29 -04:00
CUDA: use async data loading for FlashAttention (#11894)
* CUDA: use async data loading for FlashAttention --------- Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
@@ -7,6 +7,8 @@
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
||||
#define MMQ_ITER_K 256
|
||||
#define MMQ_NWARPS 8
|
||||
@@ -647,15 +649,15 @@ template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||
|
||||
typedef mma_A_I16K8<int> mma_A;
|
||||
typedef mma_B_J8K8<int> mma_B;
|
||||
typedef mma_C_I16J8<int> mma_C;
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile< 8, 8, int> tile_B;
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
||||
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
||||
|
||||
const int * x_qs = (const int *) x;
|
||||
const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
|
||||
@@ -663,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
const float * y_df = (const float *) y;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
mma_A A[ntx][WARP_SIZE/QI8_0];
|
||||
float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
|
||||
tile_A A[ntx][WARP_SIZE/QI8_0];
|
||||
float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
|
||||
|
||||
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
||||
|
||||
@@ -674,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
A[n][k01/QI8_0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
||||
load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
||||
@@ -691,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
|
||||
mma_B B;
|
||||
float dB[mma_C::ne/2];
|
||||
tile_B B;
|
||||
float dB[tile_C::ne/2];
|
||||
|
||||
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
||||
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int j = j0 + tile_C::get_j(l);
|
||||
|
||||
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||
@@ -712,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
mma_C C;
|
||||
C.mma(A[n][k01/QI8_0], B);
|
||||
tile_C C;
|
||||
mma(C, A[n][k01/QI8_0], B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -758,23 +760,23 @@ template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||
|
||||
typedef mma_A_I16K8<int> mma_A;
|
||||
typedef mma_B_J8K8<int> mma_B;
|
||||
typedef mma_C_I16J8<int> mma_C;
|
||||
typedef tile<16, 8, int> tile_A;
|
||||
typedef tile< 8, 8, int> tile_B;
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
||||
y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
|
||||
|
||||
const int * x_qs = (const int *) x;
|
||||
const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_dm = (const half2 *) y;
|
||||
|
||||
mma_A A[ntx][WARP_SIZE/QI8_1];
|
||||
float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
|
||||
tile_A A[ntx][WARP_SIZE/QI8_1];
|
||||
float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
|
||||
|
||||
const int i0 = (threadIdx.y/ntx)*rows_per_warp;
|
||||
|
||||
@@ -784,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
A[n][k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
||||
load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||
@@ -801,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||
mma_B B;
|
||||
float2 dsB[mma_C::ne/2];
|
||||
tile_B B;
|
||||
float2 dsB[tile_C::ne/2];
|
||||
|
||||
B.load_generic(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
||||
load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int j = j0 + tile_C::get_j(l);
|
||||
|
||||
dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
mma_C C;
|
||||
C.mma(A[n][k01/QI8_1], B);
|
||||
tile_C C;
|
||||
mma(C, A[n][k01/QI8_1], B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -868,26 +870,26 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
|
||||
typedef mma_A_I16K4<int> mma_A;
|
||||
typedef mma_A_I16K8<int> mma_A_K8;
|
||||
typedef mma_B_J8K4<int> mma_B;
|
||||
typedef mma_C_I16J8<int> mma_C;
|
||||
typedef tile<16, 4, int> tile_A;
|
||||
typedef tile<16, 8, int> tile_A_8;
|
||||
typedef tile< 8, 4, int> tile_B;
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
||||
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
||||
|
||||
const int * x_qs = (const int *) x;
|
||||
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const float * y_df = (const float *) y;
|
||||
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
|
||||
|
||||
mma_A A[ntx][8];
|
||||
float dA[ntx][mma_C::ne/2][8];
|
||||
tile_A A[ntx][8];
|
||||
float dA[ntx][tile_C::ne/2][8];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
@@ -895,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
((mma_A_K8 *) A[n])[k01/8].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
||||
load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
|
||||
@@ -912,32 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
|
||||
mma_B B[2];
|
||||
float dB[mma_C::ne/2];
|
||||
tile_B B[2];
|
||||
float dB[tile_C::ne/2];
|
||||
|
||||
// Here load_generic is faster than load_ldmatrix.
|
||||
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
||||
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
|
||||
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
||||
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int j = j0 + tile_C::get_j(l);
|
||||
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
mma_C C[2];
|
||||
C[0].mma(A[n][k01/4 + 0], B[0]);
|
||||
C[1].mma(A[n][k01/4 + 1], B[1]);
|
||||
tile_C C[2];
|
||||
mma(C[0], A[n][k01/4 + 0], B[0]);
|
||||
mma(C[1], A[n][k01/4 + 1], B[1]);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1056,27 +1058,27 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
|
||||
typedef mma_A_I16K4<int> mma_A;
|
||||
typedef mma_A_I16K8<int> mma_A_K8;
|
||||
typedef mma_B_J8K4<int> mma_B;
|
||||
typedef mma_C_I16J8<int> mma_C;
|
||||
typedef tile<16, 4, int> tile_A;
|
||||
typedef tile<16, 8, int> tile_A_8;
|
||||
typedef tile< 8, 4, int> tile_B;
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
||||
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
||||
|
||||
const int * x_qs = (const int *) x;
|
||||
const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
|
||||
|
||||
mma_A A[ntx][8];
|
||||
float dA[ntx][mma_C::ne/2][8];
|
||||
float mA[ntx][mma_C::ne/2][8];
|
||||
tile_A A[ntx][8];
|
||||
float dA[ntx][tile_C::ne/2][8];
|
||||
float mA[ntx][tile_C::ne/2][8];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
@@ -1084,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
((mma_A_K8 *) A[n])[k01/QI8_1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
||||
load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
|
||||
@@ -1107,58 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
float2 dB[mma_C::ne/2];
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||
float2 dB[tile_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int j = j0 + tile_C::get_j(l);
|
||||
|
||||
dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
|
||||
mma_B B[2];
|
||||
tile_B B[2];
|
||||
|
||||
// Here load_generic is faster than load_ldmatrix.
|
||||
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
||||
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
|
||||
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
|
||||
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
|
||||
|
||||
mma_C Cm[2];
|
||||
tile_C Cm[2];
|
||||
if (k01 >= WARP_SIZE * 3/4) {
|
||||
mma_A A1;
|
||||
tile_A A1;
|
||||
A1.x[0] = 0x01010101;
|
||||
A1.x[1] = 0x01010101;
|
||||
Cm[0].mma(A1, B[0]);
|
||||
Cm[1].mma(A1, B[1]);
|
||||
mma(Cm[0], A1, B[0]);
|
||||
mma(Cm[1], A1, B[1]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
mma_C Cd[2];
|
||||
tile_C Cd[2];
|
||||
|
||||
Cd[0].mma(A[n][k01/4 + 0], B[0]);
|
||||
Cd[1].mma(A[n][k01/4 + 1], B[1]);
|
||||
mma(Cd[0], A[n][k01/4 + 0], B[0]);
|
||||
mma(Cd[1], A[n][k01/4 + 1], B[1]);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
|
||||
if (k01 >= WARP_SIZE * 3/4) {
|
||||
tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
|
||||
}
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
|
||||
float2 sB[mma_C::ne/2];
|
||||
float2 sB[tile_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int j = j0 + tile_C::get_j(l);
|
||||
|
||||
sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
|
||||
}
|
||||
@@ -1166,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1708,15 +1710,15 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
|
||||
typedef mma_A_I16K4<int> mma_A;
|
||||
typedef mma_B_J8K4<int> mma_B;
|
||||
typedef mma_C_I16J8<int> mma_C;
|
||||
typedef tile<16, 4, int> tile_A;
|
||||
typedef tile< 8, 4, int> tile_B;
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
|
||||
y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
|
||||
|
||||
const int * x_qs = (const int *) x;
|
||||
const float * x_df = (const float *) x_qs + WARP_SIZE*2;
|
||||
@@ -1724,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const float * y_df = (const float *) y;
|
||||
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
|
||||
|
||||
mma_A A[ntx][8];
|
||||
int scA[ntx][mma_C::ne/2][8];
|
||||
float dA[ntx][mma_C::ne/2];
|
||||
tile_A A[ntx][8];
|
||||
int scA[ntx][tile_C::ne/2][8];
|
||||
float dA[ntx][tile_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
@@ -1736,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
A[n][k01/4 + 0].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
|
||||
A[n][k01/4 + 1].load_ldmatrix(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
|
||||
load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
|
||||
load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
@@ -1745,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
const int k0 = k00 + k01;
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
||||
|
||||
const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
|
||||
const int8_t * sc = (const int8_t *) &sc_packed;
|
||||
@@ -1759,41 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
|
||||
|
||||
dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
float tmp[ntx][mma_C::ne] = {{0.0f}};
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||
float tmp[ntx][tile_C::ne] = {{0.0f}};
|
||||
|
||||
#pragma unroll
|
||||
for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
|
||||
mma_B B[2];
|
||||
float dB[mma_C::ne/2];
|
||||
tile_B B[2];
|
||||
float dB[tile_C::ne/2];
|
||||
|
||||
// Here load_generic is faster than load_ldmatrix.
|
||||
B[0].load_generic(y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
|
||||
B[1].load_generic(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
|
||||
load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
|
||||
load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne/2; ++l) {
|
||||
const int j = j0 + tile_C::get_j(l);
|
||||
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
mma_C C[2];
|
||||
C[0].mma(A[n][k01/4 + 0], B[0]);
|
||||
C[1].mma(A[n][k01/4 + 1], B[1]);
|
||||
tile_C C[2];
|
||||
mma(C[0], A[n][k01/4 + 0], B[0]);
|
||||
mma(C[1], A[n][k01/4 + 1], B[1]);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
|
||||
}
|
||||
}
|
||||
@@ -1802,8 +1804,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2312,36 +2314,36 @@ template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
static __device__ __forceinline__ void mmq_write_back_mma(
|
||||
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
|
||||
|
||||
typedef mma_C_I16J8<int> mma_C;
|
||||
typedef tile<16, 8, int> tile_C;
|
||||
|
||||
constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
||||
constexpr int rows_per_warp = 2 * granularity;
|
||||
constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
|
||||
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
|
||||
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
|
||||
const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
|
||||
#ifdef NEW_MMA_AVAILABLE
|
||||
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
|
||||
static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
|
||||
#endif // NEW_MMA_AVAILABLE
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
|
||||
#pragma unroll
|
||||
for (int n = 0; n < ntx; ++n) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
|
||||
for (int l = 0; l < tile_C::ne; ++l) {
|
||||
const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
|
||||
|
||||
if (j > j_max) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int i = i0 + n*mma_C::I + mma_C::get_i(l);
|
||||
const int i = i0 + n*tile_C::I + tile_C::get_i(l);
|
||||
|
||||
if (need_check && i > i_max) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
|
||||
dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user