convert : experimental support for --mmproj flag (#13023)

* convert : experimental support for `--mmproj` flag

* fix bad ctrl+f replace

* fix style

* split into subclasses TextModel and VisionModel

* rename Mode --> ModelBase

* small fix

* correct CLIP_VISION arch name (because existing GGUF already use it)

* Apply suggestions from code review

Co-authored-by: compilade <git@compilade.net>

* fix Mistral3Model

* fix typo

Co-authored-by: compilade <git@compilade.net>

---------

Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
Xuan-Son Nguyen
2025-04-20 23:29:36 +02:00
committed by GitHub
parent 6602304814
commit 2016f07bd1
5 changed files with 663 additions and 295 deletions

File diff suppressed because it is too large Load Diff

View File

@ -24,7 +24,7 @@ if 'NO_LOCAL_GGUF' not in os.environ:
import gguf
# reuse model definitions from convert_hf_to_gguf.py
from convert_hf_to_gguf import LazyTorchTensor, Model
from convert_hf_to_gguf import LazyTorchTensor, ModelBase
logger = logging.getLogger("lora-to-gguf")
@ -340,11 +340,11 @@ if __name__ == '__main__':
sys.exit(1)
else:
logger.info(f"Loading base model: {dir_base_model.name}")
hparams = Model.load_hparams(dir_base_model)
hparams = ModelBase.load_hparams(dir_base_model)
with torch.inference_mode():
try:
model_class = Model.from_model_architecture(hparams["architectures"][0])
model_class = ModelBase.from_model_architecture(hparams["architectures"][0])
except NotImplementedError:
logger.error(f"Model {hparams['architectures'][0]} is not supported")
sys.exit(1)

View File

@ -50,7 +50,6 @@
// tensor name constants
//
#define TN_TOKEN_EMBD "%s.token_embd.weight"
#define TN_POS_EMBD "%s.position_embd.weight"
#define TN_CLASS_EMBD "v.class_embd"
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
@ -66,8 +65,6 @@
#define TN_LN_2 "%s.blk.%d.ln2.%s"
#define TN_LN_PRE "%s.pre_ln.%s"
#define TN_LN_POST "%s.post_ln.%s"
#define TN_TEXT_PROJ "text_projection.weight"
#define TN_VIS_PROJ "visual_projection.weight"
#define TN_LLAVA_PROJ "mm.%d.%s"
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"

View File

@ -218,6 +218,24 @@ class Keys:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
class ClipVision:
PROJECTOR_TYPE = "clip.projector_type"
HAS_VISION_ENCODER = "clip.has_vision_encoder"
HAS_LLAVA_PROJECTOR = "clip.has_llava_projector"
IMAGE_SIZE = "clip.vision.image_size"
PATCH_SIZE = "clip.vision.patch_size"
EMBEDDING_LENGTH = "clip.vision.embedding_length"
FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length"
PROJECTION_DIM = "clip.vision.projection_dim"
BLOCK_COUNT = "clip.vision.block_count"
IMAGE_MEAN = "clip.vision.image_mean"
IMAGE_STD = "clip.vision.image_std"
USE_GELU = "clip.use_gelu"
class Attention:
HEAD_COUNT = "clip.vision.attention.head_count"
LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon"
#
# recommended mapping of model tensor names for storage in gguf
#
@ -226,9 +244,11 @@ class Keys:
class GGUFType:
MODEL = "model"
ADAPTER = "adapter"
CLIP_VISION = "clip-vision"
class MODEL_ARCH(IntEnum):
CLIP_VISION = auto() # dummy arch for clip.cpp
LLAMA = auto()
LLAMA4 = auto()
DECI = auto()
@ -297,6 +317,16 @@ class MODEL_ARCH(IntEnum):
BAILINGMOE = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
MLP = auto()
LDP = auto()
LDPV2 = auto()
RESAMPLER = auto()
GLM_EDGE = auto()
MERGER = auto()
GEMMA3 = auto()
class MODEL_TENSOR(IntEnum):
TOKEN_EMBD = auto()
TOKEN_EMBD_NORM = auto()
@ -436,9 +466,41 @@ class MODEL_TENSOR(IntEnum):
POSNET_ATTN_K = auto()
POSNET_ATTN_V = auto()
POSNET_ATTN_OUT = auto()
# vision
V_MMPROJ = auto()
V_MMPROJ_FC = auto()
V_MMPROJ_MLP = auto()
V_MMPROJ_PEG = auto()
V_ENC_EMBD_CLS = auto()
V_ENC_EMBD_PATCH = auto()
V_ENC_EMBD_POS = auto()
V_ENC_ATTN_Q = auto()
V_ENC_ATTN_K = auto()
V_ENC_ATTN_V = auto()
V_ENC_INPUT_NORM = auto()
V_ENC_OUTPUT = auto()
V_ENC_OUTPUT_NORM = auto()
V_ENC_FFN_UP = auto()
V_ENC_FFN_DOWN = auto()
V_PRE_NORM = auto()
V_POST_NORM = auto()
V_MM_INP_PROJ = auto() # gemma3
V_MM_SOFT_EMB_NORM = auto() # gemma3
V_RESMPL_POS_EMBD_K = auto() # minicpmv
V_RESMPL_ATTN_Q = auto() # minicpmv
V_RESMPL_ATTN_K = auto() # minicpmv
V_RESMPL_ATTN_V = auto() # minicpmv
V_RESMPL_ATTN_OUT = auto() # minicpmv
V_RESMPL_KV = auto() # minicpmv
V_RESMPL_KV_NORM = auto() # minicpmv
V_RESMPL_POST_NORM = auto() # minicpmv
V_RESMPL_Q_NORM = auto() # minicpmv
V_RESMPL_PROJ = auto() # minicpmv
V_RESMPL_QUERY = auto() # minicpmv
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.LLAMA4: "llama4",
MODEL_ARCH.DECI: "deci",
@ -507,6 +569,16 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.BAILINGMOE: "bailingmoe",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
VISION_PROJECTOR_TYPE.MLP: "mlp",
VISION_PROJECTOR_TYPE.LDP: "ldp",
VISION_PROJECTOR_TYPE.LDPV2: "ldpv2",
VISION_PROJECTOR_TYPE.RESAMPLER: "resampler",
VISION_PROJECTOR_TYPE.GLM_EDGE: "adapter",
VISION_PROJECTOR_TYPE.MERGER: "qwen2vl_merger",
VISION_PROJECTOR_TYPE.GEMMA3: "gemma3",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
@ -646,9 +718,72 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k",
MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v",
MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
# vision
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
MODEL_TENSOR.V_MMPROJ_MLP: "mm.model.mlp.{bid}",
MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}",
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k",
MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v",
MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1",
MODEL_TENSOR.V_ENC_OUTPUT: "v.blk.{bid}.attn_out",
MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.blk.{bid}.ln2",
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k",
MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q",
MODEL_TENSOR.V_RESMPL_ATTN_K: "resampler.attn.k",
MODEL_TENSOR.V_RESMPL_ATTN_V: "resampler.attn.v",
MODEL_TENSOR.V_RESMPL_ATTN_OUT: "resampler.attn.out",
MODEL_TENSOR.V_RESMPL_KV: "resampler.kv",
MODEL_TENSOR.V_RESMPL_KV_NORM: "resampler.ln_kv",
MODEL_TENSOR.V_RESMPL_POST_NORM: "resampler.ln_post",
MODEL_TENSOR.V_RESMPL_Q_NORM: "resampler.ln_q",
MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj",
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
}
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_ARCH.CLIP_VISION: [
MODEL_TENSOR.V_MMPROJ,
MODEL_TENSOR.V_MMPROJ_FC,
MODEL_TENSOR.V_MMPROJ_MLP,
MODEL_TENSOR.V_MMPROJ_PEG,
MODEL_TENSOR.V_ENC_EMBD_CLS,
MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_POS,
MODEL_TENSOR.V_ENC_ATTN_Q,
MODEL_TENSOR.V_ENC_ATTN_K,
MODEL_TENSOR.V_ENC_ATTN_V,
MODEL_TENSOR.V_ENC_INPUT_NORM,
MODEL_TENSOR.V_ENC_OUTPUT,
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
MODEL_TENSOR.V_ENC_FFN_UP,
MODEL_TENSOR.V_ENC_FFN_DOWN,
MODEL_TENSOR.V_PRE_NORM,
MODEL_TENSOR.V_POST_NORM,
MODEL_TENSOR.V_MM_INP_PROJ,
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
MODEL_TENSOR.V_RESMPL_ATTN_Q,
MODEL_TENSOR.V_RESMPL_ATTN_K,
MODEL_TENSOR.V_RESMPL_ATTN_V,
MODEL_TENSOR.V_RESMPL_ATTN_OUT,
MODEL_TENSOR.V_RESMPL_KV,
MODEL_TENSOR.V_RESMPL_KV_NORM,
MODEL_TENSOR.V_RESMPL_POST_NORM,
MODEL_TENSOR.V_RESMPL_Q_NORM,
MODEL_TENSOR.V_RESMPL_PROJ,
MODEL_TENSOR.V_RESMPL_QUERY,
],
MODEL_ARCH.LLAMA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View File

