llama : improve sep token handling (#14272)

This commit is contained in:
Sigbjørn Skjæret
2025-06-20 14:04:09 +02:00
committed by GitHub
parent e28c1b93fd
commit 88fc854b4b
15 changed files with 161 additions and 29 deletions

View File

@ -779,7 +779,7 @@ function gg_run_rerank_tiny {
model_f16="${path_models}/ggml-model-f16.gguf" model_f16="${path_models}/ggml-model-f16.gguf"
# for this model, the SEP token is "</s>" # for this model, the SEP token is "</s>"
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?</s></s>hi\nwhat is panda?</s></s>it's a bear\nwhat is panda?</s></s>The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
# sample output # sample output
# rerank score 0: 0.029 # rerank score 0: 0.029

View File

@ -2706,6 +2706,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.embd_sep = value; params.embd_sep = value;
} }
).set_examples({LLAMA_EXAMPLE_EMBEDDING})); ).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
add_opt(common_arg(
{"--cls-separator"}, "STRING",
"separator of classification sequences (default \\t) for example \"<#seq#>\"",
[](common_params & params, const std::string & value) {
params.cls_sep = value;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
add_opt(common_arg( add_opt(common_arg(
{"--host"}, "HOST", {"--host"}, "HOST",
string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()), string_format("ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: %s)", params.hostname.c_str()),

View File

@ -358,6 +358,7 @@ struct common_params {
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embeddings std::string embd_sep = "\n"; // separator of embeddings
std::string cls_sep = "\t"; // separator of classification sequences
// server params // server params
int32_t port = 8080; // server listens on this network port int32_t port = 8080; // server listens on this network port

View File

@ -2145,7 +2145,6 @@ class Llama4Model(LlamaModel):
def set_vocab(self): def set_vocab(self):
self._set_vocab_gpt2() self._set_vocab_gpt2()
self.gguf_writer.add_add_bos_token(True)
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
@ -3918,9 +3917,6 @@ class BertModel(TextModel):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)
@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification") @ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification")
class DistilBertModel(BertModel): class DistilBertModel(BertModel):
@ -3962,8 +3958,6 @@ class RobertaModel(BertModel):
bpe_tok_path = self.dir_model / "tokenizer.json" bpe_tok_path = self.dir_model / "tokenizer.json"
if bpe_tok_path.exists(): if bpe_tok_path.exists():
self._set_vocab_gpt2() self._set_vocab_gpt2()
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)
# we need this to validate the size of the token_type embeddings # we need this to validate the size of the token_type embeddings
# though currently we are passing all zeros to the token_type embeddings # though currently we are passing all zeros to the token_type embeddings
@ -4848,8 +4842,6 @@ class JinaBertV2Model(BertModel):
self.gguf_writer.add_token_type_count(2) self.gguf_writer.add_token_type_count(2)
else: else:
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel') raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)
@ModelBase.register("OpenELMForCausalLM") @ModelBase.register("OpenELMForCausalLM")
@ -5451,9 +5443,6 @@ class T5Model(TextModel):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_add_eos_token(True)
def set_gguf_parameters(self): def set_gguf_parameters(self):
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None: if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
logger.warning("Couldn't find context length in config.json, assuming default value of 512") logger.warning("Couldn't find context length in config.json, assuming default value of 512")
@ -5591,9 +5580,6 @@ class T5EncoderModel(TextModel):
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer) special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_add_eos_token(True)
def set_gguf_parameters(self): def set_gguf_parameters(self):
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None: if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
logger.warning("Couldn't find context length in config.json, assuming default value of 512") logger.warning("Couldn't find context length in config.json, assuming default value of 512")

View File

@ -133,10 +133,36 @@ int main(int argc, char ** argv) {
// max batch size // max batch size
const uint64_t n_batch = params.n_batch; const uint64_t n_batch = params.n_batch;
// get added sep and eos token, if any
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
// tokenize the prompts and trim // tokenize the prompts and trim
std::vector<std::vector<int32_t>> inputs; std::vector<std::vector<int32_t>> inputs;
for (const auto & prompt : prompts) { for (const auto & prompt : prompts) {
auto inp = common_tokenize(ctx, prompt, true, true); std::vector<llama_token> inp;
// split classification pairs and insert expected separator tokens
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
std::string final_prompt;
for (size_t i = 0; i < pairs.size(); i++) {
final_prompt += pairs[i];
if (i != pairs.size() - 1) {
if (!added_eos_token.empty()) {
final_prompt += added_eos_token;
}
if (!added_sep_token.empty()) {
final_prompt += added_sep_token;
}
}
}
inp = common_tokenize(ctx, final_prompt, true, true);
} else {
inp = common_tokenize(ctx, prompt, true, true);
}
if (inp.size() > n_batch) { if (inp.size() > n_batch) {
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n", LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
__func__, (long long int) inp.size(), (long long int) n_batch); __func__, (long long int) inp.size(), (long long int) n_batch);
@ -145,11 +171,11 @@ int main(int argc, char ** argv) {
inputs.push_back(inp); inputs.push_back(inp);
} }
// check if the last token is SEP // check if the last token is SEP/EOS
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true' // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
for (auto & inp : inputs) { for (auto & inp : inputs) {
if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) { if (inp.empty() || (inp.back() != llama_vocab_sep(vocab) && inp.back() != llama_vocab_eos(vocab))) {
LOG_WRN("%s: last token in the prompt is not SEP\n", __func__); LOG_WRN("%s: last token in the prompt is not SEP or EOS\n", __func__);
LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__); LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
} }
} }

View File

@ -198,6 +198,7 @@ class Keys:
MASK_ID = "tokenizer.ggml.mask_token_id" MASK_ID = "tokenizer.ggml.mask_token_id"
ADD_BOS = "tokenizer.ggml.add_bos_token" ADD_BOS = "tokenizer.ggml.add_bos_token"
ADD_EOS = "tokenizer.ggml.add_eos_token" ADD_EOS = "tokenizer.ggml.add_eos_token"
ADD_SEP = "tokenizer.ggml.add_sep_token"
ADD_PREFIX = "tokenizer.ggml.add_space_prefix" ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces" REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces"
PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap" PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap"

View File

@ -891,6 +891,9 @@ class GGUFWriter:
def add_add_eos_token(self, value: bool) -> None: def add_add_eos_token(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_EOS, value) self.add_bool(Keys.Tokenizer.ADD_EOS, value)
def add_add_sep_token(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_SEP, value)
def add_add_space_prefix(self, value: bool) -> None: def add_add_space_prefix(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value) self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)

View File

@ -119,6 +119,7 @@ class SpecialVocab:
logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping') logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
def _try_load_from_tokenizer_json(self, path: Path) -> bool: def _try_load_from_tokenizer_json(self, path: Path) -> bool:
tokenizer = None
tokenizer_file = path / 'tokenizer.json' tokenizer_file = path / 'tokenizer.json'
if tokenizer_file.is_file(): if tokenizer_file.is_file():
with open(tokenizer_file, encoding = 'utf-8') as f: with open(tokenizer_file, encoding = 'utf-8') as f:
@ -152,11 +153,87 @@ class SpecialVocab:
added_tokens = tokenizer.get('added_tokens', {}) added_tokens = tokenizer.get('added_tokens', {})
else: else:
added_tokens = {} added_tokens = {}
tokenizer_config = None
tokenizer_config_file = path / 'tokenizer_config.json' tokenizer_config_file = path / 'tokenizer_config.json'
if not tokenizer_config_file.is_file(): if tokenizer_config_file.is_file():
return True
with open(tokenizer_config_file, encoding = 'utf-8') as f: with open(tokenizer_config_file, encoding = 'utf-8') as f:
tokenizer_config = json.load(f) tokenizer_config = json.load(f)
if tokenizer:
special_bos = (tokenizer_config or {}).get('bos_token')
special_cls = (tokenizer_config or {}).get('cls_token')
special_eos = (tokenizer_config or {}).get('eos_token')
special_sep = (tokenizer_config or {}).get('sep_token')
if not special_bos and special_cls and tokenizer_config:
tokenizer_config['bos_token'] = special_bos = special_cls
if not special_eos and special_sep and tokenizer_config:
tokenizer_config['eos_token'] = special_eos = special_sep
post_processor = tokenizer.get('post_processor', {})
for processor in post_processor.get('processors', [post_processor]):
if processor.get('type') == 'RobertaProcessing':
self.add_special_token['bos'] = True
self.add_special_token['eos'] = True
self.add_special_token['sep'] = True
if not special_cls and tokenizer_config:
special_cls = processor.get('cls', [special_bos])[0]
tokenizer_config['cls_token'] = special_cls
if not special_sep and tokenizer_config:
special_sep = processor.get('sep', [special_eos])[0]
tokenizer_config['sep_token'] = special_sep
continue
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
# Only works with simple templates, **will** get it wrong on unusual sequences
if processor.get('type') == 'TemplateProcessing':
tmpl_single = processor.get('single', [])
tmpl_pair = processor.get('pair', [])
special_first = None
special_last = None
if len(tmpl_single) > 1:
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
if not tokenizer_config:
special_bos = special_first
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
if special_first not in (special_bos, special_cls):
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
if not tokenizer_config:
special_eos = special_last
self.add_special_token['eos'] = True if special_last == special_eos else False
if special_last != special_eos:
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
if tmpl_pair:
seq_start = 1 if tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
seq_stop = -1 if tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
if seq_start == 0 or seq_stop is None:
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
if tmpl_a != 'A' or tmpl_b != 'B':
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
# A [sep] [eos] B
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
add_sep = False
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
if special_entry in (special_sep, special_eos) and not special_last:
add_sep = True
if special_entry not in (special_sep, special_eos):
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
else:
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
if len(tmpl_pair) == 2:
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
if special_entry in (special_sep, special_eos):
add_sep = True
if special_entry not in (special_sep, special_eos):
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
else:
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
self.add_special_token['sep'] = add_sep
if add_sep and not special_sep and tokenizer_config:
tokenizer_config['sep_token'] = special_eos
continue
if not tokenizer_config:
return True
chat_template_alt = None chat_template_alt = None
chat_template_file = path / 'chat_template.json' chat_template_file = path / 'chat_template.json'
if chat_template_file.is_file(): if chat_template_file.is_file():

View File

@ -1044,6 +1044,7 @@ extern "C" {
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab); LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab); LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab); LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab); LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);

