diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 8fcff0de7..868bb6826 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -3782,44 +3782,93 @@ class BertModel(TextModel):
from sentencepiece import sentencepiece_model_pb2 as model
tokenizer_path = self.dir_model / 'sentencepiece.bpe.model'
+
+ tokenizer_json = {}
+ tokenizer_config_json = {}
if not tokenizer_path.is_file():
- raise FileNotFoundError(f"File not found: {tokenizer_path}")
+ tokenizer_path = self.dir_model / 'tokenizer.json'
+ tokenizer_config_path = self.dir_model / 'tokenizer_config.json'
- sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
- sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
- assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
+ if not tokenizer_path.is_file():
+ raise FileNotFoundError(f"File not found: {tokenizer_path}")
- add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
- remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
- precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
+ from base64 import b64decode
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
- tokenizer = SentencePieceProcessor()
- tokenizer.LoadFromFile(str(tokenizer_path))
+ with open(tokenizer_path, "r", encoding="utf-8") as fp:
+ tokenizer_json = json.load(fp)
- vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
+ if tokenizer_config_path.is_file():
+ with open(tokenizer_config_path, "r", encoding="utf-8") as fp:
+ tokenizer_config_json = json.load(fp)
+
+ add_prefix = tokenizer.add_prefix_space
+ remove_whitespaces = tokenizer.clean_up_tokenization_spaces
+ precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"])
+
+ vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size)
+ else:
+ sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
+ sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
+ assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
+
+ add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
+ remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
+ precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
+
+ tokenizer = SentencePieceProcessor()
+ tokenizer.LoadFromFile(str(tokenizer_path))
+
+ vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
- for token_id in range(tokenizer.vocab_size()):
- piece = tokenizer.IdToPiece(token_id)
- text = piece.encode("utf-8")
- score = tokenizer.GetScore(token_id)
+ if isinstance(tokenizer, SentencePieceProcessor):
+ for token_id in range(tokenizer.vocab_size()):
+ piece = tokenizer.IdToPiece(token_id)
+ text = piece.encode("utf-8")
+ score = tokenizer.GetScore(token_id)
- toktype = SentencePieceTokenTypes.NORMAL
- if tokenizer.IsUnknown(token_id):
- toktype = SentencePieceTokenTypes.UNKNOWN
- elif tokenizer.IsControl(token_id):
- toktype = SentencePieceTokenTypes.CONTROL
- elif tokenizer.IsUnused(token_id):
- toktype = SentencePieceTokenTypes.UNUSED
- elif tokenizer.IsByte(token_id):
- toktype = SentencePieceTokenTypes.BYTE
+ toktype = SentencePieceTokenTypes.NORMAL
+ if tokenizer.IsUnknown(token_id):
+ toktype = SentencePieceTokenTypes.UNKNOWN
+ elif tokenizer.IsControl(token_id):
+ toktype = SentencePieceTokenTypes.CONTROL
+ elif tokenizer.IsUnused(token_id):
+ toktype = SentencePieceTokenTypes.UNUSED
+ elif tokenizer.IsByte(token_id):
+ toktype = SentencePieceTokenTypes.BYTE
- tokens[token_id] = text
- scores[token_id] = score
- toktypes[token_id] = toktype
+ tokens[token_id] = text
+ scores[token_id] = score
+ toktypes[token_id] = toktype
+ else:
+ added_vocab = tokenizer.get_added_vocab()
+ unk_token = tokenizer_config_json.get("unk_token")
+ unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3))
+
+ for token_id in range(vocab_size):
+ piece = tokenizer._convert_id_to_token(token_id)
+ text = piece.encode("utf-8")
+ score = tokenizer_json["model"]["vocab"][token_id][1]
+
+ toktype = SentencePieceTokenTypes.NORMAL
+ if token_id == unk_token_id:
+ toktype = SentencePieceTokenTypes.UNKNOWN
+ elif token_id in tokenizer.all_special_ids:
+ toktype = SentencePieceTokenTypes.CONTROL
+ elif token_id in added_vocab.values():
+ toktype = SentencePieceTokenTypes.USER_DEFINED
+ # No reliable way to detect this, but jina doesn't have any
+ # elif tokenizer.IsByte(token_id):
+ # toktype = SentencePieceTokenTypes.BYTE
+
+ tokens[token_id] = text
+ scores[token_id] = score
+ toktypes[token_id] = toktype
if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens)
@@ -3829,15 +3878,16 @@ class BertModel(TextModel):
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.UNUSED)
- # realign tokens (see HF tokenizer code)
- tokens = [b'', b'', b'', b''] + tokens[3:-1]
- scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1]
- toktypes = [
- SentencePieceTokenTypes.CONTROL,
- SentencePieceTokenTypes.CONTROL,
- SentencePieceTokenTypes.CONTROL,
- SentencePieceTokenTypes.UNKNOWN,
- ] + toktypes[3:-1]
+ if isinstance(tokenizer, SentencePieceProcessor):
+ # realign tokens (see HF tokenizer code)
+ tokens = [b'', b'', b'', b''] + tokens[3:-1]
+ scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1]
+ toktypes = [
+ SentencePieceTokenTypes.CONTROL,
+ SentencePieceTokenTypes.CONTROL,
+ SentencePieceTokenTypes.CONTROL,
+ SentencePieceTokenTypes.UNKNOWN,
+ ] + toktypes[3:-1]
self.gguf_writer.add_tokenizer_model("t5")
self.gguf_writer.add_tokenizer_pre("default")
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 635b61f22..3ee2b2064 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -1036,6 +1036,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.POS_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_OUT_NORM,
+ MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 48167dd64..d0dad7036 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -157,6 +157,7 @@ class TensorNameMap:
"h.{bid}.attn.c_attn", # gpt2
"transformer.h.{bid}.mixer.Wqkv", # phi2
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
+ "encoder.layers.{bid}.mixer.Wqkv", # jina
"model.layers.{bid}.self_attn.qkv_proj", # phi3
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
"transformer.layers.{bid}.attn.qkv_proj", # openelm
@@ -224,6 +225,7 @@ class TensorNameMap:
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
"model.layers.{bid}.attention.wo", # internlm2
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
+ "encoder.layers.{bid}.mixer.out_proj", # jina
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
"encoder.layers.{bid}.self_attention.dense", # chatglm
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
index 2bb18c85f..c0590e105 100644
--- a/src/llama-arch.cpp
+++ b/src/llama-arch.cpp
@@ -450,6 +450,7 @@ static const std::map> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_TYPES, "token_types" },
{ LLM_TENSOR_POS_EMBD, "position_embd" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 4a4618a2b..ecaae6bf0 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -2132,7 +2132,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
- if (arch == LLM_ARCH_BERT) {
+ layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
+ layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
+
+ if (!layer.wqkv) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0);
@@ -2141,12 +2144,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
- } else {
- layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
- }
-
- if (arch == LLM_ARCH_NOMIC_BERT_MOE) {
- layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
}
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
@@ -5910,36 +5907,11 @@ struct llm_build_bert : public llm_graph_context {
ggml_tensor * Vcur;
// self-attention
- if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) {
- Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
-
- if (model.layers[il].attn_q_norm) {
- Qcur = build_norm(Qcur,
- model.layers[il].attn_q_norm,
- model.layers[il].attn_q_norm_b,
- LLM_NORM, il);
- }
-
- Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
-
- if (model.layers[il].attn_k_norm) {
- Kcur = build_norm(Kcur,
- model.layers[il].attn_k_norm,
- model.layers[il].attn_k_norm_b,
- LLM_NORM, il);
- }
-
- Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
-
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
- } else {
- // compute Q and K and RoPE them
+ if (model.layers[il].wqkv) {
cur = build_lora_mm(model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
- if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
+ if (model.layers[il].bqkv) {
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
}
@@ -5947,11 +5919,32 @@ struct llm_build_bert : public llm_graph_context {
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+ } else {
+ Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
+ Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
+ Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
+ }
- Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
- Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
- Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+ if (model.layers[il].attn_q_norm) {
+ Qcur = build_norm(Qcur,
+ model.layers[il].attn_q_norm,
+ model.layers[il].attn_q_norm_b,
+ LLM_NORM, il);
+ }
+ if (model.layers[il].attn_k_norm) {
+ Kcur = build_norm(Kcur,
+ model.layers[il].attn_k_norm,
+ model.layers[il].attn_k_norm_b,
+ LLM_NORM, il);
+ }
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ // RoPE
+ if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,