From f324a3b715d5c1081c110ce459f8a8486fb1ee89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Tue, 5 Aug 2025 20:43:36 +0200 Subject: [PATCH] chat : only remove double bos/eos if added (#15086) * only remove double bos/eos if added * fix tests --- common/chat.cpp | 20 ++++++++++++++++++-- common/chat.h | 2 ++ tests/test-chat-template.cpp | 6 +++--- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index c5a840e80..9ba743d1c 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -126,6 +126,8 @@ std::vector common_chat_msg_diff::compute_diffs(const comm typedef minja::chat_template common_chat_template; struct common_chat_templates { + bool add_bos; + bool add_eos; bool has_explicit_template; // Model had builtin template or template overridde was specified. std::unique_ptr template_default; // always set (defaults to chatml) std::unique_ptr template_tool_use; @@ -143,6 +145,8 @@ struct templates_params { bool enable_thinking = true; std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); json extra_context; + bool add_bos; + bool add_eos; }; common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { @@ -445,6 +449,8 @@ std::string common_chat_format_single( common_chat_templates_inputs inputs; inputs.use_jinja = use_jinja; + inputs.add_bos = tmpls->add_bos; + inputs.add_eos = tmpls->add_eos; std::string fmt_past_msg; if (!past_msg.empty()) { @@ -469,6 +475,8 @@ std::string common_chat_format_single( std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) { common_chat_templates_inputs inputs; inputs.use_jinja = use_jinja; + inputs.add_bos = tmpls->add_bos; + inputs.add_eos = tmpls->add_eos; auto add_simple_msg = [&](auto role, auto content) { common_chat_msg msg; msg.role = role; @@ -546,6 +554,8 @@ common_chat_templates_ptr common_chat_templates_init( } std::string token_bos = bos_token_override; std::string token_eos = eos_token_override; + bool add_bos = false; + bool add_eos = false; if (model) { const auto * vocab = llama_model_get_vocab(model); const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { @@ -560,9 +570,13 @@ common_chat_templates_ptr common_chat_templates_init( }; token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); + add_bos = llama_vocab_get_add_bos(vocab); + add_eos = llama_vocab_get_add_eos(vocab); } common_chat_templates_ptr tmpls(new common_chat_templates()); tmpls->has_explicit_template = has_explicit_template; + tmpls->add_bos = add_bos; + tmpls->add_eos = add_eos; try { tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); } catch (const std::exception & e) { @@ -748,10 +762,10 @@ static std::string apply( // instead of using `chat_template_options.use_bos_token = false`, since these tokens // may be needed inside the template / between messages too. auto result = tmpl.apply(tmpl_inputs, tmpl_opts); - if (string_starts_with(result, tmpl.bos_token())) { + if (inputs.add_bos && string_starts_with(result, tmpl.bos_token())) { result = result.substr(tmpl.bos_token().size()); } - if (string_ends_with(result, tmpl.eos_token())) { + if (inputs.add_eos && string_ends_with(result, tmpl.eos_token())) { result = result.substr(0, result.size() - tmpl.eos_token().size()); } return result; @@ -1731,6 +1745,8 @@ static common_chat_params common_chat_templates_apply_jinja( params.enable_thinking = inputs.enable_thinking; params.grammar = inputs.grammar; params.now = inputs.now; + params.add_bos = inputs.add_bos; + params.add_eos = inputs.add_eos; params.extra_context = json::object(); for (auto el : inputs.chat_template_kwargs) { diff --git a/common/chat.h b/common/chat.h index ca807c145..512f03a4e 100644 --- a/common/chat.h +++ b/common/chat.h @@ -127,6 +127,8 @@ struct common_chat_templates_inputs { bool enable_thinking = true; std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); std::map chat_template_kwargs; + bool add_bos = false; + bool add_eos = false; }; struct common_chat_params { diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index a0a50f988..321ae7306 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -61,7 +61,7 @@ int main(void) { /* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)", /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", - /* .expected_output_jinja= */ "", + /* .expected_output_jinja= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", /* .bos_token= */ "", /* .eos_token= */ "", }, @@ -85,7 +85,7 @@ int main(void) { /* .name= */ "mlabonne/AlphaMonarch-7B", /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", - /* .expected_output_jinja= */ "", + /* .expected_output_jinja= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", /* .bos_token= */ "", /* .eos_token= */ "", }, @@ -99,7 +99,7 @@ int main(void) { /* .name= */ "OrionStarAI/Orion-14B-Chat", /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", /* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", - /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", /* .bos_token= */ "", /* .eos_token= */ "", },