mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-16 23:57:35 +00:00
llama : support Jamba hybrid Transformer-Mamba models (#7531)
* wip: llama : separate recurrent states from the KV cache This will be necessary to support Jamba (and other recurrent models mixed with Attention). Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states. * llama : use std::find for seq_nodes in llama_rs_cache * llama : state checkpoints for recurrent models * llama : correctly handle more edge cases for the rs cache * llama : rename many llama_kv_cache_* functions * llama : remove useless return value for some llama_cache_* functions * llama : rethink recurrent state cell counts * llama : begin work on support for variable GQA This will also be useful for Jamba if we consider the Mamba layers to have 0 KV heads. * llama : gracefully fail when not finding hybrid slot * llama : support Jamba * llama : fix BERT inference without KV cache * convert-hf : check for unprocessed Jamba experts * convert-hf : support Mini-Jamba conversion * llama : fix Jamba quantization sanity checks * llama : sequence-length-aware batch splitting * llama : use equal-sequence-length sub-batches for recurrent models * ggml : simplify SSM-related operators * llama : make recurrent state slot allocation contiguous * llama : adapt internal uses of batches to llama_ubatch * llama : fix batch split output count for embeddings * llama : minimize swaps when reordering logits This reduces overhead when running hellaswag on thousands of sequences with very small 100k params Mamba models. * llama : fix edge case finding batch seq_id of split recurrent cell This otherwise was a problem when running the HellaSwag benchmark with small batch sizes, making it crash. * llama : avoid copies for simple batch splits * ggml : make ggml_ssm_scan not modify its source tensors * llama : fix shared recurrent tail cell count for small ubatch sizes Otherwise it was impossible to run the 'parallel' example with '-ub 1' with a Mamba or Jamba model. * llama : fix .base() compilation error on Windows * llama : allow doing the equivalent of SSM_CONV with SUM_ROWS and MUL * ggml : allow GGML_OP_CONCAT to work on non-contiguous tensors The implementation already supported it, and this makes Mamba's conv step slightly faster. * mamba : fix non-contiguous usage of ggml_silu * llama : session saving and reloading for hybrid models * convert_hf : fix Jamba conversion * llama : fix mixed signedness comparison * llama : use unused n_embd_k_gqa in k_shift This also slightly reduces the diff from the master branch * llama : begin renaming llama_past back to llama_kv_cache * llama : remove implicit recurrent state rollbacks * llama : partially apply clang-format style * convert : fix jamba conv1d shape squeezing * graph : add back hybrid memory graph input But this time it contains the sub-cache graph inputs. This *should* make it easier to handle updating the inputs when caching the graph (eventually). * model : add Jamba to Mamba-specific hparams printing * jamba : remove redundant nullptr initializations * model : remove unnecessary prefix for tensor loading constants Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * model : use ggml_swiglu_split for Mamba Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * model : make falcon-h1 use shared mamba2 layer builder * memory : avoid referring to KV in recurrent cache logs * gguf-py : avoid adding duplicate tensor mappings for Jamba Some of the tensor names are common with Llama4 --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
@ -4974,6 +4974,123 @@ class Mamba2Model(TextModel):
|
||||
yield (new_name, data_torch)
|
||||
|
||||
|
||||
@ModelBase.register("JambaForCausalLM")
|
||||
class JambaModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.JAMBA
|
||||
|
||||
def get_vocab_base_pre(self, tokenizer) -> str:
|
||||
del tokenizer # unused
|
||||
|
||||
return "gpt-2"
|
||||
|
||||
def set_vocab(self):
|
||||
if (self.dir_model / "tokenizer.model").is_file():
|
||||
# Using Jamba's tokenizer.json causes errors on model load
|
||||
# (something about "byte not found in vocab"),
|
||||
# but there's a working tokenizer.model
|
||||
self._set_vocab_sentencepiece()
|
||||
else:
|
||||
# Some Jamba models only have a tokenizer.json, which works.
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
d_model = self.find_hparam(["hidden_size", "mamba_d_model"])
|
||||
d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4
|
||||
d_inner = self.hparams["mamba_expand"] * d_model
|
||||
d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16
|
||||
# ceiling division
|
||||
# ref: https://stackoverflow.com/a/17511341/22827863
|
||||
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
|
||||
dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16)
|
||||
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6
|
||||
n_kv_head = self.hparams["num_key_value_heads"]
|
||||
attn_offset = self.hparams["attn_layer_offset"]
|
||||
attn_period = self.hparams["attn_layer_period"]
|
||||
n_kv_vec = [0 for _ in range(attn_offset)] + [
|
||||
n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count)
|
||||
]
|
||||
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"]))
|
||||
self.gguf_writer.add_embedding_length(d_model)
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(n_kv_vec)
|
||||
self.gguf_writer.add_ssm_conv_kernel(d_conv)
|
||||
self.gguf_writer.add_ssm_inner_size(d_inner)
|
||||
self.gguf_writer.add_ssm_state_size(d_state)
|
||||
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
|
||||
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
|
||||
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
|
||||
self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
|
||||
# Mini-Jamba
|
||||
name = name.replace(".moe.", ".feed_forward.")
|
||||
if bid is not None:
|
||||
moe_offset = self.hparams["expert_layer_offset"]
|
||||
moe_period = self.hparams["expert_layer_period"]
|
||||
|
||||
if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0):
|
||||
name = name.replace(".experts.0.", ".")
|
||||
|
||||
# process the experts separately
|
||||
if ".feed_forward.experts." in name:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for wid in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
# using the same merged name as qwen2moe
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
yield new_name, data_torch
|
||||
return
|
||||
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid):
|
||||
data_torch = data_torch.squeeze()
|
||||
|
||||
if name.endswith(".A_log"):
|
||||
logger.debug("A_log --> A ==> " + new_name)
|
||||
data_torch = -torch.exp(data_torch)
|
||||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("CohereForCausalLM")
|
||||
class CommandR2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.COMMAND_R
|
||||
|
Reference in New Issue
Block a user