From 8846aace4934ad29651ea61b8c7e3f6b0556e3d2 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 26 Jun 2025 19:34:02 +0200 Subject: [PATCH] model : gemma3n text-only (#14400) * gemma3n * add llm_graph_input_one --- convert_hf_to_gguf.py | 124 +++++++- gguf-py/gguf/constants.py | 75 +++++ gguf-py/gguf/gguf_writer.py | 18 ++ gguf-py/gguf/tensor_mapping.py | 64 ++++ src/llama-arch.cpp | 54 ++++ src/llama-arch.h | 17 ++ src/llama-graph.cpp | 23 +- src/llama-graph.h | 16 +- src/llama-hparams.h | 6 + src/llama-kv-cache-unified.cpp | 30 +- src/llama-model.cpp | 517 +++++++++++++++++++++++++++++++++ src/llama-model.h | 22 ++ src/llama-quant.cpp | 9 +- 13 files changed, 960 insertions(+), 15 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bbf8b30ff..4f2339a02 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -310,6 +310,8 @@ class ModelBase: gguf.MODEL_TENSOR.POSNET_NORM2, gguf.MODEL_TENSOR.V_ENC_EMBD_POS, gguf.MODEL_TENSOR.A_ENC_EMBD_POS, + gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF, + gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF, ) ) or not new_name.endswith(".weight") @@ -320,7 +322,11 @@ class ModelBase: self.match_model_tensor_name(new_name, key, bid) for key in ( gguf.MODEL_TENSOR.TOKEN_EMBD, + gguf.MODEL_TENSOR.PER_LAYER_TOKEN_EMBD, gguf.MODEL_TENSOR.OUTPUT, + gguf.MODEL_TENSOR.ALTUP_ROUTER, + gguf.MODEL_TENSOR.LAUREL_L, + gguf.MODEL_TENSOR.LAUREL_R, ) ): if self.ftype in ( @@ -921,13 +927,16 @@ class TextModel(ModelBase): tokenizer = SentencePieceProcessor() tokenizer.LoadFromFile(str(tokenizer_path)) - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + vocab_size = self.find_hparam([ + "vocab_size_per_layer_input", # gemma3n + "vocab_size", + ], optional=True) or tokenizer.vocab_size() tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size - for token_id in range(tokenizer.vocab_size()): + for token_id in range(vocab_size): piece = tokenizer.IdToPiece(token_id) text = piece.encode("utf-8") score = tokenizer.GetScore(token_id) @@ -942,6 +951,10 @@ class TextModel(ModelBase): elif tokenizer.IsByte(token_id): toktype = SentencePieceTokenTypes.BYTE + if token_id >= vocab_size: + logger.warning(f'ignore tokens from {token_id}: id is out of range, max={vocab_size - 1}') + break + tokens[token_id] = text scores[token_id] = score toktypes[token_id] = toktype @@ -4217,6 +4230,7 @@ class Gemma2Model(TextModel): @ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration") class Gemma3Model(TextModel): model_arch = gguf.MODEL_ARCH.GEMMA3 + norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value def set_vocab(self): self._set_vocab_sentencepiece() @@ -4238,9 +4252,8 @@ class Gemma3Model(TextModel): self.gguf_writer.add_value_length(hparams.get("head_dim", 256)) self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers - # both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3 + # attn_logit_softcapping is removed in Gemma3 assert hparams.get("attn_logit_softcapping") is None - assert hparams.get("final_logit_softcapping") is None self.gguf_writer.add_sliding_window(hparams["sliding_window"]) self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4)) if hparams.get("rope_scaling") is not None: @@ -4252,7 +4265,7 @@ class Gemma3Model(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused - if name.startswith("language_model."): + if "language_model." in name: name = name.replace("language_model.", "") elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ @@ -4267,8 +4280,9 @@ class Gemma3Model(TextModel): # ref code in Gemma3RMSNorm # output = output * (1.0 + self.weight.float()) + # note: this is not the case on gemma3n if name.endswith("norm.weight"): - data_torch = data_torch + 1 + data_torch = data_torch + self.norm_shift return [(self.map_tensor_name(name), data_torch)] @@ -4325,6 +4339,104 @@ class Gemma3VisionModel(MmprojModel): return [] # skip other tensors +@ModelBase.register("Gemma3nForConditionalGeneration") +class Gemma3NModel(Gemma3Model): + model_arch = gguf.MODEL_ARCH.GEMMA3N + norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code + + _altup_proj: list[Tensor] = [] + _altup_unembd: list[Tensor] = [] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams["altup_num_inputs"] == 4, "Current conversion only supports 4 altup inputs" + self._altup_proj = [ + torch.Tensor(), # to be replaced + torch.Tensor(), # to be replaced + torch.Tensor(), # to be replaced + ] + self._altup_unembd = [ + torch.Tensor(), # to be replaced + torch.Tensor(), # to be replaced + torch.Tensor(), # to be replaced + ] + + def set_vocab(self): + with open(self.dir_model / "chat_template.jinja") as f: + # quick hack to make sure chat template is added + self.gguf_writer.add_chat_template(f.read()) + super().set_vocab() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_altup_active_idx(self.hparams["altup_active_idx"]) + self.gguf_writer.add_altup_num_inputs(self.hparams["altup_num_inputs"]) + self.gguf_writer.add_embedding_length_per_layer_input(self.hparams["hidden_size_per_layer_input"]) + self.gguf_writer.add_shared_kv_layers(self.hparams["num_kv_shared_layers"]) + + activation_sparsity_scale = [] + for s in self.hparams["activation_sparsity_pattern"]: + normal_dist = torch.distributions.normal.Normal(0, 1) + std_multiplier = normal_dist.icdf(torch.tensor(s, dtype=torch.float32)) + activation_sparsity_scale.append(std_multiplier.item()) + self.gguf_writer.add_activation_sparsity_scale(activation_sparsity_scale) + + sliding_window_pattern = [] + for t in self.hparams["layer_types"]: + sliding_window_pattern.append(t == "sliding_attention") + self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern) + + def _stack_matrices(self, matrices: list[Tensor]) -> Tensor | None: + has_all = all(m.numel() > 0 for m in matrices) + if not has_all: + return None + else: + return torch.stack(matrices, dim=0) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.endswith("_scale"): + name = name + ".weight" + + # TODO: implement self.prediction_coefs.weight.clamp_(...) + + if "language_model." not in name: + return [] # skip non-language model tensors + + if "altup_unembed_projections" in name: + data_torch = data_torch.to(device="cpu") + if ".0." in name: + self._altup_unembd[0] = data_torch + elif ".1." in name: + self._altup_unembd[1] = data_torch + elif ".2." in name: + self._altup_unembd[2] = data_torch + else: + raise ValueError(f"Unknown name: {name}") + out = self._stack_matrices(self._altup_unembd) + if out is not None: + return [(self.map_tensor_name("model.altup_unembed_projections.weight"), out)] + else: + return [] + + if "altup_projections" in name: + data_torch = data_torch.to(device="cpu") + if ".0." in name: + self._altup_proj[0] = data_torch + elif ".1." in name: + self._altup_proj[1] = data_torch + elif ".2." in name: + self._altup_proj[2] = data_torch + else: + raise ValueError(f"Unknown name: {name}") + out = self._stack_matrices(self._altup_proj) + if out is not None: + return [(self.map_tensor_name("model.altup_projections.weight"), out)] + else: + return [] + + return super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Starcoder2ForCausalLM") class StarCoder2Model(TextModel): model_arch = gguf.MODEL_ARCH.STARCODER2 diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 0429b0aaf..fb75143b0 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -118,6 +118,10 @@ class Keys: EMBEDDING_SCALE = "{arch}.embedding_scale" TOKEN_SHIFT_COUNT = "{arch}.token_shift_count" INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step" + ACTIVATION_SPARSITY_SCALE = "{arch}.activation_sparsity_scale" + ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx" + ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs" + EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -142,6 +146,8 @@ class Keys: SCALE = "{arch}.attention.scale" KEY_LENGTH_MLA = "{arch}.attention.key_length_mla" VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" + SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers" + SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" @@ -314,6 +320,7 @@ class MODEL_ARCH(IntEnum): GEMMA = auto() GEMMA2 = auto() GEMMA3 = auto() + GEMMA3N = auto() STARCODER2 = auto() RWKV6 = auto() RWKV6QWEN2 = auto() @@ -399,6 +406,22 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_NORM = auto() ATTN_K_NORM = auto() LAYER_OUT_NORM = auto() + PER_LAYER_TOKEN_EMBD = auto() # gemma3n + PER_LAYER_MODEL_PROJ = auto() # gemma3n + PER_LAYER_INP_GATE = auto() # gemma3n + PER_LAYER_PROJ = auto() # gemma3n + PER_LAYER_PROJ_NORM = auto() # gemma3n + PER_LAYER_POST_NORM = auto() # gemma3n + ALTUP_PROJ = auto() # gemma3n + ALTUP_UNEMBD_PROJ = auto() # gemma3n + ALTUP_CORRECT_COEF = auto() # gemma3n + ALTUP_CORRECT_SCALE = auto() # gemma3n + ALTUP_PREDICT_COEF = auto() # gemma3n + ALTUP_ROUTER = auto() # gemma3n + ALTUP_ROUTER_NORM = auto() # gemma3n + LAUREL_L = auto() # gemma3n + LAUREL_R = auto() # gemma3n + LAUREL_POST_NORM = auto() # gemma3n SSM_IN = auto() SSM_CONV1D = auto() SSM_X = auto() @@ -597,6 +620,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.GEMMA: "gemma", MODEL_ARCH.GEMMA2: "gemma2", MODEL_ARCH.GEMMA3: "gemma3", + MODEL_ARCH.GEMMA3N: "gemma3n", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", @@ -682,6 +706,22 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", + MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n + MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n + MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm", # gemma3n + MODEL_TENSOR.ALTUP_UNEMBD_PROJ: "altup_unembd_proj", # gemma3n + MODEL_TENSOR.ALTUP_PROJ: "altup_proj", # gemma3n + MODEL_TENSOR.PER_LAYER_INP_GATE: "blk.{bid}.inp_gate", # gemma3n + MODEL_TENSOR.PER_LAYER_PROJ: "blk.{bid}.proj", # gemma3n + MODEL_TENSOR.PER_LAYER_POST_NORM: "blk.{bid}.post_norm", # gemma3n + MODEL_TENSOR.ALTUP_CORRECT_COEF: "blk.{bid}.altup_correct_coef", # gemma3n + MODEL_TENSOR.ALTUP_CORRECT_SCALE: "blk.{bid}.altup_correct_scale", # gemma3n + MODEL_TENSOR.ALTUP_PREDICT_COEF: "blk.{bid}.altup_predict_coef", # gemma3n + MODEL_TENSOR.ALTUP_ROUTER: "blk.{bid}.altup_router", # gemma3n + MODEL_TENSOR.ALTUP_ROUTER_NORM: "blk.{bid}.altup_router_norm", # gemma3n + MODEL_TENSOR.LAUREL_L: "blk.{bid}.laurel_l", # gemma3n + MODEL_TENSOR.LAUREL_R: "blk.{bid}.laurel_r", # gemma3n + MODEL_TENSOR.LAUREL_POST_NORM: "blk.{bid}.laurel_post_norm", # gemma3n MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", @@ -1486,6 +1526,41 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_PRE_NORM, MODEL_TENSOR.FFN_POST_NORM, ], + MODEL_ARCH.GEMMA3N: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_PRE_NORM, + MODEL_TENSOR.FFN_POST_NORM, + # altup / laurel + MODEL_TENSOR.PER_LAYER_TOKEN_EMBD, + MODEL_TENSOR.PER_LAYER_MODEL_PROJ, + MODEL_TENSOR.PER_LAYER_INP_GATE, + MODEL_TENSOR.PER_LAYER_PROJ, + MODEL_TENSOR.PER_LAYER_PROJ_NORM, + MODEL_TENSOR.PER_LAYER_POST_NORM, + MODEL_TENSOR.ALTUP_PROJ, + MODEL_TENSOR.ALTUP_UNEMBD_PROJ, + MODEL_TENSOR.ALTUP_CORRECT_COEF, + MODEL_TENSOR.ALTUP_CORRECT_SCALE, + MODEL_TENSOR.ALTUP_PREDICT_COEF, + MODEL_TENSOR.ALTUP_ROUTER, + MODEL_TENSOR.ALTUP_ROUTER_NORM, + MODEL_TENSOR.LAUREL_L, + MODEL_TENSOR.LAUREL_R, + MODEL_TENSOR.LAUREL_POST_NORM, + ], MODEL_ARCH.STARCODER2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index b9b63d052..d32cd479a 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -672,6 +672,18 @@ class GGUFWriter: def add_decoder_start_token_id(self, id: int) -> None: self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id) + def add_embedding_length_per_layer_input(self, value: int) -> None: + self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value) + + def add_altup_active_idx(self, val: int) -> None: + self.add_uint32(Keys.LLM.ALTUP_ACTIVE_IDX.format(arch=self.arch), val) + + def add_altup_num_inputs(self, val: int) -> None: + self.add_uint32(Keys.LLM.ALTUP_NUM_INPUTS.format(arch=self.arch), val) + + def add_activation_sparsity_scale(self, values: Sequence[float]) -> None: + self.add_array(Keys.LLM.ACTIVATION_SPARSITY_SCALE.format(arch=self.arch), values) + def add_head_count(self, count: int | Sequence[int]) -> None: if isinstance(count, int): self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) @@ -702,6 +714,12 @@ class GGUFWriter: def add_clamp_kqv(self, value: float) -> None: self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value) + def add_shared_kv_layers(self, value: float) -> None: + self.add_float32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value) + + def add_sliding_window_pattern(self, value: Sequence[bool]) -> None: + self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value) + def add_logit_scale(self, value: float) -> None: self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 79f044d2a..b30f77dbe 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -480,6 +480,70 @@ class TensorNameMap: "encoder.layer.{bid}.layer_norm_2" # jina-v2-code ), + MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: ( + "model.embed_tokens_per_layer", # gemma3n + ), + + MODEL_TENSOR.PER_LAYER_MODEL_PROJ: ( + "model.per_layer_model_projection", # gemma3n + ), + + MODEL_TENSOR.PER_LAYER_PROJ_NORM: ( + "model.per_layer_projection_norm", # gemma3n + ), + + MODEL_TENSOR.ALTUP_PROJ: ( + "model.altup_projections", # gemma3n + ), + + MODEL_TENSOR.ALTUP_UNEMBD_PROJ: ( + "model.altup_unembed_projections", # gemma3n + ), + + MODEL_TENSOR.PER_LAYER_INP_GATE: ( + "model.layers.{bid}.per_layer_input_gate", # gemma3n + ), + + MODEL_TENSOR.PER_LAYER_PROJ: ( + "model.layers.{bid}.per_layer_projection", # gemma3n + ), + + MODEL_TENSOR.PER_LAYER_POST_NORM: ( + "model.layers.{bid}.post_per_layer_input_norm", # gemma3n + ), + + MODEL_TENSOR.ALTUP_CORRECT_COEF: ( + "model.layers.{bid}.altup.correction_coefs", # gemma3n + ), + + MODEL_TENSOR.ALTUP_CORRECT_SCALE: ( + "model.layers.{bid}.altup.correct_output_scale", # gemma3n + ), + + MODEL_TENSOR.ALTUP_PREDICT_COEF: ( + "model.layers.{bid}.altup.prediction_coefs", # gemma3n + ), + + MODEL_TENSOR.ALTUP_ROUTER: ( + "model.layers.{bid}.altup.modality_router", # gemma3n + ), + + MODEL_TENSOR.ALTUP_ROUTER_NORM: ( + "model.layers.{bid}.altup.router_norm", # gemma3n + ), + + MODEL_TENSOR.LAUREL_L: ( + "model.layers.{bid}.laurel.linear_left", # gemma3n + ), + + MODEL_TENSOR.LAUREL_R: ( + "model.layers.{bid}.laurel.linear_right", # gemma3n + ), + + MODEL_TENSOR.LAUREL_POST_NORM: ( + "model.layers.{bid}.laurel.post_laurel_norm", # gemma3n + ), + MODEL_TENSOR.SSM_IN: ( "model.layers.{bid}.in_proj", "backbone.layers.{bid}.mixer.in_proj", diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 8dadef204..435e3b9ba 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -42,6 +42,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA, "gemma" }, { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA3, "gemma3" }, + { LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, { LLM_ARCH_XVERSE, "xverse" }, @@ -932,6 +933,42 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, }, }, + { + LLM_ARCH_GEMMA3N, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + { LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" }, + { LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" }, + { LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" }, + { LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" }, + { LLM_TENSOR_ALTUP_PROJ, "altup_proj" }, + { LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" }, + { LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" }, + { LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" }, + { LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" }, + { LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" }, + { LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" }, + { LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" }, + { LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" }, + { LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" }, + { LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" }, + { LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" }, + }, + }, { LLM_ARCH_STARCODER2, { @@ -1749,6 +1786,23 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + // altup / laurel (gemma 3n) + {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_LAUREL_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 5b0230c15..9181ad053 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -46,6 +46,7 @@ enum llm_arch { LLM_ARCH_GEMMA, LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA3, + LLM_ARCH_GEMMA3N, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, LLM_ARCH_XVERSE, @@ -269,6 +270,22 @@ enum llm_tensor { LLM_TENSOR_LAYER_OUT_NORM, LLM_TENSOR_POST_ATTN_NORM, LLM_TENSOR_POST_MLP_NORM, + LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n + LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n + LLM_TENSOR_PER_LAYER_INP_GATE, // gemma3n + LLM_TENSOR_PER_LAYER_PROJ, // gemma3n + LLM_TENSOR_PER_LAYER_PROJ_NORM, // gemma3n + LLM_TENSOR_PER_LAYER_POST_NORM, // gemma3n + LLM_TENSOR_ALTUP_PROJ, // gemma3n + LLM_TENSOR_ALTUP_UNEMBD_PROJ, // gemma3n + LLM_TENSOR_ALTUP_CORRECT_COEF, // gemma3n + LLM_TENSOR_ALTUP_CORRECT_SCALE, // gemma3n + LLM_TENSOR_ALTUP_PREDICT_COEF, // gemma3n + LLM_TENSOR_ALTUP_ROUTER, // gemma3n + LLM_TENSOR_ALTUP_ROUTER_NORM, // gemma3n + LLM_TENSOR_LAUREL_L, // gemma3n + LLM_TENSOR_LAUREL_R, // gemma3n + LLM_TENSOR_LAUREL_POST_NORM, // gemma3n LLM_TENSOR_SSM_IN, LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_X, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 48589a50a..71ee431a9 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -350,6 +350,12 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { } } +void llm_graph_input_one::set_input(const llama_ubatch *) { + GGML_ASSERT(one && ggml_nelements(one) == 1); + float f_one = 1.0f; + ggml_backend_tensor_set(one, &f_one, 0, sizeof(float)); +} + // // llm_graph_context // @@ -1267,8 +1273,14 @@ ggml_tensor * llm_graph_context::build_attn( // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); - ggml_build_forward_expand(gf, k_cur); - ggml_build_forward_expand(gf, v_cur); + + if (k_cur) { + ggml_build_forward_expand(gf, k_cur); + } + + if (v_cur) { + ggml_build_forward_expand(gf, v_cur); + } const auto * mctx_iswa = static_cast(mctx); @@ -1276,9 +1288,12 @@ ggml_tensor * llm_graph_context::build_attn( const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base(); - // store to KV cache - { + // optionally store to KV cache + if (k_cur) { ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il)); + } + + if (v_cur) { ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il)); } diff --git a/src/llama-graph.h b/src/llama-graph.h index b433f266d..4b1ec354d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -329,6 +329,17 @@ public: const llama_memory_hybrid_context * mctx; }; +// TODO: remove this when ggml_scale_add is implemented +class llm_graph_input_one : public llm_graph_input_i { +public: + llm_graph_input_one() {} + virtual ~llm_graph_input_one() = default; + + void set_input(const llama_ubatch *) override; + + ggml_tensor * one = nullptr; // F32 +}; + // // llm_graph_result // @@ -589,14 +600,15 @@ struct llm_graph_context { llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const; + // note: if k_cur or v_cur are not provided, they will not be stored in the memory ggml_tensor * build_attn( llm_graph_input_attn_kv_unified_iswa * inp, ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] - ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] - ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional ggml_tensor * kq_b, ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 7b315a9a7..e85afe145 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -143,6 +143,12 @@ struct llama_hparams { uint32_t n_attn_temp_floor_scale = 8192; float f_attn_temp_scale = 0.1; + // gemma3n altup + uint32_t n_altup = 4; // altup_num_inputs + uint32_t i_altup_act = 0; // altup_active_idx + uint32_t laurel_rank = 64; + uint32_t n_embd_altup = 256; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index b506d32ed..8517b722a 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -33,13 +33,19 @@ llama_kv_cache_unified::llama_kv_cache_unified( GGML_ASSERT(kv_size % n_pad == 0); + // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] + auto n_layer_cache = hparams.n_layer; + if (model.arch == LLM_ARCH_GEMMA3N) { + n_layer_cache = 20; + } + // create a context for each buffer type std::map ctx_map; auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -62,7 +68,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( cells.resize(kv_size); - for (uint32_t il = 0; il < hparams.n_layer; il++) { + for (uint32_t il = 0; il < n_layer_cache; il++) { if (filter && !filter(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); continue; @@ -102,6 +108,26 @@ llama_kv_cache_unified::llama_kv_cache_unified( layers.push_back({ il, k, v }); } + // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] + if (model.arch == LLM_ARCH_GEMMA3N) { + LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1); + + for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) { + if (filter && !filter(il)) { + LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); + continue; + } + + const bool is_swa = hparams.is_swa(il); + const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1); + + GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end()); + map_layer_ids[il] = map_layer_ids[il_reuse]; + + LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa); + } + } + // allocate tensors and initialize the buffers to avoid NaNs in the padding for (auto it : ctx_map) { auto * buft = it.first; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c2835ce67..fc39195ed 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -103,6 +103,8 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; + case LLM_TYPE_E2B: return "E2B"; + case LLM_TYPE_E4B: return "E4B"; default: return "?B"; } } @@ -1017,6 +1019,24 @@ void llama_model::load_hparams(llama_model_loader & ml) { ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); } break; + case LLM_ARCH_GEMMA3N: + { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(5); + + hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_scale_train_swa = 1.0f; + hparams.f_attention_scale = 1.0f; + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_E2B; break; + case 35: type = LLM_TYPE_E4B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2950,6 +2970,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GEMMA3N: + { + const int64_t n_altup = hparams.n_altup; + const int64_t laurel_rank = hparams.laurel_rank; + const int64_t n_embd_altup = hparams.n_embd_altup; + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); + + altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); + altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + // altup & laurel + layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0); + layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0); + layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); + layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0); + layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0); + layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0); + layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0); + layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0); + layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0); + layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0); + layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; case LLM_ARCH_STARCODER2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -8980,6 +9056,442 @@ struct llm_build_gemma3_iswa : public llm_graph_context { } }; +struct llm_build_gemma3n_iswa : public llm_graph_context { + const llama_model & model; + ggml_cgraph * gf; + + const int64_t n_embd_head; + const int64_t n_embd_altup; + const int64_t n_altup; + const int i_altup_act; + const int n_layer_kv = 20; // number of layers having KV [KV_REUSE] + const int n_layer_sparsity = 10; // number of layers using activation sparsity + const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95) + + ggml_tensor * one; // containing single element 1.0f + + llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) + : llm_graph_context(params), + model(model), + gf(gf), + n_embd_head(model.hparams.n_embd_head_k), + n_embd_altup(model.hparams.n_embd_altup), + n_altup(model.hparams.n_altup), + i_altup_act(model.hparams.i_altup_act) { + ggml_tensor * cur; + ggml_tensor * inpL; + + // TODO: remove this when ggml_scale_add is implemented + one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + { + auto inp = std::make_unique(); + inp->one = one; + res->add_input(std::move(inp)); + } + + inpL = build_inp_embd(model.tok_embd); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + if (ubatch.token) { + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + } + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // TODO: is causal == true correct? might need some changes + auto * inp_attn = build_attn_inp_kv_unified_iswa(); + + // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer] + ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs()); + + // inpL now has only 1 altup, project it to the rest of the altups + // these "added" altups will be concat to the last dim of inpL + { + ggml_tensor * target_magnitude = calc_magnitude(inpL); + ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1); + ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1] + ggml_tensor * new_magnitude = calc_magnitude(altup_added); + altup_added = ggml_div(ctx0, + ggml_mul(ctx0, altup_added, target_magnitude), + new_magnitude); + inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup] + cb(inpL, "inp_stacked", -1); + } + + // inpL now has shape: [n_embd, n_tokens, n_altup] + // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] + + for (int il = 0; il < n_layer; ++il) { + // this block is made to be closely resemble Gemma3p5DecoderLayer on python code + const bool has_kv = (il < n_layer_kv); + + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup] + ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup] + + // predicted value will go through self-attention and laurel + ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens] + cur = active_prediction; + cb(cur, "active_prediction", il); + + // norm + cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // laurel + ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens] + + // self-attention + if (has_kv) { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps); + + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + cb(Vcur, "Vcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur_pos", il); + cb(Kcur, "Kcur_pos", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); + } else { + // no KV layers + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + } + + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens] + cb(cur, "attn_gated", il); + + ggml_tensor * attn_laurel = ggml_scale(ctx0, + ggml_add(ctx0, cur, laurel_out), + 1.0f / sqrtf(2.0f)); // [n_embd, n_tokens] + cb(attn_laurel, "attn_laurel", il); + + cur = build_norm(attn_laurel, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + { + ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur); + ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur); + + if (il < n_layer_sparsity) { + // apply activation sparsity + gate_proj = gaussian_topk(gate_proj); + } + gate_proj = ggml_gelu(ctx0, gate_proj); + + cur = ggml_mul(ctx0, up_proj, gate_proj); + cur = build_lora_mm(model.layers[il].ffn_down, cur); + cb(cur, "ffn_out", il); + } + + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens] + cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il); + + ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup] + + ggml_tensor * first_prediction; // [n_embd, n_tokens] + { + first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens] + first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale); + first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction); + first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens] + cb(first_prediction, "first_prediction_gated", il); + ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens] + first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens] + cb(first_prediction, "first_prediction_scaled", il); + + first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens] + first_prediction = build_norm(first_prediction, + model.layers[il].per_layer_post_norm, NULL, + LLM_NORM_RMS, il); + cb(first_prediction, "first_prediction_out", il); + } + + // equivalent to python code: corrected_predictions[1:] += first_prediction + { + ggml_tensor * slice_first = view_2d_slice(corrected, 0); + ggml_tensor * slice_rest = ggml_view_3d(ctx0, corrected, n_embd, n_tokens, n_altup - 1, + ggml_row_size(corrected->type, n_embd), + ggml_row_size(corrected->type, n_embd*n_tokens), + n_embd*n_tokens*ggml_element_size(corrected)); + ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1] + corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup] + } + + cur = corrected; // [n_embd, n_tokens, n_altup] + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; // [n_embd, n_tokens, n_altup] + + // cur now has multiple altup(s), we want to merge them back to 1 altup + { + ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens] + // do a view to skip the first slice (active altup) + ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, + ggml_row_size(cur->type, n_embd), + ggml_row_size(cur->type, n_embd*n_tokens), + n_embd*n_tokens*ggml_element_size(cur)); + ggml_tensor * altup_unembd = ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1] + ggml_tensor * new_magnitude = calc_magnitude(altup_unembd); + altup_unembd = ggml_div(ctx0, + ggml_mul(ctx0, altup_unembd, target_magnitude), + new_magnitude); + cb(altup_unembd, "altup_unembd", -1); + + // equivalent to torch.mean(hidden_states, dim=0) + cur = view_2d_slice(cur, 0); // [n_embd, n_tokens] + for (int i = 0; i < n_altup - 1; ++i) { + cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i)); + } + cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens] + cb(cur, "unembd_merged", -1); + } + + // cur now has shape: [n_embd, n_tokens] + + // TODO: move this to right after the last KV layer + { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + + { + // final logit soft-capping + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + + ggml_tensor * calc_magnitude(ggml_tensor * x) { + return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); + } + + // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim + ggml_tensor * view_2d_slice(ggml_tensor * x, int idx) { + GGML_ASSERT(idx < (int)x->ne[2]); + return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], + ggml_row_size(x->type, x->ne[0]), + idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); + } + + // equivalent to get_per_layer_inputs() in python code + // output shape: [n_embd_altup, n_layer, n_tokens] + ggml_tensor * get_per_layer_inputs() { + auto inp = std::make_unique(); + ggml_tensor * inp_per_layer; + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + ggml_set_input(inp->tokens); + res->t_tokens = inp->tokens; + inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float)n_embd_altup)); + cb(inp_per_layer, "inp_per_layer_selected", -1); + } else { + GGML_ABORT("TODO: support embd input"); + } + res->add_input(std::move(inp)); + return inp_per_layer; + } + + // equivalent to project_per_layer_inputs() in python code + // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim + // output shape: [n_embd_altup, n_tokens, n_layer] + ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) { + const float per_layer_projection_scale = 1.0f / sqrtf((float)n_embd); + const float per_layer_input_scale = 1.0f / sqrtf(2.0f); + + ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds); + per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale); + per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); + per_layer_proj = build_norm(per_layer_proj, + model.per_layer_proj_norm, NULL, + LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens] + cb(per_layer_proj, "per_layer_proj", -1); + + inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); + cb(inp_per_layer, "inp_per_layer", -1); + + // permute to shape: [n_embd_altup, n_tokens, n_layer] + inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); + return inp_per_layer; + } + + // input cur shape: [n_altup, n_tokens] + // output shape: [n_altup, n_tokens] + ggml_tensor * laurel(ggml_tensor * cur, int il) { + ggml_tensor * tmp = cur; + tmp = build_lora_mm(model.layers[il].laurel_l, tmp); + tmp = build_lora_mm(model.layers[il].laurel_r, tmp); + tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il); + tmp = ggml_add(ctx0, tmp, cur); + cb(tmp, "laurel_out", il); + return tmp; + } + + // input x shape: [n_embd, n_tokens] + // output shape: [n_embd, n_tokens] + ggml_tensor * gaussian_topk(ggml_tensor * x) { + ggml_tensor * mean = ggml_mean(ctx0, x); + ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0, + ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))), + 1.0f / (float)(x->ne[0] - 1) + )); + ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul)); + return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x)); + } + + // + // altup functions + // + + // equivalent to compute_router_modalities() in python code + // input x shape: [n_embd, n_tokens] + // output shape: [n_altup, n_tokens] + ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il) { + ggml_tensor * router_inputs = build_norm(x, + model.layers[il].altup_router_norm, NULL, + LLM_NORM_RMS, il); + + // router_input_scale + router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd); + + ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs); + return ggml_tanh(ctx0, output); // [n_altup, n_tokens] + } + + // input cur shape: [n_embd, n_tokens, n_altup] + // output shape: [n_embd, n_tokens, n_altup] + ggml_tensor * altup_predict(ggml_tensor * cur, int il) { + ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens] + ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] + cb(modalities, "modalities", il); + + ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities); + cb(all_coefs, "all_coefs", il); + // first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor) + all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens); + + // permute to [n_altup, n_embd, n_tokens] + ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3)); + ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens] + + // final shape must be the same as cur: [n_embd, n_tokens, n_altup] + predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3)); + predictions = ggml_add(ctx0, predictions, cur); + cb(predictions, "predictions", il); + + return predictions; + } + + // input predictions shape: [n_embd, n_tokens, n_altup] + // input activated shape: [n_embd, n_tokens] + // output shape: [n_embd, n_tokens, n_altup] + ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) { + ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] + cb(modalities, "modalities", il); + + ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); + ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens] + cb(innovation, "innovation", il); + + ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens] + all_coefs = ggml_add(ctx0, all_coefs, one); + cb(all_coefs, "all_coefs", il); + all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup] + all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup] + + innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1); + ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup] + corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup] + cb(corrected, "corrected", il); + + return corrected; + } +}; + // TODO: move up next to build_starcoder struct llm_build_starcoder2 : public llm_graph_context { llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { @@ -13974,6 +14486,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_GEMMA3N: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_STARCODER2: { llm = std::make_unique(*this, params, gf); @@ -14295,6 +14811,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: + case LLM_ARCH_GEMMA3N: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: diff --git a/src/llama-model.h b/src/llama-model.h index 06e6c6879..40063b790 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -95,6 +95,8 @@ enum llm_type { LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_30B_A3B, LLM_TYPE_235B_A22B, + LLM_TYPE_E2B, + LLM_TYPE_E4B, }; std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type); @@ -316,6 +318,19 @@ struct llama_layer { struct ggml_tensor * ffn_up_scale = nullptr; struct ggml_tensor * ffn_down_scale = nullptr; + // altup & laurel + struct ggml_tensor * per_layer_inp_gate = nullptr; + struct ggml_tensor * per_layer_proj = nullptr; + struct ggml_tensor * per_layer_post_norm = nullptr; + struct ggml_tensor * altup_correct_coef = nullptr; + struct ggml_tensor * altup_correct_scale = nullptr; + struct ggml_tensor * altup_predict_coef = nullptr; + struct ggml_tensor * altup_router = nullptr; + struct ggml_tensor * altup_router_norm = nullptr; + struct ggml_tensor * laurel_l = nullptr; + struct ggml_tensor * laurel_r = nullptr; + struct ggml_tensor * laurel_post_norm = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; @@ -354,6 +369,13 @@ struct llama_model { struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; + // gemma3n altup + struct ggml_tensor * tok_embd_per_layer = nullptr; + struct ggml_tensor * altup_proj = nullptr; + struct ggml_tensor * altup_unembd_proj = nullptr; + struct ggml_tensor * per_layer_model_proj = nullptr; + struct ggml_tensor * per_layer_proj_norm = nullptr; + std::vector layers; llama_model_params params; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 43229e193..f4b5713d7 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -223,7 +223,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t new_type = GGML_TYPE_Q6_K; } } - } else if (name == "token_embd.weight") { + } else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") { if (qs.params->token_embedding_type < GGML_TYPE_COUNT) { new_type = qs.params->token_embedding_type; } else { @@ -830,6 +830,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // NOTE: can't use LLM_TN here because the layer number is not known quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; + // these are very small (e.g. 4x4) + quantize &= name.find("altup") == std::string::npos; + quantize &= name.find("laurel") == std::string::npos; + + // these are not too big so keep them as it is + quantize &= name.find("per_layer_model_proj") == std::string::npos; + // do not quantize positional embeddings and token types (BERT) quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");