mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-04 08:15:55 -04:00
ggml : support bcast ggml_soft_max_ext, ggml_flash_attn_ext (#14435)
ggml-ci
This commit is contained in:
@@ -3327,8 +3327,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_CONT:
|
||||
return op->src[0]->type != GGML_TYPE_BF16;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
return true;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
// TODO: support batching
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
||||
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
|
||||
case GGML_OP_SOFT_MAX_BACK: {
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
|
||||
@@ -3375,6 +3382,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
if (op->src[0]->ne[0] == 192) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
|
Reference in New Issue
Block a user