more cleaning on python code

This commit is contained in:
younesbelkada
2025-07-03 18:09:30 +04:00
parent fdd5cff4ba
commit 14c37ec047
5 changed files with 204 additions and 0 deletions

View File

@ -6535,6 +6535,144 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
super().set_gguf_parameters()
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
@ModelBase.register("FalconH1ForCausalLM")
class FalconH1Model(Mamba2Model):
model_arch = gguf.MODEL_ARCH.FALCON_H1
def __init__(self, *args, **kwargs):
# Set the hparam prefixes for Falcon Mamba2
self.hparam_prefixes = ["mamba"]
# Initialize the base Mamba2Model
super().__init__(*args, **kwargs)
# Use Llama conversion for attention
self._transformer_model_class = LlamaModel
# n_group and d_inner are used during reshape_tensors for mamaba2
self.d_model = self.find_hparam(["hidden_size", "d_model"])
self.n_group = self.find_hparam(["n_groups"])
self.d_inner = self.find_hparam(["expand"]) * self.d_model
# Initialize any Falcon Mamba2 specific attributes
self.has_attention = True # Falcon Mamba2 has attention components
# Load Falcon-H1 multipliers from hyperparameters
self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True)
self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True)
self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True)
self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True)
self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True)
self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True)
self.intermediate_size = self.find_hparam(["intermediate_size"])
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
prefixed = []
for pfx in self.hparam_prefixes:
prefixed.extend(
"_".join([pfx, k])
for k in keys
)
keys = list(keys) + prefixed
return super().find_hparam(keys, *args, **kwargs)
def _generate_mup_vector(self, block_id: int) -> torch.Tensor:
zxbcdt_multipliers = self.hparams["ssm_multipliers"]
intermediate_size = self.hparams["mamba_d_ssm"]
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
vector_shape = (2 * intermediate_size + 2 * groups_time_state_size + self.hparams["mamba_n_heads"])
mup_vector = torch.ones(1, 1, vector_shape)
mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0]
mup_vector[:, :, intermediate_size:2 * intermediate_size] *= zxbcdt_multipliers[1]
mup_vector[:, :, 2 * intermediate_size:2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2]
mup_vector[:, :, 2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size] *= zxbcdt_multipliers[3]
mup_vector[:, :, 2 * intermediate_size + 2 * groups_time_state_size:] *= zxbcdt_multipliers[4]
return mup_vector
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for name, tensor in super().get_tensors():
if name.startswith("model.backbone") or name.startswith("model.lm_head"):
name = name.removeprefix("model.")
yield name, tensor
if self.ssm_multipliers is not None:
# Insert MUP vector after mamba.dt_bias
if "mamba.dt_bias" in name:
block_match = re.search(r"(?:model\.layers\.)?(\d+)\.mamba\.dt_bias", name)
if block_match:
block_id = int(block_match.group(1))
# Generate MUP vector with correct name format
mup_tensor = self._generate_mup_vector(block_id)
mup_name = f"blk.{block_id}.ssm_mup_vec"
logger.debug(f"Inserting MUP vector for block {block_id}: {mup_name}")
yield mup_name, mup_tensor
def set_gguf_parameters(self):
super().set_gguf_parameters()
## General Params ##
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
## Mamba mixer params ##
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
self.gguf_writer.add_ssm_group_count(self.n_group)
self.gguf_writer.add_ssm_inner_size(self.d_inner)
self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"]))
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
## Attention params ##
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in self.hparams else self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_key_length(self.hparams["head_dim"])
self.gguf_writer.add_value_length(self.hparams["head_dim"])
self.gguf_writer.add_float32("falcon-h1.key_multiplier", self.hparams["key_multiplier"])
## Other params
self.gguf_writer.add_float32("falcon-h1.lm_head_multiplier", self.hparams["lm_head_multiplier"])
self.gguf_writer.add_float32("falcon-h1.embedding_multiplier", self.hparams["embedding_multiplier"])
## Validation ##
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
# Add Falcon Mamba2 specific configuration
self.gguf_writer.add_uint32("falcon-h1.ssm.mamba_chunk_size", self.hparams["mamba_chunk_size"])
self.gguf_writer.add_uint32("falcon-h1.attention.head_dim", self.hparams["head_dim"])
self.gguf_writer.add_uint32("falcon-h1.ssm.mamba_d_ssm", self.hparams["mamba_d_ssm"])
self.gguf_writer.add_uint32("falcon-h1.num_attention_heads", self.find_hparam(["num_attention_heads"]))
self.gguf_writer.add_uint32("falcon-h1.num_key_value_heads",
self.find_hparam(["num_key_value_heads"], optional=True) or
self.find_hparam(["num_attention_heads"]))
# Add multipliers as metadata instead of tensors
self.gguf_writer.add_float32("falcon-h1.attention_in_multiplier", self.attention_in_multiplier)
self.gguf_writer.add_float32("falcon-h1.attention_out_multiplier", self.attention_out_multiplier)
self.gguf_writer.add_float32("falcon-h1.ssm_in_multiplier", self.ssm_in_multiplier)
self.gguf_writer.add_float32("falcon-h1.ssm_out_multiplier", self.ssm_out_multiplier)
# Add MLP multipliers
if isinstance(self.mlp_multipliers, (list, tuple)) and len(self.mlp_multipliers) == 2:
self.gguf_writer.add_float32("falcon-h1.mlp_gate_multiplier", self.mlp_multipliers[0])
self.gguf_writer.add_float32("falcon-h1.mlp_down_multiplier", self.mlp_multipliers[1])
# Add has MuP flag if SSM multipliers are present
if self.ssm_multipliers is not None:
self.gguf_writer.add_bool("falcon-h1.ssm.has_mup", True)
# Add any other Falcon Mamba2 specific configuration
self.gguf_writer.add_bool("falcon-h1.mamba_use_mlp", self.find_hparam(["mamba_use_mlp"], optional=True))
self.gguf_writer.add_bool("falcon-h1.mamba_norm_before_gate", self.find_hparam(["mamba_norm_before_gate"], optional=True))
self.gguf_writer.add_bool("falcon-h1.mamba_rms_norm", self.find_hparam(["mamba_rms_norm"], optional=True))
self.gguf_writer.add_float32("falcon-h1.rope_theta", self.find_hparam(["rope_theta"], optional=True))
###### CONVERSION LOGIC ######