mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 12:25:03 +00:00
convert : improve model arch handling (#13122)
* convert : improve model arch handling * use AutoConfig * rm trust_remote_code * Update convert_hf_to_gguf.py * fix self.block_count for vision * fix NomicBertModel
This commit is contained in:
@ -16,6 +16,7 @@ from pathlib import Path
|
|||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
|
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -66,8 +67,6 @@ class ModelBase:
|
|||||||
part_names: list[str]
|
part_names: list[str]
|
||||||
is_safetensors: bool
|
is_safetensors: bool
|
||||||
hparams: dict[str, Any]
|
hparams: dict[str, Any]
|
||||||
block_count: int
|
|
||||||
tensor_map: gguf.TensorNameMap
|
|
||||||
tensor_names: set[str] | None
|
tensor_names: set[str] | None
|
||||||
gguf_writer: gguf.GGUFWriter
|
gguf_writer: gguf.GGUFWriter
|
||||||
model_name: str | None
|
model_name: str | None
|
||||||
@ -78,6 +77,10 @@ class ModelBase:
|
|||||||
# subclasses should define this!
|
# subclasses should define this!
|
||||||
model_arch: gguf.MODEL_ARCH
|
model_arch: gguf.MODEL_ARCH
|
||||||
|
|
||||||
|
# subclasses should initialize this!
|
||||||
|
block_count: int
|
||||||
|
tensor_map: gguf.TensorNameMap
|
||||||
|
|
||||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
|
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
|
||||||
use_temp_file: bool = False, eager: bool = False,
|
use_temp_file: bool = False, eager: bool = False,
|
||||||
metadata_override: Path | None = None, model_name: str | None = None,
|
metadata_override: Path | None = None, model_name: str | None = None,
|
||||||
@ -113,8 +116,6 @@ class ModelBase:
|
|||||||
if not self.is_safetensors:
|
if not self.is_safetensors:
|
||||||
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
|
||||||
self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
|
self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
|
||||||
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
|
|
||||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
|
||||||
self.tensor_names = None
|
self.tensor_names = None
|
||||||
self.metadata_override = metadata_override
|
self.metadata_override = metadata_override
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -417,15 +418,13 @@ class ModelBase:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_hparams(dir_model: Path):
|
def load_hparams(dir_model: Path):
|
||||||
|
try:
|
||||||
|
return AutoConfig.from_pretrained(dir_model).to_dict()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load model config from {dir_model}: {e}")
|
||||||
|
logger.warning("Trying to load config.json instead")
|
||||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||||
hparams = json.load(f)
|
return json.load(f)
|
||||||
architectures = hparams.get("architectures")
|
|
||||||
if "text_config" in hparams:
|
|
||||||
hparams = {**hparams, **hparams["text_config"]}
|
|
||||||
if architectures is not None:
|
|
||||||
# preserve "architectures" from root level config
|
|
||||||
hparams["architectures"] = architectures
|
|
||||||
return hparams
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
|
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
|
||||||
@ -454,6 +453,23 @@ class ModelBase:
|
|||||||
|
|
||||||
|
|
||||||
class TextModel(ModelBase):
|
class TextModel(ModelBase):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
if "text_config" in self.hparams:
|
||||||
|
# move the text_config to the root level
|
||||||
|
self.hparams = {**self.hparams, **self.hparams["text_config"]}
|
||||||
|
|
||||||
|
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
|
||||||
|
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __init_subclass__(cls):
|
||||||
|
# can't use an abstract property, because overriding it without type errors
|
||||||
|
# would require using decorated functions instead of simply defining the property
|
||||||
|
if "model_arch" not in cls.__dict__:
|
||||||
|
raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
self._set_vocab_gpt2()
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
@ -1070,9 +1086,9 @@ class VisionModel(ModelBase):
|
|||||||
if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
|
if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
|
||||||
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
|
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
|
||||||
|
|
||||||
# small hack to correct the number of layers
|
# get n_embd of the text model
|
||||||
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, 128)
|
text_config = {**self.hparams, **self.hparams["text_config"]}
|
||||||
self.n_embd_text = self.find_hparam(["hidden_size", "n_embd"])
|
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
|
||||||
assert self.n_embd_text > 0, "n_embd not found in hparams"
|
assert self.n_embd_text > 0, "n_embd not found in hparams"
|
||||||
|
|
||||||
if "vision_config" not in self.hparams:
|
if "vision_config" not in self.hparams:
|
||||||
@ -1081,6 +1097,9 @@ class VisionModel(ModelBase):
|
|||||||
self.global_config = self.hparams
|
self.global_config = self.hparams
|
||||||
self.hparams = self.hparams["vision_config"]
|
self.hparams = self.hparams["vision_config"]
|
||||||
|
|
||||||
|
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"])
|
||||||
|
self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count)
|
||||||
|
|
||||||
# load preprocessor config
|
# load preprocessor config
|
||||||
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
||||||
self.preprocessor_config = json.load(f)
|
self.preprocessor_config = json.load(f)
|
||||||
@ -1098,7 +1117,7 @@ class VisionModel(ModelBase):
|
|||||||
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
|
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
|
||||||
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
|
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
|
||||||
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
|
||||||
self.gguf_writer.add_vision_block_count(self.find_hparam(["num_hidden_layers"]))
|
self.gguf_writer.add_vision_block_count(self.block_count)
|
||||||
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
|
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
|
||||||
|
|
||||||
# preprocessor config
|
# preprocessor config
|
||||||
@ -1719,23 +1738,12 @@ class StableLMModel(TextModel):
|
|||||||
"LlamaForCausalLM",
|
"LlamaForCausalLM",
|
||||||
"MistralForCausalLM",
|
"MistralForCausalLM",
|
||||||
"MixtralForCausalLM",
|
"MixtralForCausalLM",
|
||||||
"Idefics3ForConditionalGeneration",
|
"VLlama3ForCausalLM",
|
||||||
"SmolVLMForConditionalGeneration",
|
|
||||||
"LlavaForConditionalGeneration")
|
"LlavaForConditionalGeneration")
|
||||||
class LlamaModel(TextModel):
|
class LlamaModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||||
undo_permute = True
|
undo_permute = True
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
# fix for SmolVLM2, missing `num_attention_heads` in config.json
|
|
||||||
if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
|
|
||||||
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
|
||||||
# fix for Pixtral, missing `num_attention_heads` in config.json
|
|
||||||
if self.hparams["architectures"][0] == "LlavaForConditionalGeneration" \
|
|
||||||
and self.hparams.get("model_type") == "mistral":
|
|
||||||
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
|
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
try:
|
try:
|
||||||
self._set_vocab_sentencepiece()
|
self._set_vocab_sentencepiece()
|
||||||
@ -1898,11 +1906,7 @@ class LlavaVisionModel(VisionModel):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
if self.hparams["model_type"] == "pixtral":
|
if self.hparams["model_type"] == "pixtral":
|
||||||
# fix missing config.json values
|
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
|
||||||
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
|
|
||||||
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 24)
|
|
||||||
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 4096)
|
|
||||||
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1024)
|
|
||||||
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
|
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
|
||||||
self.img_break_tok_id = 12 # see tokenizer_config.json
|
self.img_break_tok_id = 12 # see tokenizer_config.json
|
||||||
else:
|
else:
|
||||||
@ -1913,7 +1917,6 @@ class LlavaVisionModel(VisionModel):
|
|||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
if hparams["model_type"] == "pixtral":
|
if hparams["model_type"] == "pixtral":
|
||||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
|
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
|
||||||
# default values below are taken from HF tranformers code
|
|
||||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
||||||
self.gguf_writer.add_vision_use_silu(True)
|
self.gguf_writer.add_vision_use_silu(True)
|
||||||
|
|
||||||
@ -1944,13 +1947,12 @@ class LlavaVisionModel(VisionModel):
|
|||||||
class SmolVLMModel(VisionModel):
|
class SmolVLMModel(VisionModel):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
if self.hparams["model_type"] == "smolvlm_vision":
|
||||||
# fix for SmolVLM2, missing some keys in config.json
|
# fix for SmolVLM2, missing some keys in config.json
|
||||||
# default values are taken from transformers code
|
# default values are taken from transformers code
|
||||||
if self.hparams["model_type"] == "smolvlm_vision":
|
|
||||||
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152)
|
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152)
|
||||||
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
|
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
|
||||||
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072)
|
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072)
|
||||||
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 12)
|
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
@ -3505,6 +3507,8 @@ class RobertaModel(BertModel):
|
|||||||
|
|
||||||
@ModelBase.register("NomicBertModel")
|
@ModelBase.register("NomicBertModel")
|
||||||
class NomicBertModel(BertModel):
|
class NomicBertModel(BertModel):
|
||||||
|
model_arch = gguf.MODEL_ARCH.BERT
|
||||||
|
|
||||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
|
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
|
||||||
hparams = kwargs.pop("hparams", None)
|
hparams = kwargs.pop("hparams", None)
|
||||||
if hparams is None:
|
if hparams is None:
|
||||||
@ -5849,6 +5853,19 @@ def split_str_to_n_bytes(split_str: str) -> int:
|
|||||||
return n
|
return n
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str:
|
||||||
|
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
|
||||||
|
text_config = hparams.get("text_config", {})
|
||||||
|
vision_config = hparams.get("vision_config", {})
|
||||||
|
arch = hparams["architectures"][0]
|
||||||
|
# if "architectures" is found in the sub-config, use that instead
|
||||||
|
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
|
||||||
|
arch = text_config["architectures"][0]
|
||||||
|
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
|
||||||
|
arch = vision_config["architectures"][0]
|
||||||
|
return arch
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
@ -5901,16 +5918,15 @@ def main() -> None:
|
|||||||
|
|
||||||
logger.info(f"Loading model: {dir_model.name}")
|
logger.info(f"Loading model: {dir_model.name}")
|
||||||
|
|
||||||
hparams = ModelBase.load_hparams(dir_model)
|
|
||||||
|
|
||||||
if args.mmproj:
|
if args.mmproj:
|
||||||
if "mmproj" not in fname_out.name:
|
if "mmproj" not in fname_out.name:
|
||||||
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
|
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
output_type = ftype_map[args.outtype]
|
output_type = ftype_map[args.outtype]
|
||||||
model_architecture = hparams["architectures"][0]
|
|
||||||
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
|
model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
|
||||||
|
model_architecture = get_model_architecture(dir_model, model_type)
|
||||||
|
logger.info(f"Model architecture: {model_architecture}")
|
||||||
try:
|
try:
|
||||||
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
|
model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
|
Reference in New Issue
Block a user