mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-18 05:56:00 -04:00
ggml : fix FA mask dim 2 and 3 (#14505)
* ggml : fix FA mask dim 2 and 3 ggml-ci * backends : unsupport batched FA in CUDA and Vulkan ggml-ci * vulkan : disable FA for mask->ne[2] != 1
This commit is contained in:
@@ -3674,7 +3674,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||
if (mask) {
|
||||
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(ggml_is_3d(mask));
|
||||
GGML_ASSERT(mask->ne[0] == a->ne[0]);
|
||||
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
||||
GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
|
||||
@@ -4704,12 +4703,12 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||
|
||||
if (mask) {
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(mask->ne[2] == q->ne[3]);
|
||||
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
||||
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
||||
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
||||
|
||||
GGML_ASSERT(q->ne[3] % mask->ne[2] == 0);
|
||||
GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
|
||||
GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
|
||||
}
|
||||
|
||||
if (max_bias > 0.0f) {
|
||||
|
Reference in New Issue
Block a user