model : Nomic Embed Text V2 with Mixture-of-Experts (MoE) architecture (#12466)

* Nomic Embed Text V2 with Mixture-of-Experts (MoE) architecture

- Adds MoE-based embedding model supporting multilingual embeddings.
- Selects architecture variant based on hyperparameter detection (MoE layers).
- Removes unnecessary subclass initialization checks for clarity.

https://www.nomic.ai/blog/posts/nomic-embed-text-v2

Co-authored-by: Jared Van Bortel <jared@nomic.ai>

* fix tokenizer

* don't rename this tensor

---------

Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
AT
2025-04-28 15:52:15 -04:00
committed by GitHub
parent eaea325324
commit 5f5e39e1ba
9 changed files with 247 additions and 110 deletions

View File

@ -78,7 +78,7 @@ class ModelBase:
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
use_temp_file: bool = False, eager: bool = False,
metadata_override: Path | None = None, model_name: str | None = None,
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
@ -454,13 +454,6 @@ class ModelBase:
class TextModel(ModelBase):
@classmethod
def __init_subclass__(cls):
# can't use an abstract property, because overriding it without type errors
# would require using decorated functions instead of simply defining the property
if "model_arch" not in cls.__dict__:
raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
def set_vocab(self):
self._set_vocab_gpt2()
@ -3373,14 +3366,7 @@ class BertModel(TextModel):
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("RobertaModel")
class RobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _xlmroberta_tokenizer_init(self) -> None:
# we need the pad_token_id to know how to chop down position_embd matrix
if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
self._position_offset = 1 + pad_token_id
@ -3389,82 +3375,7 @@ class RobertaModel(BertModel):
else:
self._position_offset = None
def set_vocab(self):
"""Support BPE tokenizers for roberta models"""
bpe_tok_path = self.dir_model / "tokenizer.json"
if bpe_tok_path.exists():
self._set_vocab_gpt2()
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)
# we need this to validate the size of the token_type embeddings
# though currently we are passing all zeros to the token_type embeddings
# "Sequence A" or "Sequence B"
self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
else:
return super().set_vocab()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# if name starts with "roberta.", remove the prefix
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
if name.startswith("roberta."):
name = name[8:]
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
if name == "embeddings.position_embeddings.weight":
if self._position_offset is not None:
data_torch = data_torch[self._position_offset:,:]
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("NomicBertModel")
class NomicBertModel(BertModel):
model_arch = gguf.MODEL_ARCH.NOMIC_BERT
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# the HF config claims n_ctx=8192, but it uses RoPE scaling
self.hparams["n_ctx"] = 2048
# SwigLU activation
assert self.hparams["activation_function"] == "swiglu"
# this doesn't do anything in the HF version
assert self.hparams["causal"] is False
# no bias tensors
assert self.hparams["qkv_proj_bias"] is False
assert self.hparams["mlp_fc1_bias"] is False
assert self.hparams["mlp_fc2_bias"] is False
# norm at end of layer
assert self.hparams["prenorm"] is False
# standard RoPE
assert self.hparams["rotary_emb_fraction"] == 1.0
assert self.hparams["rotary_emb_interleaved"] is False
assert self.hparams["rotary_emb_scale_base"] is None
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
class XLMRobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# we need the pad_token_id to know how to chop down position_embd matrix
if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
self._position_offset = 1 + pad_token_id
if "max_position_embeddings" in self.hparams:
self.hparams["max_position_embeddings"] -= self._position_offset
else:
self._position_offset = None
def set_vocab(self):
def _xlmroberta_set_vocab(self) -> None:
# to avoid TypeError: Descriptors cannot be created directly
# exception when importing sentencepiece_model_pb2
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
@ -3546,6 +3457,138 @@ class XLMRobertaModel(BertModel):
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)
@ModelBase.register("RobertaModel")
class RobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# we need the pad_token_id to know how to chop down position_embd matrix
if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
self._position_offset = 1 + pad_token_id
if "max_position_embeddings" in self.hparams:
self.hparams["max_position_embeddings"] -= self._position_offset
else:
self._position_offset = None
def set_vocab(self):
"""Support BPE tokenizers for roberta models"""
bpe_tok_path = self.dir_model / "tokenizer.json"
if bpe_tok_path.exists():
self._set_vocab_gpt2()
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)
# we need this to validate the size of the token_type embeddings
# though currently we are passing all zeros to the token_type embeddings
# "Sequence A" or "Sequence B"
self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
else:
return super().set_vocab()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# if name starts with "roberta.", remove the prefix
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
if name.startswith("roberta."):
name = name[8:]
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
if name == "embeddings.position_embeddings.weight":
if self._position_offset is not None:
data_torch = data_torch[self._position_offset:,:]
return super().modify_tensors(data_torch, name, bid)
@ModelBase.register("NomicBertModel")
class NomicBertModel(BertModel):
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
hparams = kwargs.pop("hparams", None)
if hparams is None:
hparams = ModelBase.load_hparams(dir_model)
self.is_moe = bool(hparams.get("moe_every_n_layers"))
self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE if self.is_moe else gguf.MODEL_ARCH.NOMIC_BERT
super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
self._tokenizer_is_xlmroberta = self._is_tokenizer_xlmroberta()
if self._tokenizer_is_xlmroberta:
self._xlmroberta_tokenizer_init()
# the HF config claims n_ctx=8192, but it uses RoPE scaling
self.hparams["n_ctx"] = 2048
assert self.hparams["activation_function"] == "gelu" if self.is_moe else "swiglu"
# this doesn't do anything in the HF version
assert self.hparams["causal"] is False
# no bias tensors unless MoE
assert self.hparams["qkv_proj_bias"] == self.is_moe
assert self.hparams["mlp_fc1_bias"] == self.is_moe
assert self.hparams["mlp_fc2_bias"] == self.is_moe
# norm at end of layer
assert self.hparams["prenorm"] is False
# standard RoPE
assert self.hparams["rotary_emb_fraction"] == 1.0
assert self.hparams["rotary_emb_interleaved"] is False
assert self.hparams["rotary_emb_scale_base"] is None
def set_vocab(self) -> None:
if self._tokenizer_is_xlmroberta:
return self._xlmroberta_set_vocab()
return super().set_vocab()
def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) -> Iterable[tuple[str, torch.Tensor]]:
# If the tensor is an experts bias tensor, skip it by returning an empty list.
if "mlp.experts.bias" in name:
return [] # Explicitly return an empty list.
if "mlp.experts.mlp.w1" in name:
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
name += ".weight"
if "mlp.experts.mlp.w2" in name:
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
data_torch = data_torch.transpose(1, 2)
name += ".weight"
return [(self.map_tensor_name(name), data_torch)]
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
if self.is_moe:
self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"])
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"])
def _is_tokenizer_xlmroberta(self) -> bool:
with open(self.dir_model / "tokenizer.json") as f:
tokenizer_json = json.load(f)
toktyp = tokenizer_json["model"]["type"]
if toktyp == "Unigram":
return True
if toktyp == "WordPiece":
return False
raise ValueError(f"unknown tokenizer: {toktyp}")
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
class XLMRobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._xlmroberta_tokenizer_init()
def set_vocab(self):
self._xlmroberta_set_vocab()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# if name starts with "roberta.", remove the prefix
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main

