#pragma once #include "common.cuh" #include "convert.cuh" #include "vecdotq.cuh" #include #define FATTN_KQ_STRIDE 256 #define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. typedef void (* fattn_kernel_t)( const char * __restrict__ Q, const char * __restrict__ K, const char * __restrict__ V, const char * __restrict__ mask, float * __restrict__ dst, float2 * __restrict__ dst_meta, const float scale, const float max_bias, const float m0, const float m1, const uint32_t n_head_log2, const float logit_softcap, const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, const int32_t nb01, const int32_t nb02, const int32_t nb03, const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, const int32_t nb11, const int32_t nb12, const int64_t nb13, const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33); typedef half (*vec_dot_KQ_f16_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); typedef float (*vec_dot_KQ_f32_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; const int iqs4 = k_KQ % QI4_0; const int shift = k_KQ & (QI8_1/2); const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); #ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size]; sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */); } else #endif // FP16_AVAILABLE { const float2 * Q_ds = (const float2 *) Q_ds_v; sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y)); } } return sum; } template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; const int iqs4 = k_KQ % QI4_1; const int shift = k_KQ & (QI8_1/2); const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); #ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size]; const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1); sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled)); } else #endif // FP16_AVAILABLE { const float2 * Q_ds = (const float2 *) Q_ds_v; const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi; const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1; sum += (T) (sumid4d8 + m4s8scaled); } } return sum; } template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; const int iqs4 = k_KQ % QI5_0; const int iqs8 = k_KQ % QI8_1; const int shift = k_KQ & (QI8_1/2); int v = (get_int_b2(K_q5_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; const int vh = get_int_b2(K_q5_0[ib].qh, 0) >> (iqs8 * QI5_0); v |= (vh << 4) & 0x00000010; // 0 -> 4 v |= (vh << 11) & 0x00001000; // 1 -> 12 v |= (vh << 18) & 0x00100000; // 2 -> 20 v |= (vh << 25) & 0x10000000; // 3 -> 28 const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); #ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size]; sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */; } else #endif // FP16_AVAILABLE { const float2 * Q_ds = (const float2 *) Q_ds_v; sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y)); } } return sum; } template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_1; const int iqs4 = k_KQ % QI5_1; const int iqs8 = k_KQ % QI8_1; const int shift = k_KQ & (QI8_1/2); int v = (get_int_b2(K_q5_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F; const int vh = get_int_b2(K_q5_1[ib].qh, 0) >> (iqs8 * QI5_1); v |= (vh << 4) & 0x00000010; // 0 -> 4 v |= (vh << 11) & 0x00001000; // 1 -> 12 v |= (vh << 18) & 0x00100000; // 2 -> 20 v |= (vh << 25) & 0x10000000; // 3 -> 28 const int u = Q_q8[k_KQ_0/warp_size]; const int sumi = ggml_cuda_dp4a(v, u, 0); #ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size]; const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1); sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled)); } else #endif // FP16_AVAILABLE { const float2 * Q_ds = (const float2 *) Q_ds_v; const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi; const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1; sum += (T) (sumid5d8 + m5s8scaled); } } return sum; } template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; GGML_UNUSED(Q_v); T sum = 0.0f; #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const int ib = k_KQ / QI8_0; const int iqs = k_KQ % QI8_0; const int v = get_int_b2(K_q8_0[ib].qs, iqs); T Q_d; if (std::is_same::value) { const half2 * Q_ds = (const half2 *) Q_ds_v; Q_d = __low2half(Q_ds[k_KQ_0/warp_size]); } else { const float2 * Q_ds = (const float2 *) Q_ds_v; Q_d = Q_ds[k_KQ_0/warp_size].x; } sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d); } return sum; } template static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { const half2 * K_h2 = (const half2 *) K_c; GGML_UNUSED(Q_q8); GGML_UNUSED(Q_ds_v); #ifdef FP16_AVAILABLE if (std::is_same::value) { const half2 * Q_h2 = (const half2 *) Q_v; half2 sum2 = make_half2(0.0f, 0.0f); #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; sum2 += K_ik * Q_h2[k_KQ_0/warp_size]; } return __low2half(sum2) + __high2half(sum2); } #endif // FP16_AVAILABLE const float2 * Q_f2 = (const float2 *) Q_v; float sum = 0.0f; #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) { const int k_KQ = k_KQ_0 + threadIdx.x; const half2 K_ik = K_h2[k_KQ]; sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x; sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y; } return sum; } template static __device__ __forceinline__ void quantize_q8_1_to_shared( const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) { float vals[sizeof(int)] = {0.0f}; #pragma unroll for (int l = 0; l < int(sizeof(int)); ++l) { vals[l] = scale * x[4*threadIdx.x + l]; } float amax = fabsf(vals[0]); float sum = vals[0]; #pragma unroll for (int l = 1; l < int(sizeof(int)); ++l) { amax = fmaxf(amax, fabsf(vals[l])); sum += vals[l]; } #pragma unroll for (int mask = QI8_1/2; mask > 0; mask >>= 1) { amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32)); sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32); } const float d = amax / 127; int q32 = 0; int8_t * q8 = (int8_t *) &q32; if (d != 0.0f) { #pragma unroll for (int l = 0; l < int(sizeof(int)); ++l) { q8[l] = roundf(vals[l] / d); } } yq32[threadIdx.x] = q32; if (threadIdx.x % QI8_1 == 0) { if (std::is_same::value) { ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum); } else { ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum); } } } typedef half (*dequantize_1_f16_t)(const void *, const int64_t); typedef float (*dequantize_1_f32_t)(const void *, const int64_t); template static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__ vx, const int64_t i) { const block_q4_0 * x = (const block_q4_0 *) vx; const int64_t ib = i / QK4_0; const int iqs = i % (QK4_0/2); const int shift = (i % QK4_0) / (QK4_0/2); const T d = x[ib].d; const int q0 = x[ib].qs[iqs]; const int q = ((q0 >> (4*shift)) & 0x0F) - 8; #ifdef FP16_AVAILABLE if (std::is_same::value) { return ((half) d)*((half) q); } #endif // FP16_AVAILABLE return ((float) d)*((float) q); } template static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__ vx, const int64_t i) { const block_q4_1 * x = (const block_q4_1 *) vx; const int64_t ib = i / QK4_1; const int iqs = i % (QK4_1/2); const int shift = (i % QK4_1) / (QK4_1/2); const half2 dm = x[ib].dm; const int q0 = x[ib].qs[iqs]; const int q = ((q0 >> (4*shift)) & 0x0F); #ifdef FP16_AVAILABLE if (std::is_same::value) { return __low2half(dm)*((half) q) + __high2half(dm); } #endif // FP16_AVAILABLE return __low2float(dm)*((float) q) + __high2float(dm); } template static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__ vx, const int64_t i) { const block_q5_0 * x = (const block_q5_0 *) vx; const int64_t ib = i / QK5_0; const int idq = i % QK5_0; const int iqs = i % (QK5_0/2); const int shift = (i % QK5_0) / (QK5_0/2); const T d = x[ib].d; const int ql0 = x[ib].qs[iqs]; const int qh0 = get_int_b2(x[ib].qh, 0); const int ql = ((ql0 >> (4*shift)) & 0x0F); const int qh = ((qh0 >> idq) << 4) & 0x10; const int q = (ql | qh) - 16; #ifdef FP16_AVAILABLE if (std::is_same::value) { return ((half) d)*((half) q); } #endif // FP16_AVAILABLE return ((float) d)*((float) q); } template static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__ vx, const int64_t i) { const block_q5_1 * x = (const block_q5_1 *) vx; const int64_t ib = i / QK5_1; const int idq = i % QK5_1; const int iqs = i % (QK5_1/2); const int shift = (i % QK5_1) / (QK5_1/2); const half2 dm = x[ib].dm; const int ql0 = x[ib].qs[iqs]; const int qh0 = get_int_b4(x[ib].qh, 0); const int ql = ((ql0 >> (4*shift)) & 0x0F); const int qh = ((qh0 >> idq) << 4) & 0x10; const int q = (ql | qh); #ifdef FP16_AVAILABLE if (std::is_same::value) { return __low2half(dm)*((half) q) + __high2half(dm); } #endif // FP16_AVAILABLE return __low2float(dm)*((float) q) + __high2float(dm); } template static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__ vx, const int64_t i) { const block_q8_0 * x = (const block_q8_0 *) vx; const int64_t ib = i / QK8_0; const int iqs = i % QK8_0; const T d = x[ib].d; const int q = x[ib].qs[iqs]; #ifdef FP16_AVAILABLE if (std::is_same::value) { return ((half) d)*((half) q); } #endif // FP16_AVAILABLE return ((float) d)*((float) q); } template static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ vx, const int64_t i) { const half * x = (const half *) vx; return x[i]; } template constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : nullptr; } template constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0 : type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1 : type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0 : type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1 : type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0 : type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16 : nullptr; } constexpr __device__ dequantize_1_f16_t get_dequantize_1_f16(ggml_type type_V) { return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0 : type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 : type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 : type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : type_V == GGML_TYPE_F16 ? dequantize_1_f16 : nullptr; } constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { return type_V == GGML_TYPE_Q4_0 ? dequantize_1_q4_0 : type_V == GGML_TYPE_Q4_1 ? dequantize_1_q4_1 : type_V == GGML_TYPE_Q5_0 ? dequantize_1_q5_0 : type_V == GGML_TYPE_Q5_1 ? dequantize_1_q5_1 : type_V == GGML_TYPE_Q8_0 ? dequantize_1_q8_0 : type_V == GGML_TYPE_F16 ? dequantize_1_f16 : nullptr; } template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; const int j = blockIdx.y; const int c = blockIdx.z; const int jc = j*ncols2 + c; const int tid = threadIdx.x; const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { return; } const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2)); const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. if (jt*ncols1 + j >= ne01) { return; } dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid; // Load the partial result that needs a fixup: float dst_val = 0.0f; float max_val = 0.0f; float rowsum = 0.0f; { dst_val = *dst; const float2 tmp = dst_fixup[bidx0*ncols + jc]; max_val = tmp.x; rowsum = tmp.y; } // Iterate over previous blocks and compute the combined results. // All CUDA blocks that get here must have a previous block that needs a fixup. int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; continue; } const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc]; // Scale the current and new value accumulators depending on the max. values. const float max_val_new = fmaxf(max_val, tmp.x); const float diff_val = max_val - max_val_new; const float diff_add = tmp.x - max_val_new; const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; dst_val = scale_val*dst_val + scale_add*dst_add; rowsum = scale_val*rowsum + scale_add*tmp.y; max_val = max_val_new; // If this block started in a previous tile we are done and don't need to combine additional partial results. if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { break; } bidx--; kbc_stop = kbc; } // Write back final result: *dst = dst_val / rowsum; } template // D == head size #if !defined(GGML_USE_HIP) __launch_bounds__(D, 1) #endif // !(defined(GGML_USE_HIP) static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const float2 * __restrict__ VKQ_meta, float * __restrict__ dst, const int parallel_blocks) { // Dimension 0: threadIdx.x // Dimension 1: blockIdx.x // Dimension 2: blockIdx.y // Dimension 3: blockIdx.z // Memory layout is permuted with [0, 2, 1, 3] const int ne01 = gridDim.x; const int ne02 = gridDim.y; const int col = blockIdx.x; const int head = blockIdx.y; const int sequence = blockIdx.z; const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head; VKQ_parts += j_dst_unrolled * parallel_blocks*D; VKQ_meta += j_dst_unrolled * parallel_blocks; dst += j_dst_unrolled * D; const int tid = threadIdx.x; __builtin_assume(tid < D); extern __shared__ float2 meta[]; for (int i = tid; i < 2*parallel_blocks; i += D) { ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; } __syncthreads(); float kqmax = meta[0].x; for (int l = 1; l < parallel_blocks; ++l) { kqmax = max(kqmax, meta[l].x); } float VKQ_numerator = 0.0f; float VKQ_denominator = 0.0f; for (int l = 0; l < parallel_blocks; ++l) { const float diff = meta[l].x - kqmax; float KQ_max_scale = expf(diff); const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); *((uint32_t *) &KQ_max_scale) &= ftz_mask; VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; } dst[tid] = VKQ_numerator / VKQ_denominator; } [[noreturn]] static void on_no_fattn_vec_case(const int D) { if (D == 64) { fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); fprintf(stderr, "By default only f16 KV cache is supported.\n"); fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); GGML_ABORT("fatal error"); } else if (D == 128) { fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); fprintf(stderr, "Supported combinations:\n"); fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n"); fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n"); fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n"); fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); GGML_ABORT("fatal error"); } else { fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D); fprintf(stderr, "Only f16 is supported.\n"); GGML_ABORT("fatal error"); } } template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE ) { constexpr int ncols = ncols1 * ncols2; const bool is_mla = DV == 512; // TODO better parameterization const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; GGML_ASSERT(V || is_mla); const ggml_tensor * mask = dst->src[3]; ggml_tensor * KQV = dst; GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32); GGML_ASSERT( Q->nb[0] == ggml_element_size(Q)); GGML_ASSERT( K->nb[0] == ggml_element_size(K)); GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; ggml_cuda_pool_alloc K_f16(pool); ggml_cuda_pool_alloc V_f16(pool); ggml_cuda_pool_alloc dst_tmp(pool); ggml_cuda_pool_alloc dst_tmp_meta(pool); const char * K_data = (const char *) K->data; size_t nb11 = K->nb[1]; size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; const char * V_data = V ? (const char *) V->data : nullptr; size_t nb21 = V ? V->nb[1] : nb11; size_t nb22 = V ? V->nb[2] : nb12; size_t nb23 = V ? V->nb[3] : nb13; if (need_f16_K && K->type != GGML_TYPE_F16) { const size_t bs = ggml_blck_size(K->type); const size_t ts = ggml_type_size(K->type); K_f16.alloc(ggml_nelements(K)); if (ggml_is_contiguously_allocated(K)) { to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); nb11 = nb11*bs*sizeof(half)/ts; nb12 = nb12*bs*sizeof(half)/ts; nb13 = nb13*bs*sizeof(half)/ts; } else { GGML_ASSERT(K->nb[0] == ts); to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type); const int64_t s01 = nb11 / ts; const int64_t s02 = nb12 / ts; const int64_t s03 = nb13 / ts; to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); nb11 = K->ne[0] * sizeof(half); nb12 = K->ne[1] * nb11; nb13 = K->ne[2] * nb12; } K_data = (char *) K_f16.ptr; } if (V && need_f16_V && V->type != GGML_TYPE_F16) { const size_t bs = ggml_blck_size(V->type); const size_t ts = ggml_type_size(V->type); V_f16.alloc(ggml_nelements(V)); if (ggml_is_contiguously_allocated(V)) { to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); V_data = (char *) V_f16.ptr; nb21 = nb21*bs*sizeof(half)/ts; nb22 = nb22*bs*sizeof(half)/ts; nb23 = nb23*bs*sizeof(half)/ts; } else { GGML_ASSERT(V->nb[0] == ts); to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); const int64_t s01 = nb21 / ts; const int64_t s02 = nb22 / ts; const int64_t s03 = nb23 / ts; to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); nb21 = V->ne[0] * sizeof(half); nb22 = V->ne[1] * nb21; nb23 = V->ne[2] * nb22; } V_data = (char *) V_f16.ptr; } int parallel_blocks = 1; const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; const dim3 block_dim(warp_size, nwarps, 1); int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); dim3 blocks_num; if (stream_k) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. const int max_blocks = max_blocks_per_sm*nsm; const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); const int nblocks_stream_k = max_blocks; const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75; blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; blocks_num.y = 1; blocks_num.z = 1; dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); // parallel_blocks must not be larger than what the tensor size allows: parallel_blocks = std::min(parallel_blocks, ntiles_KQ); // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. // Test whether parallel_blocks can be set to a higher value for better efficiency. const int blocks_per_wave = nsm * max_blocks_per_sm; int nwaves_best = 0; int efficiency_percent_best = 0; for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { const int nblocks_total = ntiles_total * parallel_blocks_test; const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. if (efficiency_percent_best >= 90 && nwaves > nwaves_best) { break; } if (efficiency_percent > efficiency_percent_best) { nwaves_best = nwaves; efficiency_percent_best = efficiency_percent; parallel_blocks = parallel_blocks_test; } } blocks_num.x = ntiles_x; blocks_num.y = parallel_blocks; blocks_num.z = Q->ne[2]*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); } } float scale = 1.0f; float max_bias = 0.0f; float logit_softcap = 0.0f; memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); if (logit_softcap != 0.0f) { scale /= logit_softcap; } const uint32_t n_head = Q->ne[2]; const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); GGML_ASSERT(block_dim.x % warp_size == 0); fattn_kernel<<>>( (const char *) Q->data, K_data, V_data, mask ? ((const char *) mask->data) : nullptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0 ); CUDA_CHECK(cudaGetLastError()); if (stream_k) { if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; flash_attn_stream_k_fixup <<>> ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); flash_attn_combine_results <<>> (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); } CUDA_CHECK(cudaGetLastError()); }