mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-02 05:15:47 +00:00
convert : avoid AutoConfig for Mamba and Mamba2 hparams
This commit is contained in:
@ -4127,6 +4127,14 @@ class ARwkv7Model(Rwkv7Model):
|
||||
class MambaModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.MAMBA
|
||||
|
||||
def __init__(self, dir_model: Path, *args, **kwargs):
|
||||
# Avoid using AutoConfig for hparams
|
||||
hparams = kwargs.pop("hparams", None)
|
||||
if hparams is None:
|
||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
hparams = json.load(f)
|
||||
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
|
||||
|
||||
def set_vocab(self):
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
# Round vocab size to next multiple of 8
|
||||
@ -4205,6 +4213,15 @@ class MambaModel(TextModel):
|
||||
class Mamba2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.MAMBA2
|
||||
|
||||
def __init__(self, dir_model: Path, *args, **kwargs):
|
||||
# Avoid using AutoConfig for hparams
|
||||
# It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
|
||||
hparams = kwargs.pop("hparams", None)
|
||||
if hparams is None:
|
||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
hparams = json.load(f)
|
||||
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
|
||||
|
||||
def set_vocab(self):
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
# Round vocab size to next multiple of 16
|
||||
@ -5968,12 +5985,20 @@ def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any
|
||||
hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams
|
||||
text_config = hparams.get("text_config", {})
|
||||
vision_config = hparams.get("vision_config", {})
|
||||
arch = hparams["architectures"][0]
|
||||
arch = None
|
||||
if (arches := hparams.get("architectures")) is not None and len(arches) > 0:
|
||||
arch = arches[0]
|
||||
elif "ssm_cfg" in hparams:
|
||||
# For non-hf Mamba and Mamba2 models
|
||||
arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM"
|
||||
|
||||
# if "architectures" is found in the sub-config, use that instead
|
||||
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
|
||||
arch = text_config["architectures"][0]
|
||||
elif model_type == ModelType.VISION and vision_config.get("architectures") is not None:
|
||||
arch = vision_config["architectures"][0]
|
||||
if arch is None:
|
||||
raise ValueError("Failed to detect model architecture")
|
||||
return arch
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user