@ -886,6 +886,150 @@ class TensorNameMap:
MODEL_TENSOR.POSNET_ATTN_OUT: (
"backbone.posnet.{bid}.proj_out", # wavtokenizer
),
#############################################################################
## Vision encoder
MODEL_TENSOR.V_MMPROJ: (
"multi_modal_projector.linear_{bid}",
),
MODEL_TENSOR.V_MMPROJ_FC: (
"model.connector.modality_projection.proj", # SmolVLM
),
MODEL_TENSOR.V_MMPROJ_MLP: (
"model.mm_projector.mlp.mlp.{bid}",
),
MODEL_TENSOR.V_MMPROJ_PEG: (
"model.mm_projector.peg.peg.{bid}",
),
MODEL_TENSOR.V_ENC_EMBD_CLS: (
"vision_tower.vision_model.embeddings.class_embedding",
),
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
"vision_tower.vision_model.embeddings.patch_embedding",
"vpm.embeddings.patch_embedding",
"model.vision_model.embeddings.patch_embedding", # SmolVLM
),
MODEL_TENSOR.V_ENC_EMBD_POS: (
"vision_tower.vision_model.embeddings.position_embedding",
"vpm.embeddings.position_embedding",
"model.vision_model.embeddings.position_embedding", # SmolVLM
),
MODEL_TENSOR.V_ENC_ATTN_Q: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
"vpm.encoder.layers.{bid}.self_attn.q_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
),
MODEL_TENSOR.V_ENC_ATTN_K: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
"vpm.encoder.layers.{bid}.self_attn.k_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
),
MODEL_TENSOR.V_ENC_ATTN_V: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
"vpm.encoder.layers.{bid}.self_attn.v_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
),
MODEL_TENSOR.V_ENC_INPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
"vpm.encoder.layers.{bid}.layer_norm1",
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
),
MODEL_TENSOR.V_ENC_OUTPUT: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
"vpm.encoder.layers.{bid}.self_attn.out_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
),
MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
"vpm.encoder.layers.{bid}.layer_norm2",
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
),
MODEL_TENSOR.V_ENC_FFN_UP: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
"vpm.encoder.layers.{bid}.mlp.fc1",
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM
),
MODEL_TENSOR.V_ENC_FFN_DOWN: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
"vpm.encoder.layers.{bid}.mlp.fc2",
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM
),
MODEL_TENSOR.V_PRE_NORM: (
"vision_tower.vision_model.pre_layrnorm",
),
MODEL_TENSOR.V_POST_NORM: (
"vision_tower.vision_model.post_layernorm",
"model.vision_model.post_layernorm", # SmolVLM
),
MODEL_TENSOR.V_MM_INP_PROJ: (
"multi_modal_projector.mm_input_projection",
),
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
"multi_modal_projector.mm_soft_emb_norm",
),
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: (
"resampler.pos_embed_k",
),
MODEL_TENSOR.V_RESMPL_ATTN_Q: (
"resampler.attn.in_proj_q", # tensor generated from resampler.attn.in_proj
),
MODEL_TENSOR.V_RESMPL_ATTN_K: (
"resampler.attn.in_proj_k", # tensor generated from resampler.attn.in_proj
),
MODEL_TENSOR.V_RESMPL_ATTN_V: (
"resampler.attn.in_proj_v", # tensor generated from resampler.attn.in_proj
),
MODEL_TENSOR.V_RESMPL_ATTN_OUT: (
"resampler.attn.out_proj",
),
MODEL_TENSOR.V_RESMPL_KV: (
"resampler.kv_proj",
),
MODEL_TENSOR.V_RESMPL_POST_NORM: (
"resampler.ln_post",
),
MODEL_TENSOR.V_RESMPL_KV_NORM: (
"resampler.ln_kv",
),
MODEL_TENSOR.V_RESMPL_Q_NORM: (
"resampler.ln_q",
),
MODEL_TENSOR.V_RESMPL_PROJ: (
"resampler.proj",
),
MODEL_TENSOR.V_RESMPL_QUERY: (
"resampler.query",
),
}
# architecture-specific block mappings