mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 11:45:21 +00:00
convert : experimental support for --mmproj
flag (#13023)
* convert : experimental support for `--mmproj` flag * fix bad ctrl+f replace * fix style * split into subclasses TextModel and VisionModel * rename Mode --> ModelBase * small fix * correct CLIP_VISION arch name (because existing GGUF already use it) * Apply suggestions from code review Co-authored-by: compilade <git@compilade.net> * fix Mistral3Model * fix typo Co-authored-by: compilade <git@compilade.net> --------- Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@ -24,7 +24,7 @@ if 'NO_LOCAL_GGUF' not in os.environ:
|
||||
import gguf
|
||||
|
||||
# reuse model definitions from convert_hf_to_gguf.py
|
||||
from convert_hf_to_gguf import LazyTorchTensor, Model
|
||||
from convert_hf_to_gguf import LazyTorchTensor, ModelBase
|
||||
|
||||
logger = logging.getLogger("lora-to-gguf")
|
||||
|
||||
@ -340,11 +340,11 @@ if __name__ == '__main__':
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.info(f"Loading base model: {dir_base_model.name}")
|
||||
hparams = Model.load_hparams(dir_base_model)
|
||||
hparams = ModelBase.load_hparams(dir_base_model)
|
||||
|
||||
with torch.inference_mode():
|
||||
try:
|
||||
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
||||
model_class = ModelBase.from_model_architecture(hparams["architectures"][0])
|
||||
except NotImplementedError:
|
||||
logger.error(f"Model {hparams['architectures'][0]} is not supported")
|
||||
sys.exit(1)
|
||||
|
@ -50,7 +50,6 @@
|
||||
// tensor name constants
|
||||
//
|
||||
|
||||
#define TN_TOKEN_EMBD "%s.token_embd.weight"
|
||||
#define TN_POS_EMBD "%s.position_embd.weight"
|
||||
#define TN_CLASS_EMBD "v.class_embd"
|
||||
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
|
||||
@ -66,8 +65,6 @@
|
||||
#define TN_LN_2 "%s.blk.%d.ln2.%s"
|
||||
#define TN_LN_PRE "%s.pre_ln.%s"
|
||||
#define TN_LN_POST "%s.post_ln.%s"
|
||||
#define TN_TEXT_PROJ "text_projection.weight"
|
||||
#define TN_VIS_PROJ "visual_projection.weight"
|
||||
#define TN_LLAVA_PROJ "mm.%d.%s"
|
||||
#define TN_MVLM_PROJ_MLP "mm.model.mlp.%d.%s"
|
||||
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
|
||||
|
@ -218,17 +218,37 @@ class Keys:
|
||||
TYPE = "adapter.type"
|
||||
LORA_ALPHA = "adapter.lora.alpha"
|
||||
|
||||
class ClipVision:
|
||||
PROJECTOR_TYPE = "clip.projector_type"
|
||||
HAS_VISION_ENCODER = "clip.has_vision_encoder"
|
||||
HAS_LLAVA_PROJECTOR = "clip.has_llava_projector"
|
||||
IMAGE_SIZE = "clip.vision.image_size"
|
||||
PATCH_SIZE = "clip.vision.patch_size"
|
||||
EMBEDDING_LENGTH = "clip.vision.embedding_length"
|
||||
FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length"
|
||||
PROJECTION_DIM = "clip.vision.projection_dim"
|
||||
BLOCK_COUNT = "clip.vision.block_count"
|
||||
IMAGE_MEAN = "clip.vision.image_mean"
|
||||
IMAGE_STD = "clip.vision.image_std"
|
||||
USE_GELU = "clip.use_gelu"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "clip.vision.attention.head_count"
|
||||
LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon"
|
||||
|
||||
#
|
||||
# recommended mapping of model tensor names for storage in gguf
|
||||
#
|
||||
|
||||
|
||||
class GGUFType:
|
||||
MODEL = "model"
|
||||
ADAPTER = "adapter"
|
||||
MODEL = "model"
|
||||
ADAPTER = "adapter"
|
||||
CLIP_VISION = "clip-vision"
|
||||
|
||||
|
||||
class MODEL_ARCH(IntEnum):
|
||||
CLIP_VISION = auto() # dummy arch for clip.cpp
|
||||
LLAMA = auto()
|
||||
LLAMA4 = auto()
|
||||
DECI = auto()
|
||||
@ -297,6 +317,16 @@ class MODEL_ARCH(IntEnum):
|
||||
BAILINGMOE = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
MLP = auto()
|
||||
LDP = auto()
|
||||
LDPV2 = auto()
|
||||
RESAMPLER = auto()
|
||||
GLM_EDGE = auto()
|
||||
MERGER = auto()
|
||||
GEMMA3 = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
TOKEN_EMBD = auto()
|
||||
TOKEN_EMBD_NORM = auto()
|
||||
@ -436,9 +466,41 @@ class MODEL_TENSOR(IntEnum):
|
||||
POSNET_ATTN_K = auto()
|
||||
POSNET_ATTN_V = auto()
|
||||
POSNET_ATTN_OUT = auto()
|
||||
# vision
|
||||
V_MMPROJ = auto()
|
||||
V_MMPROJ_FC = auto()
|
||||
V_MMPROJ_MLP = auto()
|
||||
V_MMPROJ_PEG = auto()
|
||||
V_ENC_EMBD_CLS = auto()
|
||||
V_ENC_EMBD_PATCH = auto()
|
||||
V_ENC_EMBD_POS = auto()
|
||||
V_ENC_ATTN_Q = auto()
|
||||
V_ENC_ATTN_K = auto()
|
||||
V_ENC_ATTN_V = auto()
|
||||
V_ENC_INPUT_NORM = auto()
|
||||
V_ENC_OUTPUT = auto()
|
||||
V_ENC_OUTPUT_NORM = auto()
|
||||
V_ENC_FFN_UP = auto()
|
||||
V_ENC_FFN_DOWN = auto()
|
||||
V_PRE_NORM = auto()
|
||||
V_POST_NORM = auto()
|
||||
V_MM_INP_PROJ = auto() # gemma3
|
||||
V_MM_SOFT_EMB_NORM = auto() # gemma3
|
||||
V_RESMPL_POS_EMBD_K = auto() # minicpmv
|
||||
V_RESMPL_ATTN_Q = auto() # minicpmv
|
||||
V_RESMPL_ATTN_K = auto() # minicpmv
|
||||
V_RESMPL_ATTN_V = auto() # minicpmv
|
||||
V_RESMPL_ATTN_OUT = auto() # minicpmv
|
||||
V_RESMPL_KV = auto() # minicpmv
|
||||
V_RESMPL_KV_NORM = auto() # minicpmv
|
||||
V_RESMPL_POST_NORM = auto() # minicpmv
|
||||
V_RESMPL_Q_NORM = auto() # minicpmv
|
||||
V_RESMPL_PROJ = auto() # minicpmv
|
||||
V_RESMPL_QUERY = auto() # minicpmv
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.CLIP_VISION: "clip", # dummy arch for clip.cpp
|
||||
MODEL_ARCH.LLAMA: "llama",
|
||||
MODEL_ARCH.LLAMA4: "llama4",
|
||||
MODEL_ARCH.DECI: "deci",
|
||||
@ -507,6 +569,16 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.BAILINGMOE: "bailingmoe",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
VISION_PROJECTOR_TYPE.MLP: "mlp",
|
||||
VISION_PROJECTOR_TYPE.LDP: "ldp",
|
||||
VISION_PROJECTOR_TYPE.LDPV2: "ldpv2",
|
||||
VISION_PROJECTOR_TYPE.RESAMPLER: "resampler",
|
||||
VISION_PROJECTOR_TYPE.GLM_EDGE: "adapter",
|
||||
VISION_PROJECTOR_TYPE.MERGER: "qwen2vl_merger",
|
||||
VISION_PROJECTOR_TYPE.GEMMA3: "gemma3",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm",
|
||||
@ -646,9 +718,72 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k",
|
||||
MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v",
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
|
||||
# vision
|
||||
MODEL_TENSOR.V_MMPROJ: "mm.{bid}",
|
||||
MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc",
|
||||
MODEL_TENSOR.V_MMPROJ_MLP: "mm.model.mlp.{bid}",
|
||||
MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}",
|
||||
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd",
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd",
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd",
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q",
|
||||
MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k",
|
||||
MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v",
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1",
|
||||
MODEL_TENSOR.V_ENC_OUTPUT: "v.blk.{bid}.attn_out",
|
||||
MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.blk.{bid}.ln2",
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
|
||||
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
|
||||
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
|
||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
|
||||
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k",
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q",
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_K: "resampler.attn.k",
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_V: "resampler.attn.v",
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_OUT: "resampler.attn.out",
|
||||
MODEL_TENSOR.V_RESMPL_KV: "resampler.kv",
|
||||
MODEL_TENSOR.V_RESMPL_KV_NORM: "resampler.ln_kv",
|
||||
MODEL_TENSOR.V_RESMPL_POST_NORM: "resampler.ln_post",
|
||||
MODEL_TENSOR.V_RESMPL_Q_NORM: "resampler.ln_q",
|
||||
MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj",
|
||||
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_ARCH.CLIP_VISION: [
|
||||
MODEL_TENSOR.V_MMPROJ,
|
||||
MODEL_TENSOR.V_MMPROJ_FC,
|
||||
MODEL_TENSOR.V_MMPROJ_MLP,
|
||||
MODEL_TENSOR.V_MMPROJ_PEG,
|
||||
MODEL_TENSOR.V_ENC_EMBD_CLS,
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH,
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS,
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q,
|
||||
MODEL_TENSOR.V_ENC_ATTN_K,
|
||||
MODEL_TENSOR.V_ENC_ATTN_V,
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM,
|
||||
MODEL_TENSOR.V_ENC_OUTPUT,
|
||||
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
|
||||
MODEL_TENSOR.V_ENC_FFN_UP,
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN,
|
||||
MODEL_TENSOR.V_PRE_NORM,
|
||||
MODEL_TENSOR.V_POST_NORM,
|
||||
MODEL_TENSOR.V_MM_INP_PROJ,
|
||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
|
||||
MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_Q,
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_K,
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_V,
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_OUT,
|
||||
MODEL_TENSOR.V_RESMPL_KV,
|
||||
MODEL_TENSOR.V_RESMPL_KV_NORM,
|
||||
MODEL_TENSOR.V_RESMPL_POST_NORM,
|
||||
MODEL_TENSOR.V_RESMPL_Q_NORM,
|
||||
MODEL_TENSOR.V_RESMPL_PROJ,
|
||||
MODEL_TENSOR.V_RESMPL_QUERY,
|
||||
],
|
||||
MODEL_ARCH.LLAMA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
@ -886,6 +886,150 @@ class TensorNameMap:
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT: (
|
||||
"backbone.posnet.{bid}.proj_out", # wavtokenizer
|
||||
),
|
||||
|
||||
#############################################################################
|
||||
## Vision encoder
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ: (
|
||||
"multi_modal_projector.linear_{bid}",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ_FC: (
|
||||
"model.connector.modality_projection.proj", # SmolVLM
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ_MLP: (
|
||||
"model.mm_projector.mlp.mlp.{bid}",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MMPROJ_PEG: (
|
||||
"model.mm_projector.peg.peg.{bid}",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_CLS: (
|
||||
"vision_tower.vision_model.embeddings.class_embedding",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
|
||||
"vision_tower.vision_model.embeddings.patch_embedding",
|
||||
"vpm.embeddings.patch_embedding",
|
||||
"model.vision_model.embeddings.patch_embedding", # SmolVLM
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_EMBD_POS: (
|
||||
"vision_tower.vision_model.embeddings.position_embedding",
|
||||
"vpm.embeddings.position_embedding",
|
||||
"model.vision_model.embeddings.position_embedding", # SmolVLM
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_Q: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
|
||||
"vpm.encoder.layers.{bid}.self_attn.q_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_K: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
|
||||
"vpm.encoder.layers.{bid}.self_attn.k_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_ATTN_V: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
|
||||
"vpm.encoder.layers.{bid}.self_attn.v_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_INPUT_NORM: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
|
||||
"vpm.encoder.layers.{bid}.layer_norm1",
|
||||
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_OUTPUT: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
|
||||
"vpm.encoder.layers.{bid}.self_attn.out_proj",
|
||||
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
|
||||
"vpm.encoder.layers.{bid}.layer_norm2",
|
||||
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_UP: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
|
||||
"vpm.encoder.layers.{bid}.mlp.fc1",
|
||||
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_ENC_FFN_DOWN: (
|
||||
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
|
||||
"vpm.encoder.layers.{bid}.mlp.fc2",
|
||||
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_PRE_NORM: (
|
||||
"vision_tower.vision_model.pre_layrnorm",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_POST_NORM: (
|
||||
"vision_tower.vision_model.post_layernorm",
|
||||
"model.vision_model.post_layernorm", # SmolVLM
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MM_INP_PROJ: (
|
||||
"multi_modal_projector.mm_input_projection",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
|
||||
"multi_modal_projector.mm_soft_emb_norm",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: (
|
||||
"resampler.pos_embed_k",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_Q: (
|
||||
"resampler.attn.in_proj_q", # tensor generated from resampler.attn.in_proj
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_K: (
|
||||
"resampler.attn.in_proj_k", # tensor generated from resampler.attn.in_proj
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_V: (
|
||||
"resampler.attn.in_proj_v", # tensor generated from resampler.attn.in_proj
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_ATTN_OUT: (
|
||||
"resampler.attn.out_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_KV: (
|
||||
"resampler.kv_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_POST_NORM: (
|
||||
"resampler.ln_post",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_KV_NORM: (
|
||||
"resampler.ln_kv",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_Q_NORM: (
|
||||
"resampler.ln_q",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_PROJ: (
|
||||
"resampler.proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.V_RESMPL_QUERY: (
|
||||
"resampler.query",
|
||||
),
|
||||
}
|
||||
|
||||
# architecture-specific block mappings
|
||||
|
Reference in New Issue
Block a user