mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-10 13:30:27 +00:00
Merge branch 'master' into compilade/mamba2
This commit is contained in:
@ -310,6 +310,8 @@ class ModelBase:
|
||||
gguf.MODEL_TENSOR.POSNET_NORM2,
|
||||
gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
|
||||
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
|
||||
gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
|
||||
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
|
||||
)
|
||||
)
|
||||
or not new_name.endswith(".weight")
|
||||
@ -320,7 +322,11 @@ class ModelBase:
|
||||
self.match_model_tensor_name(new_name, key, bid)
|
||||
for key in (
|
||||
gguf.MODEL_TENSOR.TOKEN_EMBD,
|
||||
gguf.MODEL_TENSOR.PER_LAYER_TOKEN_EMBD,
|
||||
gguf.MODEL_TENSOR.OUTPUT,
|
||||
gguf.MODEL_TENSOR.ALTUP_ROUTER,
|
||||
gguf.MODEL_TENSOR.LAUREL_L,
|
||||
gguf.MODEL_TENSOR.LAUREL_R,
|
||||
)
|
||||
):
|
||||
if self.ftype in (
|
||||
@ -921,13 +927,20 @@ class TextModel(ModelBase):
|
||||
tokenizer = SentencePieceProcessor()
|
||||
tokenizer.LoadFromFile(str(tokenizer_path))
|
||||
|
||||
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
|
||||
vocab_size = self.find_hparam([
|
||||
"vocab_size_per_layer_input", # gemma3n
|
||||
"vocab_size",
|
||||
], optional=True) or tokenizer.vocab_size()
|
||||
|
||||
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
||||
scores: list[float] = [-10000.0] * vocab_size
|
||||
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
|
||||
|
||||
for token_id in range(tokenizer.vocab_size()):
|
||||
if token_id >= vocab_size:
|
||||
logger.warning(f'ignore tokens from {token_id}: id is out of range, max={vocab_size - 1}')
|
||||
break
|
||||
|
||||
piece = tokenizer.IdToPiece(token_id)
|
||||
text = piece.encode("utf-8")
|
||||
score = tokenizer.GetScore(token_id)
|
||||
@ -2730,6 +2743,52 @@ class Qwen2Model(TextModel):
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Ernie4_5_ForCausalLM")
|
||||
class Ernie4_5Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.ERNIE4_5
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
num_heads = self.hparams["num_attention_heads"]
|
||||
num_kv_heads = self.hparams["num_key_value_heads"]
|
||||
head_dim = self.hparams["head_dim"]
|
||||
|
||||
if "ernie." in name:
|
||||
name = name.replace("ernie.", "model.")
|
||||
# split the qkv weights
|
||||
# qkv_proj shape: [(num_heads + 2 * num_kv_heads) * head_dim, hidden_size]
|
||||
if "qkv_proj" in name:
|
||||
name_q = name.replace("qkv_proj.weight", "q_proj.weight")
|
||||
name_k = name.replace("qkv_proj.weight", "k_proj.weight")
|
||||
name_v = name.replace("qkv_proj.weight", "v_proj.weight")
|
||||
total_q_dim = num_heads * head_dim
|
||||
total_k_dim = num_kv_heads * head_dim
|
||||
total_v_dim = num_kv_heads * head_dim
|
||||
q_proj_weight, k_proj_weight, v_proj_weight = data_torch.split([total_q_dim, total_k_dim, total_v_dim], dim=0)
|
||||
return [
|
||||
(self.map_tensor_name(name_q), q_proj_weight),
|
||||
(self.map_tensor_name(name_k), k_proj_weight),
|
||||
(self.map_tensor_name(name_v), v_proj_weight)
|
||||
]
|
||||
# split the up_gate_proj into gate and up
|
||||
# up_gate_proj shape: [2 * intermediate_size, hidden_size]
|
||||
if "up_gate_proj" in name:
|
||||
name_up = name.replace("up_gate_proj.weight", "up_proj.weight")
|
||||
name_gate = name.replace("up_gate_proj.weight", "gate_proj.weight")
|
||||
dim_half = data_torch.shape[0] // 2
|
||||
gate_proj_weight, up_proj_weight = data_torch.split(dim_half, dim=0)
|
||||
return [
|
||||
(self.map_tensor_name(name_gate), gate_proj_weight),
|
||||
(self.map_tensor_name(name_up), up_proj_weight)
|
||||
]
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@ModelBase.register(
|
||||
"Qwen2VLModel",
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
@ -4217,6 +4276,7 @@ class Gemma2Model(TextModel):
|
||||
@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
|
||||
class Gemma3Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA3
|
||||
norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
@ -4238,9 +4298,8 @@ class Gemma3Model(TextModel):
|
||||
self.gguf_writer.add_value_length(hparams.get("head_dim", 256))
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers
|
||||
# both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
|
||||
# attn_logit_softcapping is removed in Gemma3
|
||||
assert hparams.get("attn_logit_softcapping") is None
|
||||
assert hparams.get("final_logit_softcapping") is None
|
||||
self.gguf_writer.add_sliding_window(hparams["sliding_window"])
|
||||
self.gguf_writer.add_head_count_kv(hparams.get("num_key_value_heads", 4))
|
||||
if hparams.get("rope_scaling") is not None:
|
||||
@ -4252,7 +4311,7 @@ class Gemma3Model(TextModel):
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
del bid # unused
|
||||
|
||||
if name.startswith("language_model."):
|
||||
if "language_model." in name:
|
||||
name = name.replace("language_model.", "")
|
||||
|
||||
elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
|
||||
@ -4267,8 +4326,9 @@ class Gemma3Model(TextModel):
|
||||
|
||||
# ref code in Gemma3RMSNorm
|
||||
# output = output * (1.0 + self.weight.float())
|
||||
# note: this is not the case on gemma3n
|
||||
if name.endswith("norm.weight"):
|
||||
data_torch = data_torch + 1
|
||||
data_torch = data_torch + self.norm_shift
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
@ -4325,6 +4385,104 @@ class Gemma3VisionModel(MmprojModel):
|
||||
return [] # skip other tensors
|
||||
|
||||
|
||||
@ModelBase.register("Gemma3nForConditionalGeneration")
|
||||
class Gemma3NModel(Gemma3Model):
|
||||
model_arch = gguf.MODEL_ARCH.GEMMA3N
|
||||
norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
|
||||
|
||||
_altup_proj: list[Tensor] = []
|
||||
_altup_unembd: list[Tensor] = []
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.hparams["altup_num_inputs"] == 4, "Current conversion only supports 4 altup inputs"
|
||||
self._altup_proj = [
|
||||
torch.Tensor(), # to be replaced
|
||||
torch.Tensor(), # to be replaced
|
||||
torch.Tensor(), # to be replaced
|
||||
]
|
||||
self._altup_unembd = [
|
||||
torch.Tensor(), # to be replaced
|
||||
torch.Tensor(), # to be replaced
|
||||
torch.Tensor(), # to be replaced
|
||||
]
|
||||
|
||||
def set_vocab(self):
|
||||
with open(self.dir_model / "chat_template.jinja") as f:
|
||||
# quick hack to make sure chat template is added
|
||||
self.gguf_writer.add_chat_template(f.read())
|
||||
super().set_vocab()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
self.gguf_writer.add_altup_active_idx(self.hparams["altup_active_idx"])
|
||||
self.gguf_writer.add_altup_num_inputs(self.hparams["altup_num_inputs"])
|
||||
self.gguf_writer.add_embedding_length_per_layer_input(self.hparams["hidden_size_per_layer_input"])
|
||||
self.gguf_writer.add_shared_kv_layers(self.hparams["num_kv_shared_layers"])
|
||||
|
||||
activation_sparsity_scale = []
|
||||
for s in self.hparams["activation_sparsity_pattern"]:
|
||||
normal_dist = torch.distributions.normal.Normal(0, 1)
|
||||
std_multiplier = normal_dist.icdf(torch.tensor(s, dtype=torch.float32))
|
||||
activation_sparsity_scale.append(std_multiplier.item())
|
||||
self.gguf_writer.add_activation_sparsity_scale(activation_sparsity_scale)
|
||||
|
||||
sliding_window_pattern = []
|
||||
for t in self.hparams["layer_types"]:
|
||||
sliding_window_pattern.append(t == "sliding_attention")
|
||||
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
|
||||
|
||||
def _stack_matrices(self, matrices: list[Tensor]) -> Tensor | None:
|
||||
has_all = all(m.numel() > 0 for m in matrices)
|
||||
if not has_all:
|
||||
return None
|
||||
else:
|
||||
return torch.stack(matrices, dim=0)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.endswith("_scale"):
|
||||
name = name + ".weight"
|
||||
|
||||
# TODO: implement self.prediction_coefs.weight.clamp_(...)
|
||||
|
||||
if "language_model." not in name:
|
||||
return [] # skip non-language model tensors
|
||||
|
||||
if "altup_unembed_projections" in name:
|
||||
data_torch = data_torch.to(device="cpu")
|
||||
if ".0." in name:
|
||||
self._altup_unembd[0] = data_torch
|
||||
elif ".1." in name:
|
||||
self._altup_unembd[1] = data_torch
|
||||
elif ".2." in name:
|
||||
self._altup_unembd[2] = data_torch
|
||||
else:
|
||||
raise ValueError(f"Unknown name: {name}")
|
||||
out = self._stack_matrices(self._altup_unembd)
|
||||
if out is not None:
|
||||
return [(self.map_tensor_name("model.altup_unembed_projections.weight"), out)]
|
||||
else:
|
||||
return []
|
||||
|
||||
if "altup_projections" in name:
|
||||
data_torch = data_torch.to(device="cpu")
|
||||
if ".0." in name:
|
||||
self._altup_proj[0] = data_torch
|
||||
elif ".1." in name:
|
||||
self._altup_proj[1] = data_torch
|
||||
elif ".2." in name:
|
||||
self._altup_proj[2] = data_torch
|
||||
else:
|
||||
raise ValueError(f"Unknown name: {name}")
|
||||
out = self._stack_matrices(self._altup_proj)
|
||||
if out is not None:
|
||||
return [(self.map_tensor_name("model.altup_projections.weight"), out)]
|
||||
else:
|
||||
return []
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("Starcoder2ForCausalLM")
|
||||
class StarCoder2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.STARCODER2
|
||||
|
Reference in New Issue
Block a user