server: inject date_string in llama 3.x template + fix date for firefunction v2 (#12802)

* Inject date_string in llama 3.x + fix for functionary v2

https://github.com/ggml-org/llama.cpp/issues/12729

* move/fix detection of functionary v3.1 before llama 3.x, fix & test their non-tool mode

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* generate more tokens in test_completion_with_required_tool_tiny_fast to avoid truncation

---------

Co-authored-by: ochafik <ochafik@google.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Olivier Chafik
2025-05-15 02:39:51 +01:00
committed by GitHub
parent e3a9421b78
commit aa48e373f2
5 changed files with 185 additions and 112 deletions

View File

@ -6,6 +6,15 @@
#include <optional> #include <optional>
static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) {
auto time = std::chrono::system_clock::to_time_t(now);
auto local_time = *std::localtime(&time);
std::ostringstream ss;
ss << std::put_time(&local_time, format.c_str());
auto res = ss.str();
return res;
}
typedef minja::chat_template common_chat_template; typedef minja::chat_template common_chat_template;
struct common_chat_templates { struct common_chat_templates {
@ -24,6 +33,7 @@ struct templates_params {
std::string grammar; std::string grammar;
bool add_generation_prompt = true; bool add_generation_prompt = true;
bool extract_reasoning = true; bool extract_reasoning = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
}; };
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
@ -939,78 +949,83 @@ static void expect_tool_parameters(const std::string & name, const json & parame
} }
} }
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) { static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
auto builtin_tools = json::array(); auto builtin_tools = json::array();
common_chat_params data; common_chat_params data;
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; if (!inputs.tools.is_null()) {
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
std::vector<std::string> tool_rules; data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") { if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
expect_tool_parameters(name, parameters, {"query"}); expect_tool_parameters(name, parameters, {"query"});
} else if (name == "python" || name == "code_interpreter") { } else if (name == "python" || name == "code_interpreter") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py
expect_tool_parameters(name, parameters, {"code"}); expect_tool_parameters(name, parameters, {"code"});
} else { } else {
return false; return false;
}
std::vector<std::string> kvs;
for (const auto & [key, value] : parameters.at("properties").items()) {
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
builtin_tools.push_back(name);
return true;
};
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
if (allow_python_tag_builtin_tools) {
handle_builtin_tool(name, parameters);
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"{\" space "
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
"\"}\" space"));
});
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name.
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*",
});
if (!builtin_tools.empty()) {
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
} }
// Allow a few empty lines on top of the usual constrained json schema space rule.
std::vector<std::string> kvs; builder.add_rule("root", string_join(tool_rules, " | "));
for (const auto & [key, value] : parameters.at("properties").items()) { data.additional_stops.push_back("<|eom_id|>");
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\""));
builtin_tools.push_back(name);
return true;
};
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
// https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime
if (allow_python_tag_builtin_tools) {
handle_builtin_tool(name, parameters);
}
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"{\" space "
"( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? "
" \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space "
" \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " "
"\"}\" space"));
}); });
// Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name. data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
data.grammar_triggers.push_back({ ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, : COMMON_CHAT_FORMAT_LLAMA_3_X;
"\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*", } else {
}); data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
if (!builtin_tools.empty()) { }
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
}
// Allow a few empty lines on top of the usual constrained json schema space rule.
builder.add_rule("root", string_join(tool_rules, " | "));
});
data.additional_stops.push_back("<|eom_id|>");
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {
{"date_string", format_time(inputs.now, "%d %b %Y")},
{"tools_in_user_message", false}, {"tools_in_user_message", false},
{"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools},
}); });
data.format = allow_python_tag_builtin_tools && !builtin_tools.empty()
? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS
: COMMON_CHAT_FORMAT_LLAMA_3_X;
return data; return data;
} }
static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) { static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) {
@ -1150,7 +1165,7 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c
LOG_DBG("%s\n", __func__); LOG_DBG("%s\n", __func__);
common_chat_params data; common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
{"datetime", "Jan 29 2025 13:00:00 GMT"}, {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
}); });
if (inputs.tools.is_array() && !inputs.tools.empty()) { if (inputs.tools.is_array() && !inputs.tools.empty()) {
@ -1285,55 +1300,59 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_params data; common_chat_params data;
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
std::string python_code_argument_name;
auto has_raw_python = false;
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; if (!inputs.tools.is_null()) {
data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::string python_code_argument_name;
std::vector<std::string> tool_rules; auto has_raw_python = false;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function"); data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
const auto & parameters = function.at("parameters"); data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::string name = function.at("name"); std::vector<std::string> tool_rules;
if (name == "python" || name == "ipython") { foreach_function(inputs.tools, [&](const json & tool) {
if (!parameters.contains("type")) { const auto & function = tool.at("function");
throw std::runtime_error("Missing type in python tool"); const auto & parameters = function.at("parameters");
} std::string name = function.at("name");
has_raw_python = true; if (name == "python" || name == "ipython") {
const auto & type = parameters.at("type"); if (!parameters.contains("type")) {
if (type == "object") { throw std::runtime_error("Missing type in python tool");
auto properties = parameters.at("properties"); }
for (auto it = properties.begin(); it != properties.end(); ++it) { has_raw_python = true;
if (it.value().at("type") == "string") { const auto & type = parameters.at("type");
if (!python_code_argument_name.empty()) { if (type == "object") {
throw std::runtime_error("Multiple string arguments found in python tool"); auto properties = parameters.at("properties");
for (auto it = properties.begin(); it != properties.end(); ++it) {
if (it.value().at("type") == "string") {
if (!python_code_argument_name.empty()) {
throw std::runtime_error("Multiple string arguments found in python tool");
}
python_code_argument_name = it.key();
} }
python_code_argument_name = it.key();
} }
if (python_code_argument_name.empty()) {
throw std::runtime_error("No string argument found in python tool");
}
} else if (type != "string") {
throw std::runtime_error("Invalid type in python tool: " + type.dump());
} }
if (python_code_argument_name.empty()) {
throw std::runtime_error("No string argument found in python tool");
}
} else if (type != "string") {
throw std::runtime_error("Invalid type in python tool: " + type.dump());
} }
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space"));
});
if (has_raw_python) {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"});
data.preserved_tokens.push_back("<|python_tag|>");
} }
tool_rules.push_back(builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\" space")); auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
}); });
if (has_raw_python) { data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); } else {
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
data.preserved_tokens.push_back("<|python_tag|>"); }
}
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space";
builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<function="});
});
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
// TODO: if (has_raw_python) // TODO: if (has_raw_python)
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1;
return data; return data;
} }
static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) { static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::string & input) {
@ -1593,6 +1612,7 @@ static common_chat_params common_chat_templates_apply_jinja(
params.extract_reasoning = inputs.extract_reasoning; params.extract_reasoning = inputs.extract_reasoning;
params.tool_choice = inputs.tool_choice; params.tool_choice = inputs.tool_choice;
params.grammar = inputs.grammar; params.grammar = inputs.grammar;
params.now = inputs.now;
if (!inputs.json_schema.empty()) { if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema); params.json_schema = json::parse(inputs.json_schema);
} }
@ -1644,21 +1664,21 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_firefunction_v2(tmpl, params); return common_chat_params_init_firefunction_v2(tmpl, params);
} }
// Plain handler (no tools)
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params);
}
// Functionary v3.1 (w/ tools) // Functionary v3.1 (w/ tools)
if (src.find("<|start_header_id|>") != std::string::npos if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) { && src.find("<function=") != std::string::npos) {
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params); return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
} }
// Llama 3.1, 3.2, 3.3 (w/ tools) // Llama 3.1, 3.2, 3.3 (also requires date_string so using it even w/o tools)
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) { if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools); return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
}
// Plain handler (no tools)
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params);
} }
// Mistral Nemo (w/ tools) // Mistral Nemo (w/ tools)

