From e54d41befcc1575f4c898c5ff4ef43970cead75f Mon Sep 17 00:00:00 2001 From: compilade Date: Fri, 8 Aug 2025 17:48:26 -0400 Subject: [PATCH] gguf-py : add Numpy MXFP4 de/quantization support (#15111) * gguf-py : add MXFP4 de/quantization support * ggml-quants : handle zero amax for MXFP4 --- ggml/src/ggml-quants.c | 2 +- gguf-py/gguf/quants.py | 57 +++++++++++++++++++++++++++++++++--- gguf-py/tests/test_quants.py | 17 ++++++++--- 3 files changed, 67 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index a57d2a16d..94f6405ca 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -288,7 +288,7 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE } } - const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127); + const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0; const float d = GGML_E8M0_TO_FP32_HALF(e); diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index 3c8ba82e1..31845ea6e 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -228,8 +228,7 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0): d = max / -8 with np.errstate(divide="ignore"): id = np.where(d == 0, 0, 1 / d) - # FIXME: Q4_0's reference rounding is cursed and depends on FMA - qs = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15) + qs = np.trunc((blocks * id) + np.float32(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15) qs = qs.reshape((n_blocks, 2, cls.block_size // 2)) qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4)) @@ -300,8 +299,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0): d = max / -16 with np.errstate(divide="ignore"): id = np.where(d == 0, 0, 1 / d) - # FIXME: Q5_0's reference rounding is cursed and depends on FMA - q = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31) + q = np.trunc((blocks * id) + np.float32(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31) qs = q.reshape((n_blocks, 2, cls.block_size // 2)) qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4)) @@ -655,6 +653,57 @@ class TQ2_0(__Quant, qtype=GGMLQuantizationType.TQ2_0): return (d * qs.astype(np.float32)) +class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4): + # e2m1 values (doubled) + # ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + kvalues = (0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12) + + @staticmethod + # see ggml_e8m0_to_fp32_half in ggml-impl.h + def e8m0_to_fp32_half(x: np.ndarray) -> np.ndarray: + bits = np.where(x < 2, np.uint32(0x00200000) << np.uint32(x), np.uint32(x - 1) << np.uint32(23)) + return bits.view(np.float32) + + @classmethod + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d = abs(blocks).max(axis=-1, keepdims=True) + + with np.errstate(divide="ignore"): + e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8) + + d = cls.e8m0_to_fp32_half(e) + + kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 16)) + + errs = np.abs(d.reshape((n_blocks, 1, 1)) * kvalues.astype(np.float32) - blocks.reshape((n_blocks, cls.block_size, 1))) + best = np.argmin(errs, axis=-1, keepdims=True) + + qs = best.reshape(n_blocks, 2, cls.block_size // 2).astype(np.uint8) + qs = qs[:, 0] | (qs[:, 1] << np.uint8(4)) + + qs = qs.reshape((n_blocks, cls.block_size // 2)) + + return np.concatenate([e, qs], axis=-1) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + e, qs = np.hsplit(blocks, [1]) + + d = cls.e8m0_to_fp32_half(e) + + qs = qs.reshape((n_blocks, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1)) + qs = (qs & np.uint8(0x0F)).view(np.int8) + + kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16) + qs = np.take_along_axis(kvalues, qs, axis=-1).reshape((n_blocks, cls.block_size)) + + return (d * qs.astype(np.float32)) + + class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS): ksigns: bytes = ( b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f" diff --git a/gguf-py/tests/test_quants.py b/gguf-py/tests/test_quants.py index f04d5acce..172fa0018 100755 --- a/gguf-py/tests/test_quants.py +++ b/gguf-py/tests/test_quants.py @@ -67,6 +67,7 @@ class GGMLQuants: "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2_K", "q3_K", "q4_K", "q5_K", "q6_K", "tq1_0", "tq2_0", + "mxfp4", "iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m", "iq4_nl", "iq4_xs", ): @@ -140,14 +141,21 @@ def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) return False -def do_test(libggml_path: Path, quick: bool = False): +def do_test(libggml_path: Path, quick: bool = False, user_type: GGMLQuantizationType | None = None): ggml_quants = GGMLQuants(libggml_path) np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n}) r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False) + # test zero blocks + r[0, 0, :] = 0 + ## Maybe test infinities? (can make NANs, not really useful in practice) + # r[0, 1, 0] = np.inf + # r[0, 2, 0] = -np.inf + # r[0, 3, 0] = np.inf + # r[0, 3, 1] = -np.inf - for qtype in (GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()): + for qtype in ((GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()) if user_type is None else (user_type,)): has_dequantize = False has_quantize = False @@ -228,11 +236,12 @@ def do_test(libggml_path: Path, quick: bool = False): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation") - parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so") + parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "bin" / "libggml.so", help="The path to libggml.so") parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary") + parser.add_argument("--type", type=str, help="The quant type to test (all by default)") args = parser.parse_args() logging.basicConfig(level=logging.DEBUG) - do_test(args.libggml, args.quick) + do_test(args.libggml, args.quick, GGMLQuantizationType[args.type.upper()] if args.type is not None else None)