View File

@ -104,6 +104,7 @@ class Keys:
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale"
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
@ -267,6 +268,7 @@ class MODEL_ARCH(IntEnum):
REFACT = auto()
BERT = auto()
NOMIC_BERT = auto()
NOMIC_BERT_MOE = auto()
JINA_BERT_V2 = auto()
BLOOM = auto()
STABLELM = auto()
@ -521,6 +523,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.REFACT: "refact",
MODEL_ARCH.BERT: "bert",
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
@ -960,6 +963,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.NOMIC_BERT_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.TOKEN_TYPES,
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.JINA_BERT_V2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,

View File

@ -728,6 +728,9 @@ class GGUFWriter:
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
def add_moe_every_n_layers(self, value: int) -> None:
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
def add_swin_norm(self, value: bool) -> None:
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)

View File

@ -290,6 +290,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
"language_model.model.layers.{bid}.feed_forward.router", # llama4
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@ -322,6 +323,7 @@ class TensorNameMap:
"model.layers.layers.{bid}.mlp.up_proj", # plamo
"model.layers.{bid}.feed_forward.w3", # internlm2
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
"encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe
"model.layers.{bid}.mlp.c_fc", # starcoder2
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w3", # arctic
@ -337,6 +339,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
"language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
),
MODEL_TENSOR.FFN_UP_SHEXP: (
@ -418,6 +421,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
"model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged)
"language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
),
MODEL_TENSOR.FFN_DOWN_SHEXP: (

View File

@ -19,6 +19,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
@ -106,6 +107,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@ -472,6 +474,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_NOMIC_BERT_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_JINA_BERT_V2,
{

View File

@ -23,6 +23,7 @@ enum llm_arch {
LLM_ARCH_REFACT,
LLM_ARCH_BERT,
LLM_ARCH_NOMIC_BERT,
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_JINA_BERT_V2,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
@ -110,6 +111,7 @@ enum llm_kv {
LLM_KV_EXPERT_WEIGHTS_SCALE,
LLM_KV_EXPERT_WEIGHTS_NORM,
LLM_KV_EXPERT_GATING_FUNC,
LLM_KV_MOE_EVERY_N_LAYERS,
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,

View File

@ -925,28 +925,35 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(gate, "ffn_moe_gate", il);
ggml_tensor * experts = nullptr;
if (gate_exps) {
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(cur, "ffn_moe_gate", il);
} else {
cur = up;
}
switch (type_op) {
case LLM_FFN_SILU:
{
gate = ggml_silu(ctx0, gate);
cb(gate, "ffn_moe_silu", il);
cur = ggml_silu(ctx0, cur);
cb(cur, "ffn_moe_silu", il);
} break;
case LLM_FFN_GELU:
{
gate = ggml_gelu(ctx0, gate);
cb(gate, "ffn_moe_gelu", il);
cur = ggml_gelu(ctx0, cur);
cb(cur, "ffn_moe_gelu", il);
} break;
default:
GGML_ABORT("fatal error");
}
ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
cb(par, "ffn_moe_gate_par", il);
if (gate_exps) {
cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
cb(cur, "ffn_moe_gate_par", il);
}
ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
if (!weight_before_ffn) {

View File

@ -66,6 +66,7 @@ struct llama_hparams {
float expert_weights_scale = 0.0;
bool expert_weights_norm = false;
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
uint32_t moe_every_n_layers = 0;
float f_norm_eps;
float f_norm_rms_eps;

View File

@ -695,10 +695,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
}
} break;
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0);
if (hparams.n_layer == 12 && hparams.n_embd == 768) {
type = LLM_TYPE_137M;
@ -2057,6 +2059,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} break;
case LLM_ARCH_BERT:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0);
@ -2090,20 +2093,31 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
}
if (arch == LLM_ARCH_NOMIC_BERT_MOE) {
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
}
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
if (arch == LLM_ARCH_BERT) {
if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) {
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
} else {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
if (arch == LLM_ARCH_BERT || arch == LLM_ARCH_NOMIC_BERT_MOE) {
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0);
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
} else {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
}
}
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
@ -5730,6 +5744,11 @@ struct llm_build_bert : public llm_graph_context {
cur = build_lora_mm(model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
}
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
@ -5782,13 +5801,29 @@ struct llm_build_bert : public llm_graph_context {
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
if (model.arch == LLM_ARCH_BERT) {
if (hparams.moe_every_n_layers > 0 && il % hparams.moe_every_n_layers == 1) {
// MoE branch
cur = build_moe_ffn(cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
nullptr,
model.layers[il].ffn_down_exps,
nullptr,
hparams.n_expert,
hparams.n_expert_used,
LLM_FFN_GELU,
false, false,
0.0f,
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
cb(cur, "ffn_moe_out", il);
} else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
NULL, NULL, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_GELU, LLM_FFN_SEQ, il);
cb(cur, "ffn_out", il);
} else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
@ -5796,6 +5831,7 @@ struct llm_build_bert : public llm_graph_context {
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_GELU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
} else {
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
@ -5803,8 +5839,8 @@ struct llm_build_bert : public llm_graph_context {
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
}
cb(cur, "ffn_out", il);
// attentions bypass the intermediate layer
cur = ggml_add(ctx0, cur, ffn_inp);
@ -12843,6 +12879,7 @@ llm_graph_result_ptr llama_model::build_graph(
case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
{
llm = std::make_unique<llm_build_bert>(*this, params, gf);
} break;
@ -13201,6 +13238,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_DBRX:
case LLM_ARCH_BERT:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_STABLELM:
case LLM_ARCH_BITNET:
case LLM_ARCH_QWEN: