From 92ecdcc06a4c405a415bcaa0cb772bc560aa23b1 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 19 May 2025 13:04:14 +0200 Subject: [PATCH] mtmd : add vision support for llama 4 (#13282) * wip llama 4 conversion * rm redundant __init__ * fix conversion * fix conversion * test impl * try this * reshape patch_embeddings_0 * fix view * rm ffn_post_norm * cgraph ok * f32 for pos embd * add image marker tokens * Llama4UnfoldConvolution * correct pixel shuffle * fix merge conflicts * correct * add debug_graph * logits matched, but it still preceives the image incorrectly * fix style * add image_grid_pinpoints * handle llama 4 preprocessing * rm load_image_size * rm unused line * fix * small fix 2 * add test & docs * fix llava-1.6 test * test: add notion of huge models * add comment * add warn about degraded quality --- convert_hf_to_gguf.py | 21 +++ docs/multimodal.md | 3 + gguf-py/gguf/constants.py | 20 +-- gguf-py/gguf/tensor_mapping.py | 19 ++- tools/mtmd/clip-impl.h | 74 +++++++++- tools/mtmd/clip.cpp | 246 ++++++++++++++++++++++++++++----- tools/mtmd/clip.h | 4 - tools/mtmd/mtmd.cpp | 97 +++++++++---- tools/mtmd/tests.sh | 22 ++- 9 files changed, 424 insertions(+), 82 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index b5402cbea..15e019a10 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -308,6 +308,7 @@ class ModelBase: gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED, gguf.MODEL_TENSOR.POSNET_NORM1, gguf.MODEL_TENSOR.POSNET_NORM2, + gguf.MODEL_TENSOR.V_ENC_EMBD_POS, ) ) or not new_name.endswith(".weight") @@ -2092,6 +2093,26 @@ class Llama4Model(LlamaModel): return super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Llama4ForConditionalGeneration") +class Llama4VisionModel(VisionModel): + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"]) + self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"])) + assert self.hparams["hidden_act"] == "gelu" + self.gguf_writer.add_vision_use_gelu(True) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if "multi_modal_projector" in name or "vision_model" in name: + # process vision tensors + if "positional_embedding_vlm" in name and ".weight" not in name: + name += ".weight" + return [(self.map_tensor_name(name), data_torch)] + return [] + + @ModelBase.register("Mistral3ForConditionalGeneration") class Mistral3Model(LlamaModel): model_arch = gguf.MODEL_ARCH.LLAMA diff --git a/docs/multimodal.md b/docs/multimodal.md index 80014ba1c..054778e91 100644 --- a/docs/multimodal.md +++ b/docs/multimodal.md @@ -74,4 +74,7 @@ NOTE: some models may require large context window, for example: `-c 8192` (tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF (tool_name) -hf ggml-org/InternVL3-8B-Instruct-GGUF (tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF + +# Llama 4 Scout +(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF ``` diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 21af0a9a2..7aed8b83e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -482,14 +482,15 @@ class MODEL_TENSOR(IntEnum): V_ENC_EMBD_CLS = auto() V_ENC_EMBD_PATCH = auto() V_ENC_EMBD_POS = auto() + V_ENC_INPUT_NORM = auto() V_ENC_ATTN_Q = auto() V_ENC_ATTN_Q_NORM = auto() V_ENC_ATTN_K = auto() V_ENC_ATTN_K_NORM = auto() V_ENC_ATTN_V = auto() - V_ENC_INPUT_NORM = auto() - V_ENC_OUTPUT = auto() - V_ENC_OUTPUT_NORM = auto() + V_ENC_ATTN_O = auto() + V_ENC_ATTN_O_NORM = auto() + V_ENC_POST_ATTN_NORM = auto() V_ENC_FFN_UP = auto() V_ENC_FFN_GATE = auto() V_ENC_FFN_DOWN = auto() @@ -749,8 +750,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_ENC_ATTN_K_NORM: "v.blk.{bid}.attn_k_norm", MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v", MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1", - MODEL_TENSOR.V_ENC_OUTPUT: "v.blk.{bid}.attn_out", - MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.blk.{bid}.ln2", + MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out", + MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm", + MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2", MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down", @@ -785,14 +787,15 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_ENC_EMBD_CLS, MODEL_TENSOR.V_ENC_EMBD_PATCH, MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_INPUT_NORM, MODEL_TENSOR.V_ENC_ATTN_Q, MODEL_TENSOR.V_ENC_ATTN_Q_NORM, MODEL_TENSOR.V_ENC_ATTN_K, MODEL_TENSOR.V_ENC_ATTN_K_NORM, MODEL_TENSOR.V_ENC_ATTN_V, - MODEL_TENSOR.V_ENC_INPUT_NORM, - MODEL_TENSOR.V_ENC_OUTPUT, - MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_ATTN_O, + MODEL_TENSOR.V_ENC_ATTN_O_NORM, + MODEL_TENSOR.V_ENC_POST_ATTN_NORM, MODEL_TENSOR.V_ENC_FFN_UP, MODEL_TENSOR.V_ENC_FFN_GATE, MODEL_TENSOR.V_ENC_FFN_DOWN, @@ -2180,6 +2183,7 @@ class VisionProjectorType: GEMMA3 = "gemma3" IDEFICS3 = "idefics3" PIXTRAL = "pixtral" + LLAMA4 = "llama4" QWEN2VL = "qwen2vl_merger" QWEN25VL = "qwen2.5vl_merger" INTERNVL = "internvl" diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index dadb4fd94..0298f8b46 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -902,10 +902,12 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ_FC: ( "model.connector.modality_projection.proj", # SmolVLM + "multi_modal_projector.linear_1", # llama 4 ), MODEL_TENSOR.V_MMPROJ_MLP: ( "model.mm_projector.mlp.mlp.{bid}", + "vision_model.vision_adapter.mlp.fc{bid}", # llama 4 "mlp1.{bid}", # InternVL ), @@ -915,6 +917,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_EMBD_CLS: ( "vision_tower.vision_model.embeddings.class_embedding", + "vision_model.class_embedding", # llama 4 ), MODEL_TENSOR.V_ENC_EMBD_PATCH: ( @@ -922,6 +925,7 @@ class TensorNameMap: "vpm.embeddings.patch_embedding", "model.vision_model.embeddings.patch_embedding", # SmolVLM "vision_tower.patch_conv", # pixtral + "vision_model.patch_embedding.linear", # llama 4 "visual.patch_embed.proj", # qwen2vl ), @@ -929,12 +933,14 @@ class TensorNameMap: "vision_tower.vision_model.embeddings.position_embedding", "vpm.embeddings.position_embedding", "model.vision_model.embeddings.position_embedding", # SmolVLM + "vision_model.positional_embedding_vlm", # llama 4 ), MODEL_TENSOR.V_ENC_ATTN_Q: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj", "vpm.encoder.layers.{bid}.self_attn.q_proj", "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.q_proj", # llama4 "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral "visual.blocks.{bid}.attn.q", # qwen2vl, generated ), @@ -947,6 +953,7 @@ class TensorNameMap: "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", "vpm.encoder.layers.{bid}.self_attn.k_proj", "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.k_proj", # llama4 "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral "visual.blocks.{bid}.attn.k", # qwen2vl, generated ), @@ -959,6 +966,7 @@ class TensorNameMap: "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", "vpm.encoder.layers.{bid}.self_attn.v_proj", "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.v_proj", # llama4 "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral "visual.blocks.{bid}.attn.v", # qwen2vl, generated ), @@ -969,23 +977,26 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.layer_norm1", "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral + "vision_model.model.layers.{bid}.input_layernorm", # llama4 "visual.blocks.{bid}.norm1", # qwen2vl ), - MODEL_TENSOR.V_ENC_OUTPUT: ( + MODEL_TENSOR.V_ENC_ATTN_O: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", "vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL "vpm.encoder.layers.{bid}.self_attn.out_proj", "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.o_proj", # llama4 "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral "visual.blocks.{bid}.attn.proj", # qwen2vl ), - MODEL_TENSOR.V_ENC_OUTPUT_NORM: ( + MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", "vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL "vpm.encoder.layers.{bid}.layer_norm2", "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM + "vision_model.model.layers.{bid}.post_attention_layernorm", # llama4 "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral "visual.blocks.{bid}.norm2", # qwen2vl ), @@ -995,6 +1006,7 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.mlp.fc1", "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral + "vision_model.model.layers.{bid}.mlp.fc1", # llama4 "visual.blocks.{bid}.mlp.fc1", # qwen2vl "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl ), @@ -1009,6 +1021,7 @@ class TensorNameMap: "vpm.encoder.layers.{bid}.mlp.fc2", "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral + "vision_model.model.layers.{bid}.mlp.fc2", # llama4 "visual.blocks.{bid}.mlp.fc2", # qwen2vl "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl ), @@ -1024,11 +1037,13 @@ class TensorNameMap: MODEL_TENSOR.V_PRE_NORM: ( "vision_tower.vision_model.pre_layrnorm", "vision_tower.ln_pre", # pixtral + "vision_model.layernorm_pre", # llama4 ), MODEL_TENSOR.V_POST_NORM: ( "vision_tower.vision_model.post_layernorm", "model.vision_model.post_layernorm", # SmolVLM + "vision_model.layernorm_post", # llama4 "visual.merger.ln_q", # qwen2vl ), diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 23036ba72..7b7d2df39 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -44,7 +45,7 @@ // tensor name constants // -#define TN_POS_EMBD "%s.position_embd.weight" +#define TN_POS_EMBD "v.position_embd.weight" #define TN_CLASS_EMBD "v.class_embd" #define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat #define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" @@ -110,6 +111,7 @@ enum projector_type { PROJECTOR_TYPE_PIXTRAL, PROJECTOR_TYPE_QWEN25VL, PROJECTOR_TYPE_INTERNVL, + PROJECTOR_TYPE_LLAMA4, PROJECTOR_TYPE_UNKNOWN, }; @@ -125,6 +127,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, { PROJECTOR_TYPE_INTERNVL, "internvl"}, + { PROJECTOR_TYPE_LLAMA4, "llama4"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { @@ -240,6 +243,11 @@ struct clip_image_u8_batch { struct clip_image_f32_batch { std::vector entries; + // for llava-uhd style models, we need to know the grid size + // note: entries.size() == grid_x * grid_y + 1 (one overview image) + int grid_x = 0; + int grid_y = 0; + clip_image_f32_batch clone() const { clip_image_f32_batch new_batch; new_batch.entries.reserve(entries.size()); @@ -358,6 +366,70 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { } } +// +// debugging +// + +static void print_tensor_shape(ggml_tensor * t) { + printf("%s.shape = [", t->name); + for (int i = 0; i < ggml_n_dims(t); ++i) { + printf("%" PRId64, t->ne[i]); + if (i < ggml_n_dims(t) - 1) { + printf(", "); + } + } + printf("]\n"); +} + +static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) { + ggml_type type = t->type; + int64_t * ne = t->ne; + size_t * nb = t->nb; + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + printf("%s.data: [\n", t->name); + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + if (i2 == n && ne[2] > 2*n) { + printf(" ..., \n"); + i2 = ne[2] - n; + } + printf(" [\n"); + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + if (i1 == n && ne[1] > 2*n) { + printf(" ..., \n"); + i1 = ne[1] - n; + } + printf(" ["); + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + if (i0 == n && ne[0] > 2*n) { + printf("..., "); + i0 = ne[0] - n; + } + size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; + float v; + if (type == GGML_TYPE_F16) { + v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); + } else if (type == GGML_TYPE_F32) { + v = *(float *) &data[i]; + } else if (type == GGML_TYPE_I32) { + v = (float) *(int32_t *) &data[i]; + } else if (type == GGML_TYPE_I16) { + v = (float) *(int16_t *) &data[i]; + } else if (type == GGML_TYPE_I8) { + v = (float) *(int8_t *) &data[i]; + } else { + GGML_ABORT("fatal error"); + } + printf("%8.4f", v); + if (i0 < ne[0] - 1) printf(", "); + } + printf("],\n"); + } + printf(" ],\n"); + } + printf(" ]\n"); + } +} + // // API used internally with mtmd // diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 128a95cc1..eba07f6c8 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -359,9 +359,12 @@ struct clip_ctx { int max_nodes = 8192; ggml_backend_sched_ptr sched; - clip_image_size load_image_size; + // for debugging + bool debug_graph = false; + std::vector debug_print_tensors; clip_ctx(clip_context_params & ctx_params) { + debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr; backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); if (!backend_cpu) { throw std::runtime_error("failed to initialize CPU backend"); @@ -440,7 +443,7 @@ struct clip_graph { }; ctx0_ptr.reset(ggml_init(params)); ctx0 = ctx0_ptr.get(); - gf = ggml_new_graph(ctx0); + gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false); } ggml_cgraph * build_siglip() { @@ -522,7 +525,7 @@ struct clip_graph { ggml_set_input(pos_w); auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { - return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta); + return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, true); }; ggml_tensor * inp = build_inp(); @@ -936,6 +939,101 @@ struct clip_graph { return gf; } + ggml_cgraph * build_llama4() { + GGML_ASSERT(model.class_embedding != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + + const int n_pos = n_patches + 1; // +1 for [CLS] + + // 2D input positions + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + ggml_tensor * inp = build_inp_raw(); + + // Llama4UnfoldConvolution + { + ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0, + patch_size, patch_size, 3, n_embd); + inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type); + inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp); + inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); + cb(inp, "patch_conv", -1); + } + + // add CLS token + inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + + // build ViT with 2D position embeddings + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + // first half is X axis and second half is Y axis + // ref: https://github.com/huggingface/transformers/blob/40a493c7ed4f19f08eadb0639cf26d49bfa5e180/src/transformers/models/llama4/modeling_llama4.py#L1312 + // ref: https://github.com/Blaizzy/mlx-vlm/blob/a57156aa87b33cca6e5ee6cfc14dd4ef8f611be6/mlx_vlm/models/llama4/vision.py#L441 + return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); + }; + ggml_tensor * cur = build_vit( + inp, n_pos, + NORM_TYPE_NORMAL, + hparams.ffn_op, + model.position_embeddings, + add_pos); + + // remove CLS token + cur = ggml_view_2d(ctx0, cur, + n_embd, n_patches, + ggml_row_size(cur->type, n_embd), 0); + + // pixel shuffle + // based on Llama4VisionPixelShuffleMLP + // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151 + { + const int scale_factor = model.hparams.proj_scale_factor; + const int bsz = 1; // batch size, always 1 for now since we don't support batching + GGML_ASSERT(scale_factor > 0); + GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images + cur = ggml_reshape_4d(ctx0, cur, + n_embd * scale_factor, + n_patches_x / scale_factor, + n_patches_y, + bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + n_patches_x / scale_factor, + n_patches_y / scale_factor, + bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // flatten to 2D + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + n_patches / scale_factor / scale_factor); + cb(cur, "pixel_shuffle", -1); + } + + // based on Llama4VisionMLP2 (always uses GELU activation, no bias) + { + cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur); + cur = ggml_gelu(ctx0, cur); + cb(cur, "adapter_mlp", -1); + } + + // Llama4MultiModalProjector + cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur); + cb(cur, "projected", -1); + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; + } + // this graph is used by llava, granite and glm // due to having embedding_stack (used by granite), we cannot reuse build_vit ggml_cgraph * build_llava() { @@ -1315,11 +1413,15 @@ private: // utility functions // - void cb(ggml_tensor * cur, const char * name, int il) const { - // TODO: implement this - GGML_UNUSED(cur); - GGML_UNUSED(name); - GGML_UNUSED(il); + void cb(ggml_tensor * cur0, const char * name, int il) const { + if (ctx->debug_graph) { + ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0)); + std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name; + ggml_set_name(cur, cur_name.c_str()); + ggml_set_output(cur); + ggml_build_forward_expand(gf, cur); + ctx->debug_print_tensors.push_back(cur); + } } // build vision transformer (ViT) cgraph @@ -1630,9 +1732,10 @@ private: static ggml_tensor * build_rope_2d( ggml_context * ctx0, ggml_tensor * cur, - ggml_tensor * pos_h, - ggml_tensor * pos_w, - const float freq_base + ggml_tensor * pos_a, // first half + ggml_tensor * pos_b, // second half + const float freq_base, + const bool interleave_freq ) { const int64_t n_dim = cur->ne[0]; const int64_t n_head = cur->ne[1]; @@ -1646,7 +1749,9 @@ private: // ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2) // then for the second half, we use freq_scale to shift the inv_freq // ^ why? replace (2i) with (2i+1) in the above equation - const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim); + const float freq_scale_odd = interleave_freq + ? std::pow(freq_base, (float)-2/n_dim) + : 1.0; // first half ggml_tensor * first; @@ -1659,7 +1764,7 @@ private: first = ggml_rope_ext( ctx0, first, - pos_h, // positions + pos_a, // positions nullptr, // freq factors n_dim/2, // n_dims 0, 0, freq_base, @@ -1679,7 +1784,7 @@ private: second = ggml_rope_ext( ctx0, second, - pos_w, // positions + pos_b, // positions nullptr, // freq factors n_dim/2, // n_dims 0, 0, freq_base, @@ -1723,6 +1828,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_internvl(); } break; + case PROJECTOR_TYPE_LLAMA4: + { + res = graph.build_llama4(); + } break; default: { res = graph.build_llava(); @@ -1926,6 +2035,21 @@ struct clip_model_loader { hparams.warmup_image_size = hparams.patch_size * 8; get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern); } break; + case PROJECTOR_TYPE_LLAMA4: + { + hparams.rope_theta = 10000.0f; + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor); + + // borrowed from llava-1.6 + const int isize = hparams.image_size; + hparams.image_grid_pinpoints = { + isize, isize*2, // 336, 672 + isize*2, isize, // 672, 336 + isize*2, isize*2, // 672, 672 + isize*3, isize, // 1008, 336 + isize, isize*3, // 336, 1008 + }; + } break; default: break; } @@ -1946,6 +2070,10 @@ struct clip_model_loader { LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str()); LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0); + + if (ctx_clip.proj_type == PROJECTOR_TYPE_LLAMA4) { + LOG_WRN("%s: llama 4 vision is known to have degraded quality: https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__); + } } } @@ -2001,7 +2129,7 @@ struct clip_model_loader { vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false); vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false); - vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false); + vision_model.position_embeddings = get_tensor(TN_POS_EMBD, false); // layers vision_model.layers.resize(hparams.n_layer); @@ -2182,6 +2310,12 @@ struct clip_model_loader { vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); } break; + case PROJECTOR_TYPE_LLAMA4: + { + vision_model.mm_model_proj = get_tensor(TN_MM_PROJECTOR); + vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); + } break; default: GGML_ASSERT(false && "unknown projector type"); } @@ -2328,14 +2462,6 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p return ctx_clip; } -void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) { - ctx_clip->load_image_size = *load_image_size; // copy -} - -struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) { - return &ctx_clip->load_image_size; -} - struct clip_image_size * clip_image_size_init() { struct clip_image_size * load_image_size = new struct clip_image_size(); load_image_size->width = 448; @@ -2849,7 +2975,7 @@ private: // used by llava 1.6 with custom list of pinpoints static clip_image_size select_best_resolution(const std::vector & pinpoints, const clip_image_size & original_size) { - std::vector possible_resolutions; + std::vector possible_resolutions; // TODO @ngxson : construct this inside hparams, not here for (size_t i = 0; i < pinpoints.size(); i += 2) { possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]}); } @@ -2916,12 +3042,6 @@ private: } }; -// TODO @ngxson : decprecate the load_image_size singleton pattern -int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { - const auto inst = llava_uhd::get_slice_instructions(ctx_clip, ctx_clip->load_image_size); - return inst.grid_size.width; -} - // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { @@ -2943,9 +3063,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std); res_imgs->entries.push_back(std::move(res)); } + + res_imgs->grid_x = inst.grid_size.width; + res_imgs->grid_y = inst.grid_size.height; return true; - } - else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { + + } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) { clip_image_u8 resized; auto patch_size = params.patch_size * 2; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size); @@ -2971,8 +3094,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); res_imgs->entries.push_back(std::move(img_f32)); return true; - } - else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { + + } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { clip_image_u8 resized_image; auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size); image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height); @@ -2980,6 +3103,22 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); res_imgs->entries.push_back(std::move(img_f32)); return true; + + } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) { + GGML_ASSERT(!params.image_grid_pinpoints.empty()); + auto const inst = llava_uhd::get_slice_instructions(ctx, original_size); + std::vector imgs = llava_uhd::slice_image(img, inst); + + for (size_t i = 0; i < imgs.size(); ++i) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std); + res_imgs->entries.push_back(std::move(res)); + } + + res_imgs->grid_x = inst.grid_size.width; + res_imgs->grid_y = inst.grid_size.height; + return true; + } // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) @@ -3098,6 +3237,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im const auto & params = ctx->vision_model.hparams; int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); + int scale_factor = ctx->vision_model.hparams.proj_scale_factor; if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 @@ -3136,6 +3276,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1); int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1); n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row + } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) { + n_patches /= (scale_factor * scale_factor); } return n_patches; @@ -3247,6 +3389,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } // build the inference graph + ctx->debug_print_tensors.clear(); ggml_backend_sched_reset(ctx->sched.get()); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); ggml_backend_sched_alloc_graph(ctx->sched.get(), gf); @@ -3261,8 +3404,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int n_pos = num_patches + (model.class_embedding ? 1 : 0); - const int pos_w = ctx->load_image_size.width / patch_size; - const int pos_h = ctx->load_image_size.height / patch_size; + const int pos_w = image_size_width / patch_size; + const int pos_h = image_size_height / patch_size; const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl @@ -3528,6 +3671,23 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima { // do nothing } break; + case PROJECTOR_TYPE_LLAMA4: + { + // set the 2D positions + int n_patches_per_col = image_size_width / patch_size; + std::vector pos_data(num_patches + 1, 0); // +1 for the [CLS] token + // last pos is always kept 0, it's for CLS + // dimension H + for (int i = 0; i < num_patches; i++) { + pos_data[i] = (i / n_patches_per_col) + 1; + } + set_input_i32("pos_h", pos_data); + // dimension W + for (int i = 0; i < num_patches; i++) { + pos_data[i] = (i % n_patches_per_col) + 1; + } + set_input_i32("pos_w", pos_data); + } break; default: GGML_ABORT("Unknown projector type"); } @@ -3548,6 +3708,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima return false; } + // print debug nodes + if (ctx->debug_graph) { + LOG_INF("\n\n---\n\n"); + LOG_INF("\n\nDebug graph:\n\n"); + for (ggml_tensor * t : ctx->debug_print_tensors) { + std::vector data(ggml_nbytes(t)); + ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t)); + print_tensor_shape(t); + print_tensor_data(t, data.data(), 3); + } + } + // the last node is the embedding tensor ggml_tensor * embeddings = ggml_graph_node(gf, -1); @@ -3596,6 +3768,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->vision_model.projection->ne[1]; case PROJECTOR_TYPE_INTERNVL: return ctx->vision_model.mm_3_w->ne[1]; + case PROJECTOR_TYPE_LLAMA4: + return ctx->vision_model.mm_model_proj->ne[1]; default: GGML_ABORT("Unknown projector type"); } diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index 2d70eec94..e7a1c0782 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -47,10 +47,6 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * // this should be equal to the embedding dimension of the text model int clip_n_mmproj_embd(const struct clip_ctx * ctx); -int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip); -void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); -struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip); - struct clip_image_size * clip_image_size_init(void); struct clip_image_u8 * clip_image_u8_init (void); struct clip_image_f32 * clip_image_f32_init(void); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 2a852d9c1..1234dbb46 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -42,6 +42,7 @@ enum mtmd_slice_tmpl { MTMD_SLICE_TMPL_NONE, MTMD_SLICE_TMPL_MINICPMV_2_5, MTMD_SLICE_TMPL_MINICPMV_2_6, + MTMD_SLICE_TMPL_LLAMA4, // TODO @ngxson : add support for idefics (SmolVLM) }; @@ -64,15 +65,19 @@ struct mtmd_context { int n_threads; std::string image_marker; - // for minicpmv, we need special tokens in-between slices + // for llava-uhd style models, we need special tokens in-between slices + // minicpmv calls them "slices", llama 4 calls them "tiles" mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE; llama_token tok_ov_img_start = LLAMA_TOKEN_NULL; // overview image llama_token tok_ov_img_end = LLAMA_TOKEN_NULL; // overview image llama_token tok_slices_start = LLAMA_TOKEN_NULL; // start of all slices llama_token tok_slices_end = LLAMA_TOKEN_NULL; // end of all slices - llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice - llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice + llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice start + llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice end + llama_token tok_sli_img_mid = LLAMA_TOKEN_NULL; // between 2 slices llama_token tok_row_end = LLAMA_TOKEN_NULL; // end of row + bool tok_row_end_trail = false; + bool ov_img_first = false; bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE @@ -96,6 +101,7 @@ struct mtmd_context { use_mrope = clip_is_qwen2vl(ctx_clip); + projector_type proj = clip_get_projector_type(ctx_clip); int minicpmv_version = clip_is_minicpmv(ctx_clip); if (minicpmv_version == 2) { // minicpmv 2.5 format: @@ -108,6 +114,8 @@ struct mtmd_context { tok_sli_img_start = tok_ov_img_start; tok_sli_img_end = tok_ov_img_end; tok_row_end = lookup_token("\n"); + tok_row_end_trail = false; // no trailing end-of-row token + ov_img_first = true; } else if (minicpmv_version == 3 || minicpmv_version == 4) { // minicpmv 2.6 format: @@ -118,9 +126,25 @@ struct mtmd_context { tok_sli_img_start = lookup_token(""); tok_sli_img_end = lookup_token(""); tok_row_end = lookup_token("\n"); + tok_row_end_trail = false; // no trailing end-of-row token + ov_img_first = true; } else if (minicpmv_version != 0) { GGML_ASSERT(false && "unsupported minicpmv version"); + } else if (proj == PROJECTOR_TYPE_LLAMA4) { + // llama 4 format: + // <|image_start|> + // (slice) <|tile_x_separator|> (slice) <|tile_x_separator|> ... <|tile_y_separator|> + // (slice) <|tile_x_separator|> (slice) <|tile_x_separator|> ... <|tile_y_separator|> + // ... <|tile_y_separator|> <-- trailing end-of-row token + // <|image|> (overview) <-- overview image is last + // <|image_end|> + slice_tmpl = MTMD_SLICE_TMPL_LLAMA4; + tok_ov_img_start = lookup_token("<|image|>"); + tok_sli_img_mid = lookup_token("<|tile_x_separator|>"); + tok_row_end = lookup_token("<|tile_y_separator|>"); + tok_row_end_trail = true; // add trailing end-of-row token + ov_img_first = false; // overview image is last } } @@ -243,16 +267,18 @@ int32_t mtmd_tokenize(mtmd_context * ctx, // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md marker_modified = ctx->image_marker + "[IMG_END]"; string_replace_all(prompt_modified, ctx->image_marker, marker_modified); - } - else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) { + } else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) { // <|vision_start|> ... (image embeddings) ... <|vision_end|> marker_modified = "<|vision_start|>" + ctx->image_marker + "<|vision_end|>"; string_replace_all(prompt_modified, ctx->image_marker, marker_modified); - } + } else if (proj_type == PROJECTOR_TYPE_LLAMA4) { + // (more details in mtmd_context constructor) + marker_modified = "<|image_start|>" + ctx->image_marker + "<|image_end|>"; + string_replace_all(prompt_modified, ctx->image_marker, marker_modified); - else if (proj_type == PROJECTOR_TYPE_INTERNVL) { + } else if (proj_type == PROJECTOR_TYPE_INTERNVL) { // ... (image embeddings) ... marker_modified = "" + ctx->image_marker + ""; string_replace_all(prompt_modified, ctx->image_marker, marker_modified); @@ -328,7 +354,6 @@ int32_t mtmd_tokenize(mtmd_context * ctx, img_u8->ny = bitmaps[i_img]->ny; img_u8->buf.resize(bitmaps[i_img]->data.size()); std::memcpy(img_u8->buf.data(), bitmaps[i_img]->data.data(), img_u8->nx * img_u8->ny * 3); - clip_image_size img_u8_size{img_u8->nx, img_u8->ny}; // preprocess image clip_image_f32_batch batch_f32; @@ -338,28 +363,40 @@ int32_t mtmd_tokenize(mtmd_context * ctx, return 2; } - if (ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6) { + // handle llava-uhd style preprocessing + if ( + ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 + || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6 + || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4 + ) { // split batch into chunks of single images auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id); GGML_ASSERT(chunks.size() > 0); - // add overview image - add_text_chunk({ctx->tok_ov_img_start}); - output->entries.emplace_back(std::move(chunks.front())); + auto ov_chunk = std::move(chunks.front()); chunks.erase(chunks.begin()); - add_text_chunk({ctx->tok_ov_img_end}); - // add slices + // add overview image (first) + if (ctx->ov_img_first) { + if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { + add_text_chunk({ctx->tok_ov_img_start}); + } + output->entries.emplace_back(std::move(ov_chunk)); + if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { + add_text_chunk({ctx->tok_ov_img_end}); + } + } + + // add slices (or tiles) if (!chunks.empty()) { - clip_add_load_image_size(ctx->ctx_clip, &img_u8_size); - int n_col = clip_uhd_num_image_embeds_col(ctx->ctx_clip); - int n_row = (int)chunks.size() / n_col; - GGML_ASSERT(n_row * n_col == (int)chunks.size()); + const int n_col = batch_f32.grid_x; + const int n_row = batch_f32.grid_y; if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) { add_text_chunk({ctx->tok_slices_start}); } for (int y = 0; y < n_row; y++) { for (int x = 0; x < n_col; x++) { + const bool is_last_in_row = (x == n_col - 1); if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) { add_text_chunk({ctx->tok_sli_img_start}); } @@ -367,8 +404,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx, if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) { add_text_chunk({ctx->tok_sli_img_end}); } + if (!is_last_in_row && ctx->tok_sli_img_mid != LLAMA_TOKEN_NULL) { + add_text_chunk({ctx->tok_sli_img_mid}); + } } - if (ctx->tok_row_end != LLAMA_TOKEN_NULL && y != n_row - 1) { + if ((y != n_row - 1 || ctx->tok_row_end_trail) && ctx->tok_row_end != LLAMA_TOKEN_NULL) { add_text_chunk({ctx->tok_row_end}); } } @@ -377,6 +417,17 @@ int32_t mtmd_tokenize(mtmd_context * ctx, } } + // add overview image (last) + if (!ctx->ov_img_first) { + if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { + add_text_chunk({ctx->tok_ov_img_start}); + } + output->entries.emplace_back(std::move(ov_chunk)); + if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { + add_text_chunk({ctx->tok_ov_img_end}); + } + } + } else { size_t n_tokens = 0; for (const auto & entry : batch_f32.entries) { @@ -427,14 +478,6 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); bool ok = false; - // only effective for minicpmv and qwen2vl, other models will ignore load_image_size - { - clip_image_size slice_size{ - image_tokens->batch_f32.entries[0]->nx, - image_tokens->batch_f32.entries[0]->ny}; - clip_add_load_image_size(ctx->ctx_clip, &slice_size); - } - if (clip_is_llava(ctx->ctx_clip) || clip_is_minicpmv(ctx->ctx_clip) || clip_is_glm(ctx->ctx_clip)) { // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() const auto & entries = image_tokens->batch_f32.entries; diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 05ac7a04d..15a37b0d2 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -21,6 +21,13 @@ if [ "${1:-}" = "big" ]; then echo "Include BIG models..." fi +RUN_HUGE_TESTS=false +if [ "${1:-}" = "huge" ]; then + RUN_HUGE_TESTS=true + RUN_BIG_TESTS=true + echo "Include BIG models..." +fi + ############### arr_bin=() @@ -42,7 +49,7 @@ add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0" add_test "llama-mtmd-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M" add_test "llama-mtmd-cli" "second-state/Llava-v1.5-7B-GGUF:Q2_K" "vicuna" -add_test "llama-mtmd-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K" "vicuna" +add_test "llama-mtmd-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K_M" "vicuna" add_test "llama-mtmd-cli" "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # model from openbmb is corrupted add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K" @@ -60,10 +67,17 @@ if [ "$RUN_BIG_TESTS" = true ]; then add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M" - add_test "llama-mtmd-cli" "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M" - add_test "llama-mtmd-cli" "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M" # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra - # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" # too big +fi + +# to test the huge models, run: ./tests.sh huge +# this will run both the big and huge models +# huge models are > 32B parameters +if [ "$RUN_HUGE_TESTS" = true ]; then + add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S" fi # these models always give the wrong answer, not sure why