mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 19:55:04 +00:00
llama : add support for jina-reranker-v2 (#13900)
This commit is contained in:
@ -3782,44 +3782,93 @@ class BertModel(TextModel):
|
||||
from sentencepiece import sentencepiece_model_pb2 as model
|
||||
|
||||
tokenizer_path = self.dir_model / 'sentencepiece.bpe.model'
|
||||
|
||||
tokenizer_json = {}
|
||||
tokenizer_config_json = {}
|
||||
if not tokenizer_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {tokenizer_path}")
|
||||
tokenizer_path = self.dir_model / 'tokenizer.json'
|
||||
tokenizer_config_path = self.dir_model / 'tokenizer_config.json'
|
||||
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
|
||||
if not tokenizer_path.is_file():
|
||||
raise FileNotFoundError(f"File not found: {tokenizer_path}")
|
||||
|
||||
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
|
||||
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
|
||||
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
|
||||
from base64 import b64decode
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||
|
||||
tokenizer = SentencePieceProcessor()
|
||||
tokenizer.LoadFromFile(str(tokenizer_path))
|
||||
with open(tokenizer_path, "r", encoding="utf-8") as fp:
|
||||
tokenizer_json = json.load(fp)
|
||||
|
||||
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
|
||||
if tokenizer_config_path.is_file():
|
||||
with open(tokenizer_config_path, "r", encoding="utf-8") as fp:
|
||||
tokenizer_config_json = json.load(fp)
|
||||
|
||||
add_prefix = tokenizer.add_prefix_space
|
||||
remove_whitespaces = tokenizer.clean_up_tokenization_spaces
|
||||
precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"])
|
||||
|
||||
vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size)
|
||||
else:
|
||||
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
|
||||
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
|
||||
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
|
||||
|
||||
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
|
||||
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
|
||||
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
|
||||
|
||||
tokenizer = SentencePieceProcessor()
|
||||
tokenizer.LoadFromFile(str(tokenizer_path))
|
||||
|
||||
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
|
||||
|
||||
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
|
||||
scores: list[float] = [-10000.0] * vocab_size
|
||||
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
|
||||
|
||||
for token_id in range(tokenizer.vocab_size()):
|
||||
piece = tokenizer.IdToPiece(token_id)
|
||||
text = piece.encode("utf-8")
|
||||
score = tokenizer.GetScore(token_id)
|
||||
if isinstance(tokenizer, SentencePieceProcessor):
|
||||
for token_id in range(tokenizer.vocab_size()):
|
||||
piece = tokenizer.IdToPiece(token_id)
|
||||
text = piece.encode("utf-8")
|
||||
score = tokenizer.GetScore(token_id)
|
||||
|
||||
toktype = SentencePieceTokenTypes.NORMAL
|
||||
if tokenizer.IsUnknown(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
elif tokenizer.IsControl(token_id):
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif tokenizer.IsUnused(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNUSED
|
||||
elif tokenizer.IsByte(token_id):
|
||||
toktype = SentencePieceTokenTypes.BYTE
|
||||
toktype = SentencePieceTokenTypes.NORMAL
|
||||
if tokenizer.IsUnknown(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
elif tokenizer.IsControl(token_id):
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif tokenizer.IsUnused(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNUSED
|
||||
elif tokenizer.IsByte(token_id):
|
||||
toktype = SentencePieceTokenTypes.BYTE
|
||||
|
||||
tokens[token_id] = text
|
||||
scores[token_id] = score
|
||||
toktypes[token_id] = toktype
|
||||
tokens[token_id] = text
|
||||
scores[token_id] = score
|
||||
toktypes[token_id] = toktype
|
||||
else:
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
unk_token = tokenizer_config_json.get("unk_token")
|
||||
unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3))
|
||||
|
||||
for token_id in range(vocab_size):
|
||||
piece = tokenizer._convert_id_to_token(token_id)
|
||||
text = piece.encode("utf-8")
|
||||
score = tokenizer_json["model"]["vocab"][token_id][1]
|
||||
|
||||
toktype = SentencePieceTokenTypes.NORMAL
|
||||
if token_id == unk_token_id:
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
elif token_id in tokenizer.all_special_ids:
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif token_id in added_vocab.values():
|
||||
toktype = SentencePieceTokenTypes.USER_DEFINED
|
||||
# No reliable way to detect this, but jina doesn't have any
|
||||
# elif tokenizer.IsByte(token_id):
|
||||
# toktype = SentencePieceTokenTypes.BYTE
|
||||
|
||||
tokens[token_id] = text
|
||||
scores[token_id] = score
|
||||
toktypes[token_id] = toktype
|
||||
|
||||
if vocab_size > len(tokens):
|
||||
pad_count = vocab_size - len(tokens)
|
||||
@ -3829,15 +3878,16 @@ class BertModel(TextModel):
|
||||
scores.append(-1000.0)
|
||||
toktypes.append(SentencePieceTokenTypes.UNUSED)
|
||||
|
||||
# realign tokens (see HF tokenizer code)
|
||||
tokens = [b'<s>', b'<pad>', b'</s>', b'<unk>'] + tokens[3:-1]
|
||||
scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1]
|
||||
toktypes = [
|
||||
SentencePieceTokenTypes.CONTROL,
|
||||
SentencePieceTokenTypes.CONTROL,
|
||||
SentencePieceTokenTypes.CONTROL,
|
||||
SentencePieceTokenTypes.UNKNOWN,
|
||||
] + toktypes[3:-1]
|
||||
if isinstance(tokenizer, SentencePieceProcessor):
|
||||
# realign tokens (see HF tokenizer code)
|
||||
tokens = [b'<s>', b'<pad>', b'</s>', b'<unk>'] + tokens[3:-1]
|
||||
scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1]
|
||||
toktypes = [
|
||||
SentencePieceTokenTypes.CONTROL,
|
||||
SentencePieceTokenTypes.CONTROL,
|
||||
SentencePieceTokenTypes.CONTROL,
|
||||
SentencePieceTokenTypes.UNKNOWN,
|
||||
] + toktypes[3:-1]
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("t5")
|
||||
self.gguf_writer.add_tokenizer_pre("default")
|
||||
|
@ -1036,6 +1036,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.POS_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
|
@ -157,6 +157,7 @@ class TensorNameMap:
|
||||
"h.{bid}.attn.c_attn", # gpt2
|
||||
"transformer.h.{bid}.mixer.Wqkv", # phi2
|
||||
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
|
||||
"encoder.layers.{bid}.mixer.Wqkv", # jina
|
||||
"model.layers.{bid}.self_attn.qkv_proj", # phi3
|
||||
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
|
||||
"transformer.layers.{bid}.attn.qkv_proj", # openelm
|
||||
@ -224,6 +225,7 @@ class TensorNameMap:
|
||||
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
||||
"model.layers.{bid}.attention.wo", # internlm2
|
||||
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
|
||||
"encoder.layers.{bid}.mixer.out_proj", # jina
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
|
||||
"encoder.layers.{bid}.self_attention.dense", # chatglm
|
||||
|
@ -450,6 +450,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
|
||||
{ LLM_TENSOR_POS_EMBD, "position_embd" },
|
||||
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
||||
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
|
@ -2132,7 +2132,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
if (arch == LLM_ARCH_BERT) {
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
if (!layer.wqkv) {
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
|
||||
|
||||
@ -2141,12 +2144,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
|
||||
} else {
|
||||
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_NOMIC_BERT_MOE) {
|
||||
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
|
||||
}
|
||||
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
|
||||
@ -5910,36 +5907,11 @@ struct llm_build_bert : public llm_graph_context {
|
||||
ggml_tensor * Vcur;
|
||||
|
||||
// self-attention
|
||||
if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
|
||||
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
|
||||
|
||||
if (model.layers[il].attn_q_norm) {
|
||||
Qcur = build_norm(Qcur,
|
||||
model.layers[il].attn_q_norm,
|
||||
model.layers[il].attn_q_norm_b,
|
||||
LLM_NORM, il);
|
||||
}
|
||||
|
||||
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
|
||||
|
||||
if (model.layers[il].attn_k_norm) {
|
||||
Kcur = build_norm(Kcur,
|
||||
model.layers[il].attn_k_norm,
|
||||
model.layers[il].attn_k_norm_b,
|
||||
LLM_NORM, il);
|
||||
}
|
||||
|
||||
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
} else {
|
||||
// compute Q and K and RoPE them
|
||||
if (model.layers[il].wqkv) {
|
||||
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
||||
cb(cur, "wqkv", il);
|
||||
|
||||
if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
|
||||
if (model.layers[il].bqkv) {
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
||||
cb(cur, "bqkv", il);
|
||||
}
|
||||
@ -5947,11 +5919,32 @@ struct llm_build_bert : public llm_graph_context {
|
||||
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
||||
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
||||
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
||||
} else {
|
||||
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
|
||||
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
|
||||
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
if (model.layers[il].attn_q_norm) {
|
||||
Qcur = build_norm(Qcur,
|
||||
model.layers[il].attn_q_norm,
|
||||
model.layers[il].attn_q_norm_b,
|
||||
LLM_NORM, il);
|
||||
}
|
||||
|
||||
if (model.layers[il].attn_k_norm) {
|
||||
Kcur = build_norm(Kcur,
|
||||
model.layers[il].attn_k_norm,
|
||||
model.layers[il].attn_k_norm_b,
|
||||
LLM_NORM, il);
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// RoPE
|
||||
if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
|
Reference in New Issue
Block a user