ggml : remove q1_3 and q2_2

* llama : remove the separate scale tensors of BitNet b1.58

They won't be needed, since the remaining ternary quant types have
built-in scales.
This commit is contained in:
Francis Couture-Harpin
2024-08-02 19:52:19 -04:00
parent 45719a2472
commit 04eec58112
12 changed files with 45 additions and 693 deletions

View File

@@ -121,55 +121,3 @@ def quantize_q8_0(data: np.ndarray):
return __quantize_q8_0_lazy(data)
else:
return __quantize_q8_0_array(data)
__q1_3_block_size, __q1_3_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q1_3]
def can_quantize_to_q1_3(n: np.ndarray) -> bool:
return n.shape[-1] % __q1_3_block_size == 0
def __quantize_q1_3_shape_change(s: tuple[int, ...]) -> tuple[int, ...]:
return (*s[:-1], s[-1] // __q1_3_block_size * __q1_3_type_size)
def __quantize_q1_3_rows(n: np.ndarray) -> np.ndarray:
shape = n.shape
assert shape[-1] % __q1_3_block_size == 0
n_blocks = n.size // __q1_3_block_size
blocks = n.reshape((n_blocks, __q1_3_block_size)).astype(np.float32, copy=False)
# assuming the weights are pre-scaled
blocks = (np.sign(blocks).astype(np.int8) + 1).view(np.uint8)
q48, rest = np.hsplit(blocks, (48,))
q12, q4 = np.hsplit(rest, (12,))
pow3 = np.array([1, 3, 9, 27])
q48 = q48.reshape((n_blocks, 12, 4))
q48 = np.sum(q48 * pow3.reshape((1, 1, 4)), axis=2, keepdims=True).reshape((n_blocks, 12))
q4 = np.sum(q4 * pow3.reshape((1, 4)), axis=1, keepdims=True)
q48 = q48 + (q12 * 81)
q = np.concatenate([q48, q4], axis=1)
q = (((q.astype(np.uint16) * 256) + (243 - 1)) // 243).astype(np.uint8)
return q.reshape(__quantize_q1_3_shape_change(shape))
def __quantize_q1_3_array(n: np.ndarray) -> np.ndarray:
return __apply_over_grouped_rows(__quantize_q1_3_rows, arr=n, otype=np.uint8, oshape=__quantize_q1_3_shape_change(n.shape))
__quantize_q1_3_lazy = LazyNumpyTensor._wrap_fn(
__quantize_q1_3_array,
meta_noop=(np.uint8, __quantize_q1_3_shape_change),
)
def quantize_q1_3(data: np.ndarray):
if type(data) is LazyNumpyTensor:
return __quantize_q1_3_lazy(data)
else:
return __quantize_q1_3_array(data)