model : gemma3n text-only (#14400)

* gemma3n

* add llm_graph_input_one
This commit is contained in:
Xuan-Son Nguyen
2025-06-26 19:34:02 +02:00
committed by GitHub
parent a01047b041
commit 8846aace49
13 changed files with 960 additions and 15 deletions

View File

@ -310,6 +310,8 @@ class ModelBase:
gguf.MODEL_TENSOR.POSNET_NORM2,
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
)
)
or not new_name.endswith(".weight")
@ -320,7 +322,11 @@ class ModelBase:
self.match_model_tensor_name(new_name, key, bid)
for key in (
gguf.MODEL_TENSOR.TOKEN_EMBD,
gguf.MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
gguf.MODEL_TENSOR.OUTPUT,
gguf.MODEL_TENSOR.ALTUP_ROUTER,
gguf.MODEL_TENSOR.LAUREL_L,
gguf.MODEL_TENSOR.LAUREL_R,
)
):
if self.ftype in (
@ -921,13 +927,16 @@ class TextModel(ModelBase):
tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path))
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
vocab_size = self.find_hparam([
"vocab_size_per_layer_input", # gemma3n
"vocab_size",
], optional=True) or tokenizer.vocab_size()
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
for token_id in range(tokenizer.vocab_size()):
for token_id in range(vocab_size):
piece = tokenizer.IdToPiece(token_id)
text = piece.encode("utf-8")
score = tokenizer.GetScore(token_id)
@ -942,6 +951,10 @@ class TextModel(ModelBase):
elif tokenizer.IsByte(token_id):
toktype = SentencePieceTokenTypes.BYTE
if token_id >= vocab_size:
logger.warning(f'ignore tokens from {token_id}: id is out of range, max={vocab_size - 1}')
break
tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype
@ -4217,6 +4230,7 @@ class Gemma2Model(TextModel):
@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
class Gemma3Model(TextModel):
model_arch = gguf.MODEL_ARCH.GEMMA3
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
def set_vocab(self):
self._set_vocab_sentencepiece()
@ -4238,9 +4252,8 @@ class Gemma3Model(TextModel):
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
# both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
# attn_logit_softcapping is removed in Gemma3
assert hparams.get("attn_logit_softcapping") is None
assert hparams.get("final_logit_softcapping") is None
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
if hparams.get("rope_scaling") is not None:
@ -4252,7 +4265,7 @@ class Gemma3Model(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.startswith("language_model."):
if "language_model." in name:
name = name.replace("language_model.", "")
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
@ -4267,8 +4280,9 @@ class Gemma3Model(TextModel):
# ref code in Gemma3RMSNorm
# output = output * (1.0 + self.weight.float())
# note: this is not the case on gemma3n
if name.endswith("norm.weight"):
data_torch = data_torch + 1
data_torch = data_torch + self.norm_shift
return [(self.map_tensor_name(name), data_torch)]
@ -4325,6 +4339,104 @@ class Gemma3VisionModel(MmprojModel):
return [] # skip other tensors
@ModelBase.register("Gemma3nForConditionalGeneration")
class Gemma3NModel(Gemma3Model):
model_arch = gguf.MODEL_ARCH.GEMMA3N
norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
_altup_proj: list[Tensor] = []
_altup_unembd: list[Tensor] = []
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams["altup_num_inputs"] == 4, "Current conversion only supports 4 altup inputs"
self._altup_proj = [
torch.Tensor(), # to be replaced
torch.Tensor(), # to be replaced
torch.Tensor(), # to be replaced
]
self._altup_unembd = [
torch.Tensor(), # to be replaced
torch.Tensor(), # to be replaced
torch.Tensor(), # to be replaced
]
def set_vocab(self):
with open(self.dir_model / "chat_template.jinja") as f:
# quick hack to make sure chat template is added
self.gguf_writer.add_chat_template(f.read())
super().set_vocab()
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_altup_active_idx(self.hparams["altup_active_idx"])
self.gguf_writer.add_altup_num_inputs(self.hparams["altup_num_inputs"])
self.gguf_writer.add_embedding_length_per_layer_input(self.hparams["hidden_size_per_layer_input"])
self.gguf_writer.add_shared_kv_layers(self.hparams["num_kv_shared_layers"])
activation_sparsity_scale = []
for s in self.hparams["activation_sparsity_pattern"]:
normal_dist = torch.distributions.normal.Normal(0, 1)
std_multiplier = normal_dist.icdf(torch.tensor(s, dtype=torch.float32))
activation_sparsity_scale.append(std_multiplier.item())
self.gguf_writer.add_activation_sparsity_scale(activation_sparsity_scale)
sliding_window_pattern = []
for t in self.hparams["layer_types"]:
sliding_window_pattern.append(t == "sliding_attention")
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
def _stack_matrices(self, matrices: list[Tensor]) -> Tensor | None:
has_all = all(m.numel() > 0 for m in matrices)
if not has_all:
return None
else:
return torch.stack(matrices, dim=0)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.endswith("_scale"):
name = name + ".weight"
# TODO: implement self.prediction_coefs.weight.clamp_(...)
if "language_model." not in name:
return [] # skip non-language model tensors
if "altup_unembed_projections" in name:
data_torch = data_torch.to(device="cpu")
if ".0." in name:
self._altup_unembd[0] = data_torch
elif ".1." in name:
self._altup_unembd[1] = data_torch
elif ".2." in name:
self._altup_unembd[2] = data_torch
else:
raise ValueError(f"Unknown name: {name}")
out = self._stack_matrices(self._altup_unembd)
if out is not None:
return [(self.map_tensor_name("model.altup_unembed_projections.weight"), out)]
else:
return []
if "altup_projections" in name:
data_torch = data_torch.to(device="cpu")
if ".0." in name:
self._altup_proj[0] = data_torch
elif ".1." in name:
self._altup_proj[1] = data_torch
elif ".2." in name:
self._altup_proj[2] = data_torch
else:
raise ValueError(f"Unknown name: {name}")
out = self._stack_matrices(self._altup_proj)
if out is not None:
return [(self.map_tensor_name("model.altup_projections.weight"), out)]
else:
return []
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Starcoder2ForCausalLM")
class StarCoder2Model(TextModel):
model_arch = gguf.MODEL_ARCH.STARCODER2

View File

@ -118,6 +118,10 @@ class Keys:
EMBEDDING_SCALE = "{arch}.embedding_scale"
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"
ACTIVATION_SPARSITY_SCALE = "{arch}.activation_sparsity_scale"
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
@ -142,6 +146,8 @@ class Keys:
SCALE = "{arch}.attention.scale"
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
@ -314,6 +320,7 @@ class MODEL_ARCH(IntEnum):
GEMMA = auto()
GEMMA2 = auto()
GEMMA3 = auto()
GEMMA3N = auto()
STARCODER2 = auto()
RWKV6 = auto()
RWKV6QWEN2 = auto()
@ -399,6 +406,22 @@ class MODEL_TENSOR(IntEnum):
ATTN_Q_NORM = auto()
ATTN_K_NORM = auto()
LAYER_OUT_NORM = auto()
PER_LAYER_TOKEN_EMBD = auto() # gemma3n
PER_LAYER_MODEL_PROJ = auto() # gemma3n
PER_LAYER_INP_GATE = auto() # gemma3n
PER_LAYER_PROJ = auto() # gemma3n
PER_LAYER_PROJ_NORM = auto() # gemma3n
PER_LAYER_POST_NORM = auto() # gemma3n
ALTUP_PROJ = auto() # gemma3n
ALTUP_UNEMBD_PROJ = auto() # gemma3n
ALTUP_CORRECT_COEF = auto() # gemma3n
ALTUP_CORRECT_SCALE = auto() # gemma3n
ALTUP_PREDICT_COEF = auto() # gemma3n
ALTUP_ROUTER = auto() # gemma3n
ALTUP_ROUTER_NORM = auto() # gemma3n
LAUREL_L = auto() # gemma3n
LAUREL_R = auto() # gemma3n
LAUREL_POST_NORM = auto() # gemma3n
SSM_IN = auto()
SSM_CONV1D = auto()
SSM_X = auto()
@ -597,6 +620,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.GEMMA3N: "gemma3n",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2",
@ -682,6 +706,22 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: "per_layer_model_proj", # gemma3n
MODEL_TENSOR.PER_LAYER_PROJ_NORM: "per_layer_proj_norm", # gemma3n
MODEL_TENSOR.ALTUP_UNEMBD_PROJ: "altup_unembd_proj", # gemma3n
MODEL_TENSOR.ALTUP_PROJ: "altup_proj", # gemma3n
MODEL_TENSOR.PER_LAYER_INP_GATE: "blk.{bid}.inp_gate", # gemma3n
MODEL_TENSOR.PER_LAYER_PROJ: "blk.{bid}.proj", # gemma3n
MODEL_TENSOR.PER_LAYER_POST_NORM: "blk.{bid}.post_norm", # gemma3n
MODEL_TENSOR.ALTUP_CORRECT_COEF: "blk.{bid}.altup_correct_coef", # gemma3n
MODEL_TENSOR.ALTUP_CORRECT_SCALE: "blk.{bid}.altup_correct_scale", # gemma3n
MODEL_TENSOR.ALTUP_PREDICT_COEF: "blk.{bid}.altup_predict_coef", # gemma3n
MODEL_TENSOR.ALTUP_ROUTER: "blk.{bid}.altup_router", # gemma3n
MODEL_TENSOR.ALTUP_ROUTER_NORM: "blk.{bid}.altup_router_norm", # gemma3n
MODEL_TENSOR.LAUREL_L: "blk.{bid}.laurel_l", # gemma3n
MODEL_TENSOR.LAUREL_R: "blk.{bid}.laurel_r", # gemma3n
MODEL_TENSOR.LAUREL_POST_NORM: "blk.{bid}.laurel_post_norm", # gemma3n
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
@ -1486,6 +1526,41 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.GEMMA3N: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
# altup / laurel
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
MODEL_TENSOR.PER_LAYER_MODEL_PROJ,
MODEL_TENSOR.PER_LAYER_INP_GATE,
MODEL_TENSOR.PER_LAYER_PROJ,
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
MODEL_TENSOR.PER_LAYER_POST_NORM,
MODEL_TENSOR.ALTUP_PROJ,
MODEL_TENSOR.ALTUP_UNEMBD_PROJ,
MODEL_TENSOR.ALTUP_CORRECT_COEF,
MODEL_TENSOR.ALTUP_CORRECT_SCALE,
MODEL_TENSOR.ALTUP_PREDICT_COEF,
MODEL_TENSOR.ALTUP_ROUTER,
MODEL_TENSOR.ALTUP_ROUTER_NORM,
MODEL_TENSOR.LAUREL_L,
MODEL_TENSOR.LAUREL_R,
MODEL_TENSOR.LAUREL_POST_NORM,
],
MODEL_ARCH.STARCODER2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View File

