cuda : implement bf16 cpy ops and enable bf16 cont (#14763)

* implement bf16 cpy ops and enable bf16 cont

* deduplicate copy functions

* deduplicate checks
This commit is contained in:
Sigbjørn Skjæret
2025-07-22 12:33:10 +02:00
committed by GitHub
parent 8e6f8bc875
commit e28c0b80c2
4 changed files with 49 additions and 124 deletions

View File

@ -2,24 +2,13 @@
#include "ggml-common.h"
static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) {
*dst = *src;
}
static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) {
*dst = __float2half(*src);
}
static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) {
*dst = *src;
}
static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) {
*dst = *src;
}
static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) {
*dst = *src;
template<typename src_t, typename dst_t>
static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
if constexpr (std::is_same_v<src_t, dst_t>) {
*dst = *src;
} else {
*dst = float(*src);
}
}
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
@ -230,22 +219,7 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
}
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
convert_f32_f32((const float *)cxi, (float *)cdsti);
}
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
convert_f32_f16((const float *)cxi, (half *)cdsti);
}
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti);
}
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
convert_f16_f16((const half *)cxi, (half *)cdsti);
}
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
convert_f16_f32((const half *)cxi, (float *)cdsti);
template<typename src_t, typename dst_t>
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
convert_flt((const src_t *)cxi, (dst_t *)cdsti);
}

View File

@ -8,10 +8,10 @@
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
@ -139,43 +139,14 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des
#endif
}
static void ggml_cpy_f16_f32_cuda(
template<typename src_t, typename dst_t>
static void ggml_cpy_flt_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
}
static void ggml_cpy_f32_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
}
static void ggml_cpy_f32_bf16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
}
static void ggml_cpy_f32_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
}
@ -307,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
}
static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
}
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
@ -372,11 +333,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
}
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@ -403,9 +364,17 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@ -430,11 +399,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
return nullptr;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
return (void*) cpy_flt<cpy_1_flt<float, float>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
return (void*) cpy_f32_f16<cpy_1_f32_bf16>;
return (void*) cpy_flt<cpy_1_flt<float, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
return (void*) cpy_flt<cpy_1_flt<float, half>>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@ -458,9 +427,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
return (void*) cpy_flt<cpy_1_flt<half, half>>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
return (void*) cpy_flt<cpy_1_flt<half, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
return (void*) cpy_flt<cpy_1_flt<half, float>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, half>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, nv_bfloat16>>;
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_flt<cpy_1_flt<nv_bfloat16, float>>;
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));

View File

@ -3242,13 +3242,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
{
ggml_type src0_type = op->src[0]->type;
ggml_type src1_type = op->src[1]->type;
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) &&
(src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16)
) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
@ -3284,12 +3280,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
return true;
}
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
return true;
}
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
return true;
}
@ -3370,7 +3360,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->ne[1] % 128 == 0;
}
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
return true;
case GGML_OP_DIAG_MASK_INF:
return true;
case GGML_OP_SOFT_MAX:

View File

@ -4,24 +4,8 @@
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
template<typename src_t, typename dst_t>
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
GGML_UNUSED(src_f);
GGML_UNUSED(dst_f);
}
template<>
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
convert_f32_f16(src_f, dst_h);
}
template<>
__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
convert_f32_bf16(src_f, dst_b);
}
template<>
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
convert_f32_f32(src_f, dst_f);
__device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
convert_flt(src_f, dst_f);
}
// Generic quantized set_rows kernel template