llama : add support for jina-reranker-v2 (#13900)

This commit is contained in:
Sigbjørn Skjæret
2025-05-29 21:42:31 +02:00
committed by GitHub
parent 2b131621e6
commit e83ba3e460
5 changed files with 119 additions and 72 deletions

View File

@ -3782,9 +3782,33 @@ class BertModel(TextModel):
from sentencepiece import sentencepiece_model_pb2 as model from sentencepiece import sentencepiece_model_pb2 as model
tokenizer_path = self.dir_model / 'sentencepiece.bpe.model' tokenizer_path = self.dir_model / 'sentencepiece.bpe.model'
tokenizer_json = {}
tokenizer_config_json = {}
if not tokenizer_path.is_file():
tokenizer_path = self.dir_model / 'tokenizer.json'
tokenizer_config_path = self.dir_model / 'tokenizer_config.json'
if not tokenizer_path.is_file(): if not tokenizer_path.is_file():
raise FileNotFoundError(f"File not found: {tokenizer_path}") raise FileNotFoundError(f"File not found: {tokenizer_path}")
from base64 import b64decode
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
with open(tokenizer_path, "r", encoding="utf-8") as fp:
tokenizer_json = json.load(fp)
if tokenizer_config_path.is_file():
with open(tokenizer_config_path, "r", encoding="utf-8") as fp:
tokenizer_config_json = json.load(fp)
add_prefix = tokenizer.add_prefix_space
remove_whitespaces = tokenizer.clean_up_tokenization_spaces
precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"])
vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size)
else:
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
@ -3802,6 +3826,7 @@ class BertModel(TextModel):
scores: list[float] = [-10000.0] * vocab_size scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
if isinstance(tokenizer, SentencePieceProcessor):
for token_id in range(tokenizer.vocab_size()): for token_id in range(tokenizer.vocab_size()):
piece = tokenizer.IdToPiece(token_id) piece = tokenizer.IdToPiece(token_id)
text = piece.encode("utf-8") text = piece.encode("utf-8")
@ -3820,6 +3845,30 @@ class BertModel(TextModel):
tokens[token_id] = text tokens[token_id] = text
scores[token_id] = score scores[token_id] = score
toktypes[token_id] = toktype toktypes[token_id] = toktype
else:
added_vocab = tokenizer.get_added_vocab()
unk_token = tokenizer_config_json.get("unk_token")
unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3))
for token_id in range(vocab_size):
piece = tokenizer._convert_id_to_token(token_id)
text = piece.encode("utf-8")
score = tokenizer_json["model"]["vocab"][token_id][1]
toktype = SentencePieceTokenTypes.NORMAL
if token_id == unk_token_id:
toktype = SentencePieceTokenTypes.UNKNOWN
elif token_id in tokenizer.all_special_ids:
toktype = SentencePieceTokenTypes.CONTROL
elif token_id in added_vocab.values():
toktype = SentencePieceTokenTypes.USER_DEFINED
# No reliable way to detect this, but jina doesn't have any
# elif tokenizer.IsByte(token_id):
# toktype = SentencePieceTokenTypes.BYTE
tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype
if vocab_size > len(tokens): if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens) pad_count = vocab_size - len(tokens)
@ -3829,6 +3878,7 @@ class BertModel(TextModel):
scores.append(-1000.0) scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.UNUSED) toktypes.append(SentencePieceTokenTypes.UNUSED)
if isinstance(tokenizer, SentencePieceProcessor):
# realign tokens (see HF tokenizer code) # realign tokens (see HF tokenizer code)
tokens = [b'<s>', b'<pad>', b'</s>', b'<unk>'] + tokens[3:-1] tokens = [b'<s>', b'<pad>', b'</s>', b'<unk>'] + tokens[3:-1]
scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1] scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1]

View File

@ -1036,6 +1036,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.POS_EMBD, MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_OUT_NORM, MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_V,

View File

@ -157,6 +157,7 @@ class TensorNameMap:
"h.{bid}.attn.c_attn", # gpt2 "h.{bid}.attn.c_attn", # gpt2
"transformer.h.{bid}.mixer.Wqkv", # phi2 "transformer.h.{bid}.mixer.Wqkv", # phi2
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert "encoder.layers.{bid}.attn.Wqkv", # nomic-bert
"encoder.layers.{bid}.mixer.Wqkv", # jina
"model.layers.{bid}.self_attn.qkv_proj", # phi3 "model.layers.{bid}.self_attn.qkv_proj", # phi3
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm "encoder.layers.{bid}.self_attention.query_key_value", # chatglm
"transformer.layers.{bid}.attn.qkv_proj", # openelm "transformer.layers.{bid}.attn.qkv_proj", # openelm
@ -224,6 +225,7 @@ class TensorNameMap:
"model.layers.layers.{bid}.self_attn.o_proj", # plamo "model.layers.layers.{bid}.self_attn.o_proj", # plamo
"model.layers.{bid}.attention.wo", # internlm2 "model.layers.{bid}.attention.wo", # internlm2
"encoder.layers.{bid}.attn.out_proj", # nomic-bert "encoder.layers.{bid}.attn.out_proj", # nomic-bert
"encoder.layers.{bid}.mixer.out_proj", # jina
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
"encoder.layers.{bid}.self_attention.dense", # chatglm "encoder.layers.{bid}.self_attention.dense", # chatglm

View File

@ -450,6 +450,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_POS_EMBD, "position_embd" }, { LLM_TENSOR_POS_EMBD, "position_embd" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },

View File

@ -2132,7 +2132,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i]; auto & layer = layers[i];
if (arch == LLM_ARCH_BERT) { layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
if (!layer.wqkv) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
@ -2141,12 +2144,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
} else {
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
}
if (arch == LLM_ARCH_NOMIC_BERT_MOE) {
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
} }
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
@ -5910,8 +5907,23 @@ struct llm_build_bert : public llm_graph_context {
ggml_tensor * Vcur; ggml_tensor * Vcur;
// self-attention // self-attention
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) { if (model.layers[il].wqkv) {
cur = build_lora_mm(model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
if (model.layers[il].bqkv) {
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
}
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
} else {
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
}
if (model.layers[il].attn_q_norm) { if (model.layers[il].attn_q_norm) {
Qcur = build_norm(Qcur, Qcur = build_norm(Qcur,
@ -5920,8 +5932,6 @@ struct llm_build_bert : public llm_graph_context {
LLM_NORM, il); LLM_NORM, il);
} }
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
if (model.layers[il].attn_k_norm) { if (model.layers[il].attn_k_norm) {
Kcur = build_norm(Kcur, Kcur = build_norm(Kcur,
model.layers[il].attn_k_norm, model.layers[il].attn_k_norm,
@ -5929,29 +5939,12 @@ struct llm_build_bert : public llm_graph_context {
LLM_NORM, il); LLM_NORM, il);
} }
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
} else {
// compute Q and K and RoPE them
cur = build_lora_mm(model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
}
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// RoPE
if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
Qcur = ggml_rope_ext( Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr, ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,