gguf-py : add Numpy MXFP4 de/quantization support (#15111)

* gguf-py : add MXFP4 de/quantization support

* ggml-quants : handle zero amax for MXFP4
This commit is contained in:
compilade
2025-08-08 17:48:26 -04:00
committed by GitHub
parent 4850b52aed
commit e54d41befc
3 changed files with 67 additions and 9 deletions

View File

@@ -228,8 +228,7 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
d = max / -8
with np.errstate(divide="ignore"):
id = np.where(d == 0, 0, 1 / d)
# FIXME: Q4_0's reference rounding is cursed and depends on FMA
qs = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
qs = np.trunc((blocks * id) + np.float32(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
@@ -300,8 +299,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
d = max / -16
with np.errstate(divide="ignore"):
id = np.where(d == 0, 0, 1 / d)
# FIXME: Q5_0's reference rounding is cursed and depends on FMA
q = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
q = np.trunc((blocks * id) + np.float32(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
qs = q.reshape((n_blocks, 2, cls.block_size // 2))
qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
@@ -655,6 +653,57 @@ class TQ2_0(__Quant, qtype=GGMLQuantizationType.TQ2_0):
return (d * qs.astype(np.float32))
class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
# e2m1 values (doubled)
# ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
kvalues = (0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12)
@staticmethod
# see ggml_e8m0_to_fp32_half in ggml-impl.h
def e8m0_to_fp32_half(x: np.ndarray) -> np.ndarray:
bits = np.where(x < 2, np.uint32(0x00200000) << np.uint32(x), np.uint32(x - 1) << np.uint32(23))
return bits.view(np.float32)
@classmethod
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
n_blocks = blocks.shape[0]
d = abs(blocks).max(axis=-1, keepdims=True)
with np.errstate(divide="ignore"):
e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8)
d = cls.e8m0_to_fp32_half(e)
kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 16))
errs = np.abs(d.reshape((n_blocks, 1, 1)) * kvalues.astype(np.float32) - blocks.reshape((n_blocks, cls.block_size, 1)))
best = np.argmin(errs, axis=-1, keepdims=True)
qs = best.reshape(n_blocks, 2, cls.block_size // 2).astype(np.uint8)
qs = qs[:, 0] | (qs[:, 1] << np.uint8(4))
qs = qs.reshape((n_blocks, cls.block_size // 2))
return np.concatenate([e, qs], axis=-1)
@classmethod
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
n_blocks = blocks.shape[0]
e, qs = np.hsplit(blocks, [1])
d = cls.e8m0_to_fp32_half(e)
qs = qs.reshape((n_blocks, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
qs = (qs & np.uint8(0x0F)).view(np.int8)
kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
qs = np.take_along_axis(kvalues, qs, axis=-1).reshape((n_blocks, cls.block_size))
return (d * qs.astype(np.float32))
class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS):
ksigns: bytes = (
b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f"