diff --git a/common/chat.cpp b/common/chat.cpp index 0c777d7a7..c5a840e80 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1646,7 +1646,7 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { "|" // match 5 (function name again) ); - if (auto res = builder.try_find_regex(open_regex)) { + while (auto res = builder.try_find_regex(open_regex)) { const auto & block_start = res->groups[1]; std::string block_end = block_start.empty() ? "" : "```"; @@ -1668,7 +1668,6 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { builder.consume_literal(block_end); builder.consume_spaces(); } - builder.add_content(builder.consume_rest()); } else { throw common_chat_msg_partial_exception("failed to parse tool call"); } @@ -1693,11 +1692,10 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { builder.consume_spaces(); } } - builder.add_content(builder.consume_rest()); } - } else { - builder.add_content(builder.consume_rest()); } + + builder.add_content(builder.consume_rest()); } static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 6ebf1464d..73c98bfa2 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -953,6 +953,33 @@ static void test_template_output_parsers() { /* is_partial= */ false, {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + // Test multiple tool calls + common_chat_msg message_assist_multiple_calls; + message_assist_multiple_calls.role = "assistant"; + message_assist_multiple_calls.content = ""; + message_assist_multiple_calls.tool_calls.push_back({"special_function", "{\"arg1\": 1}", ""}); + message_assist_multiple_calls.tool_calls.push_back({"python", "{\"code\":\"print('hello')\"}", ""}); + + assert_msg_equals( + message_assist_multiple_calls, + common_chat_parse( + "\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "\n" + "\n" + "{\"name\": \"python\", \"arguments\": {\"code\":\"print('hello')\"}}\n" + "", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + + assert_msg_equals( + message_assist_multiple_calls, + common_chat_parse( + "{\"arg1\": 1}\n" + "{\"code\":\"print('hello')\"}", + /* is_partial= */ false, + {COMMON_CHAT_FORMAT_HERMES_2_PRO})); + assert_msg_equals( simple_assist_msg( "This is not a tool call:", @@ -1039,6 +1066,22 @@ static void test_template_output_parsers() { "\n" "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" ""); + + // Test multiple tool calls with template + common_chat_msg message_assist_multiple_calls_template; + message_assist_multiple_calls_template.role = "assistant"; + message_assist_multiple_calls_template.content = ""; + message_assist_multiple_calls_template.tool_calls.push_back({"special_function", "{\"arg1\": 1}", ""}); + message_assist_multiple_calls_template.tool_calls.push_back({"python", "{\"code\":\"print('test')\"}", ""}); + + test_templates(tmpls.get(), end_tokens, message_assist_multiple_calls_template, tools, + "\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "\n" + "\n" + "{\"name\": \"python\", \"arguments\": {\"code\":\"print('test')\"}}\n" + ""); + test_templates(tmpls.get(), end_tokens, message_assist_call_python_lines, tools, "\n" "{\"name\": \"python\", \"arguments\": {\"code\":\"# This is a program:\\nprint('hey')\"}}\n"