CUDA: set_rows + cpy.cu refactor (#14712)

This commit is contained in:
Aman Gupta
2025-07-18 14:54:18 +08:00
committed by GitHub
parent 8f974bc1e9
commit f9a31eea06
4 changed files with 396 additions and 244 deletions

View File

@ -0,0 +1,251 @@
#pragma once
#include "ggml-common.h"
static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) {
*dst = *src;
}
static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) {
*dst = __float2half(*src);
}
static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) {
*dst = *src;
}
static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) {
*dst = *src;
}
static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) {
*dst = *src;
}
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_0; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK4_0/2 + j]*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
y->qs[j] = xi0;
y->qs[j] |= xi1 << 4;
}
}
static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
float vmin = FLT_MAX;
float vmax = -FLT_MAX;
for (int j = 0; j < QK4_1; ++j) {
const float v = x[j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
y->dm.x = d;
y->dm.y = vmin;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (x[0 + j] - vmin)*id;
const float x1 = (x[QK4_1/2 + j] - vmin)*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
y->qs[j] = xi0;
y->qs[j] |= xi1 << 4;
}
}
static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK5_0; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK5_0/2 + j]*id;
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
memcpy(y->qh, &qh, sizeof(qh));
}
static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
float min = x[0];
float max = x[0];
for (int j = 1; j < QK5_1; ++j) {
const float v = x[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
y->dm.x = d;
y->dm.y = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (x[0 + j] - min)*id;
const float x1 = (x[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
memcpy(y->qh, &qh, sizeof(qh));
}
static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = x[j];
amax = fmaxf(amax, fabsf(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = x[j]*id;
y->qs[j] = roundf(x0);
}
}
static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_NL; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
y->qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = x[0 + j]*x[0 + j];
const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
y->d = sumq2 > 0 ? sumqx/sumq2 : d;
}
// Wrapper functions for cpy.cu compatibility
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
}
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
}
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
}
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
}
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
}
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
}
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
convert_f32_f32((const float *)cxi, (float *)cdsti);
}
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
convert_f32_f16((const float *)cxi, (half *)cdsti);
}
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti);
}
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
convert_f16_f16((const half *)cxi, (half *)cdsti);
}
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
convert_f16_f32((const half *)cxi, (float *)cdsti);
}

View File

@ -1,46 +1,12 @@
#include "cpy.cuh"
#include "dequantize.cuh"
#include "cpy-utils.cuh"
#ifdef GGML_USE_MUSA
#include "ggml-musa/mudnn.cuh"
#endif // GGML_USE_MUSA
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
float * dsti = (float *) cdsti;
*dsti = *xi;
}
static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti;
*dsti = *xi;
}
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
half * dsti = (half *) cdsti;
*dsti = __float2half(*xi);
}
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
const half * xi = (const half *) cxi;
half * dsti = (half *) cdsti;
*dsti = *xi;
}
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
const half * xi = (const half *) cxi;
float * dsti = (float *) cdsti;
*dsti = *xi;
}
template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@ -71,29 +37,6 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
cpy_1(cx + x_offset, cdst + dst_offset);
}
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q8_0 * dsti = (block_q8_0 *) cdsti;
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = xi[j];
amax = fmaxf(amax, fabsf(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = xi[j]*id;
dsti->qs[j] = roundf(x0);
}
}
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
float * cdstf = (float *)(cdsti);
@ -106,139 +49,6 @@ static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
}
}
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_0 * dsti = (block_q4_0 *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_0; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_0/2 + j]*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
dsti->qs[j] = xi0;
dsti->qs[j] |= xi1 << 4;
}
}
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_1 * dsti = (block_q4_1 *) cdsti;
float vmin = FLT_MAX;
float vmax = -FLT_MAX;
for (int j = 0; j < QK4_1; ++j) {
const float v = xi[j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
dsti->dm.x = d;
dsti->dm.y = vmin;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (xi[0 + j] - vmin)*id;
const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
dsti->qs[j] = xi0;
dsti->qs[j] |= xi1 << 4;
}
}
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q5_0 * dsti = (block_q5_0 *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK5_0; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK5_0/2 + j]*id;
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
memcpy(dsti->qh, &qh, sizeof(qh));
}
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q5_1 * dsti = (block_q5_1 *) cdsti;
float min = xi[0];
float max = xi[0];
for (int j = 1; j < QK5_1; ++j) {
const float v = xi[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
dsti->dm.x = d;
dsti->dm.y = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (xi[0 + j] - min)*id;
const float x1 = (xi[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
memcpy(dsti->qh, &qh, sizeof(qh));
}
template<dequantize_kernel_t dequant, int qk>
static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
float * cdstf = (float *)(cdsti);
@ -252,53 +62,6 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
}
}
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_NL; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
dsti->qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = xi[0 + j]*xi[0 + j];
const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
}
template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,

View File

@ -3226,8 +3226,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
} break;
case GGML_OP_SET_ROWS:
{
#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)")
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) &&
return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
op->src[0]->type == GGML_TYPE_F32 &&
op->src[1]->type == GGML_TYPE_I64;
} break;

View File

@ -1,4 +1,5 @@
#include "set-rows.cuh"
#include "cpy-utils.cuh"
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
@ -10,17 +11,93 @@ __device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
template<>
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
*dst_h = __float2half(*src_f);
convert_f32_f16(src_f, dst_h);
}
template<>
__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
*dst_b = *src_f;
convert_f32_bf16(src_f, dst_b);
}
template<>
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
*dst_f = *src_f;
convert_f32_f32(src_f, dst_f);
}
// Generic quantized set_rows kernel template
template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
static __global__ void k_set_rows_quant(
const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t s10, const int64_t s11, const int64_t s12,
const int64_t s1, const int64_t s2, const int64_t s3) {
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
if (i >= ne_total) {
return;
}
const int64_t i_base = i * qk;
const int64_t i03 = i_base / (ne00 * ne01 * ne02);
const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
const int64_t i12 = i03 % ne12;
const int64_t i11 = i02 % ne11;
const int64_t i10 = i01;
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type);
const float * src_block = src0_row + i00;
block_type * dst_block = dst_row_ptr + i00 / qk;
quantize_func(src_block, dst_block);
}
// Template dispatch function for quantized set_rows
template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
static void set_rows_cuda_quant(
const float * src0_d, const int64_t * src1_d, block_type * dst_d,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const size_t nb01, const size_t nb02, const size_t nb03,
const size_t nb10, const size_t nb11, const size_t nb12,
const size_t nb1, const size_t nb2, const size_t nb3,
cudaStream_t stream) {
GGML_ASSERT(ne00 % qk == 0);
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
const dim3 grid_size(num_blocks);
const int64_t s01 = nb01/sizeof(float);
const int64_t s02 = nb02/sizeof(float);
const int64_t s03 = nb03/sizeof(float);
const int64_t s10 = nb10/sizeof(int64_t);
const int64_t s11 = nb11/sizeof(int64_t);
const int64_t s12 = nb12/sizeof(int64_t);
const int64_t s1 = nb1;
const int64_t s2 = nb2;
const int64_t s3 = nb3;
if (ne_total > 0) {
k_set_rows_quant<block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
src0_d, src1_d, dst_d,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
s01, s02, s03,
s10, s11, s12,
s1, s2, s3);
}
}
template<typename src_t, typename dst_t>
@ -145,7 +222,67 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q4_0) {
set_rows_cuda_quant<block_q4_0, QK4_0, quantize_f32_q4_0_block>(
src0_d, src1_d, (block_q4_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q4_1) {
set_rows_cuda_quant<block_q4_1, QK4_1, quantize_f32_q4_1_block>(
src0_d, src1_d, (block_q4_1*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q5_0) {
set_rows_cuda_quant<block_q5_0, QK5_0, quantize_f32_q5_0_block>(
src0_d, src1_d, (block_q5_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q5_1) {
set_rows_cuda_quant<block_q5_1, QK5_1, quantize_f32_q5_1_block>(
src0_d, src1_d, (block_q5_1*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_Q8_0) {
set_rows_cuda_quant<block_q8_0, QK8_0, quantize_f32_q8_0_block>(
src0_d, src1_d, (block_q8_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else if (dst->type == GGML_TYPE_IQ4_NL) {
set_rows_cuda_quant<block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
src0_d, src1_d, (block_iq4_nl*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
stream
);
} else {
GGML_ABORT("unsupported type");
GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
}
}