llama : Add Gemma 3 support (+ experimental vision capability) (#12343)

* llama : Add Gemma 3 text-only support

* fix python coding style

* fix compile on ubuntu

* python: fix style

* fix ubuntu compile

* fix build on ubuntu (again)

* fix ubuntu build, finally

* clip : Experimental support for Gemma 3 vision (#12344)

* clip : Experimental support for Gemma 3 vision

* fix build

* PRId64
This commit is contained in:
Xuan-Son Nguyen
2025-03-12 09:30:24 +01:00
committed by GitHub
parent bf69cfe62f
commit 7841fc723e
11 changed files with 1202 additions and 10 deletions

View File

@ -861,6 +861,9 @@ class Model:
for token_id, token_data in added_tokens_decoder.items():
token_id = int(token_id)
token: str = token_data["content"]
if token_id >= vocab_size:
logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
continue
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
if tokens[token_id] != token.encode("utf-8"):
logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token!r}')
@ -3322,6 +3325,83 @@ class Gemma2Model(Model):
return [(self.map_tensor_name(name), data_torch)]
@Model.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
class Gemma3Model(Model):
model_arch = gguf.MODEL_ARCH.GEMMA3
has_vision: bool = False
# we need to merge the text_config into the root level of hparams
def __init__(self, *args, **kwargs):
hparams = Model.load_hparams(kwargs["dir_model"])
if "text_config" in hparams:
hparams = {**hparams, **hparams["text_config"]}
kwargs["hparams"] = hparams
super().__init__(*args, **kwargs)
if "vision_config" in hparams:
logger.info("Has vision encoder, but it will be ignored")
self.has_vision = True
def write(self):
super().write()
if self.has_vision:
logger.info("NOTE: this script only convert the language model to GGUF")
logger.info(" for the vision model, please use gemma3_convert_encoder_to_gguf.py")
def set_vocab(self):
self._set_vocab_sentencepiece()
self.gguf_writer.add_add_space_prefix(False)
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
# some default values are not specified in the hparams
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
self.gguf_writer.add_key_length(hparams.get("head_dim", 256))
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
# both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
assert hparams.get("attn_logit_softcapping") is None
assert hparams.get("final_logit_softcapping") is None
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
if hparams.get("rope_scaling") is not None:
assert hparams["rope_scaling"]["rope_type"] == "linear"
# important: this rope_scaling is only applied for global layers, and not used by 1B model
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.startswith("language_model."):
name = name.replace("language_model.", "")
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
or name.startswith("multimodal_projector.") or name.startswith("vision_model."): # this is for old HF model, should be removed later
# ignore vision tensors
return []
# remove OOV (out-of-vocabulary) rows in token_embd
if "embed_tokens.weight" in name:
vocab = self._create_vocab_sentencepiece()
tokens = vocab[0]
data_torch = data_torch[:len(tokens)]
# ref code in Gemma3RMSNorm
# output = output * (1.0 + self.weight.float())
if name.endswith("norm.weight"):
data_torch = data_torch + 1
return [(self.map_tensor_name(name), data_torch)]
@Model.register("Starcoder2ForCausalLM")
class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2

View File

@ -51,6 +51,13 @@ install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
set(TARGET llama-gemma3-cli)
add_executable(${TARGET} gemma3-cli.cpp)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
set(TARGET llama-llava-clip-quantize-cli)
add_executable(${TARGET} clip-quantize-cli.cpp)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize-cli)

View File

@ -0,0 +1,30 @@
# Gemma 3 vision
> [!IMPORTANT]
>
> This is very experimental, only used for demo purpose.
## How to get mmproj.gguf?
```bash
cd gemma-3-4b-it
python ../llama.cpp/examples/llava/gemma3_convert_encoder_to_gguf.py .
# output file is mmproj.gguf
```
## How to run it?
What you need:
- The text model GGUF, can be converted using `convert_hf_to_gguf.py`
- The mmproj file from step above
- An image file
```bash
# build
cmake -B build
cmake --build build --target llama-gemma3-cli
# run it
./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
```

View File

