diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c7222207e..8015b0d4e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3226,8 +3226,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } break; case GGML_OP_SET_ROWS: { -#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") - return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && +#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64; } break; diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index d8b3e63e1..3fade72b8 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -10,6 +10,11 @@ __device__ __forceinline__ void set_rows_1(const float * src_f, hal *dst_h = __float2half(*src_f); } +template<> +__device__ __forceinline__ void set_rows_1(const float * src_f, nv_bfloat16 * dst_b) { + *dst_b = *src_f; +} + template<> __device__ __forceinline__ void set_rows_1(const float * src_f, float * dst_f) { *dst_f = *src_f; @@ -124,6 +129,16 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { nb1, nb2, nb3, stream ); + } else if (dst->type == GGML_TYPE_BF16) { + set_rows_cuda( + src0_d, src1_d, (nv_bfloat16*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); } else { GGML_ABORT("unsupported type"); }