mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-13 20:07:41 -04:00
cuda: refactored ssm_scan and use CUB (#13291)
* cuda: refactored ssm_scan to use CUB * fixed compilation error when when not using CUB * assign L to constant and use size_t instead of int * deduplicated functions * change min blocks per mp to 1 * Use cub load and store warp transpose * suppress clang warning
This commit is contained in:
@@ -1,87 +1,117 @@
|
||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
||||
#define USE_CUB
|
||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
|
||||
|
||||
#ifdef USE_CUB
|
||||
#include <cub/cub.cuh>
|
||||
using namespace cub;
|
||||
#endif // USE_CUB
|
||||
|
||||
#include "ssm-scan.cuh"
|
||||
|
||||
template <size_t splitD, size_t N>
|
||||
__global__ void __launch_bounds__(splitD, 2)
|
||||
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
|
||||
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
|
||||
// We would like to keep pragma unroll for cases where L_template is not 0,
|
||||
// so we suppress the clang transformation warning.
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif // __clang__
|
||||
template <size_t splitD, size_t N, size_t L_template>
|
||||
__global__ void __launch_bounds__(splitD, 1)
|
||||
ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
|
||||
const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
|
||||
const int32_t * __restrict__ src6, float * __restrict__ dst,
|
||||
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
|
||||
const int src2_nb1, const int src2_nb2, const int src3_nb1,
|
||||
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
|
||||
const int64_t s_off, const int64_t d_inner, const int64_t L) {
|
||||
const int64_t s_off, const int64_t d_inner, const int64_t L_param)
|
||||
{
|
||||
const size_t L = L_template == 0 ? L_param : L_template;
|
||||
const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
|
||||
const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
|
||||
const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
|
||||
const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
|
||||
const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
|
||||
const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
|
||||
float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
|
||||
float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);
|
||||
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int bidx = blockIdx.x; // split along B (sequences)
|
||||
const int bidy = blockIdx.y; // split along D (d_inner)
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int wtid = tid % 32;
|
||||
|
||||
extern __shared__ float smem[];
|
||||
const int stride_sA = N + 1;
|
||||
const int stride_ss0 = N + 1;
|
||||
float * smem_A = smem;
|
||||
float * smem_s0 = smem_A + splitD * stride_sA;
|
||||
|
||||
const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
|
||||
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
|
||||
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
|
||||
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
|
||||
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
|
||||
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
|
||||
float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
|
||||
float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
|
||||
|
||||
const int stride_s0 = src0_nb2 / sizeof(float);
|
||||
const int stride_x = src1_nb2 / sizeof(float);
|
||||
const int stride_x = src1_nb2 / sizeof(float);
|
||||
const int stride_dt = src2_nb1 / sizeof(float);
|
||||
const int stride_A = src3_nb1 / sizeof(float);
|
||||
const int stride_B = src4_nb2 / sizeof(float);
|
||||
const int stride_C = src5_nb2 / sizeof(float);
|
||||
const int stride_s = stride_s0;
|
||||
const int stride_y = d_inner;
|
||||
const int stride_B = src4_nb2 / sizeof(float);
|
||||
const int stride_C = src5_nb2 / sizeof(float);
|
||||
const int stride_y = d_inner;
|
||||
|
||||
// can N not be 16? for example 32?
|
||||
if (N == 16) {
|
||||
float regA[N];
|
||||
float regs0[N];
|
||||
|
||||
__shared__ float smemB[N];
|
||||
__shared__ float smemC[N];
|
||||
|
||||
#ifdef USE_CUB
|
||||
using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
|
||||
union CubTempStorage {
|
||||
typename BlockLoad::TempStorage load_temp;
|
||||
typename BlockStore::TempStorage store_temp;
|
||||
};
|
||||
__shared__ CubTempStorage cub_temp_storage;
|
||||
|
||||
BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
|
||||
BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
|
||||
#else
|
||||
const int stride_s0 = src0_nb2 / sizeof(float);
|
||||
const int stride_A = src3_nb1 / sizeof(float);
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < splitD / 4; i += 2) {
|
||||
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
|
||||
// todo: bank conflict
|
||||
// I am always confused with how to use the swizzling method to solve
|
||||
// bank conflit. Hoping somebody can tell me.
|
||||
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < splitD / 4; i += 2) {
|
||||
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
|
||||
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
|
||||
}
|
||||
for (size_t n = 0; n < N; ++n)
|
||||
{
|
||||
regA[n] = A_block[threadIdx.x * stride_A + n];
|
||||
regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
|
||||
}
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (int64_t i = 0; i < L; i++) {
|
||||
float dt_soft_plus = dt_block[i * stride_dt + tid];
|
||||
if (dt_soft_plus <= 20.0f) {
|
||||
dt_soft_plus = log1pf(exp(dt_soft_plus));
|
||||
}
|
||||
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
|
||||
float sumf = 0.0f;
|
||||
#pragma unroll
|
||||
for (size_t j = 0; j < N; j++) {
|
||||
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
|
||||
(B_block[i * stride_B + j] * x_dt);
|
||||
sumf += state * C_block[i * stride_C + j];
|
||||
if (i == L - 1) {
|
||||
s_block[tid * stride_s + j] = state;
|
||||
} else {
|
||||
smem_s0[tid * stride_ss0 + j] = state;
|
||||
}
|
||||
for (size_t i = 0; i < L; i++)
|
||||
{
|
||||
if (threadIdx.x < N)
|
||||
{
|
||||
smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
|
||||
smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
|
||||
}
|
||||
__syncthreads();
|
||||
y_block[i * stride_y + tid] = sumf;
|
||||
|
||||
float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
|
||||
if (dt_soft_plus <= 20.0f)
|
||||
{
|
||||
dt_soft_plus = log1pf(expf(dt_soft_plus));
|
||||
}
|
||||
float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
|
||||
|
||||
float sumf = 0.0f;
|
||||
#pragma unroll
|
||||
for (size_t n = 0; n < N; n++)
|
||||
{
|
||||
float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
|
||||
sumf += state * smemC[n];
|
||||
regs0[n] = state;
|
||||
}
|
||||
y_block[i * stride_y + threadIdx.x] = sumf;
|
||||
}
|
||||
|
||||
#ifdef USE_CUB
|
||||
BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
|
||||
#else
|
||||
const int stride_s = stride_s0;
|
||||
#pragma unroll
|
||||
for (size_t n = 0; n < N; ++n)
|
||||
{
|
||||
s_block[threadIdx.x * stride_s + n] = regs0[n];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
||||
// assumes as many threads as d_state
|
||||
template <int splitH, int d_state>
|
||||
@@ -201,11 +231,11 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
|
||||
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
|
||||
cudaStream_t stream) {
|
||||
const int threads = 128;
|
||||
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
|
||||
if (src3_nb1 == sizeof(float)) {
|
||||
// Mamba-2
|
||||
if (d_state == 128) {
|
||||
const int threads = 128;
|
||||
GGML_ASSERT(d_state % threads == 0);
|
||||
// NOTE: can be any power of two between 4 and 64
|
||||
const int splitH = 16;
|
||||
@@ -229,7 +259,6 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||
GGML_ABORT("doesn't support d_state!=(128 or 256).");
|
||||
}
|
||||
} else {
|
||||
const int threads = 128;
|
||||
// Mamba-1
|
||||
GGML_ASSERT(n_head % threads == 0);
|
||||
GGML_ASSERT(head_dim == 1);
|
||||
@@ -237,10 +266,63 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
|
||||
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
|
||||
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
|
||||
if (d_state == 16) {
|
||||
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
switch (n_tok)
|
||||
{
|
||||
case 1:
|
||||
ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 2:
|
||||
ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 3:
|
||||
ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 4:
|
||||
ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 5:
|
||||
ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 6:
|
||||
ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 7:
|
||||
ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
case 8:
|
||||
ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
default:
|
||||
ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
|
||||
src0, src1, src2, src3, src4, src5, src6, dst,
|
||||
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
|
||||
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
GGML_ABORT("doesn't support d_state!=16.");
|
||||
}
|
||||
|
Reference in New Issue
Block a user