diff --git a/examples/llava/clip-impl.h b/examples/llava/clip-impl.h index 53ac38130..16d0a8efc 100644 --- a/examples/llava/clip-impl.h +++ b/examples/llava/clip-impl.h @@ -17,22 +17,15 @@ #define KEY_FTYPE "general.file_type" #define KEY_NAME "general.name" #define KEY_DESCRIPTION "general.description" -#define KEY_HAS_TEXT_ENC "clip.has_text_encoder" -#define KEY_HAS_VIS_ENC "clip.has_vision_encoder" -#define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" -#define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" -#define KEY_HAS_GLM_PROJ "clip.has_glm_projector" #define KEY_MINICPMV_VERSION "clip.minicpmv_version" -#define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" #define KEY_USE_GELU "clip.use_gelu" #define KEY_USE_SILU "clip.use_silu" -#define KEY_N_EMBD "clip.%s.embedding_length" -#define KEY_N_FF "clip.%s.feed_forward_length" -#define KEY_N_BLOCK "clip.%s.block_count" -#define KEY_N_HEAD "clip.%s.attention.head_count" -#define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" -#define KEY_PROJ_DIM "clip.%s.projection_dim" -#define KEY_TOKENS "tokenizer.ggml.tokens" +#define KEY_N_EMBD "clip.vision.embedding_length" +#define KEY_N_FF "clip.vision.feed_forward_length" +#define KEY_N_BLOCK "clip.vision.block_count" +#define KEY_N_HEAD "clip.vision.attention.head_count" +#define KEY_LAYER_NORM_EPS "clip.vision.attention.layer_norm_epsilon" +#define KEY_PROJ_DIM "clip.vision.projection_dim" #define KEY_IMAGE_SIZE "clip.vision.image_size" #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" @@ -96,9 +89,9 @@ enum projector_type { PROJECTOR_TYPE_MLP_NORM, PROJECTOR_TYPE_LDP, PROJECTOR_TYPE_LDPV2, - PROJECTOR_TYPE_RESAMPLER, + PROJECTOR_TYPE_MINICPMV, PROJECTOR_TYPE_GLM_EDGE, - PROJECTOR_TYPE_MERGER, + PROJECTOR_TYPE_QWEN2VL, PROJECTOR_TYPE_GEMMA3, PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, @@ -109,9 +102,9 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_MLP, "mlp" }, { PROJECTOR_TYPE_LDP, "ldp" }, { PROJECTOR_TYPE_LDPV2, "ldpv2"}, - { PROJECTOR_TYPE_RESAMPLER, "resampler"}, + { PROJECTOR_TYPE_MINICPMV, "resampler"}, { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, - { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"}, + { PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"}, { PROJECTOR_TYPE_GEMMA3, "gemma3"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index da8a590f0..e8c01c68a 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -308,13 +308,8 @@ struct clip_vision_model { }; struct clip_ctx { - bool has_text_encoder = false; - bool has_vision_encoder = false; bool has_llava_projector = false; - bool has_minicpmv_projector = false; - bool has_glm_projector = false; - bool has_qwen2vl_merger = false; - int minicpmv_version = 2; + int minicpmv_version = 0; struct clip_vision_model vision_model; projector_type proj_type = PROJECTOR_TYPE_MLP; @@ -373,23 +368,20 @@ struct clip_ctx { } }; -static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch & imgs) { +static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32 & img) { const auto & model = ctx->vision_model; const auto & hparams = model.hparams; - const int image_size = hparams.image_size; - int image_size_width = image_size; - int image_size_height = image_size; + int image_size_width = img.nx; + int image_size_height = img.ny; - const int patch_size = hparams.patch_size; - const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); - const int hidden_size = hparams.hidden_size; - const int n_head = hparams.n_head; - const int d_head = hidden_size / n_head; - const int n_layer = hparams.n_layer; - const float eps = hparams.eps; - - GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1 + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); + const int hidden_size = hparams.hidden_size; + const int n_head = hparams.n_head; + const int d_head = hidden_size / n_head; + const int n_layer = hparams.n_layer; + const float eps = hparams.eps; struct ggml_init_params params = { /*.mem_size =*/ ctx->buf_compute_meta.size(), @@ -621,15 +613,14 @@ static ggml_tensor * build_rope_2d( return cur; } -static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32_batch & imgs) { +static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32 & img) { const auto & model = ctx->vision_model; const auto & hparams = model.hparams; GGML_ASSERT(ctx->proj_type == PROJECTOR_TYPE_PIXTRAL); - GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1 - int image_size_width = imgs.entries[0]->nx; - int image_size_height = imgs.entries[0]->ny; + int image_size_width = img.nx; + int image_size_height = img.ny; const int patch_size = hparams.patch_size; const int n_patches_x = image_size_width / patch_size; @@ -772,18 +763,14 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i } static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) { - if (!ctx->has_vision_encoder) { - LOG_ERR("This gguf file seems to have no vision encoder\n"); - return nullptr; - } - const auto & model = ctx->vision_model; const auto & hparams = model.hparams; const int image_size = hparams.image_size; int image_size_width = image_size; int image_size_height = image_size; - if (ctx->has_minicpmv_projector) { + + if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { LOG_DBG("%s: %d %d\n", __func__, load_image_size.width, load_image_size.height); image_size_width = load_image_size.width; image_size_height = load_image_size.height; @@ -792,7 +779,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im image_size_height = imgs.entries[0]->ny; } } - else if (ctx->has_qwen2vl_merger) { + + else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { // use the image's native resolution when image is avaible if (is_inf) { // if (imgs->data->nx && imgs->data->ny) { @@ -800,12 +788,13 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im image_size_height = imgs.entries[0]->ny; } } + const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int patches_w = image_size_width / patch_size; const int patches_h = image_size_height / patch_size; const int num_positions = num_patches + (model.class_embedding ? 1 : 0); - const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions; + const int num_position_ids = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL ? num_positions * 4 : num_positions; const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; @@ -814,7 +803,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im const int batch_size = imgs.entries.size(); - if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) { + if (ctx->has_llava_projector + || ctx->proj_type == PROJECTOR_TYPE_MINICPMV + || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { GGML_ASSERT(batch_size == 1); } @@ -835,8 +826,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - if (ctx->has_qwen2vl_merger) { - GGML_ASSERT(image_size_width % (patch_size * 2) == 0); + if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { + GGML_ASSERT(image_size_width % (patch_size * 2) == 0); GGML_ASSERT(image_size_height % (patch_size * 2) == 0); auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); @@ -865,29 +856,26 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im struct ggml_tensor * embeddings = inp; struct ggml_tensor * pos_embed = nullptr; - if (ctx->has_llava_projector) { - // concat class_embeddings and patch_embeddings - if (model.class_embedding) { - embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); - ggml_set_name(embeddings, "embeddings"); - ggml_set_input(embeddings); - embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, - embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); - embeddings = ggml_acc(ctx0, embeddings, inp, - embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); - } + // concat class_embeddings and patch_embeddings + if (model.class_embedding) { + embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); + embeddings = ggml_scale(ctx0, embeddings, 0.0f); // set to all zeros + embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); + embeddings = ggml_acc(ctx0, embeddings, inp, + embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); } struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); ggml_set_name(positions, "positions"); ggml_set_input(positions); - if (!ctx->has_qwen2vl_merger) { // qwen2vl use rope position embedding + if (ctx->proj_type != PROJECTOR_TYPE_QWEN2VL) { // qwen2vl does NOT use learned position embeddings embeddings = ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); } - if (ctx->has_minicpmv_projector) { + if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { int pos_w = image_size_width/patch_size; int pos_h = image_size_height/patch_size; if (ctx->minicpmv_version == 2) { @@ -941,7 +929,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); - if (ctx->has_qwen2vl_merger) { + if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { Q = ggml_rope_multi( ctx0, Q, positions, nullptr, d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); @@ -953,7 +941,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); - if (ctx->has_qwen2vl_merger) { + if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { K = ggml_rope_multi( ctx0, K, positions, nullptr, d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); @@ -1218,106 +1206,98 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im } } // minicpmv projector - else if (ctx->has_minicpmv_projector) - { - if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { - struct ggml_tensor * q = model.mm_model_query; - { // layernorm - q = ggml_norm(ctx0, q, eps); - q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b); - } - struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings); - { // layernorm - v = ggml_norm(ctx0, v, eps); - v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b); - } - struct ggml_tensor * k; - { // position - // q = ggml_add(ctx0, q, model.mm_model_pos_embed); - k = ggml_add(ctx0, v, pos_embed); - } - - { // attention - int hidden_size = 4096; - const int d_head = 128; - int n_head = hidden_size/d_head; - int num_query = 96; - if (ctx->minicpmv_version == 2) { - hidden_size = 4096; - n_head = hidden_size/d_head; - num_query = 96; - } - else if (ctx->minicpmv_version == 3) { - hidden_size = 3584; - n_head = hidden_size/d_head; - num_query = 64; - } - else if (ctx->minicpmv_version == 4) { - hidden_size = 3584; - n_head = hidden_size/d_head; - num_query = 64; - } - - struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); - struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b); - struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b); - // permute - Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size); - Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); - Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size); - K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); - K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); - K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); - V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); - V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); - V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); - KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size); - KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size); - - embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b); - } - { // layernorm - embeddings = ggml_norm(ctx0, embeddings, eps); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b); - } - embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings); + else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { + struct ggml_tensor * q = model.mm_model_query; + { // layernorm + q = ggml_norm(ctx0, q, eps); + q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b); } - else { - GGML_ASSERT(false); + struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings); + { // layernorm + v = ggml_norm(ctx0, v, eps); + v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b); } + struct ggml_tensor * k; + { // position + // q = ggml_add(ctx0, q, model.mm_model_pos_embed); + k = ggml_add(ctx0, v, pos_embed); + } + + { // attention + int hidden_size = 4096; + const int d_head = 128; + int n_head = hidden_size/d_head; + int num_query = 96; + if (ctx->minicpmv_version == 2) { + hidden_size = 4096; + n_head = hidden_size/d_head; + num_query = 96; + } + else if (ctx->minicpmv_version == 3) { + hidden_size = 3584; + n_head = hidden_size/d_head; + num_query = 64; + } + else if (ctx->minicpmv_version == 4) { + hidden_size = 3584; + n_head = hidden_size/d_head; + num_query = 64; + } + + struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); + struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b); + struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b); + // permute + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size); + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size); + + embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b); + } + { // layernorm + embeddings = ggml_norm(ctx0, embeddings, eps); + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b); + } + embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings); } + // glm projector - else if (ctx->has_glm_projector) { - if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { - size_t gridsz = (size_t)sqrt(embeddings->ne[1]); - embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3)); - embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); - embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1); - embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size); - embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3)); - embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b); - //GLU - { - embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); - embeddings = ggml_norm(ctx0, embeddings, eps); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b); - embeddings = ggml_gelu_inplace(ctx0, embeddings); - struct ggml_tensor * x = embeddings; - embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); - x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); - embeddings = ggml_silu_inplace(ctx0, embeddings); - embeddings = ggml_mul(ctx0, embeddings,x); - embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); - } - } else { - GGML_ABORT("fatal error"); + else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { + size_t gridsz = (size_t)sqrt(embeddings->ne[1]); + embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3)); + embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); + embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1); + embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size); + embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3)); + embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b); + // GLU + { + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); + embeddings = ggml_norm(ctx0, embeddings, eps); + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b); + embeddings = ggml_gelu_inplace(ctx0, embeddings); + struct ggml_tensor * x = embeddings; + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); + x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); + embeddings = ggml_silu_inplace(ctx0, embeddings); + embeddings = ggml_mul(ctx0, embeddings,x); + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); } } - else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { + + else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); @@ -1343,11 +1323,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_IDEFICS3: { - res = clip_image_build_graph_siglip(ctx, imgs); + GGML_ASSERT(imgs.entries.size() == 1); + res = clip_image_build_graph_siglip(ctx, *imgs.entries[0]); } break; case PROJECTOR_TYPE_PIXTRAL: { - res = clip_image_build_graph_pixtral(ctx, imgs); + GGML_ASSERT(imgs.entries.size() == 1); + res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]); } break; default: { @@ -1419,8 +1401,8 @@ struct clip_model_loader { auto & hparams = ctx_clip.vision_model.hparams; // projector type + std::string proj_type; { - std::string proj_type; get_string(KEY_PROJ_TYPE, proj_type, false); if (!proj_type.empty()) { ctx_clip.proj_type = clip_projector_type_from_string(proj_type); @@ -1432,33 +1414,27 @@ struct clip_model_loader { // other hparams { - get_bool(KEY_HAS_TEXT_ENC, ctx_clip.has_text_encoder, false); - get_bool(KEY_HAS_VIS_ENC, ctx_clip.has_vision_encoder, false); - GGML_ASSERT(ctx_clip.has_vision_encoder); - GGML_ASSERT(!ctx_clip.has_text_encoder); - - // legacy keys, use KEY_PROJ_TYPE instead - get_bool(KEY_HAS_LLAVA_PROJ, ctx_clip.has_llava_projector, false); - get_bool(KEY_HAS_MINICPMV_PROJ, ctx_clip.has_minicpmv_projector, false); get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); - get_bool(KEY_HAS_GLM_PROJ, ctx_clip.has_glm_projector, false); - get_bool(KEY_HAS_QWEN2VL_MERGER, ctx_clip.has_qwen2vl_merger, false); - // !!! do NOT extend the list above, use KEY_PROJ_TYPE instead get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false); get_bool(KEY_USE_SILU, ctx_clip.use_silu, false); - get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size); - get_u32(string_format(KEY_N_HEAD, "vision"), hparams.n_head); - get_u32(string_format(KEY_N_FF, "vision"), hparams.n_intermediate); - get_u32(string_format(KEY_N_BLOCK, "vision"), hparams.n_layer); - get_u32(string_format(KEY_PROJ_DIM, "vision"), hparams.projection_dim); - get_f32(string_format(KEY_LAYER_NORM_EPS, "vision"), hparams.eps); - get_u32(KEY_IMAGE_SIZE, hparams.image_size); - get_u32(KEY_PATCH_SIZE, hparams.patch_size); - get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); + get_u32(KEY_N_EMBD, hparams.hidden_size); + get_u32(KEY_N_HEAD, hparams.n_head); + get_u32(KEY_N_FF, hparams.n_intermediate); + get_u32(KEY_N_BLOCK, hparams.n_layer); + get_u32(KEY_PROJ_DIM, hparams.projection_dim); + get_f32(KEY_LAYER_NORM_EPS, hparams.eps); + get_u32(KEY_IMAGE_SIZE, hparams.image_size); + get_u32(KEY_PATCH_SIZE, hparams.patch_size); + get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); + ctx_clip.has_llava_projector = ctx_clip.proj_type == PROJECTOR_TYPE_MLP + || ctx_clip.proj_type == PROJECTOR_TYPE_MLP_NORM + || ctx_clip.proj_type == PROJECTOR_TYPE_LDP + || ctx_clip.proj_type == PROJECTOR_TYPE_LDPV2; + { std::string mm_patch_merge_type; get_string(KEY_MM_PATCH_MERGE_TYPE, mm_patch_merge_type, false); @@ -1491,32 +1467,56 @@ struct clip_model_loader { for (auto & layer : vision_feature_layer) { hparams.vision_feature_layer.insert(layer); } - // Calculate the deepest feature layer based on hparams and projector type - ctx_clip.max_feature_layer = get_deepest_feature_layer(&ctx_clip); - LOG_INF("%s: text_encoder: %d\n", __func__, ctx_clip.has_text_encoder); - LOG_INF("%s: vision_encoder: %d\n", __func__, ctx_clip.has_vision_encoder); - LOG_INF("%s: llava_projector: %d\n", __func__, ctx_clip.has_llava_projector); - LOG_INF("%s: minicpmv_projector: %d\n", __func__, ctx_clip.has_minicpmv_projector); + // Calculate the deepest feature layer based on hparams and projector type + // NOTE: This is only used by build_graph_legacy() + { + // Get the index of the second to last layer; this is the default for models that have a llava projector + int n_layer = hparams.n_layer - 1; + int deepest_feature_layer = -1; + + if (ctx_clip.proj_type == PROJECTOR_TYPE_MINICPMV + || ctx_clip.proj_type == PROJECTOR_TYPE_GLM_EDGE + || ctx_clip.proj_type == PROJECTOR_TYPE_QWEN2VL) { + n_layer += 1; + } + + // If we set explicit vision feature layers, only go up to the deepest one + // NOTE: only used by granite-vision models for now + for (const auto & feature_layer : hparams.vision_feature_layer) { + if (feature_layer > deepest_feature_layer) { + deepest_feature_layer = feature_layer; + } + } + ctx_clip.max_feature_layer = deepest_feature_layer < 0 ? n_layer : deepest_feature_layer; + } + + // model-specific params + switch (ctx_clip.proj_type) { + case PROJECTOR_TYPE_MINICPMV: + { + if (ctx_clip.minicpmv_version == 0) { + ctx_clip.minicpmv_version = 2; // default to 2 if not set + } + } break; + case PROJECTOR_TYPE_IDEFICS3: + { + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false); + } break; + case PROJECTOR_TYPE_PIXTRAL: + { + hparams.rope_theta = 10000.0f; + } break; + default: + break; + } + + LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str()); + LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector); LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version); - LOG_INF("%s: glm_projector: %d\n", __func__, ctx_clip.has_glm_projector); LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0); } - - // model-specific params - switch (ctx_clip.proj_type) { - case PROJECTOR_TYPE_IDEFICS3: - { - get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false); - } break; - case PROJECTOR_TYPE_PIXTRAL: - { - hparams.rope_theta = 10000.0f; - } break; - default: - break; - } } void load_tensors() { @@ -1569,9 +1569,6 @@ struct clip_model_loader { vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false); vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false); vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false); - if (vision_model.patch_embeddings_1 == nullptr) { - ctx_clip.has_qwen2vl_merger = false; - } vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false); @@ -1669,7 +1666,7 @@ struct clip_model_loader { vision_model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight")); vision_model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias")); } break; - case PROJECTOR_TYPE_RESAMPLER: + case PROJECTOR_TYPE_MINICPMV: { // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); vision_model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K); @@ -1702,7 +1699,7 @@ struct clip_model_loader { vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE,"weight")); vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); } break; - case PROJECTOR_TYPE_MERGER: + case PROJECTOR_TYPE_QWEN2VL: { vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); @@ -2479,11 +2476,6 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { - if (!ctx->has_vision_encoder) { - LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); - return false; - } - clip_image_size original_size{img->nx, img->ny}; bool pad_to_square = true; auto & params = ctx->vision_model.hparams; @@ -2504,7 +2496,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } return true; } - else if (ctx->has_qwen2vl_merger) { + else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { clip_image_u8 resized; auto patch_size = clip_get_patch_size(ctx) * 2; int nx = ceil((float)img->nx / patch_size) * patch_size; @@ -2518,7 +2510,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(img_f32)); return true; } - else if (ctx->has_glm_projector + else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE || ctx->proj_type == PROJECTOR_TYPE_GEMMA3 || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { clip_image_u8 resized_image; @@ -2646,7 +2638,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { n_patches /= 4; - } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { + } else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { if (ctx->minicpmv_version == 2) { n_patches = 96; } @@ -2656,7 +2648,10 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i else if (ctx->minicpmv_version == 4) { n_patches = 64; } - } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { + else { + GGML_ABORT("Unknown minicpmv version"); + } + } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { int patch_size = params.patch_size * 2; int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); @@ -2761,11 +2756,6 @@ static std::vector> get_2d_sincos_pos_embed(int embed_dim, co } bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) { - if (!ctx->has_vision_encoder) { - LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); - return false; - } - clip_image_f32_batch imgs; clip_image_f32_ptr img_copy(clip_image_f32_init()); *img_copy = *img; @@ -2776,20 +2766,11 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3 bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) { const clip_image_f32_batch & imgs = *imgs_c_ptr; - - if (!ctx->has_vision_encoder) { - LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); - return false; - } - int batch_size = imgs.entries.size(); - if (ctx->has_llava_projector) { - GGML_ASSERT(batch_size == 1); // TODO: support multiple images - } - if (ctx->has_minicpmv_projector) { - GGML_ASSERT(batch_size == 1); - } - if (ctx->has_glm_projector) { + + if (ctx->has_llava_projector + || ctx->proj_type == PROJECTOR_TYPE_MINICPMV + || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { GGML_ASSERT(batch_size == 1); } @@ -2799,21 +2780,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_sched_alloc_graph(ctx->sched.get(), gf); // set inputs - const auto & model = ctx->vision_model; + const auto & model = ctx->vision_model; const auto & hparams = model.hparams; - // TODO @ngxson : this is ugly, need to refactor later - bool support_dynamic_size = ctx->has_minicpmv_projector - || ctx->has_qwen2vl_merger - || ctx->proj_type == PROJECTOR_TYPE_PIXTRAL; + const int image_size_width = imgs.entries[0]->nx; + const int image_size_height = imgs.entries[0]->ny; - const int image_size = hparams.image_size; - int image_size_width = image_size; - int image_size_height = image_size; - if (support_dynamic_size) { - image_size_width = imgs.entries[0]->nx; - image_size_height = imgs.entries[0]->ny; - } const int patch_size = hparams.patch_size; const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); const int num_positions = num_patches + (model.class_embedding ? 1 : 0); @@ -2839,14 +2811,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (size_t i = 0; i < imgs.entries.size(); i++) { const int nx = imgs.entries[i]->nx; const int ny = imgs.entries[i]->ny; - - if (ctx->has_glm_projector - || ctx->has_llava_projector - || ctx->proj_type == PROJECTOR_TYPE_GEMMA3 - || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { - GGML_ASSERT(nx == image_size && ny == image_size); - } - const int n = nx * ny; for (int b = 0; b < batch_size; b++) { @@ -2864,13 +2828,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); } - if (ctx->has_minicpmv_projector) { + + if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { { // inspired from siglip: // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); - int* positions_data = (int*)malloc(ggml_nbytes(positions)); + std::vector pos_data(ggml_nelements(positions)); + int * data = pos_data.data(); int bucket_coords_h[1024]; int bucket_coords_w[1024]; for (int i = 0; i < pos_h; i++){ @@ -2881,11 +2847,10 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } for (int i = 0, id = 0; i < pos_h; i++){ for (int j = 0; j < pos_w; j++){ - positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; + data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; } } - ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); - free(positions_data); + ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions)); } { @@ -2903,30 +2868,28 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima else if (ctx->minicpmv_version == 4) { embed_dim = 3584; } + else { + GGML_ABORT("Unknown minicpmv version"); + } + + // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos? auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); - float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); - for(int i=0;i < pos_w * pos_h; ++i){ - for(int j=0; j < embed_dim; ++j){ - pos_embed_data[i * embed_dim + j] = pos_embed_t[i][j]; + std::vector pos_data(ggml_nelements(pos_embed)); + float * data = pos_data.data(); + for(int i = 0; i < pos_w * pos_h; ++i){ + for(int j = 0; j < embed_dim; ++j){ + data[i * embed_dim + j] = pos_embed_t[i][j]; } } - ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed)); - free(pos_embed_data); + ggml_backend_tensor_set(pos_embed, data, 0, ggml_nbytes(pos_embed)); } } else { - if (model.class_embedding) { - struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); + // non-minicpmv models - void* zero_mem = malloc(ggml_nbytes(embeddings)); - memset(zero_mem, 0, ggml_nbytes(embeddings)); - ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); - free(zero_mem); - } - - if (ctx->has_qwen2vl_merger) { + if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) { struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); const int pw = image_size_width / patch_size; @@ -2978,6 +2941,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos)); } else { + // llava and other models struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); int* positions_data = (int*)malloc(ggml_nbytes(positions)); @@ -2987,7 +2951,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); - if (!ctx->has_glm_projector) { + if (ctx->proj_type != PROJECTOR_TYPE_GLM_EDGE) { struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); // The patches vector is used to get rows to index into the embeds with; // we should skip dim 0 only if we have CLS to avoid going out of bounds @@ -3166,7 +3130,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->vision_model.mm_2_b->ne[0]; case PROJECTOR_TYPE_MLP_NORM: return ctx->vision_model.mm_3_b->ne[0]; - case PROJECTOR_TYPE_RESAMPLER: + case PROJECTOR_TYPE_MINICPMV: if (ctx->minicpmv_version == 2) { return 4096; } else if (ctx->minicpmv_version == 3) { @@ -3174,36 +3138,33 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { } else if (ctx->minicpmv_version == 4) { return 3584; } - break; // Should not happen if version is valid + GGML_ABORT("Unknown minicpmv version"); case PROJECTOR_TYPE_GLM_EDGE: return ctx->vision_model.mm_model_mlp_3_w->ne[1]; - case PROJECTOR_TYPE_MERGER: + case PROJECTOR_TYPE_QWEN2VL: return ctx->vision_model.mm_1_b->ne[0]; case PROJECTOR_TYPE_GEMMA3: return ctx->vision_model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: return ctx->vision_model.projection->ne[1]; default: - break; // Fall through to throw + GGML_ABORT("Unknown projector type"); } - - std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; - throw std::runtime_error(string_format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); } int clip_is_minicpmv(const struct clip_ctx * ctx) { - if (ctx->has_minicpmv_projector) { + if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { return ctx->minicpmv_version; } return 0; } bool clip_is_glm(const struct clip_ctx * ctx) { - return ctx->has_glm_projector; + return ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE; } bool clip_is_qwen2vl(const struct clip_ctx * ctx) { - return ctx->has_qwen2vl_merger; + return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL; } bool clip_is_llava(const struct clip_ctx * ctx) { @@ -3214,29 +3175,6 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) { return ctx->proj_type == PROJECTOR_TYPE_GEMMA3; } -// Determine the number of encoder layers to iterate over -int get_deepest_feature_layer(const struct clip_ctx * ctx) { - // Get the index of the second to last layer; this is the - // default for models that have a llava projector - const auto & hparams = ctx->vision_model.hparams; - int n_layer = hparams.n_layer - 1; - int deepest_feature_layer = -1; - - // Handle other projectors; incrementing here indicates that we - // should use the last encoder layer for the vision features. - if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) { - n_layer += 1; - } - - // If we set explicit vision feature layers, only go up to the deepest one - for (const auto & feature_layer : hparams.vision_feature_layer) { - if (feature_layer > deepest_feature_layer) { - deepest_feature_layer = feature_layer; - } - } - return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer; -} - bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { clip_image_f32 clip_img; clip_img.buf.resize(h * w * 3); diff --git a/examples/llava/clip.h b/examples/llava/clip.h index 5fc45d3e2..6ba42ad89 100644 --- a/examples/llava/clip.h +++ b/examples/llava/clip.h @@ -114,8 +114,6 @@ CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); CLIP_API bool clip_is_llava(const struct clip_ctx * ctx); CLIP_API bool clip_is_gemma3(const struct clip_ctx * ctx); -CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx); - CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);