clip : improve projector naming (#13118)

* clip : improve projector naming

* no more kv has_llava_projector

* rm unused kv

* rm more unused
This commit is contained in:
Xuan-Son Nguyen
2025-04-26 22:39:47 +02:00
committed by GitHub
parent 77d5e9a76a
commit 4753791e70
3 changed files with 241 additions and 312 deletions

View File

@ -308,13 +308,8 @@ struct clip_vision_model {
};
struct clip_ctx {
bool has_text_encoder = false;
bool has_vision_encoder = false;
bool has_llava_projector = false;
bool has_minicpmv_projector = false;
bool has_glm_projector = false;
bool has_qwen2vl_merger = false;
int minicpmv_version = 2;
int minicpmv_version = 0;
struct clip_vision_model vision_model;
projector_type proj_type = PROJECTOR_TYPE_MLP;
@ -373,23 +368,20 @@ struct clip_ctx {
}
};
static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32 & img) {
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
const int image_size = hparams.image_size;
int image_size_width = image_size;
int image_size_height = image_size;
int image_size_width = img.nx;
int image_size_height = img.ny;
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(),
@ -621,15 +613,14 @@ static ggml_tensor * build_rope_2d(
return cur;
}
static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32 & img) {
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
GGML_ASSERT(ctx->proj_type == PROJECTOR_TYPE_PIXTRAL);
GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
int image_size_width = imgs.entries[0]->nx;
int image_size_height = imgs.entries[0]->ny;
int image_size_width = img.nx;
int image_size_height = img.ny;
const int patch_size = hparams.patch_size;
const int n_patches_x = image_size_width / patch_size;
@ -772,18 +763,14 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
}
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
if (!ctx->has_vision_encoder) {
LOG_ERR("This gguf file seems to have no vision encoder\n");
return nullptr;
}
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
const int image_size = hparams.image_size;
int image_size_width = image_size;
int image_size_height = image_size;
if (ctx->has_minicpmv_projector) {
if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
LOG_DBG("%s: %d %d\n", __func__, load_image_size.width, load_image_size.height);
image_size_width = load_image_size.width;
image_size_height = load_image_size.height;
@ -792,7 +779,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
image_size_height = imgs.entries[0]->ny;
}
}
else if (ctx->has_qwen2vl_merger) {
else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
// use the image's native resolution when image is avaible
if (is_inf) {
// if (imgs->data->nx && imgs->data->ny) {
@ -800,12 +788,13 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
image_size_height = imgs.entries[0]->ny;
}
}
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int patches_w = image_size_width / patch_size;
const int patches_h = image_size_height / patch_size;
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions;
const int num_position_ids = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL ? num_positions * 4 : num_positions;
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
@ -814,7 +803,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
const int batch_size = imgs.entries.size();
if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) {
if (ctx->has_llava_projector
|| ctx->proj_type == PROJECTOR_TYPE_MINICPMV
|| ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
GGML_ASSERT(batch_size == 1);
}
@ -835,8 +826,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
if (ctx->has_qwen2vl_merger) {
GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
@ -865,29 +856,26 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
struct ggml_tensor * embeddings = inp;
struct ggml_tensor * pos_embed = nullptr;
if (ctx->has_llava_projector) {
// concat class_embeddings and patch_embeddings
if (model.class_embedding) {
embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
ggml_set_name(embeddings, "embeddings");
ggml_set_input(embeddings);
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
embeddings = ggml_acc(ctx0, embeddings, inp,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
}
// concat class_embeddings and patch_embeddings
if (model.class_embedding) {
embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
embeddings = ggml_scale(ctx0, embeddings, 0.0f); // set to all zeros
embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
embeddings = ggml_acc(ctx0, embeddings, inp,
embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
}
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
ggml_set_name(positions, "positions");
ggml_set_input(positions);
if (!ctx->has_qwen2vl_merger) { // qwen2vl use rope position embedding
if (ctx->proj_type != PROJECTOR_TYPE_QWEN2VL) { // qwen2vl does NOT use learned position embeddings
embeddings =
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
}
if (ctx->has_minicpmv_projector) {
if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
int pos_w = image_size_width/patch_size;
int pos_h = image_size_height/patch_size;
if (ctx->minicpmv_version == 2) {
@ -941,7 +929,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
if (ctx->has_qwen2vl_merger) {
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
Q = ggml_rope_multi(
ctx0, Q, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
@ -953,7 +941,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
if (ctx->has_qwen2vl_merger) {
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
K = ggml_rope_multi(
ctx0, K, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
@ -1218,106 +1206,98 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
}
}
// minicpmv projector
else if (ctx->has_minicpmv_projector)
{
if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
struct ggml_tensor * q = model.mm_model_query;
{ // layernorm
q = ggml_norm(ctx0, q, eps);
q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
}
struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
{ // layernorm
v = ggml_norm(ctx0, v, eps);
v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
}
struct ggml_tensor * k;
{ // position
// q = ggml_add(ctx0, q, model.mm_model_pos_embed);
k = ggml_add(ctx0, v, pos_embed);
}
{ // attention
int hidden_size = 4096;
const int d_head = 128;
int n_head = hidden_size/d_head;
int num_query = 96;
if (ctx->minicpmv_version == 2) {
hidden_size = 4096;
n_head = hidden_size/d_head;
num_query = 96;
}
else if (ctx->minicpmv_version == 3) {
hidden_size = 3584;
n_head = hidden_size/d_head;
num_query = 64;
}
else if (ctx->minicpmv_version == 4) {
hidden_size = 3584;
n_head = hidden_size/d_head;
num_query = 64;
}
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
// permute
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
}
{ // layernorm
embeddings = ggml_norm(ctx0, embeddings, eps);
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
}
embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
struct ggml_tensor * q = model.mm_model_query;
{ // layernorm
q = ggml_norm(ctx0, q, eps);
q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
}
else {
GGML_ASSERT(false);
struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
{ // layernorm
v = ggml_norm(ctx0, v, eps);
v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
}
struct ggml_tensor * k;
{ // position
// q = ggml_add(ctx0, q, model.mm_model_pos_embed);
k = ggml_add(ctx0, v, pos_embed);
}
{ // attention
int hidden_size = 4096;
const int d_head = 128;
int n_head = hidden_size/d_head;
int num_query = 96;
if (ctx->minicpmv_version == 2) {
hidden_size = 4096;
n_head = hidden_size/d_head;
num_query = 96;
}
else if (ctx->minicpmv_version == 3) {
hidden_size = 3584;
n_head = hidden_size/d_head;
num_query = 64;
}
else if (ctx->minicpmv_version == 4) {
hidden_size = 3584;
n_head = hidden_size/d_head;
num_query = 64;
}
struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
// permute
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
}
{ // layernorm
embeddings = ggml_norm(ctx0, embeddings, eps);
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
}
embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
}
// glm projector
else if (ctx->has_glm_projector) {
if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
//GLU
{
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
embeddings = ggml_norm(ctx0, embeddings, eps);
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
embeddings = ggml_gelu_inplace(ctx0, embeddings);
struct ggml_tensor * x = embeddings;
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
embeddings = ggml_silu_inplace(ctx0, embeddings);
embeddings = ggml_mul(ctx0, embeddings,x);
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
}
} else {
GGML_ABORT("fatal error");
else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
// GLU
{
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
embeddings = ggml_norm(ctx0, embeddings, eps);
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
embeddings = ggml_gelu_inplace(ctx0, embeddings);
struct ggml_tensor * x = embeddings;
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
embeddings = ggml_silu_inplace(ctx0, embeddings);
embeddings = ggml_mul(ctx0, embeddings,x);
embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
}
}
else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
@ -1343,11 +1323,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_IDEFICS3:
{
res = clip_image_build_graph_siglip(ctx, imgs);
GGML_ASSERT(imgs.entries.size() == 1);
res = clip_image_build_graph_siglip(ctx, *imgs.entries[0]);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
res = clip_image_build_graph_pixtral(ctx, imgs);
GGML_ASSERT(imgs.entries.size() == 1);
res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]);
} break;
default:
{
@ -1419,8 +1401,8 @@ struct clip_model_loader {
auto & hparams = ctx_clip.vision_model.hparams;
// projector type
std::string proj_type;
{
std::string proj_type;
get_string(KEY_PROJ_TYPE, proj_type, false);
if (!proj_type.empty()) {
ctx_clip.proj_type = clip_projector_type_from_string(proj_type);
@ -1432,33 +1414,27 @@ struct clip_model_loader {
// other hparams
{
get_bool(KEY_HAS_TEXT_ENC, ctx_clip.has_text_encoder, false);
get_bool(KEY_HAS_VIS_ENC, ctx_clip.has_vision_encoder, false);
GGML_ASSERT(ctx_clip.has_vision_encoder);
GGML_ASSERT(!ctx_clip.has_text_encoder);
// legacy keys, use KEY_PROJ_TYPE instead
get_bool(KEY_HAS_LLAVA_PROJ, ctx_clip.has_llava_projector, false);
get_bool(KEY_HAS_MINICPMV_PROJ, ctx_clip.has_minicpmv_projector, false);
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false);
get_bool(KEY_HAS_GLM_PROJ, ctx_clip.has_glm_projector, false);
get_bool(KEY_HAS_QWEN2VL_MERGER, ctx_clip.has_qwen2vl_merger, false);
// !!! do NOT extend the list above, use KEY_PROJ_TYPE instead
get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size);
get_u32(string_format(KEY_N_HEAD, "vision"), hparams.n_head);
get_u32(string_format(KEY_N_FF, "vision"), hparams.n_intermediate);
get_u32(string_format(KEY_N_BLOCK, "vision"), hparams.n_layer);
get_u32(string_format(KEY_PROJ_DIM, "vision"), hparams.projection_dim);
get_f32(string_format(KEY_LAYER_NORM_EPS, "vision"), hparams.eps);
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_u32(KEY_N_EMBD, hparams.hidden_size);
get_u32(KEY_N_HEAD, hparams.n_head);
get_u32(KEY_N_FF, hparams.n_intermediate);
get_u32(KEY_N_BLOCK, hparams.n_layer);
get_u32(KEY_PROJ_DIM, hparams.projection_dim);
get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
get_u32(KEY_IMAGE_SIZE, hparams.image_size);
get_u32(KEY_PATCH_SIZE, hparams.patch_size);
get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
ctx_clip.has_llava_projector = ctx_clip.proj_type == PROJECTOR_TYPE_MLP
|| ctx_clip.proj_type == PROJECTOR_TYPE_MLP_NORM
|| ctx_clip.proj_type == PROJECTOR_TYPE_LDP
|| ctx_clip.proj_type == PROJECTOR_TYPE_LDPV2;
{
std::string mm_patch_merge_type;
get_string(KEY_MM_PATCH_MERGE_TYPE, mm_patch_merge_type, false);
@ -1491,32 +1467,56 @@ struct clip_model_loader {
for (auto & layer : vision_feature_layer) {
hparams.vision_feature_layer.insert(layer);
}
// Calculate the deepest feature layer based on hparams and projector type
ctx_clip.max_feature_layer = get_deepest_feature_layer(&ctx_clip);
LOG_INF("%s: text_encoder: %d\n", __func__, ctx_clip.has_text_encoder);
LOG_INF("%s: vision_encoder: %d\n", __func__, ctx_clip.has_vision_encoder);
LOG_INF("%s: llava_projector: %d\n", __func__, ctx_clip.has_llava_projector);
LOG_INF("%s: minicpmv_projector: %d\n", __func__, ctx_clip.has_minicpmv_projector);
// Calculate the deepest feature layer based on hparams and projector type
// NOTE: This is only used by build_graph_legacy()
{
// Get the index of the second to last layer; this is the default for models that have a llava projector
int n_layer = hparams.n_layer - 1;
int deepest_feature_layer = -1;
if (ctx_clip.proj_type == PROJECTOR_TYPE_MINICPMV
|| ctx_clip.proj_type == PROJECTOR_TYPE_GLM_EDGE
|| ctx_clip.proj_type == PROJECTOR_TYPE_QWEN2VL) {
n_layer += 1;
}
// If we set explicit vision feature layers, only go up to the deepest one
// NOTE: only used by granite-vision models for now
for (const auto & feature_layer : hparams.vision_feature_layer) {
if (feature_layer > deepest_feature_layer) {
deepest_feature_layer = feature_layer;
}
}
ctx_clip.max_feature_layer = deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
}
// model-specific params
switch (ctx_clip.proj_type) {
case PROJECTOR_TYPE_MINICPMV:
{
if (ctx_clip.minicpmv_version == 0) {
ctx_clip.minicpmv_version = 2; // default to 2 if not set
}
} break;
case PROJECTOR_TYPE_IDEFICS3:
{
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
hparams.rope_theta = 10000.0f;
} break;
default:
break;
}
LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
LOG_INF("%s: glm_projector: %d\n", __func__, ctx_clip.has_glm_projector);
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
}
// model-specific params
switch (ctx_clip.proj_type) {
case PROJECTOR_TYPE_IDEFICS3:
{
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
hparams.rope_theta = 10000.0f;
} break;
default:
break;
}
}
void load_tensors() {
@ -1569,9 +1569,6 @@ struct clip_model_loader {
vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
if (vision_model.patch_embeddings_1 == nullptr) {
ctx_clip.has_qwen2vl_merger = false;
}
vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false);
@ -1669,7 +1666,7 @@ struct clip_model_loader {
vision_model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight"));
vision_model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias"));
} break;
case PROJECTOR_TYPE_RESAMPLER:
case PROJECTOR_TYPE_MINICPMV:
{
// vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
vision_model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K);
@ -1702,7 +1699,7 @@ struct clip_model_loader {
vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE,"weight"));
vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
} break;
case PROJECTOR_TYPE_MERGER:
case PROJECTOR_TYPE_QWEN2VL:
{
vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
@ -2479,11 +2476,6 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
// res_imgs memory is being allocated here, previous allocations will be freed if found
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
if (!ctx->has_vision_encoder) {
LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
return false;
}
clip_image_size original_size{img->nx, img->ny};
bool pad_to_square = true;
auto & params = ctx->vision_model.hparams;
@ -2504,7 +2496,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
}
return true;
}
else if (ctx->has_qwen2vl_merger) {
else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
clip_image_u8 resized;
auto patch_size = clip_get_patch_size(ctx) * 2;
int nx = ceil((float)img->nx / patch_size) * patch_size;
@ -2518,7 +2510,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
else if (ctx->has_glm_projector
else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
clip_image_u8 resized_image;
@ -2646,7 +2638,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
n_patches /= 4;
} else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
} else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
if (ctx->minicpmv_version == 2) {
n_patches = 96;
}
@ -2656,7 +2648,10 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
else if (ctx->minicpmv_version == 4) {
n_patches = 64;
}
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
else {
GGML_ABORT("Unknown minicpmv version");
}
} else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
int patch_size = params.patch_size * 2;
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
@ -2761,11 +2756,6 @@ static std::vector<std::vector<float>> get_2d_sincos_pos_embed(int embed_dim, co
}
bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
if (!ctx->has_vision_encoder) {
LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
return false;
}
clip_image_f32_batch imgs;
clip_image_f32_ptr img_copy(clip_image_f32_init());
*img_copy = *img;
@ -2776,20 +2766,11 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
const clip_image_f32_batch & imgs = *imgs_c_ptr;
if (!ctx->has_vision_encoder) {
LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
return false;
}
int batch_size = imgs.entries.size();
if (ctx->has_llava_projector) {
GGML_ASSERT(batch_size == 1); // TODO: support multiple images
}
if (ctx->has_minicpmv_projector) {
GGML_ASSERT(batch_size == 1);
}
if (ctx->has_glm_projector) {
if (ctx->has_llava_projector
|| ctx->proj_type == PROJECTOR_TYPE_MINICPMV
|| ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
GGML_ASSERT(batch_size == 1);
}
@ -2799,21 +2780,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
// set inputs
const auto & model = ctx->vision_model;
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
// TODO @ngxson : this is ugly, need to refactor later
bool support_dynamic_size = ctx->has_minicpmv_projector
|| ctx->has_qwen2vl_merger
|| ctx->proj_type == PROJECTOR_TYPE_PIXTRAL;
const int image_size_width = imgs.entries[0]->nx;
const int image_size_height = imgs.entries[0]->ny;
const int image_size = hparams.image_size;
int image_size_width = image_size;
int image_size_height = image_size;
if (support_dynamic_size) {
image_size_width = imgs.entries[0]->nx;
image_size_height = imgs.entries[0]->ny;
}
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
@ -2839,14 +2811,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
for (size_t i = 0; i < imgs.entries.size(); i++) {
const int nx = imgs.entries[i]->nx;
const int ny = imgs.entries[i]->ny;
if (ctx->has_glm_projector
|| ctx->has_llava_projector
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
GGML_ASSERT(nx == image_size && ny == image_size);
}
const int n = nx * ny;
for (int b = 0; b < batch_size; b++) {
@ -2864,13 +2828,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
}
if (ctx->has_minicpmv_projector) {
if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
{
// inspired from siglip:
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
int* positions_data = (int*)malloc(ggml_nbytes(positions));
std::vector<int> pos_data(ggml_nelements(positions));
int * data = pos_data.data();
int bucket_coords_h[1024];
int bucket_coords_w[1024];
for (int i = 0; i < pos_h; i++){
@ -2881,11 +2847,10 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
}
for (int i = 0, id = 0; i < pos_h; i++){
for (int j = 0; j < pos_w; j++){
positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
}
}
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
}
{
@ -2903,30 +2868,28 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
else if (ctx->minicpmv_version == 4) {
embed_dim = 3584;
}
else {
GGML_ABORT("Unknown minicpmv version");
}
// TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos?
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
for(int i=0;i < pos_w * pos_h; ++i){
for(int j=0; j < embed_dim; ++j){
pos_embed_data[i * embed_dim + j] = pos_embed_t[i][j];
std::vector<float> pos_data(ggml_nelements(pos_embed));
float * data = pos_data.data();
for(int i = 0; i < pos_w * pos_h; ++i){
for(int j = 0; j < embed_dim; ++j){
data[i * embed_dim + j] = pos_embed_t[i][j];
}
}
ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed));
free(pos_embed_data);
ggml_backend_tensor_set(pos_embed, data, 0, ggml_nbytes(pos_embed));
}
}
else {
if (model.class_embedding) {
struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
// non-minicpmv models
void* zero_mem = malloc(ggml_nbytes(embeddings));
memset(zero_mem, 0, ggml_nbytes(embeddings));
ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
free(zero_mem);
}
if (ctx->has_qwen2vl_merger) {
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
const int pw = image_size_width / patch_size;
@ -2978,6 +2941,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
}
else {
// llava and other models
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
int* positions_data = (int*)malloc(ggml_nbytes(positions));
@ -2987,7 +2951,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
if (!ctx->has_glm_projector) {
if (ctx->proj_type != PROJECTOR_TYPE_GLM_EDGE) {
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
// The patches vector is used to get rows to index into the embeds with;
// we should skip dim 0 only if we have CLS to avoid going out of bounds
@ -3166,7 +3130,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.mm_2_b->ne[0];
case PROJECTOR_TYPE_MLP_NORM:
return ctx->vision_model.mm_3_b->ne[0];
case PROJECTOR_TYPE_RESAMPLER:
case PROJECTOR_TYPE_MINICPMV:
if (ctx->minicpmv_version == 2) {
return 4096;
} else if (ctx->minicpmv_version == 3) {
@ -3174,36 +3138,33 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
} else if (ctx->minicpmv_version == 4) {
return 3584;
}
break; // Should not happen if version is valid
GGML_ABORT("Unknown minicpmv version");
case PROJECTOR_TYPE_GLM_EDGE:
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
case PROJECTOR_TYPE_MERGER:
case PROJECTOR_TYPE_QWEN2VL:
return ctx->vision_model.mm_1_b->ne[0];
case PROJECTOR_TYPE_GEMMA3:
return ctx->vision_model.mm_input_proj_w->ne[0];
case PROJECTOR_TYPE_IDEFICS3:
return ctx->vision_model.projection->ne[1];
default:
break; // Fall through to throw
GGML_ABORT("Unknown projector type");
}
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
throw std::runtime_error(string_format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
}
int clip_is_minicpmv(const struct clip_ctx * ctx) {
if (ctx->has_minicpmv_projector) {
if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
return ctx->minicpmv_version;
}
return 0;
}
bool clip_is_glm(const struct clip_ctx * ctx) {
return ctx->has_glm_projector;
return ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE;
}
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
return ctx->has_qwen2vl_merger;
return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL;
}
bool clip_is_llava(const struct clip_ctx * ctx) {
@ -3214,29 +3175,6 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) {
return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
}
// Determine the number of encoder layers to iterate over
int get_deepest_feature_layer(const struct clip_ctx * ctx) {
// Get the index of the second to last layer; this is the
// default for models that have a llava projector
const auto & hparams = ctx->vision_model.hparams;
int n_layer = hparams.n_layer - 1;
int deepest_feature_layer = -1;
// Handle other projectors; incrementing here indicates that we
// should use the last encoder layer for the vision features.
if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
n_layer += 1;
}
// If we set explicit vision feature layers, only go up to the deepest one
for (const auto & feature_layer : hparams.vision_feature_layer) {
if (feature_layer > deepest_feature_layer) {
deepest_feature_layer = feature_layer;
}
}
return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
}
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img;
clip_img.buf.resize(h * w * 3);