From d0060fc498595c7a2b7c1e28662cc11f298e6c9a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 21 Feb 2025 15:05:03 -0500 Subject: [PATCH] ggml-quants : better and faster make_qkxs_quants --- ggml/src/ggml-quants.c | 193 ++++++++++++++++++++++++++--------------- 1 file changed, 123 insertions(+), 70 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 638766ae8..29ffb608b 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -660,97 +660,148 @@ static inline int compare_fractions_desc(const void * a, const void * b) { // exhaustive search with cumulative sums // Need Faux to have room for n*(max(abs(nmin), abs(nmax))) fractions -static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict x, const float * restrict weights, int8_t * restrict L, struct fraction * restrict Faux, bool signed_scale) { - float max = 0.0f; - float amax = 0.0f; - for (int i = 0; i < n; ++i) { - float ax = fabsf(x[i]); - if (ax > amax) { - amax = ax; - max = x[i]; - } - } - bool negative_scale = false; - if (signed_scale && -nmin != nmax) { - // the max side should have the biggest range - if ((max < 0.0f) == (-nmin < nmax)) { - // [-4, 3] ==> [-3, 4] - int tmp = nmin; - nmin = -nmax; - nmax = -tmp; - negative_scale = true; +static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct fraction * restrict Faux, bool signed_scale) { + const int orig_nmin = nmin; + const int orig_nmax = nmax; + float max = x[0]; + float min = x[0]; + float w_amax = weights[0] * fabsf(x[0]); + int max_i = 0; + int w_amax_i = 0; + int min_i = 0; + for (int i = 1; i < n; ++i) { + if (x[i] < min) { min = x[i]; min_i = i; } + if (x[i] > max) { max = x[i]; max_i = i; } + // Find the most important value + const float w = weights[i]; + const float wax = w * fabsf(x[i]); + if (wax > w_amax) { + w_amax = wax; + w_amax_i = i; } } + const int amax_i = fabsf(min) > fabsf(max) ? min_i : max_i; + const float amax = fabsf(x[amax_i]); + if (amax < GROUP_MAX_EPS) { // all zero for (int i = 0; i < n; ++i) { L[i] = 0; } return 0.0f; } + bool negative_scale = false; + if (signed_scale && -nmin != nmax) { + // the max side should have the biggest range + // FIXME: this is incorrect when the weights[.] do not sort in the same order as fabsf(x[.]) + // or is it some other condition? + if ((x[amax_i] < 0.0f) == (-nmin < nmax)) { + // [-4, 3] ==> [-3, 4] + const int tmp = nmin; + const float ftmp = min; + nmin = -nmax; + nmax = -tmp; + min = -max; + max = -ftmp; + negative_scale = true; + } + } + + // Find the max range in [0, amax_range] which doesn't result in clamping. + // This is the range from the side which would clamp first (biggest ratio of max to nmax). + int amax_range; + float range_max; + if (fabsf(-max * nmin) < fabsf(-min * nmax)) { + amax_range = MAX(0, -nmin); + range_max = fabsf(min); + } else { + amax_range = MAX(0, nmax); + range_max = fabsf(max); + } + float sumlx = 0.0f; + float suml2 = 0.0f; + float scale = 0.0f; + float best = 0.0f; + float best_denom = 1.0f; + if (amax_range > 1) { + // The smallest non-redundant iscale makes the first clamped value half+1 its max integer value. + // Proof: anything smaller has a representable vector with values twice as big. + const float iscale = ((float)(amax_range / 2 + 1))/range_max * (negative_scale ? -1.0f : 1.0f); + for (int i = 0; i < n; ++i) { + const float w = weights[i]; + int l = MAX(nmin, MIN(lroundf(x[i] * iscale), nmax)); + if (negative_scale) { l = -l; } + Laux[i] = l; + L[i] = l; + suml2 += w * l * l; + sumlx += w * l * x[i]; + } + best = sumlx * sumlx; + best_denom = suml2; // should never be zero + scale = sumlx / suml2; + } else { + for (int i = 0; i < n; ++i) { + Laux[i] = 0; + L[i] = 0; + } + } + + const int imax_range = MAX(0, (x[w_amax_i] < 0.0f) ? -nmin : nmax); + const int max_odd = 2*(imax_range + 1) + 1; + const float wmax = fabsf(x[w_amax_i]); int n_frac = 0; for (int i = 0; i < n; ++i) { // assuming nmin <= nmax - const int odd_max = MAX(0, x[i] < 0 ? -nmin : nmax); - const int odd_min = MAX(0, x[i] < 0 ? -nmax : nmin); + const int odd_max = MAX(abs(Laux[i]), x[i] < 0.0f ? -nmin : nmax); + const int odd_min = MAX(abs(Laux[i]), x[i] < 0.0f ? -nmax : nmin); const float v = fabsf(x[i]); - // fprintf(stderr, "%s: i=%d, odd_min=%d, odd_max=%d\n", __func__, i, odd_min, odd_max); + const float v_max_odd = v * max_odd; for (int j = odd_min; j < odd_max; ++j) { const float odd = 2*j + 1; - Faux[n_frac++] = (struct fraction){ - .numer=v, - .denom=odd, - .i=i, - }; + if (wmax * odd < v_max_odd) { + Faux[n_frac++] = (struct fraction){ + .numer=v, + .denom=odd, + .i=i, + }; + } else { + // stop when the inverse scale would result in clamping the max (FIXME: most important) value + break; + } } } qsort(Faux, n_frac, sizeof(struct fraction), compare_fractions_desc); - float iscale = 0.0f; - { - float sumlx = 0.0f; - float suml2 = 0.0f; - float best = 0.0f; - float best_denom = 1.0f; - for (int i = 0; i < n_frac; ++i) { - // maximize the weighted cosine - const int ii = Faux[i].i; - const float w = weights ? weights[ii] : x[ii] * x[ii]; - sumlx += w * Faux[i].numer; - suml2 += w * Faux[i].denom; - const float current = sumlx * sumlx; - // fprintf(stderr, "%s: Faux[%d]=(%f/%f) * %f, square(sumlx)=%f, suml2=%f, k*cos2=%f\n", __func__, i, Faux[i].numer, Faux[i].denom, Faux[i].weight, current, suml2, current / suml2); - // use the last in case of equality - // FIXME: > or >= ?? Why does [0, 0, 1] rounds to [0, 0, 0] with >= ? - if (suml2 > 0.0f && current * best_denom > best * suml2) { - best = current; - best_denom = suml2; - iscale = Faux[i].numer > 0.0f ? Faux[i].denom / (2.0f * Faux[i].numer) : 0.0f; - if (!isfinite(iscale)) { - fprintf(stderr, "%s: iscale is not finite, %f/(2*%f)\n", __func__, Faux[i].denom, Faux[i].numer); + int best_p_i = -1; // consecutive with 0..n_frac + for (int i = 0; i < n_frac; ++i) { + // maximize the weighted cosine + const int ii = Faux[i].i; + const float w = weights ? weights[ii] : x[ii] * x[ii]; + sumlx += w * Faux[i].numer; + suml2 += w * Faux[i].denom; + const float current = sumlx * sumlx; + Laux[ii] += x[ii] < 0.0f ? -1 : 1; + if (suml2 > 0.0f && Faux[i].numer > 0.0f && current * best_denom > best * suml2) { + best = current; + best_denom = suml2; + scale = sumlx / suml2; + if (i == best_p_i + 1) { + // reduce copies for consecutive bests + L[ii] += x[ii] < 0.0f ? -1 : 1; + } else { + for (int j = 0; j < n; ++j) { + L[j] = Laux[j]; } } + best_p_i = i; } } - // (very) small fudging necessary because floats otherwise round to nearest even - iscale = iscale * ((float)((1 << 23) + 1) / (float)(1 << 23)); - - float sumlx = 0.0f; - float suml2 = 0.0f; for (int i = 0; i < n; ++i) { - // Rounding away from zero is assumed by the search algorithm above. - int l = MAX(nmin, MIN(lroundf(x[i] * iscale), nmax)); - if (negative_scale) { - l = -l; - } - L[i] = negative_scale ? l + nmax : l - nmin; - float w = weights ? weights[i] : x[i] * x[i]; - // weighted projection scale - sumlx += w * x[i] * l; - suml2 += w * l * l; + L[i] = negative_scale ? (-L[i] + nmax) : (L[i] + -nmin); + GGML_ASSERT(L[i] >= 0 && L[i] <= nmax - nmin); } - return suml2 > 0.0f ? sumlx / suml2 : 0.0f; + return negative_scale ? -scale : scale; } // non-linear exhaustive search with cumulative sums @@ -1234,6 +1285,7 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in const int nb = k / QK_K; int8_t L[QK_K]; + int8_t Laux[16]; struct fraction Faux[16 * 4]; float scales[QK_K / 16]; float weights[16]; @@ -1247,7 +1299,7 @@ void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, in float max_scale = 0; float amax = 0; for (int j = 0; j < QK_K/16; ++j) { - scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weights, L + 16*j, Faux, true); + scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weights, L + 16*j, Laux, Faux, true); // scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true); float scale = fabsf(scales[j]); if (scale > amax) { @@ -1367,6 +1419,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri const int nb = n_per_row / QK_K; int8_t L[QK_K]; + int8_t Laux[16]; float scales[QK_K / 16]; float weight[16]; float sw[QK_K / 16]; @@ -1391,14 +1444,14 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri sw[j] = sumw; // scales[j] = make_qx_quants(16, 4, x + 16*j, L + 16*j, 1, weight); - scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weight, L + 16*j, Faux, true); + scales[j] = make_qkxs_quants(16, -4, 3, x + 16*j, weight, L + 16*j, Laux, Faux, true); } memset(y[i].scales, 0, 12); // float d_block = make_qx_quants(QK_K/16, 32, scales, Ls, 1, sw); - float d_block = make_qkxs_quants(QK_K/16, -32, 31, scales, sw, Ls, Faux, true); + float d_block = make_qkxs_quants(QK_K/16, -32, 31, scales, sw, Ls, Laux, Faux, true); for (int j = 0; j < QK_K/16; ++j) { int l = Ls[j]; if (j < 8) { @@ -4856,11 +4909,11 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block for (int j = 0; j < block_size; ++j) weight[j] = sqrtf(sigma2 + xb[j]*xb[j]); // for (int j = 0; j < block_size; ++j) weight[j] = 1; } - float amax = 0, max = 0; + float amax = 0; for (int j = 0; j < block_size; ++j) { float ax = fabsf(xb[j]); if (ax > amax) { - amax = ax; max = xb[j]; + amax = ax; } } if (amax < GROUP_MAX_EPS) {