@ -136,6 +136,8 @@ static std::string format(const char * fmt, ...) {
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline"
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
#define TN_MINICPMV_QUERY "resampler.query"
@ -162,6 +164,7 @@ enum projector_type {
PROJECTOR_TYPE_RESAMPLER,
PROJECTOR_TYPE_GLM_EDGE,
PROJECTOR_TYPE_MERGER,
PROJECTOR_TYPE_GEMMA3,
PROJECTOR_TYPE_UNKNOWN,
};
@ -172,6 +175,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
};
@ -298,7 +302,7 @@ static projector_type clip_projector_type_from_string(const std::string & name)
return kv.first;
}
}
return PROJECTOR_TYPE_UNKNOWN;
throw std::runtime_error(format("Unknown projector type: %s", name.c_str()));
}
#ifdef CLIP_DEBUG_FUNCTIONS
@ -555,6 +559,10 @@ struct clip_vision_model {
struct ggml_tensor * mm_model_ln_kv_b;
struct ggml_tensor * mm_model_ln_post_w;
struct ggml_tensor * mm_model_ln_post_b;
// gemma3
struct ggml_tensor * mm_input_proj_w;
struct ggml_tensor * mm_soft_emb_norm_w;
};
struct clip_ctx {
@ -569,7 +577,7 @@ struct clip_ctx {
struct clip_vision_model vision_model;
projector_type proj_type = PROJECTOR_TYPE_MLP;
int32_t max_feature_layer;
int32_t max_feature_layer; // unused in newer models like gemma3
float image_mean[3];
float image_std[3];
bool use_gelu = false;
@ -595,7 +603,7 @@ struct clip_ctx {
ggml_backend_sched_ptr sched;
struct clip_image_size * load_image_size;
struct clip_image_size * load_image_size = nullptr;
clip_ctx(clip_context_params & ctx_params) {
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
@ -631,7 +639,159 @@ struct clip_ctx {
}
};
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
const int image_size = hparams.image_size;
int image_size_width = image_size;
int image_size_height = image_size;
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const int n_layer = hparams.n_layer;
const float eps = hparams.eps;
GGML_ASSERT(imgs->size == 1); // batch_size == 1
struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(),
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
// input raw
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
ggml_set_name(inp_raw, "inp_raw");
ggml_set_input(inp_raw);
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
inp = ggml_add(ctx0, inp, model.patch_bias);
// position embeddings
struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings);
// loop over layers
for (int il = 0; il < n_layer; il++) {
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
// layernorm1
{
cur = ggml_norm(ctx0, cur, eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b);
}
// self-attention
{
struct ggml_tensor * Q =
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
struct ggml_tensor * K =
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
struct ggml_tensor * V =
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf((float)d_head));
KQ = ggml_soft_max_inplace(ctx0, KQ);
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
}
// attention output
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
// re-add the layer input, e.g., residual
cur = ggml_add(ctx0, cur, embeddings);
embeddings = cur; // embeddings = residual, cur = hidden_states
// layernorm2
{
cur = ggml_norm(ctx0, cur, eps);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
}
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
// siglip uses gelu
cur = ggml_gelu(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
// residual 2
cur = ggml_add(ctx0, embeddings, cur);
embeddings = cur;
}
// post-layernorm
if (ctx->has_post_norm) {
embeddings = ggml_norm(ctx0, embeddings, eps);
ggml_set_name(embeddings, "post_ln");
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
}
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
const int batch_size = 1;
const int mm_tokens_per_image = 256; // default value for gemma3
const int tokens_per_side = sqrt(mm_tokens_per_image);
const int patches_per_image = sqrt(num_patches);
const int kernel_size = patches_per_image / tokens_per_side;
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size);
// doing a pool2d to reduce the number of output tokens to 256
embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size);
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
// apply norm before projection
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
embeddings = ggml_mul(ctx0, embeddings, model.mm_soft_emb_norm_w);
// apply projection
embeddings = ggml_mul_mat(ctx0,
ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
embeddings);
}
// build the graph
ggml_build_forward_expand(gf, embeddings);
ggml_free(ctx0);
return gf;
}
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
if (!ctx->has_vision_encoder) {
LOG_ERR("This gguf file seems to have no vision encoder\n");
return nullptr;
@ -1177,7 +1337,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
} else {
GGML_ABORT("fatel error");
}
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
}
else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
@ -1199,6 +1360,15 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
return gf;
}
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
return clip_image_build_graph_siglip(ctx, imgs);
} else {
// TODO: we should have one build_* function per model
return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
}
}
// read and create ggml_context containing the tensors and their data
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
return clip_init(fname, clip_context_params{
@ -1358,8 +1528,12 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
GGML_ASSERT(new_clip->has_vision_encoder);
GGML_ASSERT(!new_clip->has_text_encoder);
idx = get_key_idx(ctx, KEY_USE_GELU);
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
try {
idx = get_key_idx(ctx, KEY_USE_GELU);
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
} catch (std::runtime_error & /*e*/) {
new_clip->use_gelu = false;
}
try {
idx = get_key_idx(ctx, KEY_USE_SILU);
@ -1567,11 +1741,17 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
}
try {
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
} catch(const std::exception& /*e*/) {
vision_model.patch_embeddings_0 = nullptr;
}
try {
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
} catch(const std::exception& /*e*/) {
LOG_ERR("%s: failed to load vision model tensors\n", __func__);
vision_model.position_embeddings = nullptr;
}
try {
vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1);
} catch(const std::exception& /*e*/) {
@ -1682,6 +1862,10 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
}
else if (new_clip->proj_type == PROJECTOR_TYPE_GEMMA3) {
vision_model.mm_input_proj_w = get_tensor(new_clip->ctx_data, TN_MM_INP_PROJ);
vision_model.mm_soft_emb_norm_w = get_tensor(new_clip->ctx_data, TN_MM_SOFT_EMB_N);
}
else {
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
@ -2223,7 +2407,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
return true;
}
if (ctx->has_glm_projector) {
if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
res_imgs->size = 1;
res_imgs->data = new clip_image_f32[res_imgs->size];
clip_image_u8 resized_image;
@ -2748,6 +2932,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
}
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
// do nothing
}
else {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
@ -2960,6 +3147,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
return ctx->vision_model.mm_1_b->ne[0];
}
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
return ctx->vision_model.mm_input_proj_w->ne[0];
}
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));

