mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 03:55:20 +00:00
mtmd : Support Pixtral 12B (#13065)
* add pixtral text model (vision is wip) * cgraph ok, just missing 2D RoPE * fix bad rebase * first working version * fix problem with img_break token * support dynamic image size * update docs * update test script
This commit is contained in:
@ -14,6 +14,28 @@ The naming and structure related to multimodal support have evolved, which might
|
||||
- [#12849](https://github.com/ggml-org/llama.cpp/pull/12849): `libmtmd` was introduced as a replacement for `llava.cpp`. Its goals include providing a single, unified command-line interface, improving the user/developer experience (UX/DX), and supporting both audio and image inputs.
|
||||
- [#13012](https://github.com/ggml-org/llama.cpp/pull/13012): `mtmd-cli` was added, consolidating the various model-specific CLIs into a single tool powered by `libmtmd`.
|
||||
|
||||
## Pre-quantized models
|
||||
|
||||
These are ready-to-use models, most of them come with `Q4_K_M` quantization by default:
|
||||
|
||||
```sh
|
||||
# Gemma 3
|
||||
llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF
|
||||
|
||||
# SmolVLM
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM-256M-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM-500M-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM2-2.2B-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM2-256M-Video-Instruct-GGUF
|
||||
llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF
|
||||
|
||||
# Pixtral 12B
|
||||
llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
|
||||
```
|
||||
|
||||
## How it works and what is `mmproj`?
|
||||
|
||||
Multimodal support in `llama.cpp` works by encoding images into embeddings using a separate model component, and then feeding these embeddings into the language model.
|
||||
@ -45,3 +67,9 @@ Multimodal projector (`mmproj`) files are specific to each model architecture. P
|
||||
- [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
|
||||
- [IBM Granite Vision](../../docs/multimodal/granitevision.md)
|
||||
- [Google Gemma 3](../../docs/multimodal/gemma3.md)
|
||||
|
||||
For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file:
|
||||
- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support
|
||||
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
|
||||
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
|
||||
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
|
||||
|
@ -60,6 +60,7 @@
|
||||
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
|
||||
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
|
||||
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
|
||||
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
|
||||
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
|
||||
#define TN_LN_1 "%s.blk.%d.ln1.%s"
|
||||
#define TN_LN_2 "%s.blk.%d.ln2.%s"
|
||||
@ -73,6 +74,7 @@
|
||||
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
||||
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
||||
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
|
||||
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
|
||||
|
||||
// mimicpmv
|
||||
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
|
||||
@ -101,6 +103,7 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_MERGER,
|
||||
PROJECTOR_TYPE_GEMMA3,
|
||||
PROJECTOR_TYPE_IDEFICS3,
|
||||
PROJECTOR_TYPE_PIXTRAL,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -113,6 +116,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
|
||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
|
||||
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
|
||||
};
|
||||
|
||||
static projector_type clip_projector_type_from_string(const std::string & str) {
|
||||
|
@ -163,7 +163,8 @@ struct clip_hparams {
|
||||
|
||||
patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
|
||||
|
||||
float eps;
|
||||
float eps = 1e-6;
|
||||
float rope_theta = 0.0;
|
||||
|
||||
std::vector<int32_t> image_grid_pinpoints;
|
||||
int32_t image_crop_resolution;
|
||||
@ -187,11 +188,17 @@ struct clip_layer {
|
||||
struct ggml_tensor * ln_1_b = nullptr;
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * ff_i_w = nullptr;
|
||||
struct ggml_tensor * ff_i_b = nullptr;
|
||||
struct ggml_tensor * ff_i_w = nullptr; // legacy naming
|
||||
struct ggml_tensor * ff_i_b = nullptr; // legacy naming
|
||||
struct ggml_tensor * ff_o_w = nullptr; // legacy naming
|
||||
struct ggml_tensor * ff_o_b = nullptr; // legacy naming
|
||||
|
||||
struct ggml_tensor * ff_o_w = nullptr;
|
||||
struct ggml_tensor * ff_o_b = nullptr;
|
||||
struct ggml_tensor * ff_up_w = nullptr;
|
||||
struct ggml_tensor * ff_up_b = nullptr;
|
||||
struct ggml_tensor * ff_gate_w = nullptr;
|
||||
struct ggml_tensor * ff_gate_b = nullptr;
|
||||
struct ggml_tensor * ff_down_w = nullptr;
|
||||
struct ggml_tensor * ff_down_b = nullptr;
|
||||
|
||||
// layernorm 2
|
||||
struct ggml_tensor * ln_2_w = nullptr;
|
||||
@ -297,6 +304,9 @@ struct clip_vision_model {
|
||||
// gemma3
|
||||
struct ggml_tensor * mm_input_proj_w = nullptr;
|
||||
struct ggml_tensor * mm_soft_emb_norm_w = nullptr;
|
||||
|
||||
// pixtral
|
||||
struct ggml_tensor * token_embd_img_break = nullptr;
|
||||
};
|
||||
|
||||
struct clip_ctx {
|
||||
@ -329,6 +339,7 @@ struct clip_ctx {
|
||||
ggml_backend_t backend_cpu;
|
||||
ggml_backend_buffer_ptr buf;
|
||||
|
||||
int max_nodes = 8192;
|
||||
ggml_backend_sched_ptr sched;
|
||||
|
||||
clip_image_size load_image_size;
|
||||
@ -544,6 +555,218 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
|
||||
return gf;
|
||||
}
|
||||
|
||||
// implementation of the 2D RoPE without adding a new op in ggml
|
||||
static ggml_tensor * build_rope_2d(
|
||||
ggml_cgraph * gf,
|
||||
ggml_context * ctx0,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * pos_h,
|
||||
ggml_tensor * pos_w,
|
||||
const float freq_base
|
||||
) {
|
||||
ggml_tensor * tmp;
|
||||
const int64_t n_dim = cur->ne[0];
|
||||
const int64_t n_head = cur->ne[1];
|
||||
const int64_t n_pos = cur->ne[2];
|
||||
|
||||
// for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos)
|
||||
// we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
|
||||
// first half of cur will use 1e-0, 1e-2 (even)
|
||||
// second half of cur will use 1e-1, 1e-3 (odd)
|
||||
//
|
||||
// for the first half, the trick here is to rotate n_dim/2, so inv_freq will be even
|
||||
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
|
||||
// then for the second half, we use freq_scale to shift the inv_freq
|
||||
// ^ why? replace (2i) with (2i+1) in the above equation
|
||||
const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
|
||||
|
||||
// first half
|
||||
{
|
||||
cur = ggml_rope_ext_inplace(
|
||||
ctx0,
|
||||
cur,
|
||||
pos_h, // positions
|
||||
nullptr, // freq factors
|
||||
n_dim/2, // n_dims
|
||||
0, 0, freq_base,
|
||||
1.0f, 0.0f, 1.0f, 0.0f, 0.0f
|
||||
);
|
||||
}
|
||||
|
||||
// second half
|
||||
{
|
||||
tmp = ggml_view_3d(ctx0, cur,
|
||||
n_dim/2, n_head, n_pos,
|
||||
ggml_row_size(cur->type, n_dim),
|
||||
ggml_row_size(cur->type, n_dim*n_head),
|
||||
n_dim/2 * ggml_element_size(cur));
|
||||
tmp = ggml_rope_ext_inplace(
|
||||
ctx0,
|
||||
tmp,
|
||||
pos_w, // positions
|
||||
nullptr, // freq factors
|
||||
n_dim/2, // n_dims
|
||||
0, 0, freq_base,
|
||||
freq_scale_odd,
|
||||
0.0f, 1.0f, 0.0f, 0.0f
|
||||
);
|
||||
// calculate inplace (modify cur directly)
|
||||
ggml_build_forward_expand(gf, tmp);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
|
||||
const auto & model = ctx->vision_model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
GGML_ASSERT(ctx->proj_type == PROJECTOR_TYPE_PIXTRAL);
|
||||
GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
|
||||
|
||||
int image_size_width = imgs.entries[0]->nx;
|
||||
int image_size_height = imgs.entries[0]->ny;
|
||||
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int n_patches_x = image_size_width / patch_size;
|
||||
const int n_patches_y = image_size_height / patch_size;
|
||||
const int num_patches = n_patches_x * n_patches_y;
|
||||
const int hidden_size = hparams.hidden_size;
|
||||
const int n_head = hparams.n_head;
|
||||
const int d_head = hidden_size / n_head;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const float eps = hparams.eps;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx->buf_compute_meta.size(),
|
||||
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
ggml_context_ptr ctx0_ptr(ggml_init(params));
|
||||
auto ctx0 = ctx0_ptr.get();
|
||||
|
||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
// input raw
|
||||
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
|
||||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
|
||||
// 2D input positions
|
||||
struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
|
||||
ggml_set_name(pos_h, "pos_h");
|
||||
ggml_set_input(pos_h);
|
||||
struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
|
||||
ggml_set_name(pos_w, "pos_w");
|
||||
ggml_set_input(pos_w);
|
||||
|
||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
|
||||
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
||||
|
||||
struct ggml_tensor * embeddings = inp;
|
||||
|
||||
// pre-layer norm
|
||||
embeddings = ggml_mul(ctx0, ggml_rms_norm(ctx0, embeddings, eps), model.pre_ln_w);
|
||||
|
||||
// loop over layers
|
||||
for (int il = 0; il < n_layer; il++) {
|
||||
struct ggml_tensor * cur = embeddings;
|
||||
|
||||
// pre-attention norm
|
||||
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_1_w);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);
|
||||
|
||||
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
|
||||
Q = build_rope_2d(gf, ctx0, Q, pos_h, pos_w, hparams.rope_theta);
|
||||
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||
|
||||
struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);
|
||||
|
||||
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
|
||||
K = build_rope_2d(gf, ctx0, K, pos_h, pos_w, hparams.rope_theta);
|
||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||
|
||||
struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
|
||||
|
||||
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
|
||||
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
|
||||
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
||||
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
|
||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur);
|
||||
}
|
||||
|
||||
// re-add the layer input, e.g., residual
|
||||
cur = ggml_add(ctx0, cur, embeddings);
|
||||
|
||||
embeddings = cur; // embeddings = residual, cur = hidden_states
|
||||
|
||||
// pre-ffn norm
|
||||
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_2_w);
|
||||
|
||||
// feed-forward
|
||||
{
|
||||
ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
|
||||
ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
|
||||
gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu
|
||||
cur = ggml_mul(ctx0, up_proj, gate_proj);
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
|
||||
}
|
||||
|
||||
// residual 2
|
||||
cur = ggml_add(ctx0, embeddings, cur);
|
||||
|
||||
embeddings = cur;
|
||||
}
|
||||
|
||||
// LlavaMultiModalProjector (with GELU activation)
|
||||
{
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
|
||||
|
||||
embeddings = ggml_gelu(ctx0, embeddings);
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
|
||||
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
|
||||
}
|
||||
|
||||
// arrangement of the [IMG_BREAK] token
|
||||
{
|
||||
// not efficient, but works
|
||||
// the trick is to view the embeddings as a 3D tensor with shape [hidden_size, n_patches_per_row, n_rows]
|
||||
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
|
||||
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
|
||||
|
||||
const int n_embd_text = embeddings->ne[0];
|
||||
const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row
|
||||
|
||||
ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y);
|
||||
ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y);
|
||||
tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
|
||||
tok = ggml_add(ctx0, tok, model.token_embd_img_break);
|
||||
cur = ggml_concat(ctx0, cur, tok, 1);
|
||||
embeddings = ggml_view_2d(ctx0, cur,
|
||||
n_embd_text, n_tokens_output,
|
||||
ggml_row_size(cur->type, n_embd_text), 0);
|
||||
}
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
|
||||
if (!ctx->has_vision_encoder) {
|
||||
LOG_ERR("This gguf file seems to have no vision encoder\n");
|
||||
@ -1118,6 +1341,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
{
|
||||
res = clip_image_build_graph_siglip(ctx, imgs);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
{
|
||||
res = clip_image_build_graph_pixtral(ctx, imgs);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
// TODO: we should have one build_* function per model
|
||||
@ -1279,6 +1506,10 @@ struct clip_model_loader {
|
||||
{
|
||||
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
{
|
||||
hparams.rope_theta = 10000.0f;
|
||||
} break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@ -1350,16 +1581,26 @@ struct clip_model_loader {
|
||||
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
|
||||
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
|
||||
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
|
||||
layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
|
||||
layer.ff_o_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
|
||||
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
|
||||
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
|
||||
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
|
||||
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
|
||||
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
|
||||
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
|
||||
layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
|
||||
layer.ff_o_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
|
||||
|
||||
// new naming
|
||||
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
|
||||
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
|
||||
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
|
||||
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), false);
|
||||
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
|
||||
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
|
||||
|
||||
// legacy naming (the in and out is reversed! don't ask me why)
|
||||
layer.ff_i_w = layer.ff_down_w;
|
||||
layer.ff_o_w = layer.ff_up_w;
|
||||
layer.ff_i_b = layer.ff_down_b;
|
||||
layer.ff_o_b = layer.ff_up_b;
|
||||
}
|
||||
|
||||
switch (ctx_clip.proj_type) {
|
||||
@ -1475,6 +1716,15 @@ struct clip_model_loader {
|
||||
{
|
||||
vision_model.projection = get_tensor(TN_MM_PROJECTOR);
|
||||
} break;
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
{
|
||||
vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
|
||||
vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
|
||||
vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
|
||||
vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
|
||||
// [IMG_BREAK] token embedding
|
||||
vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown projector type");
|
||||
}
|
||||
@ -1517,18 +1767,17 @@ struct clip_model_loader {
|
||||
}
|
||||
|
||||
void alloc_compute_meta() {
|
||||
ctx_clip.buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
|
||||
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
|
||||
|
||||
// create a fake batch
|
||||
clip_image_f32_batch batch;
|
||||
clip_image_f32_ptr img(clip_image_f32_init());
|
||||
clip_image_size image_size;
|
||||
image_size.width = clip_get_image_size(&ctx_clip);
|
||||
image_size.height = clip_get_image_size(&ctx_clip);
|
||||
int n_patches = clip_get_image_size(&ctx_clip) / image_size.width;
|
||||
img->nx = n_patches;
|
||||
img->ny = n_patches;
|
||||
img->buf.resize(n_patches * image_size.width * image_size.height * 3);
|
||||
image_size.width = ctx_clip.vision_model.hparams.image_size;
|
||||
image_size.height = ctx_clip.vision_model.hparams.image_size;
|
||||
img->nx = image_size.width;
|
||||
img->ny = image_size.height;
|
||||
img->buf.resize(image_size.width * image_size.height * 3);
|
||||
batch.entries.push_back(std::move(img));
|
||||
|
||||
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false);
|
||||
@ -1916,6 +2165,26 @@ struct image_manipulation {
|
||||
}
|
||||
}
|
||||
|
||||
// calculate the size of the **resized** image, while preserving the aspect ratio
|
||||
// the calculated size will be aligned to the nearest multiple of align_size
|
||||
// if H or W size is larger than max_dimension, it will be resized to max_dimension
|
||||
static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int max_dimension) {
|
||||
if (inp_size.width <= 0 || inp_size.height <= 0 || align_size <= 0 || max_dimension <= 0) {
|
||||
return {0, 0};
|
||||
}
|
||||
|
||||
float scale = std::min(1.0f, std::min(static_cast<float>(max_dimension) / inp_size.width,
|
||||
static_cast<float>(max_dimension) / inp_size.height));
|
||||
|
||||
float target_width_f = static_cast<float>(inp_size.width) * scale;
|
||||
float target_height_f = static_cast<float>(inp_size.height) * scale;
|
||||
|
||||
int aligned_width = GGML_PAD((int)target_width_f, align_size);
|
||||
int aligned_height = GGML_PAD((int)target_height_f, align_size);
|
||||
|
||||
return {aligned_width, aligned_height};
|
||||
}
|
||||
|
||||
private:
|
||||
static inline int clip(int x, int lower, int upper) {
|
||||
return std::max(lower, std::min(x, upper));
|
||||
@ -2247,8 +2516,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
res_imgs->entries.push_back(std::move(img_f32));
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ctx->has_glm_projector
|
||||
else if (ctx->has_glm_projector
|
||||
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|
||||
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||
clip_image_u8 resized_image;
|
||||
@ -2260,6 +2528,15 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
|
||||
res_imgs->entries.push_back(std::move(img_f32));
|
||||
return true;
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
||||
clip_image_u8 resized_image;
|
||||
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
|
||||
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
|
||||
clip_image_f32_ptr img_f32(clip_image_f32_init());
|
||||
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
|
||||
res_imgs->entries.push_back(std::move(img_f32));
|
||||
return true;
|
||||
}
|
||||
|
||||
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
|
||||
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
|
||||
@ -2387,6 +2664,10 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
|
||||
n_patches = 256;
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
||||
int n_patches_x = img->nx / params.patch_size;
|
||||
int n_patches_y = img->ny / params.patch_size;
|
||||
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
|
||||
}
|
||||
|
||||
return n_patches;
|
||||
@ -2540,10 +2821,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
|
||||
float * data = (float *)malloc(ggml_nbytes(inp_raw));
|
||||
|
||||
// TODO @ngxson : this whole code block is ugly, will need to be refactored
|
||||
for (size_t i = 0; i < imgs.entries.size(); i++) {
|
||||
const int nx = imgs.entries[i]->nx;
|
||||
const int ny = imgs.entries[i]->ny;
|
||||
if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
|
||||
|
||||
if (ctx->has_glm_projector
|
||||
|| ctx->has_llava_projector
|
||||
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|
||||
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||
GGML_ASSERT(nx == image_size && ny == image_size);
|
||||
}
|
||||
|
||||
@ -2657,6 +2943,24 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
|
||||
// do nothing
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
||||
// set the 2D positions
|
||||
int n_patches_per_col = image_size_width / patch_size;
|
||||
std::vector<int> pos_data(num_positions);
|
||||
struct ggml_tensor * pos;
|
||||
// dimension H
|
||||
pos = ggml_graph_get_tensor(gf, "pos_h");
|
||||
for (int i = 0; i < num_positions; i++) {
|
||||
pos_data[i] = i / n_patches_per_col;
|
||||
}
|
||||
ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
|
||||
// dimension W
|
||||
pos = ggml_graph_get_tensor(gf, "pos_w");
|
||||
for (int i = 0; i < num_positions; i++) {
|
||||
pos_data[i] = i % n_patches_per_col;
|
||||
}
|
||||
ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
|
||||
}
|
||||
else {
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
|
||||
@ -2849,6 +3153,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
case PROJECTOR_TYPE_LDPV2:
|
||||
return ctx->vision_model.mm_model_peg_0_b->ne[0];
|
||||
case PROJECTOR_TYPE_MLP:
|
||||
case PROJECTOR_TYPE_PIXTRAL:
|
||||
return ctx->vision_model.mm_2_b->ne[0];
|
||||
case PROJECTOR_TYPE_MLP_NORM:
|
||||
return ctx->vision_model.mm_3_b->ne[0];
|
||||
|
@ -190,6 +190,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
// https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
|
||||
marker_modified = "<fake_token_around_image><global-img>" + ctx->image_marker + "<fake_token_around_image>";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
|
||||
} else if (proj_type == PROJECTOR_TYPE_PIXTRAL) {
|
||||
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
|
||||
marker_modified = ctx->image_marker + "[IMG_END]";
|
||||
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
|
||||
}
|
||||
|
||||
// llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
|
||||
@ -219,7 +224,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
|
||||
for (auto & entry : batch_f32.entries) {
|
||||
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
|
||||
image_tokens->nx = clip_n_patches(ctx->ctx_clip);
|
||||
image_tokens->nx = clip_n_patches_by_img(ctx->ctx_clip, entry.get());
|
||||
image_tokens->ny = 1;
|
||||
image_tokens->batch_f32.entries.push_back(std::move(entry));
|
||||
image_tokens->id = id;
|
||||
@ -313,8 +318,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
|
||||
}
|
||||
|
||||
} else {
|
||||
size_t n_tokens = 0;
|
||||
for (const auto & entry : batch_f32.entries) {
|
||||
n_tokens += clip_n_patches_by_img(ctx->ctx_clip, entry.get());
|
||||
}
|
||||
|
||||
mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
|
||||
image_tokens->nx = clip_n_patches(ctx->ctx_clip) * batch_f32.entries.size(); // TODO @ngxson : use clip_n_patches_by_image
|
||||
image_tokens->nx = n_tokens;
|
||||
image_tokens->ny = 1; // TODO
|
||||
image_tokens->batch_f32 = std::move(batch_f32);
|
||||
image_tokens->id = bitmaps[i_img].id; // optional
|
||||
@ -382,7 +392,7 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
|
||||
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
|
||||
const auto & entries = image_tokens->batch_f32.entries;
|
||||
for (size_t i = 0; i < entries.size(); i++) {
|
||||
int n_tokens_per_image = clip_n_patches(ctx->ctx_clip);
|
||||
int n_tokens_per_image = clip_n_patches_by_img(ctx->ctx_clip, entries[i].get());
|
||||
ok = clip_image_encode(
|
||||
ctx->ctx_clip,
|
||||
ctx->n_threads,
|
||||
|
@ -13,6 +13,14 @@ mkdir -p $SCRIPT_DIR/output
|
||||
PROJ_ROOT="$SCRIPT_DIR/../.."
|
||||
cd $PROJ_ROOT
|
||||
|
||||
# Check if the first argument is "big", then run test with big models
|
||||
# This is useful if we're running the script on a larger machine, so we can test the big models
|
||||
RUN_BIG_TESTS=false
|
||||
if [ "${1:-}" = "big" ]; then
|
||||
RUN_BIG_TESTS=true
|
||||
echo "Include BIG models..."
|
||||
fi
|
||||
|
||||
###############
|
||||
|
||||
arr_bin=()
|
||||
@ -28,6 +36,12 @@ add_test() {
|
||||
arr_tmpl+=("$tmpl")
|
||||
}
|
||||
|
||||
add_test_big() {
|
||||
if [ "$RUN_BIG_TESTS" = true ]; then
|
||||
add_test "$@"
|
||||
fi
|
||||
}
|
||||
|
||||
add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
|
||||
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
|
||||
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
|
||||
@ -42,6 +56,9 @@ add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
|
||||
add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
|
||||
add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
|
||||
|
||||
# to test the big models, run: ./tests.sh big
|
||||
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
|
||||
|
||||
# these models always give the wrong answer, not sure why
|
||||
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
|
||||
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-256M-Instruct-GGUF:Q8_0"
|
||||
|
Reference in New Issue
Block a user