llama : Add Gemma 3 support (+ experimental vision capability) (#12343)

* llama : Add Gemma 3 text-only support

* fix python coding style

* fix compile on ubuntu

* python: fix style

* fix ubuntu compile

* fix build on ubuntu (again)

* fix ubuntu build, finally

* clip : Experimental support for Gemma 3 vision (#12344)

* clip : Experimental support for Gemma 3 vision

* fix build

* PRId64
This commit is contained in:
Xuan-Son Nguyen
2025-03-12 09:30:24 +01:00
committed by GitHub
parent bf69cfe62f
commit 7841fc723e
11 changed files with 1202 additions and 10 deletions

View File

@ -136,6 +136,8 @@ static std::string format(const char * fmt, ...) {
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline"
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
#define TN_MINICPMV_QUERY "resampler.query"
@ -162,6 +164,7 @@ enum projector_type {
PROJECTOR_TYPE_RESAMPLER,
PROJECTOR_TYPE_GLM_EDGE,
PROJECTOR_TYPE_MERGER,
PROJECTOR_TYPE_GEMMA3,
PROJECTOR_TYPE_UNKNOWN,
};
@ -172,6 +175,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
};
@ -298,7 +302,7 @@ static projector_type clip_projector_type_from_string(const std::string & name)
return kv.first;
}
}
return PROJECTOR_TYPE_UNKNOWN;
throw std::runtime_error(format("Unknown projector type: %s", name.c_str()));
}
#ifdef CLIP_DEBUG_FUNCTIONS
@ -555,6 +559,10 @@ struct clip_vision_model {
struct ggml_tensor * mm_model_ln_kv_b;
struct ggml_tensor * mm_model_ln_post_w;
struct ggml_tensor * mm_model_ln_post_b;
// gemma3
struct ggml_tensor * mm_input_proj_w;
struct ggml_tensor * mm_soft_emb_norm_w;
};
struct clip_ctx {
@ -569,7 +577,7 @@ struct clip_ctx {
struct clip_vision_model vision_model;
projector_type proj_type = PROJECTOR_TYPE_MLP;
int32_t max_feature_layer;
int32_t max_feature_layer; // unused in newer models like gemma3
float image_mean[3];
float image_std[3];
bool use_gelu = false;
@ -595,7 +603,7 @@ struct clip_ctx {
ggml_backend_sched_ptr sched;
struct clip_image_size * load_image_size;
struct clip_image_size * load_image_size = nullptr;
clip_ctx(clip_context_params & ctx_params) {
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
@ -631,7 +639,159 @@ struct clip_ctx {
}
};
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
const int image_size = hparams.image_size;
int image_size_width = image_size;
int image_size_height = image_size;
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
GGML_ASSERT(imgs->size == 1); // batch_size == 1
struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(),
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
// input raw
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
inp = ggml_add(ctx0, inp, model.patch_bias);
// position embeddings
struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings);
// loop over layers
for (int il = 0; il < n_layer; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
// layernorm1
{
cur = ggml_norm(ctx0, cur, eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b);
}
// self-attention
{
struct ggml_tensor * Q =
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
struct ggml_tensor * K =
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
struct ggml_tensor * V =
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf((float)d_head));
KQ = ggml_soft_max_inplace(ctx0, KQ);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
}
// attention output
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
// re-add the layer input, e.g., residual
cur = ggml_add(ctx0, cur, embeddings);
embeddings = cur; // embeddings = residual, cur = hidden_states
// layernorm2
{
cur = ggml_norm(ctx0, cur, eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
}
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
// siglip uses gelu
cur = ggml_gelu(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
// residual 2
cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur;
}
// post-layernorm
if (ctx->has_post_norm) {
embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "post_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
}
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
const int batch_size = 1;
const int mm_tokens_per_image = 256; // default value for gemma3
const int tokens_per_side = sqrt(mm_tokens_per_image);
const int patches_per_image = sqrt(num_patches);
const int kernel_size = patches_per_image / tokens_per_side;
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size);
// doing a pool2d to reduce the number of output tokens to 256
embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size);
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
// apply norm before projection
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
embeddings = ggml_mul(ctx0, embeddings, model.mm_soft_emb_norm_w);
// apply projection
embeddings = ggml_mul_mat(ctx0,
ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
embeddings);
}
// build the graph
ggml_build_forward_expand(gf, embeddings);
ggml_free(ctx0);
return gf;
}
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
if (!ctx->has_vision_encoder) {
LOG_ERR("This gguf file seems to have no vision encoder\n");
return nullptr;
@ -1177,7 +1337,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
} else {
GGML_ABORT("fatel error");
}
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
}
else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
@ -1199,6 +1360,15 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
return gf;
}
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
return clip_image_build_graph_siglip(ctx, imgs);
} else {
// TODO: we should have one build_* function per model
return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
}
}
// read and create ggml_context containing the tensors and their data
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
return clip_init(fname, clip_context_params{
@ -1358,8 +1528,12 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
GGML_ASSERT(new_clip->has_vision_encoder);
GGML_ASSERT(!new_clip->has_text_encoder);
idx = get_key_idx(ctx, KEY_USE_GELU);
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
try {
idx = get_key_idx(ctx, KEY_USE_GELU);
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
} catch (std::runtime_error & /*e*/) {
new_clip->use_gelu = false;
}
try {
idx = get_key_idx(ctx, KEY_USE_SILU);
@ -1567,11 +1741,17 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
}
try {
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
} catch(const std::exception& /*e*/) {
vision_model.patch_embeddings_0 = nullptr;
}
try {
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
} catch(const std::exception& /*e*/) {
LOG_ERR("%s: failed to load vision model tensors\n", __func__);
vision_model.position_embeddings = nullptr;
}
try {
vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1);
} catch(const std::exception& /*e*/) {
@ -1682,6 +1862,10 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
}
else if (new_clip->proj_type == PROJECTOR_TYPE_GEMMA3) {
vision_model.mm_input_proj_w = get_tensor(new_clip->ctx_data, TN_MM_INP_PROJ);
vision_model.mm_soft_emb_norm_w = get_tensor(new_clip->ctx_data, TN_MM_SOFT_EMB_N);
}
else {
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
@ -2223,7 +2407,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
return true;
}
if (ctx->has_glm_projector) {
if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
res_imgs->size = 1;
res_imgs->data = new clip_image_f32[res_imgs->size];
clip_image_u8 resized_image;
@ -2748,6 +2932,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
}
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
// do nothing
}
else {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
@ -2960,6 +3147,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
return ctx->vision_model.mm_1_b->ne[0];
}
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
return ctx->vision_model.mm_input_proj_w->ne[0];
}
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));