View File

@ -0,0 +1,341 @@
#include "arg.h"
#include "log.h"
#include "common.h"
#include "sampling.h"
#include "clip.h"
#include "stb_image.h"
#include "llama.h"
#include "ggml.h"
#include "console.h"
#include <vector>
#include <limits.h>
#include <inttypes.h>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <signal.h>
#endif
static bool g_is_generating = false;
/**
* Please note that this is NOT a production-ready stuff.
* It is a playground for trying Gemma 3 vision capabilities.
* For contributors: please keep this code simple and easy to understand.
*/
static void show_additional_info(int /*argc*/, char ** argv) {
LOG(
"Experimental CLI for using Gemma 3 vision model\n\n"
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> -p <prompt>\n\n"
" -m and --mmproj are required\n"
" --image and -p are optional, if NOT provided, the CLI will run in chat mode\n",
argv[0]
);
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void sigint_handler(int signo) {
if (signo == SIGINT) {
if (g_is_generating) {
g_is_generating = false;
} else {
console::cleanup();
LOG("\nInterrupted by user\n");
_exit(130);
}
}
}
#endif
struct gemma3_context {
struct clip_ctx * ctx_clip = NULL;
common_init_result llama_init;
llama_model * model;
llama_context * lctx;
const llama_vocab * vocab;
llama_batch batch;
int n_threads = 1;
llama_pos n_past = 0;
gemma3_context(common_params & params) : llama_init(common_init_from_params(params)) {
model = llama_init.model.get();
lctx = llama_init.context.get();
vocab = llama_model_get_vocab(model);
n_threads = params.cpuparams.n_threads;
batch = llama_batch_init(params.n_batch, 0, 1);
init_clip_model(params);
}
void init_clip_model(common_params & params) {
const char * clip_path = params.mmproj.c_str();
ctx_clip = clip_model_load(clip_path, params.verbosity > 1);
}
~gemma3_context() {
clip_free(ctx_clip);
}
};
struct decode_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
common_batch_clear(ctx.batch);
for (llama_token & t : tokens) {
common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
}
if (logits_last) {
ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
}
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
if (llama_decode(ctx.lctx, ctx.batch)) {
LOG_ERR("Failed to decode text\n");
return 1;
}
return 0;
}
static int eval_image(gemma3_context & ctx, std::string & fname) {
std::vector<float> image_embd_v;
int n_embd = llama_model_n_embd(ctx.model);
int n_tokens = 256;
image_embd_v.resize(n_tokens * n_embd);
bool ok;
struct clip_image_u8 * img_u8 = clip_image_u8_init();
ok = clip_image_load_from_file(fname.c_str(), img_u8);
if (!ok) {
LOG_ERR("Unable to load image %s\n", fname.c_str());
clip_image_u8_free(img_u8);
return 2; // non-fatal error
}
clip_image_f32_batch batch_f32;
ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32);
if (!ok) {
LOG_ERR("Unable to preprocess image\n");
clip_image_f32_batch_free(&batch_f32);
clip_image_u8_free(img_u8);
return 1;
}
int64_t t0 = ggml_time_ms();
LOG("Encoding image %s\n", fname.c_str());
ok = clip_image_batch_encode(ctx.ctx_clip, ctx.n_threads, &batch_f32, image_embd_v.data());
if (!ok) {
LOG_ERR("Unable to encode image\n");
clip_image_f32_batch_free(&batch_f32);
clip_image_u8_free(img_u8);
return 1;
}
LOG("Image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
clip_image_f32_batch_free(&batch_f32);
clip_image_u8_free(img_u8);
// decode image embeddings
int64_t t1 = ggml_time_ms();
eval_text(ctx, "<start_of_image>");
llama_set_causal_attn(ctx.lctx, false);
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
if (llama_decode(ctx.lctx, batch_img.batch)) {
LOG_ERR("failed to decode image\n");
return 1;
}
ctx.n_past += n_tokens;
llama_set_causal_attn(ctx.lctx, true);
eval_text(ctx, "<end_of_image>");
LOG("Image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
return 0;
}
static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) {
for (int i = 0; i < n_predict; i++) {
if (i > n_predict || !g_is_generating) {
printf("\n");
break;
}
llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1);
common_sampler_accept(smpl, token_id, true);
if (llama_vocab_is_eog(ctx.vocab, token_id)) {
printf("\n");
break; // end of generation
}
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
fflush(stdout);
// eval the token
common_batch_clear(ctx.batch);
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
if (llama_decode(ctx.lctx, ctx.batch)) {
LOG_ERR("failed to decode token\n");
return 1;
}
}
return 0;
}
int main(int argc, char ** argv) {
ggml_time_init();
common_params params;
params.sampling.temp = 0.2; // lower temp by default for better quality
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
return 1;
}
common_init();
if (params.mmproj.empty()) {
show_additional_info(argc, argv);
return 1;
}
gemma3_context ctx(params);
printf("%s: %s\n", __func__, params.model.c_str());
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling);
int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict;
// ctrl+C handling
{
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = sigint_handler;
sigemptyset (&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
#elif defined (_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
};
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
}
if (eval_text(ctx, "<bos>")) {
return 1;
}
if (is_single_turn) {
g_is_generating = true;
if (eval_text(ctx, "<start_of_turn>user\n")) {
return 1;
}
for (auto & fname : params.image) {
if (eval_image(ctx, fname)) {
return 1;
}
}
if (eval_text(ctx, params.prompt + "<end_of_turn><start_of_turn>model\n", true)) {
return 1;
}
if (generate_response(ctx, smpl, n_predict)) {
return 1;
}
} else {
LOG("\n Running in chat mode, available commands:");
LOG("\n /image <path> load an image");
LOG("\n /clear clear the chat history");
LOG("\n /quit or /exit exit the program");
LOG("\n");
if (eval_text(ctx, "<start_of_turn>user\n")) {
return 1;
}
while (true) {
g_is_generating = false;
LOG("\n> ");
console::set_display(console::user_input);
std::string line;
console::readline(line, false);
console::set_display(console::reset);
line = string_strip(line);
if (line.empty()) {
continue;
}
if (line == "/quit" || line == "/exit") {
break;
}
if (line == "/clear") {
ctx.n_past = 0;
llama_kv_cache_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS
LOG("Chat history cleared\n\n");
continue;
}
g_is_generating = true;
if (line.find("/image") == 0) {
std::string image = line.substr(7);
int res = eval_image(ctx, image);
if (res == 2) {
continue; // image not found
}
if (res) {
return 1;
}
continue;
}
if (eval_text(ctx, line + "<end_of_turn><start_of_turn>model\n", true)) {
return 1;
}
if (generate_response(ctx, smpl, n_predict)) {
return 1;
}
if (eval_text(ctx, "<end_of_turn><start_of_turn>user\n")) {
return 1;
}
}
}
return 0;
}

