mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 20:05:20 +00:00
convert : converting mmproj for Qwen2/2.5VL from convert_hf_to_gguf (#13209)
* wip
* qwen2.5vl ok
* vision: fix models missing "text_config"
* add test
* fix test repo name
* fix 32B model
* Revert "fix 32B model"
This reverts commit 651752f1ae
.
* clarify about 32B
* rm qwen surgery script
* update llava/readme
* move V_ENC_EMBD_PATCH handling to Qwen2VLVisionModel
This commit is contained in:
@ -1089,6 +1089,8 @@ class VisionModel(ModelBase):
|
|||||||
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
|
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
|
||||||
|
|
||||||
# get n_embd of the text model
|
# get n_embd of the text model
|
||||||
|
if "text_config" not in self.hparams:
|
||||||
|
self.hparams["text_config"] = {}
|
||||||
text_config = {**self.hparams, **self.hparams["text_config"]}
|
text_config = {**self.hparams, **self.hparams["text_config"]}
|
||||||
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
|
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
|
||||||
assert self.n_embd_text > 0, "n_embd not found in hparams"
|
assert self.n_embd_text > 0, "n_embd not found in hparams"
|
||||||
@ -2583,6 +2585,82 @@ class Qwen2VLModel(TextModel):
|
|||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
|
||||||
|
@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
|
||||||
|
class Qwen2VLVisionModel(VisionModel):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.hparams["image_size"] = self.hparams.get("image_size", 560)
|
||||||
|
# rename config.json values
|
||||||
|
self.hparams["num_attention_heads"] = self.hparams.get("num_heads")
|
||||||
|
self.hparams["num_hidden_layers"] = self.hparams.get("depth")
|
||||||
|
if "embed_dim" in self.hparams: # qwen2vl
|
||||||
|
self.hparams["intermediate_size"] = self.hparams.get("hidden_size")
|
||||||
|
self.hparams["hidden_size"] = self.hparams.get("embed_dim")
|
||||||
|
|
||||||
|
def set_gguf_parameters(self):
|
||||||
|
super().set_gguf_parameters()
|
||||||
|
hparams = self.hparams
|
||||||
|
if self.global_config['model_type'] == 'qwen2_vl':
|
||||||
|
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN2VL)
|
||||||
|
elif self.global_config['model_type'] == 'qwen2_5_vl':
|
||||||
|
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN25VL)
|
||||||
|
self.gguf_writer.add_vision_use_silu(True)
|
||||||
|
# find n_wa_pattern (window attention pattern)
|
||||||
|
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
|
||||||
|
assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for qwen2_5_vl"
|
||||||
|
n_wa_pattern = fullatt_block_indexes[0] + 1
|
||||||
|
# validate n_wa_pattern
|
||||||
|
for i in range(1, len(fullatt_block_indexes)):
|
||||||
|
if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern:
|
||||||
|
raise ValueError(f"Invalid fullatt_block_indexes: {fullatt_block_indexes}")
|
||||||
|
self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown QwenVL model type: {self.global_config['model_type']}")
|
||||||
|
# default values below are taken from HF tranformers code
|
||||||
|
self.gguf_writer.add_vision_attention_layernorm_eps(self.global_config.get("rms_norm_eps", 1e-6))
|
||||||
|
|
||||||
|
def tensor_force_quant(self, name, new_name, bid, n_dims):
|
||||||
|
del bid, name, n_dims # unused
|
||||||
|
if ".patch_embd." in new_name:
|
||||||
|
return gguf.GGMLQuantizationType.F16
|
||||||
|
if ".position_embd." in new_name:
|
||||||
|
return gguf.GGMLQuantizationType.F32
|
||||||
|
return False
|
||||||
|
|
||||||
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
del bid # unused
|
||||||
|
if name.startswith("visual."):
|
||||||
|
# process visual tensors
|
||||||
|
# split QKV tensors if needed
|
||||||
|
if ".qkv." in name:
|
||||||
|
if data_torch.ndim == 2: # weight
|
||||||
|
c3, _ = data_torch.shape
|
||||||
|
else: # bias
|
||||||
|
c3 = data_torch.shape[0]
|
||||||
|
assert c3 % 3 == 0
|
||||||
|
c = c3 // 3
|
||||||
|
wq = data_torch[:c]
|
||||||
|
wk = data_torch[c: c * 2]
|
||||||
|
wv = data_torch[c * 2:]
|
||||||
|
return [
|
||||||
|
(self.map_tensor_name(name.replace("qkv", "q")), wq),
|
||||||
|
(self.map_tensor_name(name.replace("qkv", "k")), wk),
|
||||||
|
(self.map_tensor_name(name.replace("qkv", "v")), wv),
|
||||||
|
]
|
||||||
|
elif 'patch_embed.proj.weight' in name:
|
||||||
|
# split Conv3D into Conv2Ds
|
||||||
|
c1, c2, kt, kh, kw = data_torch.shape
|
||||||
|
del c1, c2, kh, kw # unused
|
||||||
|
assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
|
||||||
|
return [
|
||||||
|
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight" , data_torch[:, :, 0, ...]),
|
||||||
|
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
return [] # skip other tensors
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("WavTokenizerDec")
|
@ModelBase.register("WavTokenizerDec")
|
||||||
class WavTokenizerDecModel(TextModel):
|
class WavTokenizerDecModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
|
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
|
||||||
|
@ -35,6 +35,16 @@ llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF
|
|||||||
# Pixtral 12B
|
# Pixtral 12B
|
||||||
llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
|
llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
|
||||||
|
|
||||||
|
# Qwen 2 VL
|
||||||
|
llama-mtmd-cli -hf ggml-org/Qwen2-VL-2B-Instruct-GGUF
|
||||||
|
llama-mtmd-cli -hf ggml-org/Qwen2-VL-7B-Instruct-GGUF
|
||||||
|
|
||||||
|
# Qwen 2.5 VL
|
||||||
|
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF
|
||||||
|
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF
|
||||||
|
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF
|
||||||
|
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF
|
||||||
|
|
||||||
# Mistral Small 3.1 24B (IQ2_M quantization)
|
# Mistral Small 3.1 24B (IQ2_M quantization)
|
||||||
llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7
|
llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7
|
||||||
```
|
```
|
||||||
@ -60,7 +70,17 @@ Built upon `clip.cpp` (similar to `llava.cpp`), `libmtmd` offers several advanta
|
|||||||
|
|
||||||
## How to obtain `mmproj`
|
## How to obtain `mmproj`
|
||||||
|
|
||||||
Multimodal projector (`mmproj`) files are specific to each model architecture. Please refer to the relevant guide for instructions on how to obtain or create them:
|
Multimodal projector (`mmproj`) files are specific to each model architecture.
|
||||||
|
|
||||||
|
For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file:
|
||||||
|
- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support
|
||||||
|
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
|
||||||
|
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
|
||||||
|
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
|
||||||
|
- Qwen 2 VL and Qwen 2.5 VL (from [Qwen](https://huggingface.co/Qwen))
|
||||||
|
- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
|
||||||
|
|
||||||
|
For older models, please refer to the relevant guide for instructions on how to obtain or create them:
|
||||||
|
|
||||||
- [LLaVA](../../docs/multimodal/llava.md)
|
- [LLaVA](../../docs/multimodal/llava.md)
|
||||||
- [MobileVLM](../../docs/multimodal/MobileVLM.md)
|
- [MobileVLM](../../docs/multimodal/MobileVLM.md)
|
||||||
@ -70,10 +90,3 @@ Multimodal projector (`mmproj`) files are specific to each model architecture. P
|
|||||||
- [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
|
- [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
|
||||||
- [IBM Granite Vision](../../docs/multimodal/granitevision.md)
|
- [IBM Granite Vision](../../docs/multimodal/granitevision.md)
|
||||||
- [Google Gemma 3](../../docs/multimodal/gemma3.md)
|
- [Google Gemma 3](../../docs/multimodal/gemma3.md)
|
||||||
|
|
||||||
For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file:
|
|
||||||
- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support
|
|
||||||
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
|
|
||||||
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
|
|
||||||
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
|
|
||||||
- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
|
|
||||||
|
@ -1,217 +0,0 @@
|
|||||||
import argparse
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from gguf import *
|
|
||||||
from transformers import (
|
|
||||||
AutoProcessor,
|
|
||||||
Qwen2VLConfig,
|
|
||||||
Qwen2VLProcessor,
|
|
||||||
Qwen2VLForConditionalGeneration,
|
|
||||||
Qwen2_5_VLConfig, # type: ignore[reportAttributeAccessIssue]
|
|
||||||
Qwen2_5_VLForConditionalGeneration, # type: ignore[reportAttributeAccessIssue]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
VISION = "clip.vision"
|
|
||||||
|
|
||||||
|
|
||||||
def k(raw_key: str, arch: str) -> str:
|
|
||||||
return raw_key.format(arch=arch)
|
|
||||||
|
|
||||||
|
|
||||||
def get_n_wa_pattern(fullatt_block_indexes: Optional[List[int]]):
|
|
||||||
if fullatt_block_indexes is None:
|
|
||||||
return 0
|
|
||||||
n_wa = fullatt_block_indexes[0]
|
|
||||||
for a, b in zip(fullatt_block_indexes, fullatt_block_indexes[1:]):
|
|
||||||
if b - a - 1 != n_wa:
|
|
||||||
raise ValueError(
|
|
||||||
f"window/full attention layer should have fix pattern of "
|
|
||||||
f"for each full-attention layer followed by {n_wa} window-attention layers"
|
|
||||||
)
|
|
||||||
return n_wa + 1
|
|
||||||
|
|
||||||
|
|
||||||
class VL2:
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def to_gguf_name(name: str) -> str:
|
|
||||||
og = name
|
|
||||||
name = name.replace("text_model", "t").replace("vision_model", "v")
|
|
||||||
name = name.replace("blocks", "blk").replace("embeddings.", "")
|
|
||||||
name = name.replace("attn.", "attn_")
|
|
||||||
name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.")
|
|
||||||
# name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln")
|
|
||||||
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
|
|
||||||
name = name.replace("merger.mlp", 'mm')
|
|
||||||
print(f"[to_gguf_name] {og} --> {name}")
|
|
||||||
return name
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def find_vision_tensors(cls, qwen2vl, dtype) -> Dict[str, np.ndarray]:
|
|
||||||
vision_model = qwen2vl.visual
|
|
||||||
tensor_map = {}
|
|
||||||
for name, ten in vision_model.state_dict().items():
|
|
||||||
ten = ten.numpy()
|
|
||||||
if 'qkv' in name:
|
|
||||||
if ten.ndim == 2: # weight
|
|
||||||
c3, _ = ten.shape
|
|
||||||
else: # bias
|
|
||||||
c3 = ten.shape[0]
|
|
||||||
assert c3 % 3 == 0
|
|
||||||
c = c3 // 3
|
|
||||||
wq = ten[:c]
|
|
||||||
wk = ten[c: c * 2]
|
|
||||||
wv = ten[c * 2:]
|
|
||||||
tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq
|
|
||||||
tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk
|
|
||||||
tensor_map[cls.to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv
|
|
||||||
elif 'merger' in name:
|
|
||||||
if name.endswith("ln_q.weight"):
|
|
||||||
tensor_map['v.post_ln.weight'] = ten
|
|
||||||
elif name.endswith("ln_q.bias"):
|
|
||||||
tensor_map['v.post_ln.bias'] = ten
|
|
||||||
else:
|
|
||||||
# "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias"
|
|
||||||
tensor_map[cls.to_gguf_name(name)] = ten
|
|
||||||
elif 'patch_embed.proj.weight' in name:
|
|
||||||
# NOTE: split Conv3D into Conv2Ds
|
|
||||||
c1, c2, kt, kh, kw = ten.shape
|
|
||||||
assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
|
|
||||||
tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
|
|
||||||
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]
|
|
||||||
else:
|
|
||||||
tensor_map[cls.to_gguf_name(f"vision_model.{name}")] = ten
|
|
||||||
|
|
||||||
for new_name, ten in tensor_map.items():
|
|
||||||
if ten.ndim <= 1 or new_name.endswith("_norm.weight"):
|
|
||||||
tensor_map[new_name] = ten.astype(np.float32)
|
|
||||||
else:
|
|
||||||
tensor_map[new_name] = ten.astype(dtype)
|
|
||||||
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder
|
|
||||||
return tensor_map
|
|
||||||
|
|
||||||
|
|
||||||
class VL25(VL2):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def to_gguf_name(name: str) -> str:
|
|
||||||
og = name
|
|
||||||
name = name.replace("text_model", "t").replace("vision_model", "v")
|
|
||||||
name = name.replace("blocks", "blk").replace("embeddings.", "")
|
|
||||||
name = name.replace("attn.", "attn_")
|
|
||||||
name = name.replace("mlp.down_proj", "ffn_down").replace("mlp.up_proj", "ffn_up")
|
|
||||||
name = name.replace("mlp.gate_proj", "ffn_gate").replace("proj.", "out.")
|
|
||||||
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
|
|
||||||
name = name.replace("merger.mlp", 'mm')
|
|
||||||
print(f"[vl25][to_gguf_name] {og} --> {name}")
|
|
||||||
return name
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
if args.data_type == 'fp32':
|
|
||||||
dtype = torch.float32
|
|
||||||
np_dtype = np.float32
|
|
||||||
ftype = 0
|
|
||||||
elif args.data_type == 'fp16':
|
|
||||||
dtype = torch.float16
|
|
||||||
np_dtype = np.float16
|
|
||||||
ftype = 1
|
|
||||||
else:
|
|
||||||
raise ValueError()
|
|
||||||
|
|
||||||
local_model = False
|
|
||||||
model_path = ""
|
|
||||||
model_name = args.model_name
|
|
||||||
print("model_name: ", model_name)
|
|
||||||
if args.model_type == "qwen2vl":
|
|
||||||
qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
||||||
model_name, torch_dtype=dtype, device_map="cpu"
|
|
||||||
)
|
|
||||||
cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
|
|
||||||
vcfg = cfg.vision_config
|
|
||||||
else:
|
|
||||||
qwen2vl = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
||||||
model_name, torch_dtype=dtype, device_map="cpu"
|
|
||||||
)
|
|
||||||
cfg: Qwen2_5_VLConfig = qwen2vl.config # type: ignore[reportAssignmentType]
|
|
||||||
vcfg = cfg.vision_config
|
|
||||||
|
|
||||||
if os.path.isdir(model_name):
|
|
||||||
local_model = True
|
|
||||||
if model_name.endswith(os.sep):
|
|
||||||
model_name = model_name[:-1]
|
|
||||||
model_path = model_name
|
|
||||||
model_name = os.path.basename(model_name)
|
|
||||||
fname_out = f"{model_name.replace('/', '-').lower()}-vision.gguf"
|
|
||||||
|
|
||||||
fout = GGUFWriter(path=fname_out, arch="clip")
|
|
||||||
fout.add_description("image encoder for Qwen2VL")
|
|
||||||
|
|
||||||
fout.add_file_type(ftype)
|
|
||||||
fout.add_bool("clip.has_text_encoder", False)
|
|
||||||
fout.add_bool("clip.has_vision_encoder", True)
|
|
||||||
fout.add_bool("clip.has_qwen2vl_merger", True)
|
|
||||||
|
|
||||||
print(cfg.vision_config)
|
|
||||||
if 'silu' in cfg.vision_config.hidden_act.lower():
|
|
||||||
fout.add_bool("clip.use_silu", True)
|
|
||||||
fout.add_bool("clip.use_gelu", False)
|
|
||||||
elif 'gelu' in cfg.vision_config.hidden_act.lower():
|
|
||||||
fout.add_bool("clip.use_silu", False)
|
|
||||||
fout.add_bool("clip.use_gelu", 'quick' not in cfg.vision_config.hidden_act.lower())
|
|
||||||
else:
|
|
||||||
raise ValueError()
|
|
||||||
|
|
||||||
if args.model_type == "qwen2.5vl":
|
|
||||||
fout.add_uint32("clip.vision.n_wa_pattern", get_n_wa_pattern(vcfg.fullatt_block_indexes))
|
|
||||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.hidden_size)
|
|
||||||
fout.add_uint32("clip.vision.projection_dim", vcfg.out_hidden_size)
|
|
||||||
fout.add_string("clip.projector_type", "qwen2.5vl_merger")
|
|
||||||
else:
|
|
||||||
fout.add_string("clip.projector_type", "qwen2vl_merger")
|
|
||||||
fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim)
|
|
||||||
fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size)
|
|
||||||
|
|
||||||
if args.model_type == "qwen2.5vl":
|
|
||||||
tensor_map = VL25.find_vision_tensors(qwen2vl, np_dtype)
|
|
||||||
else:
|
|
||||||
tensor_map = VL2.find_vision_tensors(qwen2vl, np_dtype)
|
|
||||||
for name, data in tensor_map.items():
|
|
||||||
fout.add_tensor(name, data)
|
|
||||||
|
|
||||||
fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
|
|
||||||
fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2)
|
|
||||||
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
|
|
||||||
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
|
|
||||||
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
|
|
||||||
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # not sure what this does, put 0 here as a placeholder
|
|
||||||
fout.add_name(model_name)
|
|
||||||
"""
|
|
||||||
HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig,
|
|
||||||
it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if local_model:
|
|
||||||
processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path)
|
|
||||||
else:
|
|
||||||
processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name)
|
|
||||||
fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue]
|
|
||||||
fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue]
|
|
||||||
|
|
||||||
fout.write_header_to_file()
|
|
||||||
fout.write_kv_data_to_file()
|
|
||||||
fout.write_tensors_to_file()
|
|
||||||
fout.close()
|
|
||||||
print("save model as: ", fname_out)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct")
|
|
||||||
parser.add_argument("--model_type", nargs='?', choices=['qwen2vl', 'qwen2.5vl'], default="qwen2vl")
|
|
||||||
parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32")
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
|
@ -36,12 +36,6 @@ add_test() {
|
|||||||
arr_tmpl+=("$tmpl")
|
arr_tmpl+=("$tmpl")
|
||||||
}
|
}
|
||||||
|
|
||||||
add_test_big() {
|
|
||||||
if [ "$RUN_BIG_TESTS" = true ]; then
|
|
||||||
add_test "$@"
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
|
add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
|
||||||
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
|
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
|
||||||
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
|
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
|
||||||
@ -58,8 +52,16 @@ add_test "llama-mtmd-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
|
|||||||
add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
|
add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
|
||||||
|
|
||||||
# to test the big models, run: ./tests.sh big
|
# to test the big models, run: ./tests.sh big
|
||||||
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
|
if [ "$RUN_BIG_TESTS" = true ]; then
|
||||||
add_test_big "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
|
add_test "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
|
||||||
|
add_test "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
|
||||||
|
add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
|
||||||
|
add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M"
|
||||||
|
add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
|
||||||
|
add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M"
|
||||||
|
# add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
|
||||||
|
# add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" # too big
|
||||||
|
fi
|
||||||
|
|
||||||
# these models always give the wrong answer, not sure why
|
# these models always give the wrong answer, not sure why
|
||||||
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
|
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
|
||||||
|
@ -234,6 +234,7 @@ class Keys:
|
|||||||
SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size"
|
SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size"
|
||||||
USE_GELU = "clip.use_gelu"
|
USE_GELU = "clip.use_gelu"
|
||||||
USE_SILU = "clip.use_silu"
|
USE_SILU = "clip.use_silu"
|
||||||
|
N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl
|
||||||
|
|
||||||
class Attention:
|
class Attention:
|
||||||
HEAD_COUNT = "clip.vision.attention.head_count"
|
HEAD_COUNT = "clip.vision.attention.head_count"
|
||||||
@ -2162,6 +2163,8 @@ class VisionProjectorType:
|
|||||||
GEMMA3 = "gemma3"
|
GEMMA3 = "gemma3"
|
||||||
IDEFICS3 = "idefics3"
|
IDEFICS3 = "idefics3"
|
||||||
PIXTRAL = "pixtral"
|
PIXTRAL = "pixtral"
|
||||||
|
QWEN2VL = "qwen2vl_merger"
|
||||||
|
QWEN25VL = "qwen2.5vl_merger"
|
||||||
|
|
||||||
|
|
||||||
# Items here are (block size, type size)
|
# Items here are (block size, type size)
|
||||||
|
@ -984,6 +984,9 @@ class GGUFWriter:
|
|||||||
def add_vision_projector_scale_factor(self, value: int) -> None:
|
def add_vision_projector_scale_factor(self, value: int) -> None:
|
||||||
self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
|
self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
|
||||||
|
|
||||||
|
def add_vision_n_wa_pattern(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value)
|
||||||
|
|
||||||
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
|
||||||
pack_prefix = ''
|
pack_prefix = ''
|
||||||
if not skip_pack_prefix:
|
if not skip_pack_prefix:
|
||||||
|
@ -896,6 +896,7 @@ class TensorNameMap:
|
|||||||
|
|
||||||
MODEL_TENSOR.V_MMPROJ: (
|
MODEL_TENSOR.V_MMPROJ: (
|
||||||
"multi_modal_projector.linear_{bid}",
|
"multi_modal_projector.linear_{bid}",
|
||||||
|
"visual.merger.mlp.{bid}", # qwen2vl
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_MMPROJ_FC: (
|
MODEL_TENSOR.V_MMPROJ_FC: (
|
||||||
@ -919,6 +920,7 @@ class TensorNameMap:
|
|||||||
"vpm.embeddings.patch_embedding",
|
"vpm.embeddings.patch_embedding",
|
||||||
"model.vision_model.embeddings.patch_embedding", # SmolVLM
|
"model.vision_model.embeddings.patch_embedding", # SmolVLM
|
||||||
"vision_tower.patch_conv", # pixtral
|
"vision_tower.patch_conv", # pixtral
|
||||||
|
"visual.patch_embed.proj", # qwen2vl
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_EMBD_POS: (
|
MODEL_TENSOR.V_ENC_EMBD_POS: (
|
||||||
@ -932,6 +934,7 @@ class TensorNameMap:
|
|||||||
"vpm.encoder.layers.{bid}.self_attn.q_proj",
|
"vpm.encoder.layers.{bid}.self_attn.q_proj",
|
||||||
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
|
||||||
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
|
||||||
|
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_ATTN_K: (
|
MODEL_TENSOR.V_ENC_ATTN_K: (
|
||||||
@ -939,6 +942,7 @@ class TensorNameMap:
|
|||||||
"vpm.encoder.layers.{bid}.self_attn.k_proj",
|
"vpm.encoder.layers.{bid}.self_attn.k_proj",
|
||||||
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
|
||||||
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
|
||||||
|
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_ATTN_V: (
|
MODEL_TENSOR.V_ENC_ATTN_V: (
|
||||||
@ -946,6 +950,7 @@ class TensorNameMap:
|
|||||||
"vpm.encoder.layers.{bid}.self_attn.v_proj",
|
"vpm.encoder.layers.{bid}.self_attn.v_proj",
|
||||||
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
|
||||||
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
|
||||||
|
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_INPUT_NORM: (
|
MODEL_TENSOR.V_ENC_INPUT_NORM: (
|
||||||
@ -953,6 +958,7 @@ class TensorNameMap:
|
|||||||
"vpm.encoder.layers.{bid}.layer_norm1",
|
"vpm.encoder.layers.{bid}.layer_norm1",
|
||||||
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
|
||||||
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
|
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
|
||||||
|
"visual.blocks.{bid}.norm1", # qwen2vl
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_OUTPUT: (
|
MODEL_TENSOR.V_ENC_OUTPUT: (
|
||||||
@ -960,6 +966,7 @@ class TensorNameMap:
|
|||||||
"vpm.encoder.layers.{bid}.self_attn.out_proj",
|
"vpm.encoder.layers.{bid}.self_attn.out_proj",
|
||||||
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
|
||||||
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
|
||||||
|
"visual.blocks.{bid}.attn.proj", # qwen2vl
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
|
MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
|
||||||
@ -967,17 +974,24 @@ class TensorNameMap:
|
|||||||
"vpm.encoder.layers.{bid}.layer_norm2",
|
"vpm.encoder.layers.{bid}.layer_norm2",
|
||||||
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
|
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
|
||||||
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
|
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
|
||||||
|
"visual.blocks.{bid}.norm2", # qwen2vl
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# some namings are messed up because the original llava code swapped fc1 and fc2
|
||||||
|
# we have no better way to fix it, just be careful
|
||||||
|
# new models like pixtral use the correct naming
|
||||||
MODEL_TENSOR.V_ENC_FFN_UP: (
|
MODEL_TENSOR.V_ENC_FFN_UP: (
|
||||||
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
|
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
|
||||||
"vpm.encoder.layers.{bid}.mlp.fc1",
|
"vpm.encoder.layers.{bid}.mlp.fc1",
|
||||||
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped)
|
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped)
|
||||||
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
|
||||||
|
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
|
||||||
|
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_FFN_GATE: (
|
MODEL_TENSOR.V_ENC_FFN_GATE: (
|
||||||
"vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral
|
||||||
|
"visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_ENC_FFN_DOWN: (
|
MODEL_TENSOR.V_ENC_FFN_DOWN: (
|
||||||
@ -985,6 +999,8 @@ class TensorNameMap:
|
|||||||
"vpm.encoder.layers.{bid}.mlp.fc2",
|
"vpm.encoder.layers.{bid}.mlp.fc2",
|
||||||
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped)
|
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped)
|
||||||
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
|
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
|
||||||
|
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
|
||||||
|
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_PRE_NORM: (
|
MODEL_TENSOR.V_PRE_NORM: (
|
||||||
@ -995,6 +1011,7 @@ class TensorNameMap:
|
|||||||
MODEL_TENSOR.V_POST_NORM: (
|
MODEL_TENSOR.V_POST_NORM: (
|
||||||
"vision_tower.vision_model.post_layernorm",
|
"vision_tower.vision_model.post_layernorm",
|
||||||
"model.vision_model.post_layernorm", # SmolVLM
|
"model.vision_model.post_layernorm", # SmolVLM
|
||||||
|
"visual.merger.ln_q", # qwen2vl
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_MM_INP_PROJ: (
|
MODEL_TENSOR.V_MM_INP_PROJ: (
|
||||||
|
Reference in New Issue
Block a user