From 03f582ae8fccecff225c30a2802461b44761e822 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 26 May 2025 08:03:57 -0700 Subject: [PATCH] server: fix streaming crashes (#13786) * add preludes to content on partial regex match * allow all parsers to parse non-tool-call content. * tweak order of <|python_tag|> vs common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) { +std::optional common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) { auto m = regex.search(input_, from == std::string::npos ? pos_ : from); if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) { return std::nullopt; } + auto prelude = input_.substr(pos_, m.groups[0].begin - pos_); + pos_ = m.groups[0].end; + + if (add_prelude_to_content) { + add_content(prelude); + } if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) { if (is_partial()) { throw common_chat_msg_partial_exception(regex.str()); } return std::nullopt; } - auto prelude = input_.substr(pos_, m.groups[0].begin - pos_); - pos_ = m.groups[0].end; - return find_regex_result{prelude, m.groups}; } diff --git a/common/chat-parser.h b/common/chat-parser.h index b21b32b8a..5d53f2df1 100644 --- a/common/chat-parser.h +++ b/common/chat-parser.h @@ -30,6 +30,7 @@ class common_chat_msg_parser { const std::string & healing_marker() const { return healing_marker_; } const bool & is_partial() const { return is_partial_; } const common_chat_msg & result() const { return result_; } + const common_chat_syntax & syntax() const { return syntax_; } void move_to(size_t pos) { if (pos > input_.size()) { @@ -77,7 +78,7 @@ class common_chat_msg_parser { std::vector groups; }; - std::optional try_find_regex(const common_regex & regex, size_t from = std::string::npos); + std::optional try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true); bool try_consume_literal(const std::string & literal); diff --git a/common/chat.cpp b/common/chat.cpp index 2e6a964bb..7584639b0 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -656,7 +656,6 @@ static void parse_json_tool_calls( } from = std::string::npos; - builder.add_content(res->prelude); auto maybe_raw_python = name == "python" && allow_raw_python; if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { @@ -686,7 +685,6 @@ static void parse_json_tool_calls( }; if (block_open) { if (auto res = builder.try_find_regex(*block_open)) { - builder.add_content(res->prelude); parse_tool_calls(); } else { builder.add_content(builder.consume_rest()); @@ -699,7 +697,6 @@ static void parse_json_tool_calls( static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) { static const std::vector> args_paths = {{"arguments"}}; if (auto res = builder.try_find_regex(prefix)) { - builder.add_content(res->prelude); builder.move_back(rstrip_prefix); auto tool_calls = builder.consume_json_with_dumped_args(args_paths); if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { @@ -835,6 +832,10 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp return data; } static void common_chat_parse_generic(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } static const std::vector> content_paths = { {"response"}, }; @@ -907,6 +908,11 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat return data; } static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); parse_prefixed_json_tool_call_array(builder, prefix); } @@ -1001,7 +1007,6 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { if (auto res = builder.try_find_regex(start_action_regex)) { // If we didn't extract thoughts, prelude includes them. - builder.add_content(res->prelude); auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); for (const auto & tool_call : tool_calls.value) { std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; @@ -1016,11 +1021,7 @@ static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { } builder.consume_regex(end_action_regex); } else if (auto res = builder.try_find_regex(start_response_regex)) { - // If we didn't extract thoughts, prelude includes them. - builder.add_content(res->prelude); - if (auto res = builder.try_find_regex(end_response_regex)) { - builder.add_content(res->prelude); - } else { + if (!builder.try_find_regex(end_response_regex)) { builder.add_content(builder.consume_rest()); throw common_chat_msg_partial_exception(end_response_regex.str()); } @@ -1128,6 +1129,11 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te return data; } static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const common_regex function_regex( "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); static const common_regex close_regex("\\}\\s*"); @@ -1138,8 +1144,6 @@ static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool w if (with_builtin_tools) { static const common_regex builtin_call_regex("<\\|python_tag\\|>"); if (auto res = builder.try_find_regex(builtin_call_regex)) { - builder.add_content(res->prelude); - auto fun_res = builder.consume_regex(function_name_regex); auto function_name = builder.str(fun_res.groups[1]); @@ -1255,6 +1259,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ } static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); static const common_regex tool_calls_end("<|tool▁calls▁end|>"); @@ -1316,6 +1324,10 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c return data; } static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } static const common_regex prefix(regex_escape(" functools[")); parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); } @@ -1457,15 +1469,12 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con return data; } static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { - // This version of Functionary still supports the llama 3.1 tool call format for the python tool. - static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); - - if (auto res = builder.try_find_regex(python_tag_regex)) { - builder.add_content(res->prelude); - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - builder.add_tool_call("python", "", arguments); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); return; } + // This version of Functionary still supports the llama 3.1 tool call format for the python tool. + static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); static const common_regex function_regex(R"()"); static const common_regex close_regex(R"()"); @@ -1477,6 +1486,12 @@ static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser function_regex, close_regex, std::nullopt); + + if (auto res = builder.try_find_regex(python_tag_regex)) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + builder.add_tool_call("python", "", arguments); + return; + } } static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { @@ -1595,6 +1610,10 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat } static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } static const common_regex open_regex( "(?:" @@ -1616,8 +1635,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { ); if (auto res = builder.try_find_regex(open_regex)) { - builder.add_content(res->prelude); - const auto & block_start = res->groups[1]; std::string block_end = block_start.empty() ? "" : "```"; @@ -1853,10 +1870,10 @@ static void common_chat_parse_content_only(common_chat_msg_parser & builder) { builder.add_content(builder.consume_rest()); } -static void common_chat_parse(common_chat_msg_parser & builder, common_chat_format format) { - LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(format), builder.input().c_str()); +static void common_chat_parse(common_chat_msg_parser & builder) { + LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); - switch (format) { + switch (builder.syntax().format) { case COMMON_CHAT_FORMAT_CONTENT_ONLY: common_chat_parse_content_only(builder); break; @@ -1891,7 +1908,7 @@ static void common_chat_parse(common_chat_msg_parser & builder, common_chat_form common_chat_parse_command_r7b(builder); break; default: - throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(format)); + throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } builder.finish(); } @@ -1899,7 +1916,7 @@ static void common_chat_parse(common_chat_msg_parser & builder, common_chat_form common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { common_chat_msg_parser builder(input, is_partial, syntax); try { - common_chat_parse(builder, syntax.format); + common_chat_parse(builder); } catch (const common_chat_msg_partial_exception & ex) { LOG_DBG("Partial parse: %s\n", ex.what()); if (!is_partial) { diff --git a/common/chat.h b/common/chat.h index 3e2cbbaae..f6b1d0ffc 100644 --- a/common/chat.h +++ b/common/chat.h @@ -144,6 +144,7 @@ struct common_chat_syntax { // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) bool reasoning_in_content = false; bool thinking_forced_open = false; + bool parse_tool_calls = true; }; // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 5f542f022..95d516991 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -401,9 +401,12 @@ static common_chat_msg simple_assist_msg(const std::string & content, const std: } return msg; } -const common_chat_msg message_assist = simple_assist_msg("Hello, world!\nWhat's up?"); -const common_chat_msg message_assist_empty = simple_assist_msg(""); -const common_chat_msg message_assist_thoughts_unparsed_deepseek = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"); +const common_chat_msg message_assist = simple_assist_msg("Hello, world!\nWhat's up?"); +const common_chat_msg message_assist_empty = simple_assist_msg(""); +const common_chat_msg message_assist_thoughts_unparsed_deepseek = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"); +const common_chat_msg message_assist_thoughts_unparsed_md = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```"); +const common_chat_msg message_assist_thoughts_unparsed_md_partial = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}"); + const common_chat_msg message_assist_thoughts_unparsed_r7b = simple_assist_msg("<|START_THINKING|>I'm\nthinking<|END_THINKING|>Hello, world!\nWhat's up?"); const common_chat_msg message_assist_thoughts = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"); const common_chat_msg message_assist_thoughts_unopened_unparsed = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"); @@ -591,8 +594,6 @@ static void test_template_output_parsers() { { /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, })); assert_msg_equals(message_assist_thoughts_unparsed_deepseek, common_chat_parse( @@ -619,8 +620,6 @@ static void test_template_output_parsers() { { /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, })); assert_msg_equals(message_assist_thoughts_call_idx, common_chat_parse( @@ -632,8 +631,6 @@ static void test_template_output_parsers() { { /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, })); assert_msg_equals(message_assist_thoughts_no_content, common_chat_parse( @@ -644,8 +641,6 @@ static void test_template_output_parsers() { { /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, })); test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools, @@ -675,6 +670,18 @@ static void test_template_output_parsers() { // Generic tool calls doesn't generate / parse content-only messages symmetrically. + assert_equals( + simple_assist_msg("{ \"tool_call\" : { \"name\" : \"t"), + common_chat_parse( + "{ \"tool_call\" : { \"name\" : \"t", + /* is_partial= */ true, + { + /* .format = */ COMMON_CHAT_FORMAT_GENERIC, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ false, + /* .thinking_forced_open = */ true, + /* .parse_tool_calls = */ false, + })); assert_equals( message_assist_empty, common_chat_parse( @@ -776,11 +783,9 @@ static void test_template_output_parsers() { { /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, })); assert_msg_equals( - simple_assist_msg(""), + simple_assist_msg("Let's call something\n"), common_chat_parse( "Let's call something\n" "{\"name", @@ -788,8 +793,6 @@ static void test_template_output_parsers() { { /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, })); assert_msg_equals(message_assist_call_thoughts, common_chat_parse( @@ -979,7 +982,34 @@ static void test_template_output_parsers() { { /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, + })); + assert_msg_equals(message_assist_thoughts, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?", + /* is_partial= */ true, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + })); + assert_msg_equals(message_assist_thoughts_unparsed_md, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```", + /* is_partial= */ false, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ true, + /* .thinking_forced_open = */ false, + /* .parse_tool_calls = */ false, + })); + assert_msg_equals(message_assist_thoughts_unparsed_md_partial, + common_chat_parse( + "I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```", + /* is_partial= */ true, + { + /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, + /* .reasoning_in_content = */ true, /* .thinking_forced_open = */ false, })); assert_msg_equals(message_assist_thoughts_unopened_unparsed, @@ -989,8 +1019,6 @@ static void test_template_output_parsers() { { /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, })); assert_msg_equals(message_assist_thoughts, common_chat_parse( @@ -1073,6 +1101,13 @@ static void test_template_output_parsers() { {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1})); } + assert_equals( + message_assist_call, + common_chat_parse( + "{\"arg1\": 1}<", + /* is_partial= */ true, + {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1})); + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "{\"arg1\": 1}"); @@ -1187,8 +1222,6 @@ static void test_template_output_parsers() { { /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, })); assert_msg_equals(message_assist_thoughts_unopened_unparsed, common_chat_parse( @@ -1197,8 +1230,6 @@ static void test_template_output_parsers() { { /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, })); assert_msg_equals(message_assist_thoughts, common_chat_parse( @@ -1252,8 +1283,6 @@ static void test_template_output_parsers() { { /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, })); assert_msg_equals(message_assist_thoughts, common_chat_parse( @@ -1295,8 +1324,6 @@ static void test_template_output_parsers() { { /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, })); test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" diff --git a/tools/server/server.cpp b/tools/server/server.cpp index fcab1dfaa..fe6c685ec 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -364,6 +364,7 @@ struct server_task { params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format; params.oaicompat_chat_syntax.reasoning_in_content = params.stream; params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); + params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); } { diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index fc9f7071e..8456a02e6 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -735,8 +735,11 @@ static json oaicompat_chat_params_parse( inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); inputs.reasoning_format = opt.reasoning_format; inputs.enable_thinking = opt.enable_thinking; - if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) { - throw std::runtime_error("Cannot use custom grammar constraints with tools."); + if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { + if (body.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + llama_params["parse_tool_calls"] = true; } // if the assistant message appears at the end of list, we do not add end-of-turn token