View File

@ -0,0 +1,307 @@
import gguf
import argparse
import logging
import sys
import torch
import json
import os
import numpy as np
from typing import cast, ContextManager, Any, Iterator
from pathlib import Path
from torch import Tensor
logger = logging.getLogger("gemma3-mmproj")
# (copied from convert_hf_to_gguf.py)
# tree of lazy tensors
class LazyTorchTensor(gguf.LazyBase):
_tensor_type = torch.Tensor
# to keep the type-checker happy
dtype: torch.dtype
shape: torch.Size
# only used when converting a torch.Tensor to a np.ndarray
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
}
# used for safetensors slices
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
_dtype_str_map: dict[str, torch.dtype] = {
"F64": torch.float64,
"F32": torch.float32,
"BF16": torch.bfloat16,
"F16": torch.float16,
# "U64": torch.uint64,
"I64": torch.int64,
# "U32": torch.uint32,
"I32": torch.int32,
# "U16": torch.uint16,
"I16": torch.int16,
"U8": torch.uint8,
"I8": torch.int8,
"BOOL": torch.bool,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}
def numpy(self) -> gguf.LazyNumpyTensor:
dtype = self._dtype_map[self.dtype]
return gguf.LazyNumpyTensor(
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
args=(self,),
func=(lambda s: s.numpy())
)
@classmethod
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
return torch.empty(size=shape, dtype=dtype, device="meta")
@classmethod
def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
dtype = cls._dtype_str_map[st_slice.get_dtype()]
shape: tuple[int, ...] = tuple(st_slice.get_shape())
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
return cast(torch.Tensor, lazy)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
del types # unused
if kwargs is None:
kwargs = {}
if func is torch.Tensor.numpy:
return args[0].numpy()
return cls._wrap_fn(func)(*args, **kwargs)
class Gemma3VisionTower:
hparams: dict
gguf_writer: gguf.GGUFWriter
fname_out: Path
ftype: gguf.LlamaFileType
@staticmethod
def load_hparams(dir_model: Path):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
part_names: list[str] = []
for filename in os.listdir(dir_model):
if filename.startswith(prefix) and filename.endswith(suffix):
part_names.append(filename)
part_names.sort()
return part_names
def __init__(self,
dir_model: Path,
fname_out: Path,
ftype: gguf.LlamaFileType,
is_big_endian: bool,):
hparams = Gemma3VisionTower.load_hparams(dir_model)
self.hparams = hparams
self.fname_out = fname_out
self.ftype = ftype
endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
self.gguf_writer = gguf.GGUFWriter(path=None, arch="clip", endianess=endianess)
text_config = hparams["text_config"]
vision_config = hparams["vision_config"]
assert hparams["architectures"][0] == "Gemma3ForConditionalGeneration"
assert text_config is not None
assert vision_config is not None
self.gguf_writer.add_string ("clip.projector_type", "gemma3")
self.gguf_writer.add_bool ("clip.has_text_encoder", False)
self.gguf_writer.add_bool ("clip.has_vision_encoder", True)
self.gguf_writer.add_bool ("clip.has_llava_projector", False) # legacy
self.gguf_writer.add_uint32 ("clip.vision.image_size", vision_config["image_size"])
self.gguf_writer.add_uint32 ("clip.vision.patch_size", vision_config["patch_size"])
self.gguf_writer.add_uint32 ("clip.vision.embedding_length", vision_config["hidden_size"])
self.gguf_writer.add_uint32 ("clip.vision.feed_forward_length", vision_config["intermediate_size"])
self.gguf_writer.add_uint32 ("clip.vision.projection_dim", text_config["hidden_size"])
self.gguf_writer.add_uint32 ("clip.vision.block_count", vision_config["num_hidden_layers"])
self.gguf_writer.add_uint32 ("clip.vision.attention.head_count", vision_config["num_attention_heads"])
self.gguf_writer.add_float32("clip.vision.attention.layer_norm_epsilon", vision_config.get("layer_norm_eps", 1e-6))
# default values taken from HF tranformers code
self.gguf_writer.add_array ("clip.vision.image_mean", [0.5, 0.5, 0.5])
self.gguf_writer.add_array ("clip.vision.image_std", [0.5, 0.5, 0.5])
self.gguf_writer.add_bool ("clip.use_gelu", True)
# load tensors
for name, data_torch in self.get_tensors(dir_model):
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
self.add_tensor(name, data_torch)
def get_tensors(self, dir_model: Path) -> Iterator[tuple[str, Tensor]]:
part_names = Gemma3VisionTower.get_model_part_names(dir_model, "model", ".safetensors")
tensor_names_from_parts: set[str] = set()
for part_name in part_names:
logger.info(f"gguf: loading model part '{part_name}'")
from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(dir_model / part_name, framework="pt", device="cpu"))
with ctx as model_part:
tensor_names_from_parts.update(model_part.keys())
for name in model_part.keys():
data = model_part.get_slice(name)
data = LazyTorchTensor.from_safetensors_slice(data)
yield name, data
def add_tensor(self, name: str, data_torch: Tensor):
is_1d = len(data_torch.shape) == 1
is_embd = ".embeddings." in name
old_dtype = data_torch.dtype
can_quantize = not is_1d and not is_embd
data_qtype = gguf.GGMLQuantizationType.F32
# this is to support old checkpoint
# TODO: remove this when we have the final model
name = name.replace("vision_model.vision_model.", "vision_tower.vision_model.")
name = name.replace("multimodal_projector.", "multi_modal_projector.")
# filter only vision tensors
if not name.startswith("vision_tower.vision_model.") and not name.startswith("multi_modal_projector."):
return
# prefix
name = name.replace("vision_tower.vision_model.encoder.layers.", "v.blk.")
name = name.replace("vision_tower.vision_model.", "v.")
# projector and input embd
name = name.replace(".embeddings.patch_embedding.", ".patch_embd.")
name = name.replace(".embeddings.position_embedding.", ".position_embd.")
name = name.replace(
"multi_modal_projector.mm_input_projection_weight",
"mm.input_projection.weight"
)
name = name.replace(
"multi_modal_projector.mm_soft_emb_norm.weight",
"mm.soft_emb_norm.weight"
)
name = name.replace("post_layernorm.", "post_ln.")
# each block
name = name.replace(".self_attn.k_proj.", ".attn_k.")
name = name.replace(".self_attn.v_proj.", ".attn_v.")
name = name.replace(".self_attn.q_proj.", ".attn_q.")
name = name.replace(".self_attn.out_proj.", ".attn_out.")
name = name.replace(".layer_norm1.", ".ln1.")
name = name.replace(".layer_norm2.", ".ln2.")
name = name.replace(".mlp.fc1.", ".ffn_down.")
name = name.replace(".mlp.fc2.", ".ffn_up.")
if can_quantize:
if self.ftype == gguf.LlamaFileType.ALL_F32:
data_qtype = gguf.GGMLQuantizationType.F32
elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
data_qtype = gguf.GGMLQuantizationType.F16
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
data_qtype = gguf.GGMLQuantizationType.BF16
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
data_qtype = gguf.GGMLQuantizationType.Q8_0
else:
raise ValueError(f"Unsupported file type: {self.ftype}")
# corrent norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
# the other norm values are part of SigLIP model, and they are already correct
# ref code: Gemma3RMSNorm
if "soft_emb_norm.weight" in name:
logger.info(f"Correcting norm value for '{name}'")
data_torch = data_torch + 1
data = data_torch.numpy()
try:
data = gguf.quants.quantize(data, data_qtype)
except Exception as e:
logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16")
data_qtype = gguf.GGMLQuantizationType.F16
data = gguf.quants.quantize(data, data_qtype)
# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}"
logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype)
def write(self):
self.gguf_writer.write_header_to_file(path=self.fname_out)
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file(progress=True)
self.gguf_writer.close()
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert Gemma 3 vision tower safetensors to GGUF format",)
parser.add_argument(
"--outfile", type=Path, default="mmproj.gguf",
help="path to write to",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
help="output format",
)
parser.add_argument(
"--bigendian", action="store_true",
help="model is executed on big endian machine",
)
parser.add_argument(
"model", type=Path,
help="directory containing model file",
nargs="?",
)
parser.add_argument(
"--verbose", action="store_true",
help="increase output verbosity",
)
args = parser.parse_args()
if args.model is None:
parser.error("the following arguments are required: model")
return args
def main() -> None:
args = parse_args()
if args.verbose:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
dir_model = args.model
if not dir_model.is_dir():
logger.error(f'Error: {args.model} is not a directory')
sys.exit(1)
ftype_map: dict[str, gguf.LlamaFileType] = {
"f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
}
logger.info(f"Loading model: {dir_model.name}")
with torch.inference_mode():
gemma3_vision_tower = Gemma3VisionTower(
dir_model=dir_model,
fname_out=args.outfile,
ftype=ftype_map[args.outtype],
is_big_endian=args.bigendian,
)
gemma3_vision_tower.write()
if __name__ == '__main__':
main()

