mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 04:35:05 +00:00
vulkan: scalar flash attention implementation (#13324)
* vulkan: scalar flash attention implementation * vulkan: always use fp32 for scalar flash attention * vulkan: use vector loads in scalar flash attention shader * vulkan: remove PV matrix, helps with register usage * vulkan: reduce register usage in scalar FA, but perf may be slightly worse * vulkan: load each Q value once. optimize O reduction. more tuning * vulkan: support q4_0/q8_0 KV in scalar FA * CI: increase timeout to accommodate newly-supported tests * vulkan: for scalar FA, select between 1 and 8 rows * vulkan: avoid using Float16 capability in scalar FA
This commit is contained in:
@ -275,6 +275,7 @@ struct vk_device_struct {
|
||||
bool prefer_host_memory;
|
||||
bool float_controls_rte_fp16;
|
||||
bool subgroup_add;
|
||||
bool subgroup_shuffle;
|
||||
|
||||
bool integer_dot_product;
|
||||
|
||||
@ -402,12 +403,20 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
||||
|
||||
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
||||
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
||||
|
||||
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
||||
|
||||
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
||||
@ -1581,13 +1590,29 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
||||
|
||||
// number of rows/cols for flash attention shader
|
||||
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
||||
static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
||||
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
|
||||
|
||||
static uint32_t get_fa_num_small_rows(bool scalar) {
|
||||
return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
|
||||
}
|
||||
|
||||
static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
||||
GGML_UNUSED(clamp);
|
||||
|
||||
if (scalar) {
|
||||
if (small_rows) {
|
||||
return {scalar_flash_attention_num_small_rows, 64};
|
||||
} else {
|
||||
return {scalar_flash_attention_num_large_rows, 32};
|
||||
}
|
||||
}
|
||||
|
||||
// small rows, large cols
|
||||
if (small_rows) {
|
||||
return {flash_attention_num_small_rows, 64};
|
||||
return {get_fa_num_small_rows(scalar), 32};
|
||||
}
|
||||
|
||||
// small cols to reduce register count
|
||||
if (ggml_is_quantized(type) || D == 256) {
|
||||
return {64, 32};
|
||||
@ -1882,65 +1907,66 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
||||
};
|
||||
|
||||
auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1};
|
||||
};
|
||||
|
||||
auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
// can't use 256 for D==80.
|
||||
// For scalar, use 128 (arbitrary)
|
||||
uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128);
|
||||
auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows);
|
||||
|
||||
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
||||
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
||||
const uint32_t D_lsb = D ^ (D & (D-1));
|
||||
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
|
||||
|
||||
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
||||
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
||||
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
||||
};
|
||||
|
||||
#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \
|
||||
CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256)
|
||||
|
||||
CREATE_FA(GGML_TYPE_F16, f16, true, )
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, )
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, )
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
|
||||
auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
||||
return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
|
||||
};
|
||||
|
||||
auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
||||
// For large number of rows, 128 invocations seems to work best.
|
||||
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
||||
// can't use 256 for D==80.
|
||||
uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
|
||||
auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
|
||||
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
||||
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
||||
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
|
||||
};
|
||||
|
||||
#define CREATE_FA2(TYPE, NAMELC, D) \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
|
||||
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
|
||||
|
||||
#define CREATE_FA(TYPE, NAMELC) \
|
||||
CREATE_FA2(TYPE, NAMELC, 64) \
|
||||
CREATE_FA2(TYPE, NAMELC, 80) \
|
||||
CREATE_FA2(TYPE, NAMELC, 96) \
|
||||
CREATE_FA2(TYPE, NAMELC, 112) \
|
||||
CREATE_FA2(TYPE, NAMELC, 128) \
|
||||
CREATE_FA2(TYPE, NAMELC, 256)
|
||||
|
||||
CREATE_FA(GGML_TYPE_F16, f16)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0)
|
||||
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
||||
//CREATE_FA(GGML_TYPE_Q2_K, q2_k)
|
||||
//CREATE_FA(GGML_TYPE_Q3_K, q3_k)
|
||||
//CREATE_FA(GGML_TYPE_Q4_K, q4_k)
|
||||
//CREATE_FA(GGML_TYPE_Q5_K, q5_k)
|
||||
//CREATE_FA(GGML_TYPE_Q6_K, q6_k)
|
||||
//CREATE_FA(GGML_TYPE_IQ1_S, iq1_s)
|
||||
//CREATE_FA(GGML_TYPE_IQ1_M, iq1_m)
|
||||
//CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs)
|
||||
//CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs)
|
||||
//CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
|
||||
//CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs)
|
||||
//CREATE_FA(GGML_TYPE_IQ3_S, iq3_s)
|
||||
//CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
|
||||
CREATE_FA(GGML_TYPE_F16, f16, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2)
|
||||
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2)
|
||||
}
|
||||
#endif
|
||||
#undef CREATE_FA2
|
||||
#undef CREATE_FA
|
||||
|
||||
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (device->coopmat2) {
|
||||
|
||||
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
||||
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
||||
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
||||
@ -2837,6 +2863,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
||||
device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
||||
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
|
||||
|
||||
device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
||||
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
|
||||
|
||||
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
|
||||
|
||||
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
||||
@ -5709,20 +5738,57 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
assert(q->type == GGML_TYPE_F32);
|
||||
assert(k->type == v->type);
|
||||
|
||||
bool scalar = !ctx->device->coopmat2;
|
||||
|
||||
uint32_t gqa_ratio = 1;
|
||||
uint32_t qk_ratio = neq2 / nek2;
|
||||
uint32_t workgroups_x = (uint32_t)neq1;
|
||||
uint32_t workgroups_y = (uint32_t)neq2;
|
||||
uint32_t workgroups_z = (uint32_t)neq3;
|
||||
|
||||
// For scalar FA, we can use the "large" size to accommodate qga.
|
||||
// For coopmat FA, we always use the small size (which is still pretty large for gqa).
|
||||
const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false);
|
||||
|
||||
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
||||
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
||||
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
||||
// and change addressing calculations to index Q's dimension 2.
|
||||
gqa_ratio = qk_ratio;
|
||||
N = gqa_ratio;
|
||||
workgroups_y /= N;
|
||||
}
|
||||
|
||||
vk_pipeline *pipelines;
|
||||
// XXX TODO other backends may be changing accumulator precision to default to f32 soon
|
||||
bool f32acc = dst->op_params[3] == GGML_PREC_F32;
|
||||
bool small_rows = N <= flash_attention_num_small_rows;
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
assert(!"unsupported D value");
|
||||
return;
|
||||
bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
|
||||
bool small_rows = N <= get_fa_num_small_rows(scalar);
|
||||
|
||||
if (scalar) {
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
switch (D) {
|
||||
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
|
||||
default:
|
||||
GGML_ASSERT(!"unsupported D value");
|
||||
return;
|
||||
}
|
||||
}
|
||||
assert(pipelines);
|
||||
|
||||
@ -5740,27 +5806,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
||||
vk_pipeline pipeline = pipelines[aligned];
|
||||
assert(pipeline);
|
||||
|
||||
uint32_t gqa_ratio = 1;
|
||||
uint32_t qk_ratio = neq2 / nek2;
|
||||
uint32_t workgroups_x = (uint32_t)neq1;
|
||||
uint32_t workgroups_y = (uint32_t)neq2;
|
||||
uint32_t workgroups_z = (uint32_t)neq3;
|
||||
|
||||
if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows &&
|
||||
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
||||
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
||||
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
||||
// and change addressing calculations to index Q's dimension 2.
|
||||
gqa_ratio = qk_ratio;
|
||||
N = gqa_ratio;
|
||||
workgroups_y /= N;
|
||||
}
|
||||
|
||||
uint32_t split_kv = KV;
|
||||
uint32_t split_k = 1;
|
||||
|
||||
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
|
||||
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
||||
|
||||
// Try to use split_k when KV is large enough to be worth the overhead
|
||||
if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) {
|
||||
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
||||
// Try to run two workgroups per SM.
|
||||
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
|
||||
if (split_k > 1) {
|
||||
@ -9530,9 +9583,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
{
|
||||
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
||||
if (!ggml_vk_get_device(ctx->device)->coopmat2) {
|
||||
return false;
|
||||
}
|
||||
auto device = ggml_vk_get_device(ctx->device);
|
||||
bool coopmat2 = device->coopmat2;
|
||||
switch (op->src[0]->ne[0]) {
|
||||
case 64:
|
||||
case 80:
|
||||
@ -9540,7 +9592,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case 112:
|
||||
case 128:
|
||||
case 256:
|
||||
case 575: // DeepSeek MLA
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
@ -9566,10 +9617,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
switch (op->src[1]->type) {
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q8_0:
|
||||
// supported in scalar and coopmat2 paths
|
||||
break;
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_Q8_0:
|
||||
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
||||
//case GGML_TYPE_Q2_K:
|
||||
//case GGML_TYPE_Q3_K:
|
||||
@ -9585,10 +9638,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
//case GGML_TYPE_IQ3_S:
|
||||
//case GGML_TYPE_IQ4_XS:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
// currently supported only in coopmat2 path
|
||||
if (!coopmat2) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
if (!coopmat2 && !device->subgroup_shuffle) {
|
||||
// scalar FA uses subgroupShuffle
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
case GGML_OP_GET_ROWS:
|
||||
|
Reference in New Issue
Block a user