View File

@ -3,6 +3,7 @@
#pragma once #pragma once
#include "common.h" #include "common.h"
#include <chrono>
#include <string> #include <string>
#include <vector> #include <vector>
@ -71,6 +72,7 @@ struct common_chat_templates_inputs {
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false; bool parallel_tool_calls = false;
bool extract_reasoning = true; bool extract_reasoning = true;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
}; };
struct common_chat_params { struct common_chat_params {

View File

@ -832,7 +832,9 @@ static void test_template_output_parsers() {
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
common_chat_templates_apply(tmpls.get(), inputs_tools).format); common_chat_templates_apply(tmpls.get(), inputs_tools).format);
assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
test_templates(tmpls.get(), end_tokens, message_assist_call, tools, test_templates(tmpls.get(), end_tokens, message_assist_call, tools,

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python
import pytest
# ensure grandparent path is in sys.path
from pathlib import Path
import sys
from unit.test_tool_call import TEST_TOOL
path = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(path))
import datetime
from utils import *
server: ServerProcess
TIMEOUT_SERVER_START = 15*60
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.model_alias = "tinyllama-2"
server.server_port = 8081
server.n_slots = 1
@pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]])
@pytest.mark.parametrize("template_name,format", [
("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"),
("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"),
])
def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
global server
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/apply-template", data={
"messages": [
{"role": "user", "content": "What is today?"},
],
"tools": tools,
})
assert res.status_code == 200
prompt = res.body["prompt"]
today_str = datetime.date.today().strftime(format)
assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})"

View File

@ -109,7 +109,7 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
]) ])
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None): def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
global server global server
n_predict = 512 n_predict = 1024
# server = ServerPreset.stories15m_moe() # server = ServerPreset.stories15m_moe()
server.jinja = True server.jinja = True
server.n_predict = n_predict server.n_predict = n_predict