diff --git a/tools/server/tests/unit/test_template.py b/tools/server/tests/unit/test_template.py index cf9f96a7f..7bb857b33 100644 --- a/tools/server/tests/unit/test_template.py +++ b/tools/server/tests/unit/test_template.py @@ -47,3 +47,28 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]): today_str = datetime.date.today().strftime(format) assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})" + + +@pytest.mark.parametrize("add_generation_prompt", [False, True]) +@pytest.mark.parametrize("template_name,expected_generation_prompt", [ + ("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"), +]) +def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool): + global server + server.jinja = True + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + + res = server.make_request("POST", "/apply-template", data={ + "messages": [ + {"role": "user", "content": "What is today?"}, + ], + "add_generation_prompt": add_generation_prompt, + }) + assert res.status_code == 200 + prompt = res.body["prompt"] + + if add_generation_prompt: + assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})" + else: + assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})" diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 91efcfef0..ee33f76c2 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -731,6 +731,7 @@ static json oaicompat_chat_params_parse( inputs.grammar = grammar; inputs.use_jinja = opt.use_jinja; inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); inputs.reasoning_format = opt.reasoning_format; if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) { throw std::runtime_error("Cannot use custom grammar constraints with tools.");