@ -672,6 +672,18 @@ class GGUFWriter:
def add_decoder_start_token_id(self, id: int) -> None:
self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
def add_embedding_length_per_layer_input(self, value: int) -> None:
self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)
def add_altup_active_idx(self, val: int) -> None:
self.add_uint32(Keys.LLM.ALTUP_ACTIVE_IDX.format(arch=self.arch), val)
def add_altup_num_inputs(self, val: int) -> None:
self.add_uint32(Keys.LLM.ALTUP_NUM_INPUTS.format(arch=self.arch), val)
def add_activation_sparsity_scale(self, values: Sequence[float]) -> None:
self.add_array(Keys.LLM.ACTIVATION_SPARSITY_SCALE.format(arch=self.arch), values)
def add_head_count(self, count: int | Sequence[int]) -> None:
if isinstance(count, int):
self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count)
@ -702,6 +714,12 @@ class GGUFWriter:
def add_clamp_kqv(self, value: float) -> None:
self.add_float32(Keys.Attention.CLAMP_KQV.format(arch=self.arch), value)
def add_shared_kv_layers(self, value: float) -> None:
self.add_float32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
def add_logit_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)

View File

@ -480,6 +480,70 @@ class TensorNameMap:
"encoder.layer.{bid}.layer_norm_2" # jina-v2-code
),
MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: (
"model.embed_tokens_per_layer", # gemma3n
),
MODEL_TENSOR.PER_LAYER_MODEL_PROJ: (
"model.per_layer_model_projection", # gemma3n
),
MODEL_TENSOR.PER_LAYER_PROJ_NORM: (
"model.per_layer_projection_norm", # gemma3n
),
MODEL_TENSOR.ALTUP_PROJ: (
"model.altup_projections", # gemma3n
),
MODEL_TENSOR.ALTUP_UNEMBD_PROJ: (
"model.altup_unembed_projections", # gemma3n
),
MODEL_TENSOR.PER_LAYER_INP_GATE: (
"model.layers.{bid}.per_layer_input_gate", # gemma3n
),
MODEL_TENSOR.PER_LAYER_PROJ: (
"model.layers.{bid}.per_layer_projection", # gemma3n
),
MODEL_TENSOR.PER_LAYER_POST_NORM: (
"model.layers.{bid}.post_per_layer_input_norm", # gemma3n
),
MODEL_TENSOR.ALTUP_CORRECT_COEF: (
"model.layers.{bid}.altup.correction_coefs", # gemma3n
),
MODEL_TENSOR.ALTUP_CORRECT_SCALE: (
"model.layers.{bid}.altup.correct_output_scale", # gemma3n
),
MODEL_TENSOR.ALTUP_PREDICT_COEF: (
"model.layers.{bid}.altup.prediction_coefs", # gemma3n
),
MODEL_TENSOR.ALTUP_ROUTER: (
"model.layers.{bid}.altup.modality_router", # gemma3n
),
MODEL_TENSOR.ALTUP_ROUTER_NORM: (
"model.layers.{bid}.altup.router_norm", # gemma3n
),
MODEL_TENSOR.LAUREL_L: (
"model.layers.{bid}.laurel.linear_left", # gemma3n
),
MODEL_TENSOR.LAUREL_R: (
"model.layers.{bid}.laurel.linear_right", # gemma3n
),
MODEL_TENSOR.LAUREL_POST_NORM: (
"model.layers.{bid}.laurel.post_laurel_norm", # gemma3n
),
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj",
"backbone.layers.{bid}.mixer.in_proj",

