mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 03:55:20 +00:00
mtmd : add **vision** support for Mistral Small 3.1 (#13231)
* convert ok * load ok, missing patch merger * ah sheet it works * update llava/readme * add test * fix test
This commit is contained in:
@ -1899,7 +1899,10 @@ class LlamaModel(TextModel):
|
|||||||
raise ValueError(f"Unprocessed experts: {experts}")
|
raise ValueError(f"Unprocessed experts: {experts}")
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("LlavaForConditionalGeneration")
|
@ModelBase.register(
|
||||||
|
"LlavaForConditionalGeneration", # pixtral
|
||||||
|
"Mistral3ForConditionalGeneration", # mistral small 3.1
|
||||||
|
)
|
||||||
class LlavaVisionModel(VisionModel):
|
class LlavaVisionModel(VisionModel):
|
||||||
img_break_tok_id = -1
|
img_break_tok_id = -1
|
||||||
|
|
||||||
@ -1908,17 +1911,38 @@ class LlavaVisionModel(VisionModel):
|
|||||||
if self.hparams["model_type"] == "pixtral":
|
if self.hparams["model_type"] == "pixtral":
|
||||||
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
|
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
|
||||||
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
|
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
|
||||||
self.img_break_tok_id = 12 # see tokenizer_config.json
|
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
|
||||||
|
logger.info(f"Image break token id: {self.img_break_tok_id}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
|
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
|
||||||
|
|
||||||
|
def get_token_id(self, token: str) -> int:
|
||||||
|
tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
|
||||||
|
with open(tokenizer_config_file, "r", encoding="utf-8") as f:
|
||||||
|
added_tokens_decoder = json.load(f)['added_tokens_decoder']
|
||||||
|
for id_, token_data in added_tokens_decoder.items():
|
||||||
|
if token_data["content"] == token:
|
||||||
|
return int(id_)
|
||||||
|
raise ValueError(f"Token '{token}' not found in tokenizer config.")
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
hparams = self.hparams
|
hparams = self.hparams
|
||||||
if hparams["model_type"] == "pixtral":
|
if hparams["model_type"] == "pixtral":
|
||||||
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
|
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
|
||||||
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
|
||||||
self.gguf_writer.add_vision_use_silu(True)
|
|
||||||
|
# hidden_act
|
||||||
|
if hparams["hidden_act"] == "silu":
|
||||||
|
self.gguf_writer.add_vision_use_silu(True)
|
||||||
|
elif hparams["hidden_act"] == "gelu":
|
||||||
|
self.gguf_writer.add_vision_use_gelu(True)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}")
|
||||||
|
|
||||||
|
# spatial_merge_size
|
||||||
|
if "spatial_merge_size" in self.global_config:
|
||||||
|
self.gguf_writer.add_vision_spatial_merge_size(self.global_config["spatial_merge_size"])
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
del bid # unused
|
del bid # unused
|
||||||
|
@ -34,6 +34,9 @@ llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF
|
|||||||
|
|
||||||
# Pixtral 12B
|
# Pixtral 12B
|
||||||
llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
|
llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
|
||||||
|
|
||||||
|
# Mistral Small 3.1 24B (IQ2_M quantization)
|
||||||
|
llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7
|
||||||
```
|
```
|
||||||
|
|
||||||
## How it works and what is `mmproj`?
|
## How it works and what is `mmproj`?
|
||||||
@ -73,3 +76,4 @@ For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` fla
|
|||||||
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
|
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
|
||||||
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
|
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
|
||||||
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
|
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
|
||||||
|
- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
|
||||||
|
@ -31,6 +31,7 @@
|
|||||||
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
|
||||||
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
|
||||||
#define KEY_PROJ_TYPE "clip.projector_type"
|
#define KEY_PROJ_TYPE "clip.projector_type"
|
||||||
|
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
|
||||||
|
|
||||||
#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl
|
#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl
|
||||||
#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl
|
#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl
|
||||||
@ -68,9 +69,11 @@
|
|||||||
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
|
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
|
||||||
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
|
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
|
||||||
#define TN_IMAGE_NEWLINE "model.image_newline"
|
#define TN_IMAGE_NEWLINE "model.image_newline"
|
||||||
|
#define TN_MM_INP_NORM "mm.input_norm.weight"
|
||||||
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
||||||
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
||||||
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
|
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
|
||||||
|
#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1
|
||||||
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
|
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
|
||||||
|
|
||||||
// mimicpmv
|
// mimicpmv
|
||||||
|
@ -172,6 +172,7 @@ struct clip_hparams {
|
|||||||
std::unordered_set<int32_t> vision_feature_layer;
|
std::unordered_set<int32_t> vision_feature_layer;
|
||||||
int32_t attn_window_size = 0;
|
int32_t attn_window_size = 0;
|
||||||
int32_t n_wa_pattern = 0;
|
int32_t n_wa_pattern = 0;
|
||||||
|
int32_t spatial_merge_size = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_layer {
|
struct clip_layer {
|
||||||
@ -232,6 +233,7 @@ struct clip_vision_model {
|
|||||||
struct ggml_tensor * projection;
|
struct ggml_tensor * projection;
|
||||||
|
|
||||||
// LLaVA projection
|
// LLaVA projection
|
||||||
|
struct ggml_tensor * mm_input_norm_w = nullptr;
|
||||||
struct ggml_tensor * mm_0_w = nullptr;
|
struct ggml_tensor * mm_0_w = nullptr;
|
||||||
struct ggml_tensor * mm_0_b = nullptr;
|
struct ggml_tensor * mm_0_b = nullptr;
|
||||||
struct ggml_tensor * mm_2_w = nullptr;
|
struct ggml_tensor * mm_2_w = nullptr;
|
||||||
@ -311,6 +313,7 @@ struct clip_vision_model {
|
|||||||
|
|
||||||
// pixtral
|
// pixtral
|
||||||
struct ggml_tensor * token_embd_img_break = nullptr;
|
struct ggml_tensor * token_embd_img_break = nullptr;
|
||||||
|
struct ggml_tensor * mm_patch_merger_w = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct clip_ctx {
|
struct clip_ctx {
|
||||||
@ -637,6 +640,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
|||||||
const int d_head = hidden_size / n_head;
|
const int d_head = hidden_size / n_head;
|
||||||
const int n_layer = hparams.n_layer;
|
const int n_layer = hparams.n_layer;
|
||||||
const float eps = hparams.eps;
|
const float eps = hparams.eps;
|
||||||
|
const int n_merge = hparams.spatial_merge_size;
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
/*.mem_size =*/ ctx->buf_compute_meta.size(),
|
/*.mem_size =*/ ctx->buf_compute_meta.size(),
|
||||||
@ -721,7 +725,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
|||||||
{
|
{
|
||||||
ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
|
ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
|
||||||
ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
|
ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
|
||||||
gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu
|
if (ctx->use_silu) {
|
||||||
|
gate_proj = ggml_silu(ctx0, gate_proj);
|
||||||
|
} else if (ctx->use_gelu) {
|
||||||
|
gate_proj = ggml_gelu(ctx0, gate_proj);
|
||||||
|
} else {
|
||||||
|
GGML_ABORT("Pixtral: Unsupported activation");
|
||||||
|
}
|
||||||
cur = ggml_mul(ctx0, up_proj, gate_proj);
|
cur = ggml_mul(ctx0, up_proj, gate_proj);
|
||||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
|
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
|
||||||
}
|
}
|
||||||
@ -732,14 +742,42 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
|||||||
embeddings = cur;
|
embeddings = cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
// LlavaMultiModalProjector (with GELU activation)
|
// mistral small 3.1 patch merger
|
||||||
|
// ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
|
||||||
|
if (model.mm_patch_merger_w) {
|
||||||
|
GGML_ASSERT(hparams.spatial_merge_size > 0);
|
||||||
|
|
||||||
|
ggml_tensor * cur = embeddings;
|
||||||
|
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
|
||||||
|
|
||||||
|
// reshape image tokens to 2D grid
|
||||||
|
cur = ggml_reshape_3d(ctx0, cur, hidden_size, n_patches_x, n_patches_y);
|
||||||
|
cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, hidden_size]
|
||||||
|
cur = ggml_cont(ctx0, cur);
|
||||||
|
|
||||||
|
// torch.nn.functional.unfold is just an im2col under the hood
|
||||||
|
// we just need a dummy kernel to make it work
|
||||||
|
ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
|
||||||
|
cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
|
||||||
|
|
||||||
|
// project to hidden_size
|
||||||
|
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
|
||||||
|
cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
|
||||||
|
embeddings = cur;
|
||||||
|
}
|
||||||
|
|
||||||
|
// LlavaMultiModalProjector (always using GELU activation)
|
||||||
{
|
{
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
if (model.mm_1_b) {
|
||||||
|
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||||
|
}
|
||||||
|
|
||||||
embeddings = ggml_gelu(ctx0, embeddings);
|
embeddings = ggml_gelu(ctx0, embeddings);
|
||||||
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
||||||
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
if (model.mm_2_b) {
|
||||||
|
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// arrangement of the [IMG_BREAK] token
|
// arrangement of the [IMG_BREAK] token
|
||||||
@ -749,11 +787,14 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
|
|||||||
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
|
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
|
||||||
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
|
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
|
||||||
|
|
||||||
|
const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
|
||||||
|
const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
|
||||||
|
const int p_total = p_x * p_y;
|
||||||
const int n_embd_text = embeddings->ne[0];
|
const int n_embd_text = embeddings->ne[0];
|
||||||
const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row
|
const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row
|
||||||
|
|
||||||
ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y);
|
ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, p_x, p_y);
|
||||||
ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y);
|
ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, p_y);
|
||||||
tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
|
tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
|
||||||
tok = ggml_add(ctx0, tok, model.token_embd_img_break);
|
tok = ggml_add(ctx0, tok, model.token_embd_img_break);
|
||||||
cur = ggml_concat(ctx0, cur, tok, 1);
|
cur = ggml_concat(ctx0, cur, tok, 1);
|
||||||
@ -1734,6 +1775,7 @@ struct clip_model_loader {
|
|||||||
case PROJECTOR_TYPE_PIXTRAL:
|
case PROJECTOR_TYPE_PIXTRAL:
|
||||||
{
|
{
|
||||||
hparams.rope_theta = 10000.0f;
|
hparams.rope_theta = 10000.0f;
|
||||||
|
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
|
||||||
} break;
|
} break;
|
||||||
case PROJECTOR_TYPE_QWEN25VL:
|
case PROJECTOR_TYPE_QWEN25VL:
|
||||||
{
|
{
|
||||||
@ -1957,11 +1999,14 @@ struct clip_model_loader {
|
|||||||
case PROJECTOR_TYPE_PIXTRAL:
|
case PROJECTOR_TYPE_PIXTRAL:
|
||||||
{
|
{
|
||||||
vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
|
vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
|
||||||
vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
|
vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
|
||||||
vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
||||||
vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
|
vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
|
||||||
// [IMG_BREAK] token embedding
|
// [IMG_BREAK] token embedding
|
||||||
vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
|
vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
|
||||||
|
// for mistral small 3.1
|
||||||
|
vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false);
|
||||||
|
vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unknown projector type");
|
GGML_ASSERT(false && "unknown projector type");
|
||||||
@ -2926,8 +2971,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
|
|||||||
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||||
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
|
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
|
||||||
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
||||||
int n_patches_x = img->nx / params.patch_size;
|
int n_merge = ctx->vision_model.hparams.spatial_merge_size;
|
||||||
int n_patches_y = img->ny / params.patch_size;
|
int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
|
||||||
|
int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
|
||||||
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
|
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3484,7 +3530,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
|||||||
return ctx->vision_model.mm_model_peg_0_b->ne[0];
|
return ctx->vision_model.mm_model_peg_0_b->ne[0];
|
||||||
case PROJECTOR_TYPE_MLP:
|
case PROJECTOR_TYPE_MLP:
|
||||||
case PROJECTOR_TYPE_PIXTRAL:
|
case PROJECTOR_TYPE_PIXTRAL:
|
||||||
return ctx->vision_model.mm_2_b->ne[0];
|
return ctx->vision_model.mm_2_w->ne[1];
|
||||||
case PROJECTOR_TYPE_MLP_NORM:
|
case PROJECTOR_TYPE_MLP_NORM:
|
||||||
return ctx->vision_model.mm_3_b->ne[0];
|
return ctx->vision_model.mm_3_b->ne[0];
|
||||||
case PROJECTOR_TYPE_MINICPMV:
|
case PROJECTOR_TYPE_MINICPMV:
|
||||||
|
@ -94,6 +94,7 @@ struct mtmd_cli_context {
|
|||||||
LOG_ERR("Model does not have chat template.\n");
|
LOG_ERR("Model does not have chat template.\n");
|
||||||
LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n");
|
LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n");
|
||||||
LOG_ERR(" For MobileVLM models, use '--chat-template deepseek'\n");
|
LOG_ERR(" For MobileVLM models, use '--chat-template deepseek'\n");
|
||||||
|
LOG_ERR(" For Mistral Small 3.1, use '--chat-template mistral-v7'\n");
|
||||||
exit(1);
|
exit(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
|
|||||||
|
|
||||||
# to test the big models, run: ./tests.sh big
|
# to test the big models, run: ./tests.sh big
|
||||||
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
|
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
|
||||||
|
add_test_big "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
|
||||||
|
|
||||||
# these models always give the wrong answer, not sure why
|
# these models always give the wrong answer, not sure why
|
||||||
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
|
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
|
||||||
|
@ -231,6 +231,7 @@ class Keys:
|
|||||||
BLOCK_COUNT = "clip.vision.block_count"
|
BLOCK_COUNT = "clip.vision.block_count"
|
||||||
IMAGE_MEAN = "clip.vision.image_mean"
|
IMAGE_MEAN = "clip.vision.image_mean"
|
||||||
IMAGE_STD = "clip.vision.image_std"
|
IMAGE_STD = "clip.vision.image_std"
|
||||||
|
SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size"
|
||||||
USE_GELU = "clip.use_gelu"
|
USE_GELU = "clip.use_gelu"
|
||||||
USE_SILU = "clip.use_silu"
|
USE_SILU = "clip.use_silu"
|
||||||
|
|
||||||
@ -491,6 +492,7 @@ class MODEL_TENSOR(IntEnum):
|
|||||||
V_ENC_FFN_DOWN = auto()
|
V_ENC_FFN_DOWN = auto()
|
||||||
V_PRE_NORM = auto()
|
V_PRE_NORM = auto()
|
||||||
V_POST_NORM = auto()
|
V_POST_NORM = auto()
|
||||||
|
V_MM_INP_NORM = auto()
|
||||||
V_MM_INP_PROJ = auto() # gemma3
|
V_MM_INP_PROJ = auto() # gemma3
|
||||||
V_MM_SOFT_EMB_NORM = auto() # gemma3
|
V_MM_SOFT_EMB_NORM = auto() # gemma3
|
||||||
V_RESMPL_POS_EMBD_K = auto() # minicpmv
|
V_RESMPL_POS_EMBD_K = auto() # minicpmv
|
||||||
@ -505,6 +507,7 @@ class MODEL_TENSOR(IntEnum):
|
|||||||
V_RESMPL_PROJ = auto() # minicpmv
|
V_RESMPL_PROJ = auto() # minicpmv
|
||||||
V_RESMPL_QUERY = auto() # minicpmv
|
V_RESMPL_QUERY = auto() # minicpmv
|
||||||
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
|
V_TOK_EMBD_IMG_BREAK = auto() # pixtral
|
||||||
|
V_MM_PATCH_MERGER = auto() # mistral small 3.1
|
||||||
|
|
||||||
|
|
||||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||||
@ -747,6 +750,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||||||
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
|
MODEL_TENSOR.V_PRE_NORM: "v.pre_ln",
|
||||||
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
|
MODEL_TENSOR.V_POST_NORM: "v.post_ln",
|
||||||
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
|
MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection",
|
||||||
|
MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm",
|
||||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
|
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm",
|
||||||
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k",
|
MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k",
|
||||||
MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q",
|
MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q",
|
||||||
@ -760,6 +764,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||||||
MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj",
|
MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj",
|
||||||
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
|
MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query",
|
||||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
|
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral
|
||||||
|
MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
@ -783,6 +788,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.V_PRE_NORM,
|
MODEL_TENSOR.V_PRE_NORM,
|
||||||
MODEL_TENSOR.V_POST_NORM,
|
MODEL_TENSOR.V_POST_NORM,
|
||||||
MODEL_TENSOR.V_MM_INP_PROJ,
|
MODEL_TENSOR.V_MM_INP_PROJ,
|
||||||
|
MODEL_TENSOR.V_MM_INP_NORM,
|
||||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
|
MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
|
||||||
MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
|
MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
|
||||||
MODEL_TENSOR.V_RESMPL_ATTN_Q,
|
MODEL_TENSOR.V_RESMPL_ATTN_Q,
|
||||||
@ -796,6 +802,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||||||
MODEL_TENSOR.V_RESMPL_PROJ,
|
MODEL_TENSOR.V_RESMPL_PROJ,
|
||||||
MODEL_TENSOR.V_RESMPL_QUERY,
|
MODEL_TENSOR.V_RESMPL_QUERY,
|
||||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
|
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
|
||||||
|
MODEL_TENSOR.V_MM_PATCH_MERGER,
|
||||||
],
|
],
|
||||||
MODEL_ARCH.LLAMA: [
|
MODEL_ARCH.LLAMA: [
|
||||||
MODEL_TENSOR.TOKEN_EMBD,
|
MODEL_TENSOR.TOKEN_EMBD,
|
||||||
|
@ -972,6 +972,9 @@ class GGUFWriter:
|
|||||||
def add_vision_image_std(self, values: Sequence[float]) -> None:
|
def add_vision_image_std(self, values: Sequence[float]) -> None:
|
||||||
self.add_array(Keys.ClipVision.IMAGE_STD, values)
|
self.add_array(Keys.ClipVision.IMAGE_STD, values)
|
||||||
|
|
||||||
|
def add_vision_spatial_merge_size(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value)
|
||||||
|
|
||||||
def add_vision_use_gelu(self, value: bool) -> None:
|
def add_vision_use_gelu(self, value: bool) -> None:
|
||||||
self.add_bool(Keys.ClipVision.USE_GELU, value)
|
self.add_bool(Keys.ClipVision.USE_GELU, value)
|
||||||
|
|
||||||
|
@ -1001,6 +1001,10 @@ class TensorNameMap:
|
|||||||
"multi_modal_projector.mm_input_projection",
|
"multi_modal_projector.mm_input_projection",
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_MM_INP_NORM: (
|
||||||
|
"multi_modal_projector.norm",
|
||||||
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
|
MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
|
||||||
"multi_modal_projector.mm_soft_emb_norm",
|
"multi_modal_projector.mm_soft_emb_norm",
|
||||||
),
|
),
|
||||||
@ -1052,6 +1056,10 @@ class TensorNameMap:
|
|||||||
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: (
|
MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: (
|
||||||
"v.token_embd.img_break", # for pixtral, this is a generated vector
|
"v.token_embd.img_break", # for pixtral, this is a generated vector
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MODEL_TENSOR.V_MM_PATCH_MERGER: (
|
||||||
|
"multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
# architecture-specific block mappings
|
# architecture-specific block mappings
|
||||||
|
Reference in New Issue
Block a user