mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-06 01:05:03 -04:00
cuda : implement bf16 cpy ops and enable bf16 cont (#14763)
* implement bf16 cpy ops and enable bf16 cont * deduplicate copy functions * deduplicate checks
This commit is contained in:
@@ -2,24 +2,13 @@
|
||||
|
||||
#include "ggml-common.h"
|
||||
|
||||
static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) {
|
||||
*dst = __float2half(*src);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) {
|
||||
*dst = *src;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) {
|
||||
*dst = *src;
|
||||
template<typename src_t, typename dst_t>
|
||||
static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
|
||||
if constexpr (std::is_same_v<src_t, dst_t>) {
|
||||
*dst = *src;
|
||||
} else {
|
||||
*dst = float(*src);
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
|
||||
@@ -230,22 +219,7 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
|
||||
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
|
||||
convert_f32_f32((const float *)cxi, (float *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
||||
convert_f32_f16((const float *)cxi, (half *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
|
||||
convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
|
||||
convert_f16_f16((const half *)cxi, (half *)cdsti);
|
||||
}
|
||||
|
||||
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
||||
convert_f16_f32((const half *)cxi, (float *)cdsti);
|
||||
template<typename src_t, typename dst_t>
|
||||
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
|
||||
convert_flt((const src_t *)cxi, (dst_t *)cdsti);
|
||||
}
|
||||
|
Reference in New Issue
Block a user