CUDA: add softmax broadcast (#14475)

* CUDA: add softmax broadcast

* Pass by const ref

* Review: Use blockDims for indexing, remove designated initializers

* Add TODO for noncontigous input/output
This commit is contained in:
Aman Gupta
2025-07-02 20:34:24 +08:00
committed by Georgi Gerganov
parent 12a81af45f
commit 55a1c5a5fd
2 changed files with 92 additions and 35 deletions

View File

@@ -3329,13 +3329,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_DIAG_MASK_INF:
return true;
case GGML_OP_SOFT_MAX:
// TODO: support batching
if (op->src[0]->ne[3] != 1) {
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
return true;
case GGML_OP_SOFT_MAX_BACK: {
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));