View File

@ -198,6 +198,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },

View File

@ -194,6 +194,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_MASK_ID, LLM_KV_TOKENIZER_MASK_ID,
LLM_KV_TOKENIZER_ADD_BOS, LLM_KV_TOKENIZER_ADD_BOS,
LLM_KV_TOKENIZER_ADD_EOS, LLM_KV_TOKENIZER_ADD_EOS,
LLM_KV_TOKENIZER_ADD_SEP,
LLM_KV_TOKENIZER_ADD_PREFIX, LLM_KV_TOKENIZER_ADD_PREFIX,
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,

View File

@ -228,6 +228,7 @@ void llama_model_saver::add_kv_from_model() {
// add_kv(LLM_KV_TOKENIZER_MASK_ID, ???); // add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos()); add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos()); add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
add_kv(LLM_KV_TOKENIZER_ADD_SEP, vocab.get_add_sep());
add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix()); add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces()); add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap()); add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());

View File

@ -1269,6 +1269,7 @@ struct llama_vocab::impl {
bool add_space_prefix = false; bool add_space_prefix = false;
bool add_bos = false; bool add_bos = false;
bool add_eos = false; bool add_eos = false;
bool add_sep = false;
bool ignore_merges = false; bool ignore_merges = false;
bool clean_spaces = false; // clean_up_tokenization_spaces bool clean_spaces = false; // clean_up_tokenization_spaces
bool remove_extra_whitespaces = false; bool remove_extra_whitespaces = false;
@ -1421,6 +1422,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
special_sep_id = 102; special_sep_id = 102;
special_pad_id = 0; special_pad_id = 0;
special_mask_id = 103; special_mask_id = 103;
add_sep = true;
} else if (tokenizer_model == "gpt2") { } else if (tokenizer_model == "gpt2") {
type = LLAMA_VOCAB_TYPE_BPE; type = LLAMA_VOCAB_TYPE_BPE;
@ -1550,12 +1553,15 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "jina-es" || tokenizer_pre == "jina-es" ||
tokenizer_pre == "jina-de" || tokenizer_pre == "jina-de" ||
tokenizer_pre == "gigachat" || tokenizer_pre == "gigachat" ||
tokenizer_pre == "jina-v1-en" ||
tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-es" ||
tokenizer_pre == "jina-v2-de" || tokenizer_pre == "jina-v2-de") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
} else if (
tokenizer_pre == "jina-v1-en" ||
tokenizer_pre == "jina-v2-code" || tokenizer_pre == "jina-v2-code" ||
tokenizer_pre == "roberta-bpe") { tokenizer_pre == "roberta-bpe") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
add_sep = true;
} else if ( } else if (
tokenizer_pre == "refact") { tokenizer_pre == "refact") {
pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT; pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT;
@ -1665,6 +1671,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
clean_spaces = true; clean_spaces = true;
add_bos = true; add_bos = true;
add_eos = false; add_eos = false;
add_sep = true;
} else if (type == LLAMA_VOCAB_TYPE_UGM) { } else if (type == LLAMA_VOCAB_TYPE_UGM) {
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
add_bos = false; add_bos = false;
@ -1801,7 +1808,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
} }
} }
// Handle add_bos and add_eos // Handle add_bos, add_eos and add_sep
{ {
bool temp = true; bool temp = true;
@ -1811,6 +1818,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) { if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
add_eos = temp; add_eos = temp;
} }
if (ml.get_key(LLM_KV_TOKENIZER_ADD_SEP, temp, false)) {
add_sep = temp;
}
} }
// auto-detect special tokens by text // auto-detect special tokens by text
@ -3000,6 +3010,10 @@ bool llama_vocab::get_add_eos() const {
return pimpl->add_eos; return pimpl->add_eos;
} }
bool llama_vocab::get_add_sep() const {
return pimpl->add_sep;
}
bool llama_vocab::get_ignore_merges() const { bool llama_vocab::get_ignore_merges() const {
return pimpl->ignore_merges; return pimpl->ignore_merges;
} }
@ -3191,6 +3205,10 @@ bool llama_vocab_get_add_eos(const struct llama_vocab * vocab) {
return vocab->get_add_eos(); return vocab->get_add_eos();
} }
bool llama_vocab_get_add_sep(const struct llama_vocab * vocab) {
return vocab->get_add_sep();
}
llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) { llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
return vocab->token_fim_pre(); return vocab->token_fim_pre();
} }

View File

@ -74,6 +74,7 @@ struct llama_vocab {
bool get_add_space_prefix () const; bool get_add_space_prefix () const;
bool get_add_bos () const; bool get_add_bos () const;
bool get_add_eos () const; bool get_add_eos () const;
bool get_add_sep () const;
bool get_ignore_merges () const; bool get_ignore_merges () const;
bool get_clean_spaces () const; bool get_clean_spaces () const;
bool get_remove_extra_whitespaces () const; bool get_remove_extra_whitespaces () const;

View File

@ -271,12 +271,20 @@ static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_
} }
result.reserve(doc.size() + query.size() + 4); result.reserve(doc.size() + query.size() + 4);
if (llama_vocab_get_add_bos(vocab)) {
result.push_back(llama_vocab_bos(vocab)); result.push_back(llama_vocab_bos(vocab));
}
result.insert(result.end(), query.begin(), query.end()); result.insert(result.end(), query.begin(), query.end());
if (llama_vocab_get_add_eos(vocab)) {
result.push_back(eos_token); result.push_back(eos_token);
}
if (llama_vocab_get_add_sep(vocab)) {
result.push_back(llama_vocab_sep(vocab)); result.push_back(llama_vocab_sep(vocab));
}
result.insert(result.end(), doc.begin(), doc.end()); result.insert(result.end(), doc.begin(), doc.end());
if (llama_vocab_get_add_eos(vocab)) {
result.push_back(eos_token); result.push_back(eos_token);
}
return result; return result;
} }