View File

@ -253,6 +253,7 @@ class MODEL_ARCH(IntEnum):
MINICPM3 = auto()
GEMMA = auto()
GEMMA2 = auto()
GEMMA3 = auto()
STARCODER2 = auto()
RWKV6 = auto()
RWKV6QWEN2 = auto()
@ -440,6 +441,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.MINICPM3: "minicpm3",
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
@ -1077,6 +1079,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.GEMMA3: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.STARCODER2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View File

@ -36,6 +36,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MINICPM3, "minicpm3" },
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@ -766,6 +767,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_GEMMA3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_STARCODER2,
{

View File

@ -40,6 +40,7 @@ enum llm_arch {
LLM_ARCH_MINICPM3,
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,

View File

@ -9,6 +9,7 @@
#include <algorithm>
#include <cassert>
#include <cstring>
#include <cmath>
#include <functional>
#include <map>
#include <sstream>
@ -864,6 +865,23 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_GEMMA3:
{
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 26: type = LLM_TYPE_1B; break;
case 34: type = LLM_TYPE_4B; break;
case 48: type = LLM_TYPE_12B; break;
case 62: type = LLM_TYPE_27B; break;
default: type = LLM_TYPE_UNKNOWN;
}
hparams.f_attention_scale = type == LLM_TYPE_27B
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
} break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@ -2454,6 +2472,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_GEMMA3:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
@ -3650,6 +3697,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale);
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
@ -3923,6 +3971,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
case LLM_ARCH_PHIMOE:
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
case LLM_ARCH_GEMMA3:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:

