mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-29 13:43:38 -04:00
CUDA: use mma PTX instructions for FlashAttention (#11583)
* CUDA: use mma PTX instructions for FlashAttention * __shfl_sync workaround for movmatrix * add __shfl_sync to HIP Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
@@ -516,6 +516,104 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
||||
nullptr;
|
||||
}
|
||||
|
||||
template<int D, int ncols, int KQ_stride> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
||||
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
||||
|
||||
const int iter_k = ne11 / KQ_stride;
|
||||
const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
|
||||
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc0 / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
||||
|
||||
dst += jt*ncols*ne02*D + channel*D;
|
||||
|
||||
// Load the partial result that needs a fixup:
|
||||
float dst_val[ncols] = {0.0f};
|
||||
float max_val[ncols] = {0.0f};
|
||||
float rowsum[ncols] = {0.0f};
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (jt*ncols + j >= ne01) {
|
||||
break;
|
||||
}
|
||||
dst_val[j] = dst[j*ne02*D + threadIdx.x];
|
||||
|
||||
const float2 tmp = dst_fixup[bidx0*ncols + j];
|
||||
max_val[j] = tmp.x;
|
||||
rowsum[j] = tmp.y;
|
||||
}
|
||||
|
||||
// Iterate over previous blocks and compute the combined results.
|
||||
// All CUDA blocks that get here must have a previous block that needs a fixup.
|
||||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
continue;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (jt*ncols + j >= ne01) {
|
||||
break;
|
||||
}
|
||||
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
|
||||
|
||||
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
|
||||
|
||||
// Scale the current and new value accumulators depending on the max. values.
|
||||
const float max_val_new = fmaxf(max_val[j], tmp.x);
|
||||
|
||||
const float diff_val = max_val[j] - max_val_new;
|
||||
const float diff_add = tmp.x - max_val_new;
|
||||
|
||||
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
||||
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
||||
|
||||
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
|
||||
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
|
||||
|
||||
max_val[j] = max_val_new;
|
||||
}
|
||||
|
||||
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
||||
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
||||
break;
|
||||
}
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
}
|
||||
|
||||
// Write back final result:
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
if (jt*ncols + j >= ne01) {
|
||||
return;
|
||||
}
|
||||
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
|
||||
}
|
||||
}
|
||||
|
||||
template<int D, int parallel_blocks> // D == head size
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
@@ -581,10 +679,11 @@ static void on_no_fattn_vec_case(const int D) {
|
||||
}
|
||||
}
|
||||
|
||||
template <int D, int parallel_blocks>
|
||||
// parallel_blocks == 0 is stream-k decomposition
|
||||
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
|
||||
void launch_fattn(
|
||||
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
||||
const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
|
||||
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
|
||||
) {
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
@@ -603,20 +702,23 @@ void launch_fattn(
|
||||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
GGML_ASSERT(Q->ne[3] == 1);
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
||||
|
||||
ggml_cuda_pool_alloc<half> K_f16(pool);
|
||||
ggml_cuda_pool_alloc<half> V_f16(pool);
|
||||
ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
||||
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
||||
|
||||
char * K_data = (char *) K->data;
|
||||
const char * K_data = (const char *) K->data;
|
||||
size_t nb11 = K->nb[1];
|
||||
size_t nb12 = K->nb[2];
|
||||
size_t nb13 = K->nb[3];
|
||||
|
||||
char * V_data = (char *) V->data;
|
||||
const char * V_data = (const char *) V->data;
|
||||
size_t nb21 = V->nb[1];
|
||||
size_t nb22 = V->nb[2];
|
||||
size_t nb23 = V->nb[3];
|
||||
@@ -649,39 +751,60 @@ void launch_fattn(
|
||||
nb23 = nb23*bs*sizeof(half)/ts;
|
||||
}
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||
}
|
||||
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
|
||||
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
|
||||
|
||||
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
||||
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
|
||||
const int shmem = 0;
|
||||
dim3 blocks_num;
|
||||
if (parallel_blocks == 0) {
|
||||
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
||||
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
|
||||
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
|
||||
const bool short_context = K->ne[1] < 4096;
|
||||
|
||||
const int nblocks_stream_k = 2*nsm;
|
||||
|
||||
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
|
||||
blocks_num.y = 1;
|
||||
blocks_num.z = 1;
|
||||
|
||||
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
|
||||
} else {
|
||||
blocks_num.x = parallel_blocks*ntiles_x;
|
||||
blocks_num.y = Q->ne[2];
|
||||
blocks_num.z = Q->ne[3];
|
||||
|
||||
if (parallel_blocks > 1) {
|
||||
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
||||
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
float logit_softcap = 0.0f;
|
||||
|
||||
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
|
||||
memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
scale /= logit_softcap;
|
||||
}
|
||||
|
||||
const uint32_t n_head = Q->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||
const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
|
||||
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
||||
(const char *) Q->data,
|
||||
K_data,
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
@@ -693,16 +816,22 @@ void launch_fattn(
|
||||
);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if ((parallel_blocks) == 1) {
|
||||
return;
|
||||
if constexpr (parallel_blocks == 0) {
|
||||
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine = blocks_num;
|
||||
|
||||
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
}
|
||||
} else if constexpr (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
||||
|
||||
flash_attn_combine_results<D, parallel_blocks>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||
}
|
||||
|
||||
const dim3 block_dim_combine(D, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
||||
const int shmem_combine = 0;
|
||||
|
||||
flash_attn_combine_results<D, parallel_blocks>
|
||||
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
||||
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
Reference in New Issue
Block a user