mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-19 00:57:41 +00:00
server : fix assistant prefilling when content is an array (#14360)
This commit is contained in:
@ -132,6 +132,28 @@ def test_chat_template():
|
|||||||
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("prefill,re_prefill", [
|
||||||
|
("Whill", "Whill"),
|
||||||
|
([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
|
||||||
|
])
|
||||||
|
def test_chat_template_assistant_prefill(prefill, re_prefill):
|
||||||
|
global server
|
||||||
|
server.chat_template = "llama3"
|
||||||
|
server.debug = True # to get the "__verbose" object in the response
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
"max_tokens": 8,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "Book"},
|
||||||
|
{"role": "user", "content": "What is the best book"},
|
||||||
|
{"role": "assistant", "content": prefill},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "__verbose" in res.body
|
||||||
|
assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
|
||||||
|
|
||||||
|
|
||||||
def test_apply_chat_template():
|
def test_apply_chat_template():
|
||||||
global server
|
global server
|
||||||
server.chat_template = "command-r"
|
server.chat_template = "command-r"
|
||||||
@ -228,6 +250,7 @@ def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re
|
|||||||
[{"role": "system", "content": 123}],
|
[{"role": "system", "content": 123}],
|
||||||
# [{"content": "hello"}], # TODO: should not be a valid case
|
# [{"content": "hello"}], # TODO: should not be a valid case
|
||||||
[{"role": "system", "content": "test"}, {}],
|
[{"role": "system", "content": "test"}, {}],
|
||||||
|
[{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
|
||||||
])
|
])
|
||||||
def test_invalid_chat_completion_req(messages):
|
def test_invalid_chat_completion_req(messages):
|
||||||
global server
|
global server
|
||||||
|
@ -792,7 +792,13 @@ static json oaicompat_chat_params_parse(
|
|||||||
|
|
||||||
/* Append assistant prefilled message */
|
/* Append assistant prefilled message */
|
||||||
if (prefill_assistant_message) {
|
if (prefill_assistant_message) {
|
||||||
chat_params.prompt += last_message.content;
|
if (!last_message.content_parts.empty()) {
|
||||||
|
for (auto & p : last_message.content_parts) {
|
||||||
|
chat_params.prompt += p.text;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
chat_params.prompt += last_message.content;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
||||||
|
Reference in New Issue
Block a user