ggml-quants : use qkxh in more places

This commit is contained in:
Francis Couture-Harpin
2025-03-21 14:05:58 -04:00
parent 3be115100f
commit f86b8ff210

View File

@ -629,7 +629,6 @@ static float make_qkx2_quants(int n, int nmax, const float * restrict x, const f
} }
struct fraction { struct fraction {
// float frac;
float numer; float numer;
float denom; float denom;
int i; int i;
@ -677,7 +676,8 @@ struct k_heap {
struct k_heap_cell * heap; struct k_heap_cell * heap;
}; };
// build a max heap out of k_cells starting from node i // build a max heap out of k_cells starting from node i;
// makes sure the node i is bigger than its children
static void k_heap_build(struct k_heap * heap, int i) { static void k_heap_build(struct k_heap * heap, int i) {
const int n = heap->n; const int n = heap->n;
int max = i; int max = i;
@ -744,6 +744,7 @@ static void k_heap_init(struct k_heap * restrict k_heap, int k, const int8_t * r
steps[mid_k] = 0.0f; steps[mid_k] = 0.0f;
} }
// for linear types which have a constant step of 1 between representable values
static void k_heap_init_linear(struct k_heap * k_heap, int nmin, int nmax, struct k_heap_cell * restrict heap_cells, float * restrict odd) { static void k_heap_init_linear(struct k_heap * k_heap, int nmin, int nmax, struct k_heap_cell * restrict heap_cells, float * restrict odd) {
GGML_ASSERT(k_heap && heap_cells && odd); GGML_ASSERT(k_heap && heap_cells && odd);
nmin = MIN(0, nmin); nmin = MIN(0, nmin);
@ -1004,6 +1005,7 @@ static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict
return negative_scale ? -scale : scale; return negative_scale ? -scale : scale;
} }
// exhaustive search with cumulative sums
static float make_qkxh_quants(int n, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct k_heap * restrict k_heap, bool signed_scale) { static float make_qkxh_quants(int n, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct k_heap * restrict k_heap, bool signed_scale) {
const int nmin = -k_heap->mid_k; // TODO: maybe directly pass these const int nmin = -k_heap->mid_k; // TODO: maybe directly pass these
const int nmax = k_heap->k + nmin - 1; const int nmax = k_heap->k + nmin - 1;
@ -1027,9 +1029,7 @@ static float make_qkxh_quants(int n, const float * restrict x, const float * res
} }
if (amax < GROUP_MAX_EPS) { // all zero if (amax < GROUP_MAX_EPS) { // all zero
for (int i = 0; i < n; ++i) { memset(L, 0, n);
L[i] = 0;
}
return 0.0f; return 0.0f;
} }
@ -1304,7 +1304,7 @@ static float make_qkxss_quants(int n, int nmin, int nmax, const float * restrict
} }
// non-linear exhaustive search with cumulative sums // non-linear exhaustive search with cumulative sums
static float make_qkxs_nl_quants(int n, const float * restrict x, const float * restrict weights, uint8_t * restrict L, uint8_t * restrict Laux, struct k_heap * restrict k_heap, bool signed_scale, bool fast) { static float make_qkxh_nl_quants(int n, const float * restrict x, const float * restrict weights, uint8_t * restrict L, uint8_t * restrict Laux, struct k_heap * restrict k_heap, bool signed_scale, bool fast) {
float sumlx = 0.0f; float sumlx = 0.0f;
float suml2 = 0.0f; float suml2 = 0.0f;
float amax = -1.0f; float amax = -1.0f;
@ -1687,11 +1687,18 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
uint8_t L[QK_K]; uint8_t L[QK_K];
uint8_t Laux[16]; uint8_t Laux[16];
int8_t Lsaux[16];
float mins[QK_K/16]; float mins[QK_K/16];
float scales[QK_K/16]; float scales[QK_K/16];
float sw[QK_K/16]; float sw[QK_K/16];
float weight[16]; float weight[16];
uint8_t Ls[QK_K/16], Lm[QK_K/16]; int8_t Ls[QK_K/16], Lm[QK_K/16];
struct k_heap_cell heap_cells_s[QK_K/16];
float odd_s[16];
struct k_heap k_heap_s;
k_heap_init_linear(&k_heap_s, 0, 15, heap_cells_s, odd_s);
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
memset(sw, 0, QK_K/16*sizeof(float)); memset(sw, 0, QK_K/16*sizeof(float));
@ -1706,8 +1713,8 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
} }
float dm, mm; float dm, mm;
dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw); dm = make_qkxh_quants(QK_K/16, scales, sw, Ls, Lsaux, &k_heap_s, false);
mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw); mm = make_qkxh_quants(QK_K/16, mins, sw, Lm, Lsaux, &k_heap_s, false);
y[i].d = GGML_FP32_TO_FP16(dm); y[i].d = GGML_FP32_TO_FP16(dm);
y[i].dmin = GGML_FP32_TO_FP16(mm); y[i].dmin = GGML_FP32_TO_FP16(mm);
@ -2607,7 +2614,12 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
float weight[QK4_0]; float weight[QK4_0];
int8_t L[QK4_0]; int8_t L[QK4_0];
int8_t Laux[QK4_0]; int8_t Laux[QK4_0];
struct fraction Faux[8 * QK4_0]; // struct fraction Faux[8 * QK4_0];
struct k_heap_cell heap_cells[QK4_0];
float odd[16];
struct k_heap k_heap;
k_heap_init_linear(&k_heap, -8, 7, heap_cells, odd);
float sum_x2 = 0; float sum_x2 = 0;
for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
@ -2618,7 +2630,8 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
const float * xb = x + QK4_0 * ib; const float * xb = x + QK4_0 * ib;
const float * qw = quant_weights + QK4_0 * ib; const float * qw = quant_weights + QK4_0 * ib;
for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
float d = make_qkxs_quants(QK4_0, -8, 7, xb, weight, L, Laux, Faux, true); // float d = make_qkxs_quants(QK4_0, -8, 7, xb, weight, L, Laux, Faux, true);
float d = make_qkxh_quants(QK4_0, xb, weight, L, Laux, &k_heap, true);
y[ib].d = GGML_FP32_TO_FP16(d); y[ib].d = GGML_FP32_TO_FP16(d);
for (int j = 0; j < 16; ++j) { for (int j = 0; j < 16; ++j) {
y[ib].qs[j] = L[j] | (L[j+16] << 4); y[ib].qs[j] = L[j] | (L[j+16] << 4);
@ -2697,7 +2710,12 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
float weight[QK5_0]; float weight[QK5_0];
int8_t L[QK5_0]; int8_t L[QK5_0];
int8_t Laux[QK5_0]; int8_t Laux[QK5_0];
struct fraction Faux[16 * QK5_0]; // struct fraction Faux[16 * QK5_0];
struct k_heap_cell heap_cells[QK5_0];
float odd[32];
struct k_heap k_heap;
k_heap_init_linear(&k_heap, -16, 15, heap_cells, odd);
float sum_x2 = 0; float sum_x2 = 0;
for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j];
@ -2708,7 +2726,8 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
const float * xb = x + QK5_0 * ib; const float * xb = x + QK5_0 * ib;
const float * qw = quant_weights + QK5_0 * ib; const float * qw = quant_weights + QK5_0 * ib;
for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
float d = make_qkxs_quants(QK5_0, -16, 15, xb, weight, L, Laux, Faux, true); float d = make_qkxh_quants(QK5_0, xb, weight, L, Laux, &k_heap, true);
// float d = make_qkxs_quants(QK5_0, -16, 15, xb, weight, L, Laux, Faux, true);
y[ib].d = GGML_FP32_TO_FP16(d); y[ib].d = GGML_FP32_TO_FP16(d);
uint32_t qh = 0; uint32_t qh = 0;
@ -5505,7 +5524,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
scales[ib] = 0; scales[ib] = 0;
continue; continue;
} }
float d = make_qkxs_nl_quants(block_size, xb, weight, Lb, Laux, k_heap, true, !quant_weights); float d = make_qkxh_nl_quants(block_size, xb, weight, Lb, Laux, k_heap, true, !quant_weights);
scales[ib] = d; scales[ib] = d;
float abs_d = fabsf(d); float abs_d = fabsf(d);
if (abs_d > amax_scale) { if (abs_d > amax_scale) {
@ -5516,19 +5535,13 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
if (super_block_size/block_size > 1) { if (super_block_size/block_size > 1) {
int nb = super_block_size/block_size; int nb = super_block_size/block_size;
memset(scales_h, 0, ((nb+7)/8)*sizeof(uint16_t)); memset(scales_h, 0, ((nb+7)/8)*sizeof(uint16_t));
// TODO: use make_qkxh_quants
float d = -max_scale/32; float d = -max_scale/32;
dh[0] = GGML_FP32_TO_FP16(d); dh[0] = GGML_FP32_TO_FP16(d);
float id = d ? 1/d : 0.f; float id = d ? 1/d : 0.f;
for (int ib = 0; ib < super_block_size/block_size; ++ib) { for (int ib = 0; ib < super_block_size/block_size; ++ib) {
int l = nearest_int(id*scales[ib]); int l = nearest_int(id*scales[ib]);
l = MAX(-32, MIN(31, l)); l = MAX(-32, MIN(31, l));
// float dl = d * l;
// float idl = dl ? 1/dl : 0.f;
// uint8_t * Lb = L + ib*block_size;
// const float * xb = x + ib*block_size;
// for (int j = 0; j < block_size; ++j) {
// Lb[j] = best_index_int8(16, values, idl*xb[j]);
// }
l += 32; l += 32;
uint8_t l_l = l & 0xf; uint8_t l_l = l & 0xf;
uint8_t l_h = l >> 4; uint8_t l_h = l >> 4;
@ -5538,12 +5551,6 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
} }
} else { } else {
dh[0] = GGML_FP32_TO_FP16(scales[0]); dh[0] = GGML_FP32_TO_FP16(scales[0]);
// if (ntry > 0) {
// float id = scales[0] ? 1/scales[0] : 0;
// for (int j = 0; j < super_block_size; ++j) {
// L[j] = best_index_int8(16, values, id*x[j]);
// }
// }
} }
for (int i = 0; i < super_block_size/32; ++i) { for (int i = 0; i < super_block_size/32; ++i) {