From 792387b1642f85ccc5f35c088b172c6e5d8dcdf6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 30 Apr 2025 12:51:40 +0200 Subject: [PATCH] wip --- convert_hf_to_gguf.py | 75 ++++++++++++++++++++++++++++++++++ gguf-py/gguf/constants.py | 6 +++ gguf-py/gguf/gguf_writer.py | 3 ++ gguf-py/gguf/tensor_mapping.py | 21 ++++++++++ 4 files changed, 105 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b9cea7e46..37a9f36dc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2555,6 +2555,81 @@ class Qwen2VLModel(TextModel): return [(self.map_tensor_name(name), data_torch)] +@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") +class Qwen2VLVisionModel(VisionModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hparams["image_size"] = self.hparams.get("image_size", 560) + # rename config.json values + self.hparams["num_attention_heads"] = self.hparams.get("num_heads") + self.hparams["num_hidden_layers"] = self.hparams.get("depth") + self.hparams["intermediate_size"] = self.hparams.get("hidden_size") + self.hparams["hidden_size"] = self.hparams.get("embed_dim") + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + if self.global_config['model_type'] == 'qwen2_vl': + self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN2VL) + elif self.global_config['model_type'] == 'qwen2_5_vl': + self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN25VL) + self.gguf_writer.add_vision_use_silu(True) + # find n_wa_pattern (window attention pattern) + fullatt_block_indexes = hparams.get("fullatt_block_indexes") + assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for qwen2_5_vl" + n_wa_pattern = fullatt_block_indexes[0] + 1 + # validate n_wa_pattern + for i in range(1, len(fullatt_block_indexes)): + if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern: + raise ValueError(f"Invalid fullatt_block_indexes: {fullatt_block_indexes}") + self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern) + else: + raise ValueError(f"Unknown QwenVL model type: {self.global_config['model_type']}") + # default values below are taken from HF tranformers code + self.gguf_writer.add_vision_attention_layernorm_eps(self.global_config.get("rms_norm_eps", 1e-6)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, name, n_dims # unused + if ".patch_embd." in new_name: + return gguf.GGMLQuantizationType.F16 + if ".position_embd." in new_name: + return gguf.GGMLQuantizationType.F32 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if name.startswith("visual."): + # process visual tensors + # split QKV tensors if needed + if ".qkv." in name: + if data_torch.ndim == 2: # weight + c3, _ = data_torch.shape + else: # bias + c3 = data_torch.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = data_torch[:c] + wk = data_torch[c: c * 2] + wv = data_torch[c * 2:] + return [ + (self.map_tensor_name(name.replace("qkv", "q")), wq), + (self.map_tensor_name(name.replace("qkv", "k")), wk), + (self.map_tensor_name(name.replace("qkv", "v")), wv), + ] + elif 'patch_embed.proj.weight' in name: + # split Conv3D into Conv2Ds + c1, c2, kt, kh, kw = data_torch.shape + del c1, c2, kh, kw # unused + assert kt == 2, "Current implmentation only support temporal_patch_size of 2" + return [ + (self.map_tensor_name(name), data_torch[:, :, 0, ...]), + (self.map_tensor_name(name + '.1'), data_torch[:, :, 1, ...]), + ] + else: + return [(self.map_tensor_name(name), data_torch)] + return [] # skip other tensors + + @ModelBase.register("WavTokenizerDec") class WavTokenizerDecModel(TextModel): model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 326ccdb07..ec3bd20e2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -233,6 +233,7 @@ class Keys: IMAGE_STD = "clip.vision.image_std" USE_GELU = "clip.use_gelu" USE_SILU = "clip.use_silu" + N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl class Attention: HEAD_COUNT = "clip.vision.attention.head_count" @@ -479,6 +480,7 @@ class MODEL_TENSOR(IntEnum): V_MMPROJ_PEG = auto() V_ENC_EMBD_CLS = auto() V_ENC_EMBD_PATCH = auto() + V_ENC_EMBD_PATCH1 = auto() # qwen2vl V_ENC_EMBD_POS = auto() V_ENC_ATTN_Q = auto() V_ENC_ATTN_K = auto() @@ -734,6 +736,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}", MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd", MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd", + MODEL_TENSOR.V_ENC_EMBD_PATCH1: "v.patch_embd.weight.1", # qwen2vl MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd", MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q", MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k", @@ -770,6 +773,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_MMPROJ_PEG, MODEL_TENSOR.V_ENC_EMBD_CLS, MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_PATCH1, MODEL_TENSOR.V_ENC_EMBD_POS, MODEL_TENSOR.V_ENC_ATTN_Q, MODEL_TENSOR.V_ENC_ATTN_K, @@ -2155,6 +2159,8 @@ class VisionProjectorType: GEMMA3 = "gemma3" IDEFICS3 = "idefics3" PIXTRAL = "pixtral" + QWEN2VL = "qwen2vl_merger" + QWEN25VL = "qwen2.5vl_merger" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index f22a6d4a3..b796c1129 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -981,6 +981,9 @@ class GGUFWriter: def add_vision_projector_scale_factor(self, value: int) -> None: self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value) + def add_vision_n_wa_pattern(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value) + def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: pack_prefix = '' if not skip_pack_prefix: diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 311d1ff69..0100d0f33 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -896,6 +896,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", + "visual.merger.mlp.{bid}", # qwen2vl ), MODEL_TENSOR.V_MMPROJ_FC: ( @@ -919,6 +920,11 @@ class TensorNameMap: "vpm.embeddings.patch_embedding", "model.vision_model.embeddings.patch_embedding", # SmolVLM "vision_tower.patch_conv", # pixtral + "visual.patch_embed.proj", # qwen2vl + ), + + MODEL_TENSOR.V_ENC_EMBD_PATCH1: ( + "visual.patch_embed.proj.weight.1", # qwen2vl, generated ), MODEL_TENSOR.V_ENC_EMBD_POS: ( @@ -932,6 +938,7 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.self_attn.q_proj", "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral + "visual.blocks.{bid}.attn.q", # qwen2vl, generated ), MODEL_TENSOR.V_ENC_ATTN_K: ( @@ -939,6 +946,7 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.self_attn.k_proj", "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral + "visual.blocks.{bid}.attn.k", # qwen2vl, generated ), MODEL_TENSOR.V_ENC_ATTN_V: ( @@ -946,6 +954,7 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.self_attn.v_proj", "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral + "visual.blocks.{bid}.attn.v", # qwen2vl, generated ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( @@ -953,6 +962,7 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.layer_norm1", "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral + "visual.blocks.{bid}.norm1", # qwen2vl ), MODEL_TENSOR.V_ENC_OUTPUT: ( @@ -960,6 +970,7 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.self_attn.out_proj", "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral + "visual.blocks.{bid}.attn.proj", # qwen2vl ), MODEL_TENSOR.V_ENC_OUTPUT_NORM: ( @@ -967,17 +978,24 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.layer_norm2", "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral + "visual.blocks.{bid}.norm2", # qwen2vl ), + # some namings are messed up because the original llava code swapped fc1 and fc2 + # we have no better way to fix it, just be careful + # new models like pixtral use the correct naming MODEL_TENSOR.V_ENC_FFN_UP: ( "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", "vpm.encoder.layers.{bid}.mlp.fc1", "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped) "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral + "visual.blocks.{bid}.mlp.fc2", # qwen2vl + "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl ), MODEL_TENSOR.V_ENC_FFN_GATE: ( "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral + "visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl ), MODEL_TENSOR.V_ENC_FFN_DOWN: ( @@ -985,6 +1003,8 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.mlp.fc2", "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped) "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral + "visual.blocks.{bid}.mlp.fc1", # qwen2vl + "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl ), MODEL_TENSOR.V_PRE_NORM: ( @@ -995,6 +1015,7 @@ class TensorNameMap: MODEL_TENSOR.V_POST_NORM: ( "vision_tower.vision_model.post_layernorm", "model.vision_model.post_layernorm", # SmolVLM + "visual.merger.ln_q", # qwen2vl ), MODEL_TENSOR.V_MM_INP_PROJ: (