View File

@ -4978,6 +4978,149 @@ struct llm_build_context {
return gf;
}
struct ggml_cgraph * build_gemma3() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
const int64_t n_embd_head_k = hparams.n_embd_head_k;
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
if (ubatch.token) {
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
cb(inpL, "inp_scaled", -1);
}
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
// gemma3 requires different mask for layers using sliding window (SWA)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
// "5-to-1 interleaved attention"
// 5 layers of local attention followed by 1 layer of global attention
static const int sliding_window_pattern = 6;
for (int il = 0; il < n_layer; ++il) {
const bool is_sliding = (il + 1) % sliding_window_pattern;
const float freq_base_l = is_sliding ? 10000.0f : freq_base;
const float freq_scale_l = is_sliding ? 1.0f : freq_scale;
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens);
Qcur = llm_build_norm(ctx0, Qcur, hparams,
model.layers[il].attn_q_norm,
NULL,
LLM_NORM_RMS, cb, il);
cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens);
Kcur = llm_build_norm(ctx0, Kcur, hparams,
model.layers[il].attn_k_norm,
NULL,
LLM_NORM_RMS, cb, il);
cb(Kcur, "Kcur_normed", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
}
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].attn_post_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_post_norm", il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
cb(sa_out, "sa_out", il);
cur = llm_build_norm(ctx0, sa_out, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
// feed-forward network
{
cur = llm_build_ffn(ctx0, lctx, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
}
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].ffn_post_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "ffn_post_norm", -1);
cur = ggml_add(ctx0, cur, sa_out);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
struct ggml_cgraph * build_starcoder2() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
@ -8298,6 +8441,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_gemma2();
} break;
case LLM_ARCH_GEMMA3:
{
result = llm.build_gemma3();
} break;
case LLM_ARCH_STARCODER2:
{
result = llm.build_starcoder2();