ggml-quants : remove slower qsort-based cumulative search

This commit is contained in:
Francis Couture-Harpin
2025-03-22 12:07:28 -04:00
parent 3e4b675c9f
commit af23abd3cb

View File

@ -861,150 +861,6 @@ static struct fraction k_heap_pop(struct k_heap * k_heap) {
}; };
} }
// exhaustive search with cumulative sums
// Need Faux to have room for n*(max(abs(nmin), abs(nmax))) fractions
static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct fraction * restrict Faux, bool signed_scale) {
float max = x[0];
float min = x[0];
float w_amax = weights[0] * fabsf(x[0]);
int max_i = 0;
int w_amax_i = 0;
int min_i = 0;
for (int i = 1; i < n; ++i) {
if (x[i] < min) { min = x[i]; min_i = i; }
if (x[i] > max) { max = x[i]; max_i = i; }
// Find the most important value
const float w = weights[i];
const float wax = w * fabsf(x[i]);
if (wax > w_amax) {
w_amax = wax;
w_amax_i = i;
}
}
const int amax_i = fabsf(min) > fabsf(max) ? min_i : max_i;
const float amax = fabsf(x[amax_i]);
if (amax < GROUP_MAX_EPS) { // all zero
for (int i = 0; i < n; ++i) {
L[i] = 0;
}
return 0.0f;
}
bool negative_scale = false;
if (signed_scale && -nmin != nmax) {
// the max side should have the biggest range
// NOTE: this is not always the best sign
if ((x[amax_i] < 0.0f) == (-nmin < nmax)) {
// [-4, 3] ==> [-3, 4]
const int tmp = nmin;
const float ftmp = min;
nmin = -nmax;
nmax = -tmp;
min = -max;
max = -ftmp;
negative_scale = true;
}
}
// Find the max range in [0, amax_range] which doesn't result in clamping.
// This is the range from the side which would clamp first (biggest ratio of max to nmax).
int amax_range;
float range_max;
if (fabsf(-max * nmin) < fabsf(-min * nmax)) {
amax_range = MAX(0, -nmin);
range_max = fabsf(min);
} else {
amax_range = MAX(0, nmax);
range_max = fabsf(max);
}
float sumlx = 0.0f;
float suml2 = 0.0f;
float scale = 0.0f;
float best = 0.0f;
float best_denom = 1.0f;
if (amax_range > 1) {
// The smallest non-redundant iscale makes the first clamped value half+1 its max integer value.
// Proof: anything smaller has a representable vector with values twice as big.
const float iscale = ((float)((amax_range >> 1) + 1))/range_max * (negative_scale ? -1.0f : 1.0f);
for (int i = 0; i < n; ++i) {
const float w = weights[i];
int l = MAX(nmin, MIN(lroundf(x[i] * iscale), nmax));
if (negative_scale) { l = -l; }
Laux[i] = l;
L[i] = l;
suml2 += w * l * l;
sumlx += w * l * x[i];
}
best = sumlx * sumlx;
best_denom = suml2; // should never be zero
scale = sumlx / suml2;
} else {
for (int i = 0; i < n; ++i) {
Laux[i] = 0;
L[i] = 0;
}
}
const int imax_range = MAX(0, (x[w_amax_i] < 0.0f) ? -nmin : nmax);
const int max_odd = 2*(imax_range + 1) + 1;
const float wmax = fabsf(x[w_amax_i]);
int n_frac = 0;
for (int i = 0; i < n; ++i) {
// assuming nmin <= nmax
const int odd_max = MAX(abs(Laux[i]), x[i] < 0.0f ? -nmin : nmax);
const int odd_min = MAX(abs(Laux[i]), x[i] < 0.0f ? -nmax : nmin);
const float v = fabsf(x[i]);
const float v_max_odd = v * max_odd;
for (int j = odd_min; j < odd_max; ++j) {
const float odd = 2*j + 1;
if (wmax * odd < v_max_odd) {
Faux[n_frac++] = (struct fraction){
.numer=v,
.denom=odd,
.i=i,
};
} else {
// stop when the inverse scale would result in clamping the most important value
break;
}
}
}
qsort(Faux, n_frac, sizeof(struct fraction), compare_fractions_desc);
int best_p_i = -1; // consecutive with 0..n_frac
for (int i = 0; i < n_frac; ++i) {
// maximize the weighted cosine
const int ii = Faux[i].i;
const float w = weights ? weights[ii] : x[ii] * x[ii];
sumlx += w * Faux[i].numer;
suml2 += w * Faux[i].denom;
const float current = sumlx * sumlx;
Laux[ii] += x[ii] < 0.0f ? -1 : 1;
if (suml2 > 0.0f && Faux[i].numer > 0.0f && current * best_denom > best * suml2) {
best = current;
best_denom = suml2;
scale = sumlx / suml2;
if (i == best_p_i + 1) {
// reduce copies for consecutive bests
L[ii] += x[ii] < 0.0f ? -1 : 1;
} else {
for (int j = 0; j < n; ++j) {
L[j] = Laux[j];
}
}
best_p_i = i;
}
}
for (int i = 0; i < n; ++i) {
L[i] = negative_scale ? (-L[i] + nmax) : (L[i] + -nmin);
GGML_ASSERT(L[i] >= 0 && L[i] <= nmax - nmin);
}
return negative_scale ? -scale : scale;
}
// exhaustive search with cumulative sums // exhaustive search with cumulative sums
static float make_qkxh_quants(int n, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct k_heap * restrict k_heap, bool signed_scale) { static float make_qkxh_quants(int n, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct k_heap * restrict k_heap, bool signed_scale) {
const int nmin = MIN(0, -k_heap->mid_k); // TODO: maybe directly pass these const int nmin = MIN(0, -k_heap->mid_k); // TODO: maybe directly pass these
@ -1279,182 +1135,6 @@ static float make_qkxsh_quants(int n, int nmin, int nmax, const float * restrict
return scale; return scale;
} }
// Very similar to make_qkxs_quants, but the sign of the scale is not assumed to be the sign of the absmax value.
static float make_qkxss_quants(int n, int nmin, int nmax, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct fraction * restrict Faux) {
// start at zero
nmin = MIN(0, nmin);
nmax = MAX(0, nmax);
float amax = 0.0f;
float min = 0.0f;
float max = 0.0f;
float w_amax = 0.0f;
int amax_i = -1;
int w_amax_i = -1;
for (int i = 0; i < n; ++i) {
const float w = weights ? weights[i] : x[i] * x[i];
const float ax = fabsf(x[i]);
const float wax = w * ax;
if (ax > amax) { amax = ax; amax_i = i; }
if (x[i] > max) { max = x[i]; }
if (x[i] < min) { min = x[i]; }
// Find the most important value
if (wax > w_amax) { w_amax = wax; w_amax_i = i; }
}
if (amax < GROUP_MAX_EPS || amax_i < 0 || w_amax_i < 0) { // all zero
for (int i = 0; i < n; ++i) { L[i] = 0; }
return 0.0f;
}
// Use the side which will clamp first.
// The first clamped value is the absmax at the end of the common range.
// TODO: reduce the search space when one of the ranges is 0
const int amax_range = MIN(-nmin, nmax);
float sumlx_p = 0.0f;
float suml2_p = 0.0f;
float sumlx_n = 0.0f;
float suml2_n = 0.0f;
float scale = 0.0f;
float best = 0.0f;
float best_denom = 1.0f;
int best_i = -2; // not consecutive with 0..n_frac
// Pre-calculate the half-point for the common range.
// All smaller vectors have a representable vector with twice the values, and thus can be skipped.
if (amax_range > 1) {
const float iscale = ((float)((amax_range >> 1) + 1))/amax;
for (int i = 0; i < n; ++i) {
const float w = weights ? weights[i] : x[i] * x[i];
int l = MAX(nmin, MIN(lroundf(x[i] * iscale), nmax));
Laux[i] = l;
suml2_p += w * l * l;
sumlx_p += w * l * x[i];
}
sumlx_n = -sumlx_p;
suml2_n = suml2_p;
const float current_p = sumlx_p * sumlx_p;
if (suml2_p > 0.0f && current_p * best_denom > best * suml2_p) {
best = current_p;
best_denom = suml2_p;
scale = sumlx_p / suml2_p;
for (int i = 0; i < n; ++i) {
L[i] = Laux[i];
}
best_i = -1; // right before 0 of the loop after sorting
}
} else {
for (int i = 0; i < n; ++i) {
Laux[i] = 0;
}
}
const int imax_range = MAX(nmax, -nmin);
const int max_odd = 2*(imax_range + 1) + 1;
const float wmax = fabsf(x[w_amax_i]);
int n_frac = 0;
for (int i = 0; i < n; ++i) {
// assuming nmin <= nmax
const int odd_max = MAX(nmax, -nmin);
const float v = fabsf(x[i]);
const float v_max_odd = v * max_odd;
for (int j = abs(Laux[i]); j < odd_max; ++j) {
const float odd = 2*j + 1;
const float wmax_odd = wmax * odd;
if (wmax_odd < v_max_odd) {
Faux[n_frac++] = (struct fraction){
.numer=v,
.denom=odd,
.i=i,
};
} else {
// stop when the inverse scale would result in clamping the most important value
break;
}
}
}
qsort(Faux, n_frac, sizeof(struct fraction), compare_fractions_desc);
const float max_common_odd = (MIN(nmax, -nmin) * 2) + 1;
const float max_odd_p = (nmax * 2) + 1;
const float max_odd_n = (-nmin * 2) + 1;
for (int i = 0; i < n_frac; ++i) {
// maximize the weighted cosine similarity
const int ii = Faux[i].i;
const float w = weights ? weights[ii] : x[ii] * x[ii];
const float lx = w * Faux[i].numer;
const float odd = Faux[i].denom;
const float l2 = w * odd;
Laux[ii] += x[ii] < 0.0f ? -1 : 1;
float sumlx = 0.0f;
float proj = 0.0f;
float norm = 0.0f;
if (odd < max_common_odd) {
sumlx_p += lx;
suml2_p += l2;
sumlx_n -= lx;
suml2_n += l2;
sumlx = sumlx_p;
proj = sumlx_p * sumlx_p;
norm = suml2_p;
// avoid double-copying Laux in a single iteration
if (suml2_p != suml2_n && suml2_p * suml2_n > 0.0f) {
const float proj_n = sumlx_n * sumlx_n;
if (proj_n * norm > proj * suml2_n) {
sumlx = sumlx_n;
proj = proj_n;
norm = suml2_n;
}
}
} else if (x[ii] < 0.0f ? odd < max_odd_n : odd < max_odd_p) {
sumlx_p += lx;
suml2_p += l2;
sumlx = sumlx_p;
proj = sumlx_p * sumlx_p;
norm = suml2_p;
} else {
// outside the positive range means we're now into negatives
sumlx_n -= lx;
suml2_n += l2;
sumlx = sumlx_n;
proj = sumlx_n * sumlx_n;
norm = suml2_n;
}
if (norm > 0.0f && proj * best_denom > best * norm) {
best = proj;
best_denom = norm;
scale = sumlx / norm;
if (i == best_i + 1) {
// reduce copies for consecutive bests
L[ii] += x[ii] < 0.0f ? -1 : 1;
} else {
for (int j = 0; j < n; ++j) {
L[j] = Laux[j];
}
}
best_i = i;
}
}
if (scale < 0.0f) {
for (int i = 0; i < n; ++i) {
L[i] = MAX(nmin, MIN(-L[i], nmax)) - nmin;
}
} else {
for (int i = 0; i < n; ++i) {
L[i] = MAX(nmin, MIN(L[i], nmax)) - nmin;
}
}
return scale;
}
// non-linear exhaustive search with cumulative sums // non-linear exhaustive search with cumulative sums
static float make_qkxh_nl_quants(int n, const float * restrict x, const float * restrict weights, uint8_t * restrict L, uint8_t * restrict Laux, struct k_heap * restrict k_heap, bool signed_scale, bool fast) { static float make_qkxh_nl_quants(int n, const float * restrict x, const float * restrict weights, uint8_t * restrict L, uint8_t * restrict Laux, struct k_heap * restrict k_heap, bool signed_scale, bool fast) {
float sumlx = 0.0f; float sumlx = 0.0f;
@ -1924,7 +1604,6 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
int8_t L[QK_K]; int8_t L[QK_K];
int8_t Laux[16]; int8_t Laux[16];
// struct fraction Faux[16 * 4];
struct k_heap_cell heap_cells[16]; struct k_heap_cell heap_cells[16];
float odd[8]; float odd[8];
struct k_heap k_heap; struct k_heap k_heap;
@ -1942,7 +1621,6 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in
float max_scale = 0; float max_scale = 0;
float amax = 0; float amax = 0;
for (int j = 0; j < QK_K/16; ++j) { for (int j = 0; j < QK_K/16; ++j) {
// scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weights, L + 16*j, Laux, Faux, true);
scales[j] = make_qkxh_quants(16, x + 16*j, weights, L + 16*j, Laux, &k_heap, true); scales[j] = make_qkxh_quants(16, x + 16*j, weights, L + 16*j, Laux, &k_heap, true);
float scale = fabsf(scales[j]); float scale = fabsf(scales[j]);
if (scale > amax) { if (scale > amax) {
@ -2052,7 +1730,6 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
float weight[16]; float weight[16];
float sw[QK_K / 16]; float sw[QK_K / 16];
int8_t Ls[QK_K / 16]; int8_t Ls[QK_K / 16];
// struct fraction Faux[16 * 32];
struct k_heap_cell heap_cells[16]; struct k_heap_cell heap_cells[16];
float odd[8]; float odd[8];
struct k_heap k_heap; struct k_heap k_heap;
@ -2080,14 +1757,12 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
for (int l = 0; l < 16; ++l) sumw += weight[l]; for (int l = 0; l < 16; ++l) sumw += weight[l];
sw[j] = sumw; sw[j] = sumw;
// scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weight, L + 16*j, Laux, Faux, true);
scales[j] = make_qkxh_quants(16, x + 16*j, weight, L + 16*j, Laux, &k_heap, true); scales[j] = make_qkxh_quants(16, x + 16*j, weight, L + 16*j, Laux, &k_heap, true);
} }
memset(y[i].scales, 0, 12); memset(y[i].scales, 0, 12);
// float d_block = make_qkxs_quants(QK_K/16, -32, 31, scales, sw, Ls, Laux, Faux, true);
float d_block = make_qkxh_quants(QK_K/16, scales, sw, Ls, Laux, &k_heap_s, true); float d_block = make_qkxh_quants(QK_K/16, scales, sw, Ls, Laux, &k_heap_s, true);
for (int j = 0; j < QK_K/16; ++j) { for (int j = 0; j < QK_K/16; ++j) {
int l = Ls[j]; int l = Ls[j];
@ -2766,7 +2441,6 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
float weight[QK4_0]; float weight[QK4_0];
int8_t L[QK4_0]; int8_t L[QK4_0];
int8_t Laux[QK4_0]; int8_t Laux[QK4_0];
// struct fraction Faux[8 * QK4_0];
struct k_heap_cell heap_cells[QK4_0]; struct k_heap_cell heap_cells[QK4_0];
float odd[16]; float odd[16];
struct k_heap k_heap; struct k_heap k_heap;
@ -2782,7 +2456,6 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
const float * xb = x + QK4_0 * ib; const float * xb = x + QK4_0 * ib;
const float * qw = quant_weights + QK4_0 * ib; const float * qw = quant_weights + QK4_0 * ib;
for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
// float d = make_qkxs_quants(QK4_0, -8, 7, xb, weight, L, Laux, Faux, true);
float d = make_qkxh_quants(QK4_0, xb, weight, L, Laux, &k_heap, true); float d = make_qkxh_quants(QK4_0, xb, weight, L, Laux, &k_heap, true);
y[ib].d = GGML_FP32_TO_FP16(d); y[ib].d = GGML_FP32_TO_FP16(d);
for (int j = 0; j < 16; ++j) { for (int j = 0; j < 16; ++j) {
@ -2862,7 +2535,6 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
float weight[QK5_0]; float weight[QK5_0];
int8_t L[QK5_0]; int8_t L[QK5_0];
int8_t Laux[QK5_0]; int8_t Laux[QK5_0];
// struct fraction Faux[16 * QK5_0];
struct k_heap_cell heap_cells[QK5_0]; struct k_heap_cell heap_cells[QK5_0];
float odd[32]; float odd[32];
struct k_heap k_heap; struct k_heap k_heap;
@ -2879,7 +2551,6 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
const float * qw = quant_weights + QK5_0 * ib; const float * qw = quant_weights + QK5_0 * ib;
for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); for (int j = 0; j < QK5_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
float d = make_qkxh_quants(QK5_0, xb, weight, L, Laux, &k_heap, true); float d = make_qkxh_quants(QK5_0, xb, weight, L, Laux, &k_heap, true);
// float d = make_qkxs_quants(QK5_0, -16, 15, xb, weight, L, Laux, Faux, true);
y[ib].d = GGML_FP32_TO_FP16(d); y[ib].d = GGML_FP32_TO_FP16(d);
uint32_t qh = 0; uint32_t qh = 0;