mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 20:05:20 +00:00
llama : support GEGLU for jina-bert-v2 (#14090)
This commit is contained in:
@ -4798,25 +4798,6 @@ class OlmoeModel(TextModel):
|
||||
class JinaBertV2Model(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.JINA_BERT_V2
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.intermediate_size = self.hparams["intermediate_size"]
|
||||
|
||||
def get_tensors(self):
|
||||
for name, data in super().get_tensors():
|
||||
if 'gated_layer' in name:
|
||||
d1 = data[:self.intermediate_size, :]
|
||||
name1 = name.replace('gated_layers', 'gated_layers_w')
|
||||
name1 = name1.replace('up_gated_layer', 'gated_layers_v')
|
||||
d2 = data[self.intermediate_size:, :]
|
||||
name2 = name.replace('gated_layers', 'gated_layers_v')
|
||||
name2 = name2.replace('up_gated_layer', 'gated_layers_w')
|
||||
yield name1, d1
|
||||
yield name2, d2
|
||||
continue
|
||||
|
||||
yield name, data
|
||||
|
||||
def set_vocab(self):
|
||||
tokenizer_class = 'BertTokenizer'
|
||||
with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f:
|
||||
@ -4832,14 +4813,6 @@ class JinaBertV2Model(BertModel):
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# if name starts with "bert.", remove the prefix
|
||||
# e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
||||
if name.startswith("bert."):
|
||||
name = name[5:]
|
||||
|
||||
return super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
@ModelBase.register("OpenELMForCausalLM")
|
||||
class OpenELMModel(TextModel):
|
||||
|
@ -333,7 +333,9 @@ class TensorNameMap:
|
||||
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
|
||||
"encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe
|
||||
"model.layers.{bid}.mlp.c_fc", # starcoder2
|
||||
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
|
||||
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 (split up/gate, no longer used)
|
||||
"encoder.layer.{bid}.mlp.gated_layers", # jina-bert-v2 (GEGLU)
|
||||
"encoder.layer.{bid}.mlp.up_gated_layer", # jina-v2-code (GEGLU)
|
||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||
@ -370,7 +372,7 @@ class TensorNameMap:
|
||||
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
|
||||
"model.layers.{bid}.feed_forward.w1", # internlm2
|
||||
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
|
||||
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
|
||||
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used)
|
||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
|
@ -650,6 +650,7 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
{
|
||||
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
||||
int64_t split_point = cur->ne[0] / 2;
|
||||
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||
|
||||
@ -663,7 +664,7 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
{
|
||||
// Split into two equal parts
|
||||
int64_t split_point = cur->ne[0] / 2;
|
||||
// TODO: these conts should not be needed
|
||||
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
||||
|
||||
|
@ -2224,8 +2224,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0);
|
||||
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
|
||||
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
|
||||
@ -6043,7 +6043,7 @@ struct llm_build_bert : public llm_graph_context {
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||
NULL,
|
||||
LLM_FFN_GELU, LLM_FFN_PAR, il);
|
||||
model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
cur = build_ffn(cur,
|
||||
|
Reference in New Issue
Block a user