diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 8fcff0de7..868bb6826 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3782,44 +3782,93 @@ class BertModel(TextModel): from sentencepiece import sentencepiece_model_pb2 as model tokenizer_path = self.dir_model / 'sentencepiece.bpe.model' + + tokenizer_json = {} + tokenizer_config_json = {} if not tokenizer_path.is_file(): - raise FileNotFoundError(f"File not found: {tokenizer_path}") + tokenizer_path = self.dir_model / 'tokenizer.json' + tokenizer_config_path = self.dir_model / 'tokenizer_config.json' - sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] - sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) - assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM + if not tokenizer_path.is_file(): + raise FileNotFoundError(f"File not found: {tokenizer_path}") - add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix - remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces - precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap + from base64 import b64decode + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) - tokenizer = SentencePieceProcessor() - tokenizer.LoadFromFile(str(tokenizer_path)) + with open(tokenizer_path, "r", encoding="utf-8") as fp: + tokenizer_json = json.load(fp) - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + if tokenizer_config_path.is_file(): + with open(tokenizer_config_path, "r", encoding="utf-8") as fp: + tokenizer_config_json = json.load(fp) + + add_prefix = tokenizer.add_prefix_space + remove_whitespaces = tokenizer.clean_up_tokenization_spaces + precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"]) + + vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size) + else: + sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM + + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix + remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces + precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap + + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size - for token_id in range(tokenizer.vocab_size()): - piece = tokenizer.IdToPiece(token_id) - text = piece.encode("utf-8") - score = tokenizer.GetScore(token_id) + if isinstance(tokenizer, SentencePieceProcessor): + for token_id in range(tokenizer.vocab_size()): + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") + score = tokenizer.GetScore(token_id) - toktype = SentencePieceTokenTypes.NORMAL - if tokenizer.IsUnknown(token_id): - toktype = SentencePieceTokenTypes.UNKNOWN - elif tokenizer.IsControl(token_id): - toktype = SentencePieceTokenTypes.CONTROL - elif tokenizer.IsUnused(token_id): - toktype = SentencePieceTokenTypes.UNUSED - elif tokenizer.IsByte(token_id): - toktype = SentencePieceTokenTypes.BYTE + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.IsUnknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.IsControl(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.IsUnused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.IsByte(token_id): + toktype = SentencePieceTokenTypes.BYTE - tokens[token_id] = text - scores[token_id] = score - toktypes[token_id] = toktype + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + else: + added_vocab = tokenizer.get_added_vocab() + unk_token = tokenizer_config_json.get("unk_token") + unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3)) + + for token_id in range(vocab_size): + piece = tokenizer._convert_id_to_token(token_id) + text = piece.encode("utf-8") + score = tokenizer_json["model"]["vocab"][token_id][1] + + toktype = SentencePieceTokenTypes.NORMAL + if token_id == unk_token_id: + toktype = SentencePieceTokenTypes.UNKNOWN + elif token_id in tokenizer.all_special_ids: + toktype = SentencePieceTokenTypes.CONTROL + elif token_id in added_vocab.values(): + toktype = SentencePieceTokenTypes.USER_DEFINED + # No reliable way to detect this, but jina doesn't have any + # elif tokenizer.IsByte(token_id): + # toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype if vocab_size > len(tokens): pad_count = vocab_size - len(tokens) @@ -3829,15 +3878,16 @@ class BertModel(TextModel): scores.append(-1000.0) toktypes.append(SentencePieceTokenTypes.UNUSED) - # realign tokens (see HF tokenizer code) - tokens = [b'', b'', b'', b''] + tokens[3:-1] - scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1] - toktypes = [ - SentencePieceTokenTypes.CONTROL, - SentencePieceTokenTypes.CONTROL, - SentencePieceTokenTypes.CONTROL, - SentencePieceTokenTypes.UNKNOWN, - ] + toktypes[3:-1] + if isinstance(tokenizer, SentencePieceProcessor): + # realign tokens (see HF tokenizer code) + tokens = [b'', b'', b'', b''] + tokens[3:-1] + scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1] + toktypes = [ + SentencePieceTokenTypes.CONTROL, + SentencePieceTokenTypes.CONTROL, + SentencePieceTokenTypes.CONTROL, + SentencePieceTokenTypes.UNKNOWN, + ] + toktypes[3:-1] self.gguf_writer.add_tokenizer_model("t5") self.gguf_writer.add_tokenizer_pre("default") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 635b61f22..3ee2b2064 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1036,6 +1036,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.POS_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 48167dd64..d0dad7036 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -157,6 +157,7 @@ class TensorNameMap: "h.{bid}.attn.c_attn", # gpt2 "transformer.h.{bid}.mixer.Wqkv", # phi2 "encoder.layers.{bid}.attn.Wqkv", # nomic-bert + "encoder.layers.{bid}.mixer.Wqkv", # jina "model.layers.{bid}.self_attn.qkv_proj", # phi3 "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm @@ -224,6 +225,7 @@ class TensorNameMap: "model.layers.layers.{bid}.self_attn.o_proj", # plamo "model.layers.{bid}.attention.wo", # internlm2 "encoder.layers.{bid}.attn.out_proj", # nomic-bert + "encoder.layers.{bid}.mixer.out_proj", # jina "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx "encoder.layers.{bid}.self_attention.dense", # chatglm diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 2bb18c85f..c0590e105 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -450,6 +450,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_POS_EMBD, "position_embd" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4a4618a2b..ecaae6bf0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2132,7 +2132,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - if (arch == LLM_ARCH_BERT) { + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (!layer.wqkv) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); @@ -2141,12 +2144,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } else { - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - } - - if (arch == LLM_ARCH_NOMIC_BERT_MOE) { - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); @@ -5910,36 +5907,11 @@ struct llm_build_bert : public llm_graph_context { ggml_tensor * Vcur; // self-attention - if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); - - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, - model.layers[il].attn_q_norm, - model.layers[il].attn_q_norm_b, - LLM_NORM, il); - } - - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); - - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, - model.layers[il].attn_k_norm, - model.layers[il].attn_k_norm_b, - LLM_NORM, il); - } - - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } else { - // compute Q and K and RoPE them + if (model.layers[il].wqkv) { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); - if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) { + if (model.layers[il].bqkv) { cur = ggml_add(ctx0, cur, model.layers[il].bqkv); cb(cur, "bqkv", il); } @@ -5947,11 +5919,32 @@ struct llm_build_bert : public llm_graph_context { Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } else { + Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); + Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); + Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); + } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // RoPE + if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,