mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 12:05:03 +00:00
llama : Add Gemma 3 support (+ experimental vision capability) (#12343)
* llama : Add Gemma 3 text-only support * fix python coding style * fix compile on ubuntu * python: fix style * fix ubuntu compile * fix build on ubuntu (again) * fix ubuntu build, finally * clip : Experimental support for Gemma 3 vision (#12344) * clip : Experimental support for Gemma 3 vision * fix build * PRId64
This commit is contained in:
@ -861,6 +861,9 @@ class Model:
|
||||
for token_id, token_data in added_tokens_decoder.items():
|
||||
token_id = int(token_id)
|
||||
token: str = token_data["content"]
|
||||
if token_id >= vocab_size:
|
||||
logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
|
||||
continue
|
||||
if toktypes[token_id] != SentencePieceTokenTypes.UNUSED:
|
||||
if tokens[token_id] != token.encode("utf-8"):
|
||||
logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token!r}')
|
||||
@ -3322,6 +3325,83 @@ class Gemma2Model(Model):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
|
||||
class Gemma3Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA3
|
||||
has_vision: bool = False
|
||||
|
||||
# we need to merge the text_config into the root level of hparams
|
||||
def __init__(self, *args, **kwargs):
|
||||
hparams = Model.load_hparams(kwargs["dir_model"])
|
||||
if "text_config" in hparams:
|
||||
hparams = {**hparams, **hparams["text_config"]}
|
||||
kwargs["hparams"] = hparams
|
||||
super().__init__(*args, **kwargs)
|
||||
if "vision_config" in hparams:
|
||||
logger.info("Has vision encoder, but it will be ignored")
|
||||
self.has_vision = True
|
||||
|
||||
def write(self):
|
||||
super().write()
|
||||
if self.has_vision:
|
||||
logger.info("NOTE: this script only convert the language model to GGUF")
|
||||
logger.info(" for the vision model, please use gemma3_convert_encoder_to_gguf.py")
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
block_count = hparams["num_hidden_layers"]
|
||||
|
||||
# some default values are not specified in the hparams
|
||||
self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072))
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 8))
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-6))
|
||||
self.gguf_writer.add_key_length(hparams.get("head_dim", 256))
|
||||
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
|
||||
# both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
|
||||
assert hparams.get("attn_logit_softcapping") is None
|
||||
assert hparams.get("final_logit_softcapping") is None
|
||||
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
|
||||
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
|
||||
if hparams.get("rope_scaling") is not None:
|
||||
assert hparams["rope_scaling"]["rope_type"] == "linear"
|
||||
# important: this rope_scaling is only applied for global layers, and not used by 1B model
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
if name.startswith("language_model."):
|
||||
name = name.replace("language_model.", "")
|
||||
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
|
||||
or name.startswith("multimodal_projector.") or name.startswith("vision_model."): # this is for old HF model, should be removed later
|
||||
# ignore vision tensors
|
||||
return []
|
||||
|
||||
# remove OOV (out-of-vocabulary) rows in token_embd
|
||||
if "embed_tokens.weight" in name:
|
||||
vocab = self._create_vocab_sentencepiece()
|
||||
tokens = vocab[0]
|
||||
data_torch = data_torch[:len(tokens)]
|
||||
|
||||
# ref code in Gemma3RMSNorm
|
||||
# output = output * (1.0 + self.weight.float())
|
||||
if name.endswith("norm.weight"):
|
||||
data_torch = data_torch + 1
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("Starcoder2ForCausalLM")
|
||||
class StarCoder2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.STARCODER2
|
||||
|
@ -51,6 +51,13 @@ install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
set(TARGET llama-gemma3-cli)
|
||||
add_executable(${TARGET} gemma3-cli.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
set(TARGET llama-llava-clip-quantize-cli)
|
||||
add_executable(${TARGET} clip-quantize-cli.cpp)
|
||||
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize-cli)
|
||||
|
30
examples/llava/README-gemma3.md
Normal file
30
examples/llava/README-gemma3.md
Normal file
@ -0,0 +1,30 @@
|
||||
# Gemma 3 vision
|
||||
|
||||
> [!IMPORTANT]
|
||||
>
|
||||
> This is very experimental, only used for demo purpose.
|
||||
|
||||
## How to get mmproj.gguf?
|
||||
|
||||
```bash
|
||||
cd gemma-3-4b-it
|
||||
python ../llama.cpp/examples/llava/gemma3_convert_encoder_to_gguf.py .
|
||||
|
||||
# output file is mmproj.gguf
|
||||
```
|
||||
|
||||
## How to run it?
|
||||
|
||||
What you need:
|
||||
- The text model GGUF, can be converted using `convert_hf_to_gguf.py`
|
||||
- The mmproj file from step above
|
||||
- An image file
|
||||
|
||||
```bash
|
||||
# build
|
||||
cmake -B build
|
||||
cmake --build build --target llama-gemma3-cli
|
||||
|
||||
# run it
|
||||
./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
|
||||
```
|
@ -136,6 +136,8 @@ static std::string format(const char * fmt, ...) {
|
||||
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
|
||||
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
|
||||
#define TN_IMAGE_NEWLINE "model.image_newline"
|
||||
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
|
||||
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
|
||||
|
||||
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
|
||||
#define TN_MINICPMV_QUERY "resampler.query"
|
||||
@ -162,6 +164,7 @@ enum projector_type {
|
||||
PROJECTOR_TYPE_RESAMPLER,
|
||||
PROJECTOR_TYPE_GLM_EDGE,
|
||||
PROJECTOR_TYPE_MERGER,
|
||||
PROJECTOR_TYPE_GEMMA3,
|
||||
PROJECTOR_TYPE_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -172,6 +175,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
|
||||
{ PROJECTOR_TYPE_RESAMPLER, "resampler"},
|
||||
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
|
||||
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
|
||||
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
|
||||
};
|
||||
|
||||
|
||||
@ -298,7 +302,7 @@ static projector_type clip_projector_type_from_string(const std::string & name)
|
||||
return kv.first;
|
||||
}
|
||||
}
|
||||
return PROJECTOR_TYPE_UNKNOWN;
|
||||
throw std::runtime_error(format("Unknown projector type: %s", name.c_str()));
|
||||
}
|
||||
|
||||
#ifdef CLIP_DEBUG_FUNCTIONS
|
||||
@ -555,6 +559,10 @@ struct clip_vision_model {
|
||||
struct ggml_tensor * mm_model_ln_kv_b;
|
||||
struct ggml_tensor * mm_model_ln_post_w;
|
||||
struct ggml_tensor * mm_model_ln_post_b;
|
||||
|
||||
// gemma3
|
||||
struct ggml_tensor * mm_input_proj_w;
|
||||
struct ggml_tensor * mm_soft_emb_norm_w;
|
||||
};
|
||||
|
||||
struct clip_ctx {
|
||||
@ -569,7 +577,7 @@ struct clip_ctx {
|
||||
struct clip_vision_model vision_model;
|
||||
projector_type proj_type = PROJECTOR_TYPE_MLP;
|
||||
|
||||
int32_t max_feature_layer;
|
||||
int32_t max_feature_layer; // unused in newer models like gemma3
|
||||
float image_mean[3];
|
||||
float image_std[3];
|
||||
bool use_gelu = false;
|
||||
@ -595,7 +603,7 @@ struct clip_ctx {
|
||||
|
||||
ggml_backend_sched_ptr sched;
|
||||
|
||||
struct clip_image_size * load_image_size;
|
||||
struct clip_image_size * load_image_size = nullptr;
|
||||
|
||||
clip_ctx(clip_context_params & ctx_params) {
|
||||
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
|
||||
@ -631,7 +639,159 @@ struct clip_ctx {
|
||||
}
|
||||
};
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
|
||||
static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch * imgs) {
|
||||
const auto & model = ctx->vision_model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int image_size = hparams.image_size;
|
||||
int image_size_width = image_size;
|
||||
int image_size_height = image_size;
|
||||
|
||||
const int patch_size = hparams.patch_size;
|
||||
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
|
||||
const int hidden_size = hparams.hidden_size;
|
||||
const int n_head = hparams.n_head;
|
||||
const int d_head = hidden_size / n_head;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const float eps = hparams.eps;
|
||||
|
||||
GGML_ASSERT(imgs->size == 1); // batch_size == 1
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ ctx->buf_compute_meta.size(),
|
||||
/*.mem_buffer =*/ ctx->buf_compute_meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
// input raw
|
||||
struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
|
||||
ggml_set_name(inp_raw, "inp_raw");
|
||||
ggml_set_input(inp_raw);
|
||||
|
||||
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
|
||||
inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
|
||||
inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
|
||||
inp = ggml_add(ctx0, inp, model.patch_bias);
|
||||
|
||||
// position embeddings
|
||||
struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings);
|
||||
|
||||
// loop over layers
|
||||
for (int il = 0; il < n_layer; il++) {
|
||||
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
|
||||
|
||||
// layernorm1
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, eps);
|
||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b);
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
|
||||
struct ggml_tensor * Q =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
|
||||
|
||||
Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
|
||||
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
|
||||
|
||||
K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
|
||||
K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
|
||||
|
||||
V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
|
||||
V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
|
||||
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf((float)d_head));
|
||||
KQ = ggml_soft_max_inplace(ctx0, KQ);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
|
||||
KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
|
||||
KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
|
||||
}
|
||||
|
||||
// attention output
|
||||
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
|
||||
|
||||
// re-add the layer input, e.g., residual
|
||||
cur = ggml_add(ctx0, cur, embeddings);
|
||||
|
||||
embeddings = cur; // embeddings = residual, cur = hidden_states
|
||||
|
||||
// layernorm2
|
||||
{
|
||||
cur = ggml_norm(ctx0, cur, eps);
|
||||
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
|
||||
}
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b);
|
||||
|
||||
// siglip uses gelu
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
|
||||
cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur);
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b);
|
||||
|
||||
// residual 2
|
||||
cur = ggml_add(ctx0, embeddings, cur);
|
||||
|
||||
embeddings = cur;
|
||||
}
|
||||
|
||||
// post-layernorm
|
||||
if (ctx->has_post_norm) {
|
||||
embeddings = ggml_norm(ctx0, embeddings, eps);
|
||||
ggml_set_name(embeddings, "post_ln");
|
||||
|
||||
embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
|
||||
}
|
||||
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
const int batch_size = 1;
|
||||
const int mm_tokens_per_image = 256; // default value for gemma3
|
||||
const int tokens_per_side = sqrt(mm_tokens_per_image);
|
||||
const int patches_per_image = sqrt(num_patches);
|
||||
const int kernel_size = patches_per_image / tokens_per_side;
|
||||
|
||||
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
|
||||
embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size);
|
||||
|
||||
// doing a pool2d to reduce the number of output tokens to 256
|
||||
embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size);
|
||||
embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
|
||||
|
||||
// apply norm before projection
|
||||
embeddings = ggml_rms_norm(ctx0, embeddings, eps);
|
||||
embeddings = ggml_mul(ctx0, embeddings, model.mm_soft_emb_norm_w);
|
||||
|
||||
// apply projection
|
||||
embeddings = ggml_mul_mat(ctx0,
|
||||
ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
|
||||
embeddings);
|
||||
}
|
||||
|
||||
// build the graph
|
||||
ggml_build_forward_expand(gf, embeddings);
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
|
||||
if (!ctx->has_vision_encoder) {
|
||||
LOG_ERR("This gguf file seems to have no vision encoder\n");
|
||||
return nullptr;
|
||||
@ -1177,7 +1337,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
} else {
|
||||
GGML_ABORT("fatel error");
|
||||
}
|
||||
} else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
|
||||
|
||||
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
|
||||
@ -1199,6 +1360,15 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
|
||||
return gf;
|
||||
}
|
||||
|
||||
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) {
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
return clip_image_build_graph_siglip(ctx, imgs);
|
||||
} else {
|
||||
// TODO: we should have one build_* function per model
|
||||
return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
|
||||
}
|
||||
}
|
||||
|
||||
// read and create ggml_context containing the tensors and their data
|
||||
struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
||||
return clip_init(fname, clip_context_params{
|
||||
@ -1358,8 +1528,12 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
|
||||
GGML_ASSERT(new_clip->has_vision_encoder);
|
||||
GGML_ASSERT(!new_clip->has_text_encoder);
|
||||
|
||||
idx = get_key_idx(ctx, KEY_USE_GELU);
|
||||
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
|
||||
try {
|
||||
idx = get_key_idx(ctx, KEY_USE_GELU);
|
||||
new_clip->use_gelu = gguf_get_val_bool(ctx, idx);
|
||||
} catch (std::runtime_error & /*e*/) {
|
||||
new_clip->use_gelu = false;
|
||||
}
|
||||
|
||||
try {
|
||||
idx = get_key_idx(ctx, KEY_USE_SILU);
|
||||
@ -1567,11 +1741,17 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
|
||||
}
|
||||
|
||||
try {
|
||||
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
|
||||
vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD);
|
||||
} catch(const std::exception& /*e*/) {
|
||||
vision_model.patch_embeddings_0 = nullptr;
|
||||
}
|
||||
|
||||
try {
|
||||
vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v"));
|
||||
} catch(const std::exception& /*e*/) {
|
||||
LOG_ERR("%s: failed to load vision model tensors\n", __func__);
|
||||
vision_model.position_embeddings = nullptr;
|
||||
}
|
||||
|
||||
try {
|
||||
vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1);
|
||||
} catch(const std::exception& /*e*/) {
|
||||
@ -1682,6 +1862,10 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
|
||||
vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight"));
|
||||
vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias"));
|
||||
}
|
||||
else if (new_clip->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
vision_model.mm_input_proj_w = get_tensor(new_clip->ctx_data, TN_MM_INP_PROJ);
|
||||
vision_model.mm_soft_emb_norm_w = get_tensor(new_clip->ctx_data, TN_MM_SOFT_EMB_N);
|
||||
}
|
||||
else {
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
@ -2223,7 +2407,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli
|
||||
return true;
|
||||
}
|
||||
|
||||
if (ctx->has_glm_projector) {
|
||||
if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
res_imgs->size = 1;
|
||||
res_imgs->data = new clip_image_f32[res_imgs->size];
|
||||
clip_image_u8 resized_image;
|
||||
@ -2748,6 +2932,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
||||
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
|
||||
free(positions_data);
|
||||
}
|
||||
else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
// do nothing
|
||||
}
|
||||
else {
|
||||
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
|
||||
|
||||
@ -2960,6 +3147,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
|
||||
return ctx->vision_model.mm_1_b->ne[0];
|
||||
}
|
||||
if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
|
||||
return ctx->vision_model.mm_input_proj_w->ne[0];
|
||||
}
|
||||
|
||||
std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
|
||||
throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
|
||||
|
341
examples/llava/gemma3-cli.cpp
Normal file
341
examples/llava/gemma3-cli.cpp
Normal file
@ -0,0 +1,341 @@
|
||||
#include "arg.h"
|
||||
#include "log.h"
|
||||
#include "common.h"
|
||||
#include "sampling.h"
|
||||
#include "clip.h"
|
||||
#include "stb_image.h"
|
||||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
#include "console.h"
|
||||
|
||||
#include <vector>
|
||||
#include <limits.h>
|
||||
#include <inttypes.h>
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
#include <signal.h>
|
||||
#include <unistd.h>
|
||||
#elif defined (_WIN32)
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#include <signal.h>
|
||||
#endif
|
||||
|
||||
static bool g_is_generating = false;
|
||||
|
||||
/**
|
||||
* Please note that this is NOT a production-ready stuff.
|
||||
* It is a playground for trying Gemma 3 vision capabilities.
|
||||
* For contributors: please keep this code simple and easy to understand.
|
||||
*/
|
||||
|
||||
static void show_additional_info(int /*argc*/, char ** argv) {
|
||||
LOG(
|
||||
"Experimental CLI for using Gemma 3 vision model\n\n"
|
||||
"Usage: %s [options] -m <model> --mmproj <mmproj> --image <image> -p <prompt>\n\n"
|
||||
" -m and --mmproj are required\n"
|
||||
" --image and -p are optional, if NOT provided, the CLI will run in chat mode\n",
|
||||
argv[0]
|
||||
);
|
||||
}
|
||||
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
|
||||
static void sigint_handler(int signo) {
|
||||
if (signo == SIGINT) {
|
||||
if (g_is_generating) {
|
||||
g_is_generating = false;
|
||||
} else {
|
||||
console::cleanup();
|
||||
LOG("\nInterrupted by user\n");
|
||||
_exit(130);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
struct gemma3_context {
|
||||
struct clip_ctx * ctx_clip = NULL;
|
||||
common_init_result llama_init;
|
||||
|
||||
llama_model * model;
|
||||
llama_context * lctx;
|
||||
const llama_vocab * vocab;
|
||||
llama_batch batch;
|
||||
|
||||
int n_threads = 1;
|
||||
llama_pos n_past = 0;
|
||||
|
||||
gemma3_context(common_params & params) : llama_init(common_init_from_params(params)) {
|
||||
model = llama_init.model.get();
|
||||
lctx = llama_init.context.get();
|
||||
vocab = llama_model_get_vocab(model);
|
||||
n_threads = params.cpuparams.n_threads;
|
||||
batch = llama_batch_init(params.n_batch, 0, 1);
|
||||
init_clip_model(params);
|
||||
}
|
||||
|
||||
void init_clip_model(common_params & params) {
|
||||
const char * clip_path = params.mmproj.c_str();
|
||||
ctx_clip = clip_model_load(clip_path, params.verbosity > 1);
|
||||
}
|
||||
|
||||
~gemma3_context() {
|
||||
clip_free(ctx_clip);
|
||||
}
|
||||
};
|
||||
|
||||
struct decode_embd_batch {
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id> seq_id_0;
|
||||
std::vector<llama_seq_id *> seq_ids;
|
||||
std::vector<int8_t> logits;
|
||||
llama_batch batch;
|
||||
decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
|
||||
pos .resize(n_tokens);
|
||||
n_seq_id.resize(n_tokens);
|
||||
seq_ids .resize(n_tokens + 1);
|
||||
logits .resize(n_tokens);
|
||||
seq_id_0.resize(1);
|
||||
seq_id_0[0] = seq_id;
|
||||
seq_ids [n_tokens] = nullptr;
|
||||
batch = {
|
||||
/*n_tokens =*/ n_tokens,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ embd,
|
||||
/*pos =*/ pos.data(),
|
||||
/*n_seq_id =*/ n_seq_id.data(),
|
||||
/*seq_id =*/ seq_ids.data(),
|
||||
/*logits =*/ logits.data(),
|
||||
};
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
batch.pos [i] = pos_0 + i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id [i] = seq_id_0.data();
|
||||
batch.logits [i] = false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
|
||||
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
|
||||
common_batch_clear(ctx.batch);
|
||||
for (llama_token & t : tokens) {
|
||||
common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false);
|
||||
}
|
||||
if (logits_last) {
|
||||
ctx.batch.logits[ctx.batch.n_tokens - 1] = true;
|
||||
}
|
||||
// LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str());
|
||||
if (llama_decode(ctx.lctx, ctx.batch)) {
|
||||
LOG_ERR("Failed to decode text\n");
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int eval_image(gemma3_context & ctx, std::string & fname) {
|
||||
std::vector<float> image_embd_v;
|
||||
int n_embd = llama_model_n_embd(ctx.model);
|
||||
int n_tokens = 256;
|
||||
image_embd_v.resize(n_tokens * n_embd);
|
||||
|
||||
bool ok;
|
||||
struct clip_image_u8 * img_u8 = clip_image_u8_init();
|
||||
ok = clip_image_load_from_file(fname.c_str(), img_u8);
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to load image %s\n", fname.c_str());
|
||||
clip_image_u8_free(img_u8);
|
||||
return 2; // non-fatal error
|
||||
}
|
||||
|
||||
clip_image_f32_batch batch_f32;
|
||||
ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32);
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to preprocess image\n");
|
||||
clip_image_f32_batch_free(&batch_f32);
|
||||
clip_image_u8_free(img_u8);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
LOG("Encoding image %s\n", fname.c_str());
|
||||
ok = clip_image_batch_encode(ctx.ctx_clip, ctx.n_threads, &batch_f32, image_embd_v.data());
|
||||
if (!ok) {
|
||||
LOG_ERR("Unable to encode image\n");
|
||||
clip_image_f32_batch_free(&batch_f32);
|
||||
clip_image_u8_free(img_u8);
|
||||
return 1;
|
||||
}
|
||||
LOG("Image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
|
||||
|
||||
clip_image_f32_batch_free(&batch_f32);
|
||||
clip_image_u8_free(img_u8);
|
||||
|
||||
// decode image embeddings
|
||||
int64_t t1 = ggml_time_ms();
|
||||
eval_text(ctx, "<start_of_image>");
|
||||
llama_set_causal_attn(ctx.lctx, false);
|
||||
decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0);
|
||||
if (llama_decode(ctx.lctx, batch_img.batch)) {
|
||||
LOG_ERR("failed to decode image\n");
|
||||
return 1;
|
||||
}
|
||||
ctx.n_past += n_tokens;
|
||||
llama_set_causal_attn(ctx.lctx, true);
|
||||
eval_text(ctx, "<end_of_image>");
|
||||
LOG("Image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) {
|
||||
for (int i = 0; i < n_predict; i++) {
|
||||
if (i > n_predict || !g_is_generating) {
|
||||
printf("\n");
|
||||
break;
|
||||
}
|
||||
|
||||
llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1);
|
||||
common_sampler_accept(smpl, token_id, true);
|
||||
|
||||
if (llama_vocab_is_eog(ctx.vocab, token_id)) {
|
||||
printf("\n");
|
||||
break; // end of generation
|
||||
}
|
||||
|
||||
printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
|
||||
fflush(stdout);
|
||||
|
||||
// eval the token
|
||||
common_batch_clear(ctx.batch);
|
||||
common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
|
||||
if (llama_decode(ctx.lctx, ctx.batch)) {
|
||||
LOG_ERR("failed to decode token\n");
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_time_init();
|
||||
|
||||
common_params params;
|
||||
params.sampling.temp = 0.2; // lower temp by default for better quality
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
if (params.mmproj.empty()) {
|
||||
show_additional_info(argc, argv);
|
||||
return 1;
|
||||
}
|
||||
|
||||
gemma3_context ctx(params);
|
||||
printf("%s: %s\n", __func__, params.model.c_str());
|
||||
|
||||
bool is_single_turn = !params.prompt.empty() && !params.image.empty();
|
||||
|
||||
struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling);
|
||||
int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict;
|
||||
|
||||
// ctrl+C handling
|
||||
{
|
||||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
||||
struct sigaction sigint_action;
|
||||
sigint_action.sa_handler = sigint_handler;
|
||||
sigemptyset (&sigint_action.sa_mask);
|
||||
sigint_action.sa_flags = 0;
|
||||
sigaction(SIGINT, &sigint_action, NULL);
|
||||
#elif defined (_WIN32)
|
||||
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
|
||||
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
|
||||
};
|
||||
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (eval_text(ctx, "<bos>")) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (is_single_turn) {
|
||||
g_is_generating = true;
|
||||
if (eval_text(ctx, "<start_of_turn>user\n")) {
|
||||
return 1;
|
||||
}
|
||||
for (auto & fname : params.image) {
|
||||
if (eval_image(ctx, fname)) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
if (eval_text(ctx, params.prompt + "<end_of_turn><start_of_turn>model\n", true)) {
|
||||
return 1;
|
||||
}
|
||||
if (generate_response(ctx, smpl, n_predict)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
} else {
|
||||
LOG("\n Running in chat mode, available commands:");
|
||||
LOG("\n /image <path> load an image");
|
||||
LOG("\n /clear clear the chat history");
|
||||
LOG("\n /quit or /exit exit the program");
|
||||
LOG("\n");
|
||||
|
||||
if (eval_text(ctx, "<start_of_turn>user\n")) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
while (true) {
|
||||
g_is_generating = false;
|
||||
LOG("\n> ");
|
||||
console::set_display(console::user_input);
|
||||
std::string line;
|
||||
console::readline(line, false);
|
||||
console::set_display(console::reset);
|
||||
line = string_strip(line);
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
if (line == "/quit" || line == "/exit") {
|
||||
break;
|
||||
}
|
||||
if (line == "/clear") {
|
||||
ctx.n_past = 0;
|
||||
llama_kv_cache_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS
|
||||
LOG("Chat history cleared\n\n");
|
||||
continue;
|
||||
}
|
||||
g_is_generating = true;
|
||||
if (line.find("/image") == 0) {
|
||||
std::string image = line.substr(7);
|
||||
int res = eval_image(ctx, image);
|
||||
if (res == 2) {
|
||||
continue; // image not found
|
||||
}
|
||||
if (res) {
|
||||
return 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (eval_text(ctx, line + "<end_of_turn><start_of_turn>model\n", true)) {
|
||||
return 1;
|
||||
}
|
||||
if (generate_response(ctx, smpl, n_predict)) {
|
||||
return 1;
|
||||
}
|
||||
if (eval_text(ctx, "<end_of_turn><start_of_turn>user\n")) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
307
examples/llava/gemma3_convert_encoder_to_gguf.py
Normal file
307
examples/llava/gemma3_convert_encoder_to_gguf.py
Normal file
@ -0,0 +1,307 @@
|
||||
import gguf
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import torch
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
from typing import cast, ContextManager, Any, Iterator
|
||||
from pathlib import Path
|
||||
from torch import Tensor
|
||||
|
||||
logger = logging.getLogger("gemma3-mmproj")
|
||||
|
||||
|
||||
# (copied from convert_hf_to_gguf.py)
|
||||
# tree of lazy tensors
|
||||
class LazyTorchTensor(gguf.LazyBase):
|
||||
_tensor_type = torch.Tensor
|
||||
# to keep the type-checker happy
|
||||
dtype: torch.dtype
|
||||
shape: torch.Size
|
||||
|
||||
# only used when converting a torch.Tensor to a np.ndarray
|
||||
_dtype_map: dict[torch.dtype, type] = {
|
||||
torch.float16: np.float16,
|
||||
torch.float32: np.float32,
|
||||
}
|
||||
|
||||
# used for safetensors slices
|
||||
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
|
||||
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
|
||||
_dtype_str_map: dict[str, torch.dtype] = {
|
||||
"F64": torch.float64,
|
||||
"F32": torch.float32,
|
||||
"BF16": torch.bfloat16,
|
||||
"F16": torch.float16,
|
||||
# "U64": torch.uint64,
|
||||
"I64": torch.int64,
|
||||
# "U32": torch.uint32,
|
||||
"I32": torch.int32,
|
||||
# "U16": torch.uint16,
|
||||
"I16": torch.int16,
|
||||
"U8": torch.uint8,
|
||||
"I8": torch.int8,
|
||||
"BOOL": torch.bool,
|
||||
"F8_E4M3": torch.float8_e4m3fn,
|
||||
"F8_E5M2": torch.float8_e5m2,
|
||||
}
|
||||
|
||||
def numpy(self) -> gguf.LazyNumpyTensor:
|
||||
dtype = self._dtype_map[self.dtype]
|
||||
return gguf.LazyNumpyTensor(
|
||||
meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
|
||||
args=(self,),
|
||||
func=(lambda s: s.numpy())
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
|
||||
return torch.empty(size=shape, dtype=dtype, device="meta")
|
||||
|
||||
@classmethod
|
||||
def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
|
||||
dtype = cls._dtype_str_map[st_slice.get_dtype()]
|
||||
shape: tuple[int, ...] = tuple(st_slice.get_shape())
|
||||
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
|
||||
return cast(torch.Tensor, lazy)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
del types # unused
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if func is torch.Tensor.numpy:
|
||||
return args[0].numpy()
|
||||
|
||||
return cls._wrap_fn(func)(*args, **kwargs)
|
||||
|
||||
|
||||
class Gemma3VisionTower:
|
||||
hparams: dict
|
||||
gguf_writer: gguf.GGUFWriter
|
||||
fname_out: Path
|
||||
ftype: gguf.LlamaFileType
|
||||
|
||||
@staticmethod
|
||||
def load_hparams(dir_model: Path):
|
||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
|
||||
part_names: list[str] = []
|
||||
for filename in os.listdir(dir_model):
|
||||
if filename.startswith(prefix) and filename.endswith(suffix):
|
||||
part_names.append(filename)
|
||||
part_names.sort()
|
||||
return part_names
|
||||
|
||||
def __init__(self,
|
||||
dir_model: Path,
|
||||
fname_out: Path,
|
||||
ftype: gguf.LlamaFileType,
|
||||
is_big_endian: bool,):
|
||||
hparams = Gemma3VisionTower.load_hparams(dir_model)
|
||||
self.hparams = hparams
|
||||
self.fname_out = fname_out
|
||||
self.ftype = ftype
|
||||
endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
||||
self.gguf_writer = gguf.GGUFWriter(path=None, arch="clip", endianess=endianess)
|
||||
|
||||
text_config = hparams["text_config"]
|
||||
vision_config = hparams["vision_config"]
|
||||
|
||||
assert hparams["architectures"][0] == "Gemma3ForConditionalGeneration"
|
||||
assert text_config is not None
|
||||
assert vision_config is not None
|
||||
|
||||
self.gguf_writer.add_string ("clip.projector_type", "gemma3")
|
||||
self.gguf_writer.add_bool ("clip.has_text_encoder", False)
|
||||
self.gguf_writer.add_bool ("clip.has_vision_encoder", True)
|
||||
self.gguf_writer.add_bool ("clip.has_llava_projector", False) # legacy
|
||||
self.gguf_writer.add_uint32 ("clip.vision.image_size", vision_config["image_size"])
|
||||
self.gguf_writer.add_uint32 ("clip.vision.patch_size", vision_config["patch_size"])
|
||||
self.gguf_writer.add_uint32 ("clip.vision.embedding_length", vision_config["hidden_size"])
|
||||
self.gguf_writer.add_uint32 ("clip.vision.feed_forward_length", vision_config["intermediate_size"])
|
||||
self.gguf_writer.add_uint32 ("clip.vision.projection_dim", text_config["hidden_size"])
|
||||
self.gguf_writer.add_uint32 ("clip.vision.block_count", vision_config["num_hidden_layers"])
|
||||
self.gguf_writer.add_uint32 ("clip.vision.attention.head_count", vision_config["num_attention_heads"])
|
||||
self.gguf_writer.add_float32("clip.vision.attention.layer_norm_epsilon", vision_config.get("layer_norm_eps", 1e-6))
|
||||
# default values taken from HF tranformers code
|
||||
self.gguf_writer.add_array ("clip.vision.image_mean", [0.5, 0.5, 0.5])
|
||||
self.gguf_writer.add_array ("clip.vision.image_std", [0.5, 0.5, 0.5])
|
||||
self.gguf_writer.add_bool ("clip.use_gelu", True)
|
||||
|
||||
# load tensors
|
||||
for name, data_torch in self.get_tensors(dir_model):
|
||||
# convert any unsupported data types to float32
|
||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
self.add_tensor(name, data_torch)
|
||||
|
||||
def get_tensors(self, dir_model: Path) -> Iterator[tuple[str, Tensor]]:
|
||||
part_names = Gemma3VisionTower.get_model_part_names(dir_model, "model", ".safetensors")
|
||||
tensor_names_from_parts: set[str] = set()
|
||||
for part_name in part_names:
|
||||
logger.info(f"gguf: loading model part '{part_name}'")
|
||||
from safetensors import safe_open
|
||||
ctx = cast(ContextManager[Any], safe_open(dir_model / part_name, framework="pt", device="cpu"))
|
||||
with ctx as model_part:
|
||||
tensor_names_from_parts.update(model_part.keys())
|
||||
|
||||
for name in model_part.keys():
|
||||
data = model_part.get_slice(name)
|
||||
data = LazyTorchTensor.from_safetensors_slice(data)
|
||||
yield name, data
|
||||
|
||||
def add_tensor(self, name: str, data_torch: Tensor):
|
||||
is_1d = len(data_torch.shape) == 1
|
||||
is_embd = ".embeddings." in name
|
||||
old_dtype = data_torch.dtype
|
||||
can_quantize = not is_1d and not is_embd
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
|
||||
# this is to support old checkpoint
|
||||
# TODO: remove this when we have the final model
|
||||
name = name.replace("vision_model.vision_model.", "vision_tower.vision_model.")
|
||||
name = name.replace("multimodal_projector.", "multi_modal_projector.")
|
||||
|
||||
# filter only vision tensors
|
||||
if not name.startswith("vision_tower.vision_model.") and not name.startswith("multi_modal_projector."):
|
||||
return
|
||||
# prefix
|
||||
name = name.replace("vision_tower.vision_model.encoder.layers.", "v.blk.")
|
||||
name = name.replace("vision_tower.vision_model.", "v.")
|
||||
# projector and input embd
|
||||
name = name.replace(".embeddings.patch_embedding.", ".patch_embd.")
|
||||
name = name.replace(".embeddings.position_embedding.", ".position_embd.")
|
||||
name = name.replace(
|
||||
"multi_modal_projector.mm_input_projection_weight",
|
||||
"mm.input_projection.weight"
|
||||
)
|
||||
name = name.replace(
|
||||
"multi_modal_projector.mm_soft_emb_norm.weight",
|
||||
"mm.soft_emb_norm.weight"
|
||||
)
|
||||
name = name.replace("post_layernorm.", "post_ln.")
|
||||
# each block
|
||||
name = name.replace(".self_attn.k_proj.", ".attn_k.")
|
||||
name = name.replace(".self_attn.v_proj.", ".attn_v.")
|
||||
name = name.replace(".self_attn.q_proj.", ".attn_q.")
|
||||
name = name.replace(".self_attn.out_proj.", ".attn_out.")
|
||||
name = name.replace(".layer_norm1.", ".ln1.")
|
||||
name = name.replace(".layer_norm2.", ".ln2.")
|
||||
name = name.replace(".mlp.fc1.", ".ffn_down.")
|
||||
name = name.replace(".mlp.fc2.", ".ffn_up.")
|
||||
|
||||
if can_quantize:
|
||||
if self.ftype == gguf.LlamaFileType.ALL_F32:
|
||||
data_qtype = gguf.GGMLQuantizationType.F32
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
||||
data_qtype = gguf.GGMLQuantizationType.BF16
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
|
||||
data_qtype = gguf.GGMLQuantizationType.Q8_0
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {self.ftype}")
|
||||
|
||||
# corrent norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
|
||||
# the other norm values are part of SigLIP model, and they are already correct
|
||||
# ref code: Gemma3RMSNorm
|
||||
if "soft_emb_norm.weight" in name:
|
||||
logger.info(f"Correcting norm value for '{name}'")
|
||||
data_torch = data_torch + 1
|
||||
|
||||
data = data_torch.numpy()
|
||||
|
||||
try:
|
||||
data = gguf.quants.quantize(data, data_qtype)
|
||||
except Exception as e:
|
||||
logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16")
|
||||
data_qtype = gguf.GGMLQuantizationType.F16
|
||||
data = gguf.quants.quantize(data, data_qtype)
|
||||
|
||||
# reverse shape to make it similar to the internal ggml dimension order
|
||||
shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}"
|
||||
logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
|
||||
|
||||
self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype)
|
||||
|
||||
def write(self):
|
||||
self.gguf_writer.write_header_to_file(path=self.fname_out)
|
||||
self.gguf_writer.write_kv_data_to_file()
|
||||
self.gguf_writer.write_tensors_to_file(progress=True)
|
||||
self.gguf_writer.close()
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert Gemma 3 vision tower safetensors to GGUF format",)
|
||||
parser.add_argument(
|
||||
"--outfile", type=Path, default="mmproj.gguf",
|
||||
help="path to write to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
|
||||
help="output format",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bigendian", action="store_true",
|
||||
help="model is executed on big endian machine",
|
||||
)
|
||||
parser.add_argument(
|
||||
"model", type=Path,
|
||||
help="directory containing model file",
|
||||
nargs="?",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true",
|
||||
help="increase output verbosity",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.model is None:
|
||||
parser.error("the following arguments are required: model")
|
||||
return args
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
dir_model = args.model
|
||||
|
||||
if not dir_model.is_dir():
|
||||
logger.error(f'Error: {args.model} is not a directory')
|
||||
sys.exit(1)
|
||||
|
||||
ftype_map: dict[str, gguf.LlamaFileType] = {
|
||||
"f32": gguf.LlamaFileType.ALL_F32,
|
||||
"f16": gguf.LlamaFileType.MOSTLY_F16,
|
||||
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
|
||||
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
|
||||
}
|
||||
|
||||
logger.info(f"Loading model: {dir_model.name}")
|
||||
|
||||
with torch.inference_mode():
|
||||
gemma3_vision_tower = Gemma3VisionTower(
|
||||
dir_model=dir_model,
|
||||
fname_out=args.outfile,
|
||||
ftype=ftype_map[args.outtype],
|
||||
is_big_endian=args.bigendian,
|
||||
)
|
||||
gemma3_vision_tower.write()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
@ -253,6 +253,7 @@ class MODEL_ARCH(IntEnum):
|
||||
MINICPM3 = auto()
|
||||
GEMMA = auto()
|
||||
GEMMA2 = auto()
|
||||
GEMMA3 = auto()
|
||||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
RWKV6QWEN2 = auto()
|
||||
@ -440,6 +441,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.MINICPM3: "minicpm3",
|
||||
MODEL_ARCH.GEMMA: "gemma",
|
||||
MODEL_ARCH.GEMMA2: "gemma2",
|
||||
MODEL_ARCH.GEMMA3: "gemma3",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
|
||||
@ -1077,6 +1079,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.GEMMA3: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_PRE_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
],
|
||||
MODEL_ARCH.STARCODER2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
@ -36,6 +36,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
||||
{ LLM_ARCH_GEMMA, "gemma" },
|
||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||
{ LLM_ARCH_GEMMA3, "gemma3" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
{ LLM_ARCH_MAMBA, "mamba" },
|
||||
{ LLM_ARCH_XVERSE, "xverse" },
|
||||
@ -766,6 +767,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GEMMA3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_STARCODER2,
|
||||
{
|
||||
|
@ -40,6 +40,7 @@ enum llm_arch {
|
||||
LLM_ARCH_MINICPM3,
|
||||
LLM_ARCH_GEMMA,
|
||||
LLM_ARCH_GEMMA2,
|
||||
LLM_ARCH_GEMMA3,
|
||||
LLM_ARCH_STARCODER2,
|
||||
LLM_ARCH_MAMBA,
|
||||
LLM_ARCH_XVERSE,
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
@ -864,6 +865,23 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 26: type = LLM_TYPE_1B; break;
|
||||
case 34: type = LLM_TYPE_4B; break;
|
||||
case 48: type = LLM_TYPE_12B; break;
|
||||
case 62: type = LLM_TYPE_27B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
|
||||
hparams.f_attention_scale = type == LLM_TYPE_27B
|
||||
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
|
||||
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@ -2454,6 +2472,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
@ -3650,6 +3697,7 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
|
||||
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
|
||||
LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
|
||||
LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale);
|
||||
LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str());
|
||||
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
|
||||
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
|
||||
@ -3923,6 +3971,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_PHIMOE:
|
||||
case LLM_ARCH_GEMMA:
|
||||
case LLM_ARCH_GEMMA2:
|
||||
case LLM_ARCH_GEMMA3:
|
||||
case LLM_ARCH_STARCODER2:
|
||||
case LLM_ARCH_OPENELM:
|
||||
case LLM_ARCH_GPTNEOX:
|
||||
|
147
src/llama.cpp
147
src/llama.cpp
@ -4978,6 +4978,149 @@ struct llm_build_context {
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_gemma3() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
||||
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
||||
|
||||
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
|
||||
if (ubatch.token) {
|
||||
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
|
||||
cb(inpL, "inp_scaled", -1);
|
||||
}
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
// gemma3 requires different mask for layers using sliding window (SWA)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(true);
|
||||
struct ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(true);
|
||||
|
||||
// "5-to-1 interleaved attention"
|
||||
// 5 layers of local attention followed by 1 layer of global attention
|
||||
static const int sliding_window_pattern = 6;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
const bool is_sliding = (il + 1) % sliding_window_pattern;
|
||||
const float freq_base_l = is_sliding ? 10000.0f : freq_base;
|
||||
const float freq_scale_l = is_sliding ? 1.0f : freq_scale;
|
||||
struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens);
|
||||
Qcur = llm_build_norm(ctx0, Qcur, hparams,
|
||||
model.layers[il].attn_q_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens);
|
||||
Kcur = llm_build_norm(ctx0, Kcur, hparams,
|
||||
model.layers[il].attn_k_norm,
|
||||
NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
Kcur, Vcur, Qcur, KQ_mask_l, n_tokens, kv_head, n_kv, hparams.f_attention_scale, cb, il);
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].attn_post_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
|
||||
cb(sa_out, "sa_out", il);
|
||||
|
||||
cur = llm_build_norm(ctx0, sa_out, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.layers[il].ffn_post_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "ffn_post_norm", -1);
|
||||
|
||||
cur = ggml_add(ctx0, cur, sa_out);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_starcoder2() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
||||
@ -8298,6 +8441,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm.build_gemma2();
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA3:
|
||||
{
|
||||
result = llm.build_gemma3();
|
||||
} break;
|
||||
case LLM_ARCH_STARCODER2:
|
||||
{
|
||||
result = llm.build_starcoder2();
|
||||
|
Reference in New Issue
Block a user