mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-16 05:02:58 -04:00
gguf-py : add Numpy MXFP4 de/quantization support (#15111)
* gguf-py : add MXFP4 de/quantization support * ggml-quants : handle zero amax for MXFP4
This commit is contained in:
@@ -228,8 +228,7 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
|
||||
d = max / -8
|
||||
with np.errstate(divide="ignore"):
|
||||
id = np.where(d == 0, 0, 1 / d)
|
||||
# FIXME: Q4_0's reference rounding is cursed and depends on FMA
|
||||
qs = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
|
||||
qs = np.trunc((blocks * id) + np.float32(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
|
||||
|
||||
qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
|
||||
qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
|
||||
@@ -300,8 +299,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
|
||||
d = max / -16
|
||||
with np.errstate(divide="ignore"):
|
||||
id = np.where(d == 0, 0, 1 / d)
|
||||
# FIXME: Q5_0's reference rounding is cursed and depends on FMA
|
||||
q = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
|
||||
q = np.trunc((blocks * id) + np.float32(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
|
||||
|
||||
qs = q.reshape((n_blocks, 2, cls.block_size // 2))
|
||||
qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
|
||||
@@ -655,6 +653,57 @@ class TQ2_0(__Quant, qtype=GGMLQuantizationType.TQ2_0):
|
||||
return (d * qs.astype(np.float32))
|
||||
|
||||
|
||||
class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
|
||||
# e2m1 values (doubled)
|
||||
# ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
kvalues = (0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12)
|
||||
|
||||
@staticmethod
|
||||
# see ggml_e8m0_to_fp32_half in ggml-impl.h
|
||||
def e8m0_to_fp32_half(x: np.ndarray) -> np.ndarray:
|
||||
bits = np.where(x < 2, np.uint32(0x00200000) << np.uint32(x), np.uint32(x - 1) << np.uint32(23))
|
||||
return bits.view(np.float32)
|
||||
|
||||
@classmethod
|
||||
def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d = abs(blocks).max(axis=-1, keepdims=True)
|
||||
|
||||
with np.errstate(divide="ignore"):
|
||||
e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8)
|
||||
|
||||
d = cls.e8m0_to_fp32_half(e)
|
||||
|
||||
kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 16))
|
||||
|
||||
errs = np.abs(d.reshape((n_blocks, 1, 1)) * kvalues.astype(np.float32) - blocks.reshape((n_blocks, cls.block_size, 1)))
|
||||
best = np.argmin(errs, axis=-1, keepdims=True)
|
||||
|
||||
qs = best.reshape(n_blocks, 2, cls.block_size // 2).astype(np.uint8)
|
||||
qs = qs[:, 0] | (qs[:, 1] << np.uint8(4))
|
||||
|
||||
qs = qs.reshape((n_blocks, cls.block_size // 2))
|
||||
|
||||
return np.concatenate([e, qs], axis=-1)
|
||||
|
||||
@classmethod
|
||||
def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
e, qs = np.hsplit(blocks, [1])
|
||||
|
||||
d = cls.e8m0_to_fp32_half(e)
|
||||
|
||||
qs = qs.reshape((n_blocks, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
|
||||
qs = (qs & np.uint8(0x0F)).view(np.int8)
|
||||
|
||||
kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
|
||||
qs = np.take_along_axis(kvalues, qs, axis=-1).reshape((n_blocks, cls.block_size))
|
||||
|
||||
return (d * qs.astype(np.float32))
|
||||
|
||||
|
||||
class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS):
|
||||
ksigns: bytes = (
|
||||
b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f"
|
||||
|
Reference in New Issue
Block a user