mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 20:05:20 +00:00
server
: inject date_string in llama 3.x template + fix date for firefunction v2 (#12802)
* Inject date_string in llama 3.x + fix for functionary v2 https://github.com/ggml-org/llama.cpp/issues/12729 * move/fix detection of functionary v3.1 before llama 3.x, fix & test their non-tool mode Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * generate more tokens in test_completion_with_required_tool_tiny_fast to avoid truncation --------- Co-authored-by: ochafik <ochafik@google.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
@ -6,6 +6,15 @@
|
||||
|
||||
#include <optional>
|
||||
|
||||
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
|
||||
auto time = std::chrono::system_clock::to_time_t(now);
|
||||
auto local_time = *std::localtime(&time);
|
||||
std::ostringstream ss;
|
||||
ss << std::put_time(&local_time, format.c_str());
|
||||
auto res = ss.str();
|
||||
return res;
|
||||
}
|
||||
|
||||
typedef minja::chat_template common_chat_template;
|
||||
|
||||
struct common_chat_templates {
|
||||
@ -24,6 +33,7 @@ struct templates_params {
|
||||
std::string grammar;
|
||||
bool add_generation_prompt = true;
|
||||
bool extract_reasoning = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
};
|
||||
|
||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
||||
@ -939,9 +949,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
||||
static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
||||
auto builtin_tools = json::array();
|
||||
common_chat_params data;
|
||||
if (!inputs.tools.is_null()) {
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
@ -1002,15 +1013,19 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
||||
}
|
||||
// Allow a few empty lines on top of the usual constrained json schema space rule.
|
||||
builder.add_rule("root", string_join(tool_rules, " | "));
|
||||
});
|
||||
data.additional_stops.push_back("<|eom_id|>");
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||
{"tools_in_user_message", false},
|
||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||
});
|
||||
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
|
||||
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
|
||||
: COMMON_CHAT_FORMAT_LLAMA_3_X;
|
||||
} else {
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
}
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
|
||||
{"date_string", format_time(inputs.now, "%d %b %Y")},
|
||||
{"tools_in_user_message", false},
|
||||
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
|
||||
});
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
|
||||
@ -1150,7 +1165,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
|
||||
LOG_DBG("%s\n", __func__);
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
|
||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||
});
|
||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||
@ -1285,7 +1300,8 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||
common_chat_params data;
|
||||
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
||||
|
||||
if (!inputs.tools.is_null()) {
|
||||
std::string python_code_argument_name;
|
||||
auto has_raw_python = false;
|
||||
|
||||
@ -1330,10 +1346,13 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
||||
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
|
||||
});
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||
} else {
|
||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||
}
|
||||
|
||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||
// TODO: if (has_raw_python)
|
||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
|
||||
return data;
|
||||
}
|
||||
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
|
||||
@ -1593,6 +1612,7 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
params.extract_reasoning = inputs.extract_reasoning;
|
||||
params.tool_choice = inputs.tool_choice;
|
||||
params.grammar = inputs.grammar;
|
||||
params.now = inputs.now;
|
||||
if (!inputs.json_schema.empty()) {
|
||||
params.json_schema = json::parse(inputs.json_schema);
|
||||
}
|
||||
@ -1644,21 +1664,21 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_firefunction_v2(tmpl, params);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
}
|
||||
|
||||
// Functionary v3.1 (w/ tools)
|
||||
if (src.find("<|start_header_id|>") != std::string::npos
|
||||
&& src.find("<function=") != std::string::npos) {
|
||||
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
|
||||
}
|
||||
|
||||
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
||||
// Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
|
||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
||||
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
|
||||
}
|
||||
|
||||
// Plain handler (no tools)
|
||||
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||
return common_chat_params_init_without_tools(tmpl, params);
|
||||
}
|
||||
|
||||
// Mistral Nemo (w/ tools)
|
||||
|
@ -3,6 +3,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include <chrono>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@ -71,6 +72,7 @@ struct common_chat_templates_inputs {
|
||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||
bool parallel_tool_calls = false;
|
||||
bool extract_reasoning = true;
|
||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||
};
|
||||
|
||||
struct common_chat_params {
|
||||
|
@ -833,6 +833,8 @@ static void test_template_output_parsers() {
|
||||
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||
common_chat_templates_apply(tmpls.get(), inputs_tools).format);
|
||||
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
|
||||
|
||||
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
|
||||
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
|
||||
|
49
tools/server/tests/unit/test_template.py
Normal file
49
tools/server/tests/unit/test_template.py
Normal file
@ -0,0 +1,49 @@
|
||||
#!/usr/bin/env python
|
||||
import pytest
|
||||
|
||||
# ensure grandparent path is in sys.path
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
from unit.test_tool_call import TEST_TOOL
|
||||
path = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(path))
|
||||
|
||||
import datetime
|
||||
from utils import *
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
TIMEOUT_SERVER_START = 15*60
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinyllama2()
|
||||
server.model_alias = "tinyllama-2"
|
||||
server.server_port = 8081
|
||||
server.n_slots = 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
|
||||
@pytest.mark.parametrize("template_name,format", [
|
||||
("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
|
||||
("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
|
||||
])
|
||||
def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
|
||||
global server
|
||||
server.jinja = True
|
||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||
|
||||
res = server.make_request("POST", "/apply-template", data={
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is today?"},
|
||||
],
|
||||
"tools": tools,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
prompt = res.body["prompt"]
|
||||
|
||||
today_str = datetime.date.today().strftime(format)
|
||||
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"
|
@ -109,7 +109,7 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
|
||||
])
|
||||
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
|
||||
global server
|
||||
n_predict = 512
|
||||
n_predict = 1024
|
||||
# server = ServerPreset.stories15m_moe()
|
||||
server.jinja = True
|
||||
server.n_predict = n_predict
|
||||
|
Reference in New Issue
Block a user