CUDA: add conv_2d_transpose (#14287)

* CUDA: add conv_2d_transpose

* remove direct include of cuda_fp16

* Review: add brackets for readability, remove ggml_set_param and add asserts
This commit is contained in:
Aman Gupta
2025-06-20 22:48:24 +08:00
committed by GitHub
parent 22015b2092
commit c959f462a0
4 changed files with 134 additions and 0 deletions

View File

@@ -12,6 +12,7 @@
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/conv2d-transpose.cuh"
#include "ggml-cuda/convert.cuh"
#include "ggml-cuda/count-equal.cuh"
#include "ggml-cuda/cpy.cuh"
@@ -2341,6 +2342,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CONV_2D_DW:
ggml_cuda_op_conv2d_dw(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_2D:
ggml_cuda_conv_2d_transpose_p0(ctx, dst);
break;
case GGML_OP_CONV_TRANSPOSE_1D:
ggml_cuda_op_conv_transpose_1d(ctx,dst);
break;
@@ -3252,6 +3256,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
}
case GGML_OP_IM2COL:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS: