mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-09-03 13:48:51 -04:00
CUDA: deduplicate FlashAttention code (#7352)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
#include "common.cuh"
|
||||
#include "softmax.cuh"
|
||||
|
||||
template <typename T>
|
||||
@@ -23,17 +24,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (max_bias > 0.0f) {
|
||||
const int h = rowx/nrows_y; // head index
|
||||
|
||||
const float base = h < n_head_log2 ? m0 : m1;
|
||||
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
||||
|
||||
slope = powf(base, exph);
|
||||
}
|
||||
const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
|
||||
|
||||
extern __shared__ float data_soft_max_f32[];
|
||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||
|
Reference in New Issue
Block a user