View File

@ -42,6 +42,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA, "gemma" },
{ LLM_ARCH_GEMMA2, "gemma2" },
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_XVERSE, "xverse" },
@ -932,6 +933,42 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
},
},
{
LLM_ARCH_GEMMA3N,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
{ LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
{ LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
{ LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" },
{ LLM_TENSOR_ALTUP_PROJ, "altup_proj" },
{ LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" },
{ LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" },
{ LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" },
{ LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" },
{ LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" },
{ LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" },
{ LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" },
{ LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" },
{ LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" },
{ LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" },
{ LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
},
},
{
LLM_ARCH_STARCODER2,
{
@ -1749,6 +1786,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
// altup / laurel (gemma 3n)
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LAUREL_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
// this tensor is loaded for T5, but never used
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},

View File

@ -46,6 +46,7 @@ enum llm_arch {
LLM_ARCH_GEMMA,
LLM_ARCH_GEMMA2,
LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_XVERSE,
@ -269,6 +270,22 @@ enum llm_tensor {
LLM_TENSOR_LAYER_OUT_NORM,
LLM_TENSOR_POST_ATTN_NORM,
LLM_TENSOR_POST_MLP_NORM,
LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n
LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n
LLM_TENSOR_PER_LAYER_INP_GATE, // gemma3n
LLM_TENSOR_PER_LAYER_PROJ, // gemma3n
LLM_TENSOR_PER_LAYER_PROJ_NORM, // gemma3n
LLM_TENSOR_PER_LAYER_POST_NORM, // gemma3n
LLM_TENSOR_ALTUP_PROJ, // gemma3n
LLM_TENSOR_ALTUP_UNEMBD_PROJ, // gemma3n
LLM_TENSOR_ALTUP_CORRECT_COEF, // gemma3n
LLM_TENSOR_ALTUP_CORRECT_SCALE, // gemma3n
LLM_TENSOR_ALTUP_PREDICT_COEF, // gemma3n
LLM_TENSOR_ALTUP_ROUTER, // gemma3n
LLM_TENSOR_ALTUP_ROUTER_NORM, // gemma3n
LLM_TENSOR_LAUREL_L, // gemma3n
LLM_TENSOR_LAUREL_R, // gemma3n
LLM_TENSOR_LAUREL_POST_NORM, // gemma3n
LLM_TENSOR_SSM_IN,
LLM_TENSOR_SSM_CONV1D,
LLM_TENSOR_SSM_X,

View File

@ -350,6 +350,12 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_one::set_input(const llama_ubatch *) {
GGML_ASSERT(one && ggml_nelements(one) == 1);
float f_one = 1.0f;
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
}
//
// llm_graph_context
//
@ -1267,8 +1273,14 @@ ggml_tensor * llm_graph_context::build_attn(
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
if (k_cur) {
ggml_build_forward_expand(gf, k_cur);
}
if (v_cur) {
ggml_build_forward_expand(gf, v_cur);
}
const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
@ -1276,9 +1288,12 @@ ggml_tensor * llm_graph_context::build_attn(
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
// store to KV cache
{
// optionally store to KV cache
if (k_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
}
if (v_cur) {
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
}

View File

@ -329,6 +329,17 @@ public:
const llama_memory_hybrid_context * mctx;
};
// TODO: remove this when ggml_scale_add is implemented
class llm_graph_input_one : public llm_graph_input_i {
public:
llm_graph_input_one() {}
virtual ~llm_graph_input_one() = default;
void set_input(const llama_ubatch *) override;
ggml_tensor * one = nullptr; // F32
};
//
// llm_graph_result
//
@ -589,14 +600,15 @@ struct llm_graph_context {
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,

View File

@ -143,6 +143,12 @@ struct llama_hparams {
uint32_t n_attn_temp_floor_scale = 8192;
float f_attn_temp_scale = 0.1;
// gemma3n altup
uint32_t n_altup = 4; // altup_num_inputs
uint32_t i_altup_act = 0; // altup_active_idx
uint32_t laurel_rank = 64;
uint32_t n_embd_altup = 256;
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;

View File

@ -33,13 +33,19 @@ llama_kv_cache_unified::llama_kv_cache_unified(
GGML_ASSERT(kv_size % n_pad == 0);
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
auto n_layer_cache = hparams.n_layer;
if (model.arch == LLM_ARCH_GEMMA3N) {
n_layer_cache = 20;
}
// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
/*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()),
/*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
@ -62,7 +68,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
cells.resize(kv_size);
for (uint32_t il = 0; il < hparams.n_layer; il++) {
for (uint32_t il = 0; il < n_layer_cache; il++) {
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
continue;
@ -102,6 +108,26 @@ llama_kv_cache_unified::llama_kv_cache_unified(
layers.push_back({ il, k, v });
}
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
if (model.arch == LLM_ARCH_GEMMA3N) {
LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
continue;
}
const bool is_swa = hparams.is_swa(il);
const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
map_layer_ids[il] = map_layer_ids[il_reuse];
LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
}
}
// allocate tensors and initialize the buffers to avoid NaNs in the padding
for (auto it : ctx_map) {
auto * buft = it.first;

View File

@ -103,6 +103,8 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
case LLM_TYPE_30B_A3B: return "30B.A3B";
case LLM_TYPE_235B_A22B: return "235B.A22B";
case LLM_TYPE_E2B: return "E2B";
case LLM_TYPE_E4B: return "E4B";
default: return "?B";
}
}
@ -1017,6 +1019,24 @@ void llama_model::load_hparams(llama_model_loader & ml) {
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
} break;
case LLM_ARCH_GEMMA3N:
{
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.set_swa_pattern(5);
hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;
hparams.f_attention_scale = 1.0f;
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 30: type = LLM_TYPE_E2B; break;
case 35: type = LLM_TYPE_E4B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_STARCODER2:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@ -2950,6 +2970,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_GEMMA3N:
{
const int64_t n_altup = hparams.n_altup;
const int64_t laurel_rank = hparams.laurel_rank;
const int64_t n_embd_altup = hparams.n_embd_altup;
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
// altup & laurel
layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0);
layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0);
layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0);
layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0);
layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0);
layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0);
layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0);
layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0);
layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0);
layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_STARCODER2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -8980,6 +9056,442 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
}
};
struct llm_build_gemma3n_iswa : public llm_graph_context {
const llama_model & model;
ggml_cgraph * gf;
const int64_t n_embd_head;
const int64_t n_embd_altup;
const int64_t n_altup;
const int i_altup_act;
const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
const int n_layer_sparsity = 10; // number of layers using activation sparsity
const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
ggml_tensor * one; // containing single element 1.0f
llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
: llm_graph_context(params),
model(model),
gf(gf),
n_embd_head(model.hparams.n_embd_head_k),
n_embd_altup(model.hparams.n_embd_altup),
n_altup(model.hparams.n_altup),
i_altup_act(model.hparams.i_altup_act) {
ggml_tensor * cur;
ggml_tensor * inpL;
// TODO: remove this when ggml_scale_add is implemented
one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
{
auto inp = std::make_unique<llm_graph_input_one>();
inp->one = one;
res->add_input(std::move(inp));
}
inpL = build_inp_embd(model.tok_embd);
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
if (ubatch.token) {
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
cb(inpL, "inp_scaled", -1);
}
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
// TODO: is causal == true correct? might need some changes
auto * inp_attn = build_attn_inp_kv_unified_iswa();
// inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
// inpL now has only 1 altup, project it to the rest of the altups
// these "added" altups will be concat to the last dim of inpL
{
ggml_tensor * target_magnitude = calc_magnitude(inpL);
ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1]
ggml_tensor * new_magnitude = calc_magnitude(altup_added);
altup_added = ggml_div(ctx0,
ggml_mul(ctx0, altup_added, target_magnitude),
new_magnitude);
inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
cb(inpL, "inp_stacked", -1);
}
// inpL now has shape: [n_embd, n_tokens, n_altup]
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
for (int il = 0; il < n_layer; ++il) {
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
const bool has_kv = (il < n_layer_kv);
const float freq_base_l = model.get_rope_freq_base (cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup]
ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
// predicted value will go through self-attention and laurel
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
cur = active_prediction;
cb(cur, "active_prediction", il);
// norm
cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// laurel
ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
// self-attention
if (has_kv) {
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
cb(Qcur, "Qcur_normed", il);
cb(Kcur, "Kcur_normed", il);
cb(Vcur, "Vcur_normed", il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_pos", il);
cb(Kcur, "Kcur_pos", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
} else {
// no KV layers
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
cb(Qcur, "Qcur_normed", il);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_pos", il);
cur = build_attn(inp_attn, gf,
model.layers[il].wo, NULL,
Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
}
cur = build_norm(cur,
model.layers[il].attn_post_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_post_norm", il);
cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens]
cb(cur, "attn_gated", il);
ggml_tensor * attn_laurel = ggml_scale(ctx0,
ggml_add(ctx0, cur, laurel_out),
1.0f / sqrtf(2.0f)); // [n_embd, n_tokens]
cb(attn_laurel, "attn_laurel", il);
cur = build_norm(attn_laurel,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
// feed-forward network
{
ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur);
ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
if (il < n_layer_sparsity) {
// apply activation sparsity
gate_proj = gaussian_topk(gate_proj);
}
gate_proj = ggml_gelu(ctx0, gate_proj);
cur = ggml_mul(ctx0, up_proj, gate_proj);
cur = build_lora_mm(model.layers[il].ffn_down, cur);
cb(cur, "ffn_out", il);
}
cur = build_norm(cur,
model.layers[il].ffn_post_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "ffn_post_norm", il);
ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens]
cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup]
ggml_tensor * first_prediction; // [n_embd, n_tokens]
{
first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
cb(first_prediction, "first_prediction_gated", il);
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
cb(first_prediction, "first_prediction_scaled", il);
first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens]
first_prediction = build_norm(first_prediction,
model.layers[il].per_layer_post_norm, NULL,
LLM_NORM_RMS, il);
cb(first_prediction, "first_prediction_out", il);
}
// equivalent to python code: corrected_predictions[1:] += first_prediction
{
ggml_tensor * slice_first = view_2d_slice(corrected, 0);
ggml_tensor * slice_rest = ggml_view_3d(ctx0, corrected, n_embd, n_tokens, n_altup - 1,
ggml_row_size(corrected->type, n_embd),
ggml_row_size(corrected->type, n_embd*n_tokens),
n_embd*n_tokens*ggml_element_size(corrected));
ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1]
corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup]
}
cur = corrected; // [n_embd, n_tokens, n_altup]
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL; // [n_embd, n_tokens, n_altup]
// cur now has multiple altup(s), we want to merge them back to 1 altup
{
ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
// do a view to skip the first slice (active altup)
ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1,
ggml_row_size(cur->type, n_embd),
ggml_row_size(cur->type, n_embd*n_tokens),
n_embd*n_tokens*ggml_element_size(cur));
ggml_tensor * altup_unembd = ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1]
ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
altup_unembd = ggml_div(ctx0,
ggml_mul(ctx0, altup_unembd, target_magnitude),
new_magnitude);
cb(altup_unembd, "altup_unembd", -1);
// equivalent to torch.mean(hidden_states, dim=0)
cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
for (int i = 0; i < n_altup - 1; ++i) {
cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
}
cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
cb(cur, "unembd_merged", -1);
}
// cur now has shape: [n_embd, n_tokens]
// TODO: move this to right after the last KV layer
{
// skip computing output for unused tokens
ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
}
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
{
// final logit soft-capping
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
cur = ggml_tanh(ctx0, cur);
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
}
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
ggml_tensor * calc_magnitude(ggml_tensor * x) {
return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
}
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx) {
GGML_ASSERT(idx < (int)x->ne[2]);
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1],
ggml_row_size(x->type, x->ne[0]),
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
}
// equivalent to get_per_layer_inputs() in python code
// output shape: [n_embd_altup, n_layer, n_tokens]
ggml_tensor * get_per_layer_inputs() {
auto inp = std::make_unique<llm_graph_input_embd>();
ggml_tensor * inp_per_layer;
if (ubatch.token) {
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
ggml_set_input(inp->tokens);
res->t_tokens = inp->tokens;
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float)n_embd_altup));
cb(inp_per_layer, "inp_per_layer_selected", -1);
} else {
GGML_ABORT("TODO: support embd input");
}
res->add_input(std::move(inp));
return inp_per_layer;
}
// equivalent to project_per_layer_inputs() in python code
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
// output shape: [n_embd_altup, n_tokens, n_layer]
ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
const float per_layer_projection_scale = 1.0f / sqrtf((float)n_embd);
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
per_layer_proj = build_norm(per_layer_proj,
model.per_layer_proj_norm, NULL,
LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens]
cb(per_layer_proj, "per_layer_proj", -1);
inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj);
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
cb(inp_per_layer, "inp_per_layer", -1);
// permute to shape: [n_embd_altup, n_tokens, n_layer]
inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
return inp_per_layer;
}
// input cur shape: [n_altup, n_tokens]
// output shape: [n_altup, n_tokens]
ggml_tensor * laurel(ggml_tensor * cur, int il) {
ggml_tensor * tmp = cur;
tmp = build_lora_mm(model.layers[il].laurel_l, tmp);
tmp = build_lora_mm(model.layers[il].laurel_r, tmp);
tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
tmp = ggml_add(ctx0, tmp, cur);
cb(tmp, "laurel_out", il);
return tmp;
}
// input x shape: [n_embd, n_tokens]
// output shape: [n_embd, n_tokens]
ggml_tensor * gaussian_topk(ggml_tensor * x) {
ggml_tensor * mean = ggml_mean(ctx0, x);
ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0,
ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
1.0f / (float)(x->ne[0] - 1)
));
ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
}
//
// altup functions
//
// equivalent to compute_router_modalities() in python code
// input x shape: [n_embd, n_tokens]
// output shape: [n_altup, n_tokens]
ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il) {
ggml_tensor * router_inputs = build_norm(x,
model.layers[il].altup_router_norm, NULL,
LLM_NORM_RMS, il);
// router_input_scale
router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd);
ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
return ggml_tanh(ctx0, output); // [n_altup, n_tokens]
}
// input cur shape: [n_embd, n_tokens, n_altup]
// output shape: [n_embd, n_tokens, n_altup]
ggml_tensor * altup_predict(ggml_tensor * cur, int il) {
ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
cb(modalities, "modalities", il);
ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
cb(all_coefs, "all_coefs", il);
// first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
// permute to [n_altup, n_embd, n_tokens]
ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens]
// final shape must be the same as cur: [n_embd, n_tokens, n_altup]
predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
predictions = ggml_add(ctx0, predictions, cur);
cb(predictions, "predictions", il);
return predictions;
}
// input predictions shape: [n_embd, n_tokens, n_altup]
// input activated shape: [n_embd, n_tokens]
// output shape: [n_embd, n_tokens, n_altup]
ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
cb(modalities, "modalities", il);
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
cb(innovation, "innovation", il);
ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens]
all_coefs = ggml_add(ctx0, all_coefs, one);
cb(all_coefs, "all_coefs", il);
all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup]
all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup]
innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup]
corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup]
cb(corrected, "corrected", il);
return corrected;
}
};
// TODO: move up next to build_starcoder
struct llm_build_starcoder2 : public llm_graph_context {
llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
@ -13974,6 +14486,10 @@ llm_graph_result_ptr llama_model::build_graph(
{
llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
} break;
case LLM_ARCH_GEMMA3N:
{
llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params, gf);
} break;
case LLM_ARCH_STARCODER2:
{
llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
@ -14295,6 +14811,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GEMMA:
case LLM_ARCH_GEMMA2:
case LLM_ARCH_GEMMA3:
case LLM_ARCH_GEMMA3N:
case LLM_ARCH_STARCODER2:
case LLM_ARCH_OPENELM:
case LLM_ARCH_GPTNEOX:

View File

@ -95,6 +95,8 @@ enum llm_type {
LLM_TYPE_17B_128E, // llama4 Maverick
LLM_TYPE_30B_A3B,
LLM_TYPE_235B_A22B,
LLM_TYPE_E2B,
LLM_TYPE_E4B,
};
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
@ -316,6 +318,19 @@ struct llama_layer {
struct ggml_tensor * ffn_up_scale = nullptr;
struct ggml_tensor * ffn_down_scale = nullptr;
// altup & laurel
struct ggml_tensor * per_layer_inp_gate = nullptr;
struct ggml_tensor * per_layer_proj = nullptr;
struct ggml_tensor * per_layer_post_norm = nullptr;
struct ggml_tensor * altup_correct_coef = nullptr;
struct ggml_tensor * altup_correct_scale = nullptr;
struct ggml_tensor * altup_predict_coef = nullptr;
struct ggml_tensor * altup_router = nullptr;
struct ggml_tensor * altup_router_norm = nullptr;
struct ggml_tensor * laurel_l = nullptr;
struct ggml_tensor * laurel_r = nullptr;
struct ggml_tensor * laurel_post_norm = nullptr;
struct llama_layer_posnet posnet;
struct llama_layer_convnext convnext;
@ -354,6 +369,13 @@ struct llama_model {
struct ggml_tensor * conv1d = nullptr;
struct ggml_tensor * conv1d_b = nullptr;
// gemma3n altup
struct ggml_tensor * tok_embd_per_layer = nullptr;
struct ggml_tensor * altup_proj = nullptr;
struct ggml_tensor * altup_unembd_proj = nullptr;
struct ggml_tensor * per_layer_model_proj = nullptr;
struct ggml_tensor * per_layer_proj_norm = nullptr;
std::vector<llama_layer> layers;
llama_model_params params;

View File

@ -223,7 +223,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
new_type = GGML_TYPE_Q6_K;
}
}
} else if (name == "token_embd.weight") {
} else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
new_type = qs.params->token_embedding_type;
} else {
@ -830,6 +830,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
// NOTE: can't use LLM_TN here because the layer number is not known
quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
// these are very small (e.g. 4x4)
quantize &= name.find("altup") == std::string::npos;
quantize &= name.find("laurel") == std::string::npos;
// these are not too big so keep them as it is
quantize &= name.find("per_layer_model_proj") == std::string::npos;
// do not quantize positional embeddings and token types (BERT)
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");