mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-07 01:16:35 -04:00
ggml : fix FA mask dim 2 and 3 (#14505)
* ggml : fix FA mask dim 2 and 3 ggml-ci * backends : unsupport batched FA in CUDA and Vulkan ggml-ci * vulkan : disable FA for mask->ne[2] != 1
This commit is contained in:
@@ -3390,7 +3390,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
||||
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
|
||||
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
Reference in New Issue
Block a user