mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 12:05:03 +00:00
llama : Add Gemma 3 support (+ experimental vision capability) (#12343)
* llama : Add Gemma 3 text-only support * fix python coding style * fix compile on ubuntu * python: fix style * fix ubuntu compile * fix build on ubuntu (again) * fix ubuntu build, finally * clip : Experimental support for Gemma 3 vision (#12344) * clip : Experimental support for Gemma 3 vision * fix build * PRId64
This commit is contained in:
@ -136,6 +136,8 @@ static std::string format(const char * fmt, ...) {
|
||||
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
|
||||
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
|
||||
#define TN_IMAGE_NEWLINE "model.image_newline"
|
||||
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
||||
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
||||
|
||||
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
|
||||
#define TN_MINICPMV_QUERY "resampler.query"
|
||||
@ -162,6 +164,7 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_RESAMPLER,
|
||||
PROJECTOR_TYPE_GLM_EDGE,
|
||||
PROJECTOR_TYPE_MERGER,
|
||||
PROJECTOR_TYPE_GEMMA3,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -172,6 +175,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
|
||||
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
|
||||
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
|
||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||
};
|
||||
|
||||
|
||||
@ -298,7 +302,7 @@ static projector_type clip_projector_type_from_string(const std::string & name)
|
||||
return kv.first;
|
||||
}
|
||||
}
|
||||
return PROJECTOR_TYPE_UNKNOWN;
|
||||
throw std::runtime_error(format("Unknown projector type: %s", name.c_str()));
|
||||
}
|
||||
|
||||
#ifdef CLIP_DEBUG_FUNCTIONS
|
||||
@ -555,6 +559,10 @@ struct clip_vision_model {
|
||||
struct ggml_tensor * mm_model_ln_kv_b;
|
||||
struct ggml_tensor * mm_model_ln_post_w;
|
||||
struct ggml_tensor * mm_model_ln_post_b;
|
||||
|
||||
// gemma3
|
||||
struct ggml_tensor * mm_input_proj_w;
|
||||
struct ggml_tensor * mm_soft_emb_norm_w;
|
||||
};
|
||||
|
||||
struct clip_ctx {
|
||||
@ -569,7 +577,7 @@ struct clip_ctx {
|
||||
struct clip_vision_model vision_model;
|
||||
projector_type proj_type = PROJECTOR_TYPE_MLP;
|
||||
|
||||
int32_t max_feature_layer;
|
||||
int32_t max_feature_layer; // unused in newer models like gemma3
|
||||
float image_mean[3];
|
||||
float image_std[3];
|
||||
bool use_gelu = false;
|
||||
@ -595,7 +603,7 @@ struct clip_ctx {
|
||||
|
||||
ggml_backend_sched_ptr sched;
|
||||
|
||||
struct clip_image_size * load_image_size;
|
||||
struct clip_image_size * load_image_size = nullptr;
|
||||
|
||||
clip_ctx(clip_context_params & ctx_params) {
|
||||
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
@ -631,7 +639,159 @@ struct clip_ctx {
|
||||
}
|
||||
};
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
|
||||
static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
|
||||
const auto & model = ctx->vision_model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int image_size = hparams.image_size;
|
||||
int image_size_width = image_size;
|
||||
int image_size_height = image_size;
|
||||
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||
const int hidden_size = hparams.hidden_size;
|
||||
const int n_head = hparams.n_head;
|
||||
const int d_head = hidden_size / n_head;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const float eps = hparams.eps;
|
||||
|
||||
GGML_ASSERT(imgs->size == 1); // batch_size == 1
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx->buf_compute_meta.size(),
|
||||
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
// input raw
|
||||
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
|
||||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
|
||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
|
||||
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
||||
inp = ggml_add(ctx0, inp, model.patch_bias);
|
||||
|
||||
// position embeddings
|
||||
struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings);
|
||||
|
||||
// loop over layers
|
||||
for (int il = 0; il < n_layer; il++) {
|
||||
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
|
||||
|
||||
// layernorm1
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, eps);
|
||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
|
||||
|
||||
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
|
||||
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
|
||||
|
||||
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
|
||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
|
||||
|
||||
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
|
||||
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
|
||||
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf((float)d_head));
|
||||
KQ = ggml_soft_max_inplace(ctx0, KQ);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
||||
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
|
||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
|
||||
}
|
||||
|
||||
// attention output
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
|
||||
|
||||
// re-add the layer input, e.g., residual
|
||||
cur = ggml_add(ctx0, cur, embeddings);
|
||||
|
||||
embeddings = cur; // embeddings = residual, cur = hidden_states
|
||||
|
||||
// layernorm2
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, eps);
|
||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
|
||||
}
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
|
||||
|
||||
// siglip uses gelu
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
|
||||
|
||||
// residual 2
|
||||
cur = ggml_add(ctx0, embeddings, cur);
|
||||
|
||||
embeddings = cur;
|
||||
}
|
||||
|
||||
// post-layernorm
|
||||
if (ctx->has_post_norm) {
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "post_ln");
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
|
||||
}
|
||||
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
const int batch_size = 1;
|
||||
const int mm_tokens_per_image = 256; // default value for gemma3
|
||||
const int tokens_per_side = sqrt(mm_tokens_per_image);
|
||||
const int patches_per_image = sqrt(num_patches);
|
||||
const int kernel_size = patches_per_image / tokens_per_side;
|
||||
|
||||
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
|
||||
embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size);
|
||||
|
||||
// doing a pool2d to reduce the number of output tokens to 256
|
||||
embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size);
|
||||
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
|
||||
|
||||
// apply norm before projection
|
||||
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
|
||||
embeddings = ggml_mul(ctx0, embeddings, model.mm_soft_emb_norm_w);
|
||||
|
||||
// apply projection
|
||||
embeddings = ggml_mul_mat(ctx0,
|
||||
ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
|
||||
embeddings);
|
||||
}
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
|
||||
if (!ctx->has_vision_encoder) {
|
||||
LOG_ERR("This gguf file seems to have no vision encoder\n");
|
||||
return nullptr;
|
||||
@ -1177,7 +1337,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
} else {
|
||||
GGML_ABORT("fatel error");
|
||||
}
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
|
||||
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||
@ -1199,6 +1360,15 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
return gf;
|
||||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
return clip_image_build_graph_siglip(ctx, imgs);
|
||||
} else {
|
||||
// TODO: we should have one build_* function per model
|
||||
return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
|
||||
}
|
||||
}
|
||||
|
||||
// read and create ggml_context containing the tensors and their data
|
||||
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
return clip_init(fname, clip_context_params{
|
||||
@ -1358,8 +1528,12 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
|
||||
GGML_ASSERT(new_clip->has_vision_encoder);
|
||||
GGML_ASSERT(!new_clip->has_text_encoder);
|
||||
|
||||
idx = get_key_idx(ctx, KEY_USE_GELU);
|
||||
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
|
||||
try {
|
||||
idx = get_key_idx(ctx, KEY_USE_GELU);
|
||||
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
|
||||
} catch (std::runtime_error & /*e*/) {
|
||||
new_clip->use_gelu = false;
|
||||
}
|
||||
|
||||
try {
|
||||
idx = get_key_idx(ctx, KEY_USE_SILU);
|
||||
@ -1567,11 +1741,17 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
|
||||
}
|
||||
|
||||
try {
|
||||
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
|
||||
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
|
||||
} catch(const std::exception& /*e*/) {
|
||||
vision_model.patch_embeddings_0 = nullptr;
|
||||
}
|
||||
|
||||
try {
|
||||
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
|
||||
} catch(const std::exception& /*e*/) {
|
||||
LOG_ERR("%s: failed to load vision model tensors\n", __func__);
|
||||
vision_model.position_embeddings = nullptr;
|
||||
}
|
||||
|
||||
try {
|
||||
vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1);
|
||||
} catch(const std::exception& /*e*/) {
|
||||
@ -1682,6 +1862,10 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
|
||||
vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
|
||||
vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
|
||||
}
|
||||
else if (new_clip->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
vision_model.mm_input_proj_w = get_tensor(new_clip->ctx_data, TN_MM_INP_PROJ);
|
||||
vision_model.mm_soft_emb_norm_w = get_tensor(new_clip->ctx_data, TN_MM_SOFT_EMB_N);
|
||||
}
|
||||
else {
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
@ -2223,7 +2407,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ctx->has_glm_projector) {
|
||||
if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
res_imgs->size = 1;
|
||||
res_imgs->data = new clip_image_f32[res_imgs->size];
|
||||
clip_image_u8 resized_image;
|
||||
@ -2748,6 +2932,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
// do nothing
|
||||
}
|
||||
else {
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
|
||||
@ -2960,6 +3147,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
return ctx->vision_model.mm_1_b->ne[0];
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
return ctx->vision_model.mm_input_proj_w->ne[0];
|
||||
}
|
||||
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
|
Reference in New Issue
Block a user