convert : avoid AutoConfig for Mamba and Mamba2 hparams

This commit is contained in:
Francis Couture-Harpin
2025-05-02 18:24:55 -04:00
parent 929fe85db3
commit d55b0d0621

View File

@ -4127,6 +4127,14 @@ class ARwkv7Model(Rwkv7Model):
class MambaModel(TextModel):
model_arch = gguf.MODEL_ARCH.MAMBA
def __init__(self, dir_model: Path, *args, **kwargs):
# Avoid using AutoConfig for hparams
hparams = kwargs.pop("hparams", None)
if hparams is None:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 8
@ -4205,6 +4213,15 @@ class MambaModel(TextModel):
class Mamba2Model(TextModel):
model_arch = gguf.MODEL_ARCH.MAMBA2
def __init__(self, dir_model: Path, *args, **kwargs):
# Avoid using AutoConfig for hparams
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
hparams = kwargs.pop("hparams", None)
if hparams is None:
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 16
@ -5968,12 +5985,20 @@ def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
text_config = hparams.get("text_config", {})
vision_config = hparams.get("vision_config", {})
arch = hparams["architectures"][0]
arch = None
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
arch = arches[0]
elif "ssm_cfg" in hparams:
# For non-hf Mamba and Mamba2 models
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
# if "architectures" is found in the sub-config, use that instead
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
arch = text_config["architectures"][0]
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
arch = vision_config["architectures"][0]
if arch is None:
raise ValueError("Failed to detect model architecture")
return arch