mtmd : add **vision** support for Mistral Small 3.1 (#13231)

* convert ok

* load ok, missing patch merger

* ah sheet it works

* update llava/readme

* add test

* fix test
This commit is contained in:
Xuan-Son Nguyen
2025-05-01 17:05:42 +02:00
committed by GitHub
parent 13c9a3319b
commit 8936784f7a
9 changed files with 112 additions and 15 deletions

View File

@ -1899,7 +1899,10 @@ class LlamaModel(TextModel):
raise ValueError(f"Unprocessed experts: {experts}") raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("LlavaForConditionalGeneration") @ModelBase.register(
"LlavaForConditionalGeneration", # pixtral
"Mistral3ForConditionalGeneration", # mistral small 3.1
)
class LlavaVisionModel(VisionModel): class LlavaVisionModel(VisionModel):
img_break_tok_id = -1 img_break_tok_id = -1
@ -1908,17 +1911,38 @@ class LlavaVisionModel(VisionModel):
if self.hparams["model_type"] == "pixtral": if self.hparams["model_type"] == "pixtral":
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5) self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
self.img_break_tok_id = 12 # see tokenizer_config.json self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
logger.info(f"Image break token id: {self.img_break_tok_id}")
else: else:
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}") raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
def get_token_id(self, token: str) -> int:
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
added_tokens_decoder = json.load(f)['added_tokens_decoder']
for id_, token_data in added_tokens_decoder.items():
if token_data["content"] == token:
return int(id_)
raise ValueError(f"Token '{token}' not found in tokenizer config.")
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
hparams = self.hparams hparams = self.hparams
if hparams["model_type"] == "pixtral": if hparams["model_type"] == "pixtral":
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL) self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
self.gguf_writer.add_vision_use_silu(True)
# hidden_act
if hparams["hidden_act"] == "silu":
self.gguf_writer.add_vision_use_silu(True)
elif hparams["hidden_act"] == "gelu":
self.gguf_writer.add_vision_use_gelu(True)
else:
raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}")
# spatial_merge_size
if "spatial_merge_size" in self.global_config:
self.gguf_writer.add_vision_spatial_merge_size(self.global_config["spatial_merge_size"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused del bid # unused

View File

@ -34,6 +34,9 @@ llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF
# Pixtral 12B # Pixtral 12B
llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
# Mistral Small 3.1 24B (IQ2_M quantization)
llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7
``` ```
## How it works and what is `mmproj`? ## How it works and what is `mmproj`?
@ -73,3 +76,4 @@ For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` fla
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB)) - SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB)) - SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint - [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)

View File

@ -31,6 +31,7 @@
#define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_PROJ_TYPE "clip.projector_type" #define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl #define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl
#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl #define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl
@ -68,9 +69,11 @@
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" #define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline" #define TN_IMAGE_NEWLINE "model.image_newline"
#define TN_MM_INP_NORM "mm.input_norm.weight"
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 #define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 #define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3 #define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral #define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
// mimicpmv // mimicpmv

View File

@ -172,6 +172,7 @@ struct clip_hparams {
std::unordered_set<int32_t> vision_feature_layer; std::unordered_set<int32_t> vision_feature_layer;
int32_t attn_window_size = 0; int32_t attn_window_size = 0;
int32_t n_wa_pattern = 0; int32_t n_wa_pattern = 0;
int32_t spatial_merge_size = 0;
}; };
struct clip_layer { struct clip_layer {
@ -232,6 +233,7 @@ struct clip_vision_model {
struct ggml_tensor * projection; struct ggml_tensor * projection;
// LLaVA projection // LLaVA projection
struct ggml_tensor * mm_input_norm_w = nullptr;
struct ggml_tensor * mm_0_w = nullptr; struct ggml_tensor * mm_0_w = nullptr;
struct ggml_tensor * mm_0_b = nullptr; struct ggml_tensor * mm_0_b = nullptr;
struct ggml_tensor * mm_2_w = nullptr; struct ggml_tensor * mm_2_w = nullptr;
@ -311,6 +313,7 @@ struct clip_vision_model {
// pixtral // pixtral
struct ggml_tensor * token_embd_img_break = nullptr; struct ggml_tensor * token_embd_img_break = nullptr;
struct ggml_tensor * mm_patch_merger_w = nullptr;
}; };
struct clip_ctx { struct clip_ctx {
@ -637,6 +640,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
const int d_head = hidden_size / n_head; const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer; const int n_layer = hparams.n_layer;
const float eps = hparams.eps; const float eps = hparams.eps;
const int n_merge = hparams.spatial_merge_size;
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(), /*.mem_size =*/ ctx->buf_compute_meta.size(),
@ -721,7 +725,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
{ {
ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur); ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur); ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu if (ctx->use_silu) {
gate_proj = ggml_silu(ctx0, gate_proj);
} else if (ctx->use_gelu) {
gate_proj = ggml_gelu(ctx0, gate_proj);
} else {
GGML_ABORT("Pixtral: Unsupported activation");
}
cur = ggml_mul(ctx0, up_proj, gate_proj); cur = ggml_mul(ctx0, up_proj, gate_proj);
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur); cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
} }
@ -732,14 +742,42 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
embeddings = cur; embeddings = cur;
} }
// LlavaMultiModalProjector (with GELU activation) // mistral small 3.1 patch merger
// ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
if (model.mm_patch_merger_w) {
GGML_ASSERT(hparams.spatial_merge_size > 0);
ggml_tensor * cur = embeddings;
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
// reshape image tokens to 2D grid
cur = ggml_reshape_3d(ctx0, cur, hidden_size, n_patches_x, n_patches_y);
cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, hidden_size]
cur = ggml_cont(ctx0, cur);
// torch.nn.functional.unfold is just an im2col under the hood
// we just need a dummy kernel to make it work
ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
// project to hidden_size
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
embeddings = cur;
}
// LlavaMultiModalProjector (always using GELU activation)
{ {
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); if (model.mm_1_b) {
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
}
embeddings = ggml_gelu(ctx0, embeddings); embeddings = ggml_gelu(ctx0, embeddings);
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); if (model.mm_2_b) {
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
}
} }
// arrangement of the [IMG_BREAK] token // arrangement of the [IMG_BREAK] token
@ -749,11 +787,14 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows] // after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
const int p_total = p_x * p_y;
const int n_embd_text = embeddings->ne[0]; const int n_embd_text = embeddings->ne[0];
const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row
ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y); ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, p_x, p_y);
ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y); ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, p_y);
tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
tok = ggml_add(ctx0, tok, model.token_embd_img_break); tok = ggml_add(ctx0, tok, model.token_embd_img_break);
cur = ggml_concat(ctx0, cur, tok, 1); cur = ggml_concat(ctx0, cur, tok, 1);
@ -1734,6 +1775,7 @@ struct clip_model_loader {
case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_PIXTRAL:
{ {
hparams.rope_theta = 10000.0f; hparams.rope_theta = 10000.0f;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
} break; } break;
case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN25VL:
{ {
@ -1957,11 +1999,14 @@ struct clip_model_loader {
case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_PIXTRAL:
{ {
vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
// [IMG_BREAK] token embedding // [IMG_BREAK] token embedding
vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK); vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
// for mistral small 3.1
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
} break; } break;
default: default:
GGML_ASSERT(false && "unknown projector type"); GGML_ASSERT(false && "unknown projector type");
@ -2926,8 +2971,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
n_patches /= ctx->vision_model.hparams.proj_scale_factor; n_patches /= ctx->vision_model.hparams.proj_scale_factor;
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
int n_patches_x = img->nx / params.patch_size; int n_merge = ctx->vision_model.hparams.spatial_merge_size;
int n_patches_y = img->ny / params.patch_size; int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
} }
@ -3484,7 +3530,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.mm_model_peg_0_b->ne[0]; return ctx->vision_model.mm_model_peg_0_b->ne[0];
case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_PIXTRAL:
return ctx->vision_model.mm_2_b->ne[0]; return ctx->vision_model.mm_2_w->ne[1];
case PROJECTOR_TYPE_MLP_NORM: case PROJECTOR_TYPE_MLP_NORM:
return ctx->vision_model.mm_3_b->ne[0]; return ctx->vision_model.mm_3_b->ne[0];
case PROJECTOR_TYPE_MINICPMV: case PROJECTOR_TYPE_MINICPMV:

View File

@ -94,6 +94,7 @@ struct mtmd_cli_context {
LOG_ERR("Model does not have chat template.\n"); LOG_ERR("Model does not have chat template.\n");
LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n"); LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n");
LOG_ERR(" For MobileVLM models, use '--chat-template deepseek'\n"); LOG_ERR(" For MobileVLM models, use '--chat-template deepseek'\n");
LOG_ERR(" For Mistral Small 3.1, use '--chat-template mistral-v7'\n");
exit(1); exit(1);
} }

View File

@ -59,6 +59,7 @@ add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
# to test the big models, run: ./tests.sh big # to test the big models, run: ./tests.sh big
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M" add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
add_test_big "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
# these models always give the wrong answer, not sure why # these models always give the wrong answer, not sure why
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M" # add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"

View File

@ -231,6 +231,7 @@ class Keys:
BLOCK_COUNT = "clip.vision.block_count" BLOCK_COUNT = "clip.vision.block_count"
IMAGE_MEAN = "clip.vision.image_mean" IMAGE_MEAN = "clip.vision.image_mean"
IMAGE_STD = "clip.vision.image_std" IMAGE_STD = "clip.vision.image_std"
SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size"
USE_GELU = "clip.use_gelu" USE_GELU = "clip.use_gelu"
USE_SILU = "clip.use_silu" USE_SILU = "clip.use_silu"
@ -491,6 +492,7 @@ class MODEL_TENSOR(IntEnum):
V_ENC_FFN_DOWN = auto() V_ENC_FFN_DOWN = auto()
V_PRE_NORM = auto() V_PRE_NORM = auto()
V_POST_NORM = auto() V_POST_NORM = auto()
V_MM_INP_NORM = auto()
V_MM_INP_PROJ = auto() # gemma3 V_MM_INP_PROJ = auto() # gemma3
V_MM_SOFT_EMB_NORM = auto() # gemma3 V_MM_SOFT_EMB_NORM = auto() # gemma3
V_RESMPL_POS_EMBD_K = auto() # minicpmv V_RESMPL_POS_EMBD_K = auto() # minicpmv
@ -505,6 +507,7 @@ class MODEL_TENSOR(IntEnum):
V_RESMPL_PROJ = auto() # minicpmv V_RESMPL_PROJ = auto() # minicpmv
V_RESMPL_QUERY = auto() # minicpmv V_RESMPL_QUERY = auto() # minicpmv
V_TOK_EMBD_IMG_BREAK = auto() # pixtral V_TOK_EMBD_IMG_BREAK = auto() # pixtral
V_MM_PATCH_MERGER = auto() # mistral small 3.1
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -747,6 +750,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln", MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
MODEL_TENSOR.V_POST_NORM: "v.post_ln", MODEL_TENSOR.V_POST_NORM: "v.post_ln",
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection", MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm",
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm", MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k", MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k",
MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q", MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q",
@ -760,6 +764,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj", MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj",
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query", MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
} }
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -783,6 +788,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_PRE_NORM, MODEL_TENSOR.V_PRE_NORM,
MODEL_TENSOR.V_POST_NORM, MODEL_TENSOR.V_POST_NORM,
MODEL_TENSOR.V_MM_INP_PROJ, MODEL_TENSOR.V_MM_INP_PROJ,
MODEL_TENSOR.V_MM_INP_NORM,
MODEL_TENSOR.V_MM_SOFT_EMB_NORM, MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
MODEL_TENSOR.V_RESMPL_POS_EMBD_K, MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
MODEL_TENSOR.V_RESMPL_ATTN_Q, MODEL_TENSOR.V_RESMPL_ATTN_Q,
@ -796,6 +802,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.V_RESMPL_PROJ, MODEL_TENSOR.V_RESMPL_PROJ,
MODEL_TENSOR.V_RESMPL_QUERY, MODEL_TENSOR.V_RESMPL_QUERY,
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK, MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
MODEL_TENSOR.V_MM_PATCH_MERGER,
], ],
MODEL_ARCH.LLAMA: [ MODEL_ARCH.LLAMA: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,

View File

@ -972,6 +972,9 @@ class GGUFWriter:
def add_vision_image_std(self, values: Sequence[float]) -> None: def add_vision_image_std(self, values: Sequence[float]) -> None:
self.add_array(Keys.ClipVision.IMAGE_STD, values) self.add_array(Keys.ClipVision.IMAGE_STD, values)
def add_vision_spatial_merge_size(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value)
def add_vision_use_gelu(self, value: bool) -> None: def add_vision_use_gelu(self, value: bool) -> None:
self.add_bool(Keys.ClipVision.USE_GELU, value) self.add_bool(Keys.ClipVision.USE_GELU, value)

View File

@ -1001,6 +1001,10 @@ class TensorNameMap:
"multi_modal_projector.mm_input_projection", "multi_modal_projector.mm_input_projection",
), ),
MODEL_TENSOR.V_MM_INP_NORM: (
"multi_modal_projector.norm",
),
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
"multi_modal_projector.mm_soft_emb_norm", "multi_modal_projector.mm_soft_emb_norm",
), ),
@ -1052,6 +1056,10 @@ class TensorNameMap:
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: ( MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: (
"v.token_embd.img_break", # for pixtral, this is a generated vector "v.token_embd.img_break", # for pixtral, this is a generated vector
), ),
MODEL_TENSOR.V_MM_PATCH_MERGER: (
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
),
} }
# architecture-specific block mappings # architecture-specific block mappings