Add --jinja and --chat-template-file flags

This commit is contained in:
ochafik
2024-12-30 03:40:34 +00:00
parent abd274a48f
commit e5113e8d74
12 changed files with 289 additions and 50 deletions

View File

@@ -4,22 +4,24 @@ from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja",
[
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False),
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"model": model,
@@ -102,6 +104,7 @@ def test_chat_completion_with_openai_library():
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
({"type": "json_schema", "json_schema": {"const": "42"}}, 6, "\"42\""),
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
({"type": "json_object"}, 10, "(\\{|John)+"),
({"type": "sound"}, 0, None),

View File

@@ -68,8 +68,9 @@ class ServerProcess:
pooling: str | None = None
draft: int | None = None
api_key: str | None = None
response_format: str | None = None
lora_files: List[str] | None = None
chat_template_file: str | None = None
jinja: bool | None = None
disable_ctx_shift: int | None = False
draft_min: int | None = None
draft_max: int | None = None
@@ -154,6 +155,10 @@ class ServerProcess:
if self.lora_files:
for lora_file in self.lora_files:
server_args.extend(["--lora", lora_file])
if self.chat_template_file:
server_args.extend(["--chat-template-file", self.chat_template_file])
if self.jinja:
server_args.append("--jinja")
if self.disable_ctx_shift:
server_args.extend(["--no-context-shift"])
if self.api_key: