mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 11:45:21 +00:00
gguf-py : add support for sub_type (in arrays) in GGUFWriter add_key_value method (#13561)
This commit is contained in:
@ -49,6 +49,7 @@ class TensorInfo:
|
||||
class GGUFValue:
|
||||
value: Any
|
||||
type: GGUFValueType
|
||||
sub_type: GGUFValueType | None = None
|
||||
|
||||
|
||||
class WriterState(Enum):
|
||||
@ -238,7 +239,7 @@ class GGUFWriter:
|
||||
|
||||
for key, val in kv_data.items():
|
||||
kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False)
|
||||
kv_bytes += self._pack_val(val.value, val.type, add_vtype=True)
|
||||
kv_bytes += self._pack_val(val.value, val.type, add_vtype=True, sub_type=val.sub_type)
|
||||
|
||||
fout.write(kv_bytes)
|
||||
|
||||
@ -268,11 +269,11 @@ class GGUFWriter:
|
||||
fout.flush()
|
||||
self.state = WriterState.TI_DATA
|
||||
|
||||
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
|
||||
def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None:
|
||||
if any(key in kv_data for kv_data in self.kv_data):
|
||||
raise ValueError(f'Duplicated key name {key!r}')
|
||||
|
||||
self.kv_data[0][key] = GGUFValue(value=val, type=vtype)
|
||||
self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type)
|
||||
|
||||
def add_uint8(self, key: str, val: int) -> None:
|
||||
self.add_key_value(key,val, GGUFValueType.UINT8)
|
||||
@ -1022,7 +1023,7 @@ class GGUFWriter:
|
||||
pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>'
|
||||
return struct.pack(f'{pack_prefix}{fmt}', value)
|
||||
|
||||
def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes:
|
||||
def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: GGUFValueType | None = None) -> bytes:
|
||||
kv_data = bytearray()
|
||||
|
||||
if add_vtype:
|
||||
@ -1043,7 +1044,9 @@ class GGUFWriter:
|
||||
if len(val) == 0:
|
||||
raise ValueError("Invalid GGUF metadata array. Empty array")
|
||||
|
||||
if isinstance(val, bytes):
|
||||
if sub_type is not None:
|
||||
ltype = sub_type
|
||||
elif isinstance(val, bytes):
|
||||
ltype = GGUFValueType.UINT8
|
||||
else:
|
||||
ltype = GGUFValueType.get_type(val[0])
|
||||
|
@ -1521,19 +1521,21 @@ class GGUFEditorWindow(QMainWindow):
|
||||
continue
|
||||
|
||||
# Apply changes if any
|
||||
sub_type = None
|
||||
if field.name in self.metadata_changes:
|
||||
value_type, value = self.metadata_changes[field.name]
|
||||
if value_type == GGUFValueType.ARRAY:
|
||||
# Handle array values
|
||||
element_type, array_values = value
|
||||
writer.add_array(field.name, array_values)
|
||||
else:
|
||||
writer.add_key_value(field.name, value, value_type)
|
||||
sub_type, value = value
|
||||
else:
|
||||
# Copy original value
|
||||
value = field.contents()
|
||||
if value is not None and field.types:
|
||||
writer.add_key_value(field.name, value, field.types[0])
|
||||
value_type = field.types[0]
|
||||
if value_type == GGUFValueType.ARRAY:
|
||||
sub_type = field.types[-1]
|
||||
|
||||
if value is not None:
|
||||
writer.add_key_value(field.name, value, value_type, sub_type=sub_type)
|
||||
|
||||
# Add new metadata
|
||||
for key, (value_type, value) in self.metadata_changes.items():
|
||||
@ -1541,7 +1543,12 @@ class GGUFEditorWindow(QMainWindow):
|
||||
if self.reader.get_field(key) is not None:
|
||||
continue
|
||||
|
||||
writer.add_key_value(key, value, value_type)
|
||||
sub_type = None
|
||||
if value_type == GGUFValueType.ARRAY:
|
||||
# Handle array values
|
||||
sub_type, value = value
|
||||
|
||||
writer.add_key_value(key, value, value_type, sub_type=sub_type)
|
||||
|
||||
# Add tensors (including data)
|
||||
for tensor in self.reader.tensors:
|
||||
|
@ -24,6 +24,7 @@ class MetadataDetails(NamedTuple):
|
||||
type: gguf.GGUFValueType
|
||||
value: Any
|
||||
description: str = ''
|
||||
sub_type: gguf.GGUFValueType | None = None
|
||||
|
||||
|
||||
def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
|
||||
@ -57,7 +58,9 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
||||
logger.debug(f'Removing {field.name}')
|
||||
continue
|
||||
|
||||
old_val = MetadataDetails(field.types[0], field.contents())
|
||||
val_type = field.types[0]
|
||||
sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None
|
||||
old_val = MetadataDetails(val_type, field.contents(), sub_type=sub_type)
|
||||
val = new_metadata.get(field.name, old_val)
|
||||
|
||||
if field.name in new_metadata:
|
||||
@ -67,7 +70,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
|
||||
logger.debug(f'Copying {field.name}')
|
||||
|
||||
if val.value is not None:
|
||||
writer.add_key_value(field.name, val.value, val.type)
|
||||
writer.add_key_value(field.name, val.value, val.type, sub_type=sub_type if val.sub_type is None else val.sub_type)
|
||||
|
||||
if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
|
||||
logger.debug('Adding chat template(s)')
|
||||
|
@ -1,6 +1,6 @@
|
||||
[tool.poetry]
|
||||
name = "gguf"
|
||||
version = "0.16.3"
|
||||
version = "0.17.0"
|
||||
description = "Read and write ML models in GGUF for GGML"
|
||||
authors = ["GGML <ggml@ggml.ai>"]
|
||||
packages = [
|
||||
|
@ -1,3 +1,3 @@
|
||||
numpy~=1.26.4
|
||||
PySide6~=6.9.0
|
||||
gguf>=0.16.0
|
||||
gguf>=0.17.0
|
||||
|
Reference in New Issue
Block a user