mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-20 17:49:18 +00:00
cuda : add set rows for bf16 (#14664)
This commit is contained in:
@ -3226,8 +3226,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_SET_ROWS:
|
case GGML_OP_SET_ROWS:
|
||||||
{
|
{
|
||||||
#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
|
#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
|
||||||
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) &&
|
||||||
op->src[0]->type == GGML_TYPE_F32 &&
|
op->src[0]->type == GGML_TYPE_F32 &&
|
||||||
op->src[1]->type == GGML_TYPE_I64;
|
op->src[1]->type == GGML_TYPE_I64;
|
||||||
} break;
|
} break;
|
||||||
|
@ -10,6 +10,11 @@ __device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, hal
|
|||||||
*dst_h = __float2half(*src_f);
|
*dst_h = __float2half(*src_f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
|
||||||
|
*dst_b = *src_f;
|
||||||
|
}
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
|
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
|
||||||
*dst_f = *src_f;
|
*dst_f = *src_f;
|
||||||
@ -124,6 +129,16 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|||||||
nb1, nb2, nb3,
|
nb1, nb2, nb3,
|
||||||
stream
|
stream
|
||||||
);
|
);
|
||||||
|
} else if (dst->type == GGML_TYPE_BF16) {
|
||||||
|
set_rows_cuda(
|
||||||
|
src0_d, src1_d, (nv_bfloat16*)dst->data,
|
||||||
|
ne00, ne01, ne02, ne03,
|
||||||
|
ne10, ne11, ne12, ne13,
|
||||||
|
nb01, nb02, nb03,
|
||||||
|
nb10, nb11, nb12,
|
||||||
|
nb1, nb2, nb3,
|
||||||
|
stream
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
GGML_ABORT("unsupported type");
|
GGML_ABORT("unsupported type");
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user