mtmd : Support Pixtral 12B (#13065)

* add pixtral text model (vision is wip)

* cgraph ok, just missing 2D RoPE

* fix bad rebase

* first working version

* fix problem with img_break token

* support dynamic image size

* update docs

* update test script
This commit is contained in:
Xuan-Son Nguyen
2025-04-23 20:21:59 +02:00
committed by GitHub
parent eb1776b15a
commit ecda2ec4b3
14 changed files with 643 additions and 31 deletions

View File

@ -163,7 +163,8 @@ struct clip_hparams {
patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
float eps;
float eps = 1e-6;
float rope_theta = 0.0;
std::vector<int32_t> image_grid_pinpoints;
int32_t image_crop_resolution;
@ -187,11 +188,17 @@ struct clip_layer {
struct ggml_tensor * ln_1_b = nullptr;
// ff
struct ggml_tensor * ff_i_w = nullptr;
struct ggml_tensor * ff_i_b = nullptr;
struct ggml_tensor * ff_i_w = nullptr; // legacy naming
struct ggml_tensor * ff_i_b = nullptr; // legacy naming
struct ggml_tensor * ff_o_w = nullptr; // legacy naming
struct ggml_tensor * ff_o_b = nullptr; // legacy naming
struct ggml_tensor * ff_o_w = nullptr;
struct ggml_tensor * ff_o_b = nullptr;
struct ggml_tensor * ff_up_w = nullptr;
struct ggml_tensor * ff_up_b = nullptr;
struct ggml_tensor * ff_gate_w = nullptr;
struct ggml_tensor * ff_gate_b = nullptr;
struct ggml_tensor * ff_down_w = nullptr;
struct ggml_tensor * ff_down_b = nullptr;
// layernorm 2
struct ggml_tensor * ln_2_w = nullptr;
@ -297,6 +304,9 @@ struct clip_vision_model {
// gemma3
struct ggml_tensor * mm_input_proj_w = nullptr;
struct ggml_tensor * mm_soft_emb_norm_w = nullptr;
// pixtral
struct ggml_tensor * token_embd_img_break = nullptr;
};
struct clip_ctx {
@ -329,6 +339,7 @@ struct clip_ctx {
ggml_backend_t backend_cpu;
ggml_backend_buffer_ptr buf;
int max_nodes = 8192;
ggml_backend_sched_ptr sched;
clip_image_size load_image_size;
@ -544,6 +555,218 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
return gf;
}
// implementation of the 2D RoPE without adding a new op in ggml
static ggml_tensor * build_rope_2d(
ggml_cgraph * gf,
ggml_context * ctx0,
ggml_tensor * cur,
ggml_tensor * pos_h,
ggml_tensor * pos_w,
const float freq_base
) {
ggml_tensor * tmp;
const int64_t n_dim = cur->ne[0];
const int64_t n_head = cur->ne[1];
const int64_t n_pos = cur->ne[2];
// for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos)
// we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
// first half of cur will use 1e-0, 1e-2 (even)
// second half of cur will use 1e-1, 1e-3 (odd)
//
// for the first half, the trick here is to rotate n_dim/2, so inv_freq will be even
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
// then for the second half, we use freq_scale to shift the inv_freq
// ^ why? replace (2i) with (2i+1) in the above equation
const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
// first half
{
cur = ggml_rope_ext_inplace(
ctx0,
cur,
pos_h, // positions
nullptr, // freq factors
n_dim/2, // n_dims
0, 0, freq_base,
1.0f, 0.0f, 1.0f, 0.0f, 0.0f
);
}
// second half
{
tmp = ggml_view_3d(ctx0, cur,
n_dim/2, n_head, n_pos,
ggml_row_size(cur->type, n_dim),
ggml_row_size(cur->type, n_dim*n_head),
n_dim/2 * ggml_element_size(cur));
tmp = ggml_rope_ext_inplace(
ctx0,
tmp,
pos_w, // positions
nullptr, // freq factors
n_dim/2, // n_dims
0, 0, freq_base,
freq_scale_odd,
0.0f, 1.0f, 0.0f, 0.0f
);
// calculate inplace (modify cur directly)
ggml_build_forward_expand(gf, tmp);
}
return cur;
}
static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
GGML_ASSERT(ctx->proj_type == PROJECTOR_TYPE_PIXTRAL);
GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
int image_size_width = imgs.entries[0]->nx;
int image_size_height = imgs.entries[0]->ny;
const int patch_size = hparams.patch_size;
const int n_patches_x = image_size_width / patch_size;
const int n_patches_y = image_size_height / patch_size;
const int num_patches = n_patches_x * n_patches_y;
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(),
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx0_ptr(ggml_init(params));
auto ctx0 = ctx0_ptr.get();
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
// input raw
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
// 2D input positions
struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
ggml_set_name(pos_h, "pos_h");
ggml_set_input(pos_h);
struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
ggml_set_name(pos_w, "pos_w");
ggml_set_input(pos_w);
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
struct ggml_tensor * embeddings = inp;
// pre-layer norm
embeddings = ggml_mul(ctx0, ggml_rms_norm(ctx0, embeddings, eps), model.pre_ln_w);
// loop over layers
for (int il = 0; il < n_layer; il++) {
struct ggml_tensor * cur = embeddings;
// pre-attention norm
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_1_w);
// self-attention
{
struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
Q = build_rope_2d(gf, ctx0, Q, pos_h, pos_w, hparams.rope_theta);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
K = build_rope_2d(gf, ctx0, K, pos_h, pos_w, hparams.rope_theta);
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur);
}
// re-add the layer input, e.g., residual
cur = ggml_add(ctx0, cur, embeddings);
embeddings = cur; // embeddings = residual, cur = hidden_states
// pre-ffn norm
cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_2_w);
// feed-forward
{
ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu
cur = ggml_mul(ctx0, up_proj, gate_proj);
cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
}
// residual 2
cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur;
}
// LlavaMultiModalProjector (with GELU activation)
{
embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
embeddings = ggml_gelu(ctx0, embeddings);
embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
}
// arrangement of the [IMG_BREAK] token
{
// not efficient, but works
// the trick is to view the embeddings as a 3D tensor with shape [hidden_size, n_patches_per_row, n_rows]
// and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
// after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
const int n_embd_text = embeddings->ne[0];
const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row
ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y);
ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y);
tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
tok = ggml_add(ctx0, tok, model.token_embd_img_break);
cur = ggml_concat(ctx0, cur, tok, 1);
embeddings = ggml_view_2d(ctx0, cur,
n_embd_text, n_tokens_output,
ggml_row_size(cur->type, n_embd_text), 0);
}
// build the graph
ggml_build_forward_expand(gf, embeddings);
return gf;
}
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
if (!ctx->has_vision_encoder) {
LOG_ERR("This gguf file seems to have no vision encoder\n");
@ -1118,6 +1341,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
{
res = clip_image_build_graph_siglip(ctx, imgs);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
res = clip_image_build_graph_pixtral(ctx, imgs);
} break;
default:
{
// TODO: we should have one build_* function per model
@ -1279,6 +1506,10 @@ struct clip_model_loader {
{
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
hparams.rope_theta = 10000.0f;
} break;
default:
break;
}
@ -1350,16 +1581,26 @@ struct clip_model_loader {
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
layer.ff_o_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
layer.ff_o_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
// new naming
layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), false);
layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
// legacy naming (the in and out is reversed! don't ask me why)
layer.ff_i_w = layer.ff_down_w;
layer.ff_o_w = layer.ff_up_w;
layer.ff_i_b = layer.ff_down_b;
layer.ff_o_b = layer.ff_up_b;
}
switch (ctx_clip.proj_type) {
@ -1475,6 +1716,15 @@ struct clip_model_loader {
{
vision_model.projection = get_tensor(TN_MM_PROJECTOR);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
// [IMG_BREAK] token embedding
vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
} break;
default:
GGML_ASSERT(false && "unknown projector type");
}
@ -1517,18 +1767,17 @@ struct clip_model_loader {
}
void alloc_compute_meta() {
ctx_clip.buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
// create a fake batch
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
clip_image_size image_size;
image_size.width = clip_get_image_size(&ctx_clip);
image_size.height = clip_get_image_size(&ctx_clip);
int n_patches = clip_get_image_size(&ctx_clip) / image_size.width;
img->nx = n_patches;
img->ny = n_patches;
img->buf.resize(n_patches * image_size.width * image_size.height * 3);
image_size.width = ctx_clip.vision_model.hparams.image_size;
image_size.height = ctx_clip.vision_model.hparams.image_size;
img->nx = image_size.width;
img->ny = image_size.height;
img->buf.resize(image_size.width * image_size.height * 3);
batch.entries.push_back(std::move(img));
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false);
@ -1916,6 +2165,26 @@ struct image_manipulation {
}
}
// calculate the size of the **resized** image, while preserving the aspect ratio
// the calculated size will be aligned to the nearest multiple of align_size
// if H or W size is larger than max_dimension, it will be resized to max_dimension
static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int max_dimension) {
if (inp_size.width <= 0 || inp_size.height <= 0 || align_size <= 0 || max_dimension <= 0) {
return {0, 0};
}
float scale = std::min(1.0f, std::min(static_cast<float>(max_dimension) / inp_size.width,
static_cast<float>(max_dimension) / inp_size.height));
float target_width_f = static_cast<float>(inp_size.width) * scale;
float target_height_f = static_cast<float>(inp_size.height) * scale;
int aligned_width = GGML_PAD((int)target_width_f, align_size);
int aligned_height = GGML_PAD((int)target_height_f, align_size);
return {aligned_width, aligned_height};
}
private:
static inline int clip(int x, int lower, int upper) {
return std::max(lower, std::min(x, upper));
@ -2247,8 +2516,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
if (ctx->has_glm_projector
else if (ctx->has_glm_projector
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
clip_image_u8 resized_image;
@ -2260,6 +2528,15 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
clip_image_u8 resized_image;
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
clip_image_f32_ptr img_f32(clip_image_f32_init());
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
@ -2387,6 +2664,10 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
n_patches = 256;
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
} else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
int n_patches_x = img->nx / params.patch_size;
int n_patches_y = img->ny / params.patch_size;
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
}
return n_patches;
@ -2540,10 +2821,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
float * data = (float *)malloc(ggml_nbytes(inp_raw));
// TODO @ngxson : this whole code block is ugly, will need to be refactored
for (size_t i = 0; i < imgs.entries.size(); i++) {
const int nx = imgs.entries[i]->nx;
const int ny = imgs.entries[i]->ny;
if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
if (ctx->has_glm_projector
|| ctx->has_llava_projector
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
GGML_ASSERT(nx == image_size && ny == image_size);
}
@ -2657,6 +2943,24 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
// do nothing
}
else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
// set the 2D positions
int n_patches_per_col = image_size_width / patch_size;
std::vector<int> pos_data(num_positions);
struct ggml_tensor * pos;
// dimension H
pos = ggml_graph_get_tensor(gf, "pos_h");
for (int i = 0; i < num_positions; i++) {
pos_data[i] = i / n_patches_per_col;
}
ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
// dimension W
pos = ggml_graph_get_tensor(gf, "pos_w");
for (int i = 0; i < num_positions; i++) {
pos_data[i] = i % n_patches_per_col;
}
ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
}
else {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
@ -2849,6 +3153,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_LDPV2:
return ctx->vision_model.mm_model_peg_0_b->ne[0];
case PROJECTOR_TYPE_MLP:
case PROJECTOR_TYPE_PIXTRAL:
return ctx->vision_model.mm_2_b->ne[0];
case PROJECTOR_TYPE_MLP_NORM:
return ctx->vision_model.mm_3_b->ne[0];