tool-call: refactor common chat / tool-call api (+ tests / fixes) (#11900)

* tool-call refactoring: moved common_chat_* to chat.h, common_chat_templates_init return a unique_ptr to opaque type

* addressed clang-tidy lints in [test-]chat.*

* rm minja deps from util & common & move it to common/minja/

* add name & tool_call_id to common_chat_msg

* add common_chat_tool

* added json <-> tools, msgs conversions to chat.h

* fix double bos/eos jinja avoidance hack (was preventing inner bos/eos tokens)

* fix deepseek r1 slow test (no longer <think> opening w/ new template)

* allow empty tools w/ auto + grammar

* fix & test server grammar & json_schema params w/ & w/o --jinja
This commit is contained in:
Olivier Chafik
2025-02-18 18:03:23 +00:00
committed by GitHub
parent 63ac128563
commit 63e489c025
18 changed files with 1385 additions and 993 deletions

View File

@ -1364,7 +1364,7 @@ llama-server: \
examples/server/index.html.hpp \ examples/server/index.html.hpp \
examples/server/loading.html.hpp \ examples/server/loading.html.hpp \
common/chat.cpp \ common/chat.cpp \
common/chat.hpp \ common/chat.h \
common/chat-template.hpp \ common/chat-template.hpp \
common/json.hpp \ common/json.hpp \
common/minja.hpp \ common/minja.hpp \

View File

@ -57,8 +57,7 @@ add_library(${TARGET} STATIC
arg.h arg.h
base64.hpp base64.hpp
chat.cpp chat.cpp
chat.hpp chat.h
chat-template.hpp
common.cpp common.cpp
common.h common.h
console.cpp console.cpp
@ -68,7 +67,8 @@ add_library(${TARGET} STATIC
llguidance.cpp llguidance.cpp
log.cpp log.cpp
log.h log.h
minja.hpp minja/chat-template.hpp
minja/minja.hpp
ngram-cache.cpp ngram-cache.cpp
ngram-cache.h ngram-cache.h
sampling.cpp sampling.cpp

View File

@ -2,6 +2,7 @@
#include "log.h" #include "log.h"
#include "sampling.h" #include "sampling.h"
#include "chat.h"
#include <algorithm> #include <algorithm>
#include <climits> #include <climits>

View File

@ -1,8 +1,433 @@
#include "chat.hpp" #include "chat.h"
#include "chat-template.hpp"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "log.h" #include "log.h"
#include "minja.hpp" #include "minja/chat-template.hpp"
#include "minja/minja.hpp"
#include <optional>
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
};
struct templates_params {
json messages;
json tools;
common_chat_tool_choice tool_choice;
json json_schema;
bool parallel_tool_calls;
bool stream;
std::string grammar;
bool add_generation_prompt = true;
bool extract_reasoning = true;
};
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
if (tool_choice == "auto") {
return COMMON_CHAT_TOOL_CHOICE_AUTO;
}
if (tool_choice == "none") {
return COMMON_CHAT_TOOL_CHOICE_NONE;
}
if (tool_choice == "required") {
return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
}
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
}
template <>
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
std::vector<common_chat_msg> msgs;
try {
if (!messages.is_array()) {
throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
}
for (const auto & message : messages) {
if (!message.is_object()) {
throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
}
common_chat_msg msg;
if (!message.contains("role")) {
throw std::runtime_error("Missing 'role' in message: " + message.dump());
}
msg.role = message.at("role");
if (message.contains("content")) {
const auto & content = message.at("content");
if (content.is_string()) {
msg.content = content;
} else if (content.is_array()) {
for (const auto & part : content) {
if (!part.contains("type")) {
throw std::runtime_error("Missing content part type: " + part.dump());
}
const auto & type = part.at("type");
if (type != "text") {
throw std::runtime_error("Unsupported content part type: " + type.dump());
}
common_chat_msg_content_part msg_part;
msg_part.type = type;
msg_part.text = part.at("text");
msg.content_parts.push_back(msg_part);
}
} else if (!content.is_null()) {
throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
} else {
throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
if (message.contains("reasoning_content")) {
msg.reasoning_content = message.at("reasoning_content");
}
if (message.contains("name")) {
msg.tool_name = message.at("name");
}
if (message.contains("tool_call_id")) {
msg.tool_call_id = message.at("tool_call_id");
}
if (message.contains("tool_calls")) {
for (const auto & tool_call : message.at("tool_calls")) {
common_chat_tool_call tc;
if (!tool_call.contains("type")) {
throw std::runtime_error("Missing tool call type: " + tool_call.dump());
}
const auto & type = tool_call.at("type");
if (type != "function") {
throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
}
if (!tool_call.contains("function")) {
throw std::runtime_error("Missing tool call function: " + tool_call.dump());
}
const auto & fc = tool_call.at("function");
if (!fc.contains("name")) {
throw std::runtime_error("Missing tool call name: " + tool_call.dump());
}
tc.name = fc.at("name");
tc.arguments = fc.at("arguments");
if (tool_call.contains("id")) {
tc.id = tool_call.at("id");
}
msg.tool_calls.push_back(tc);
}
}
msgs.push_back(msg);
}
} catch (const std::exception & e) {
throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
}
return msgs;
}
template <>
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
json messages = json::array();
for (const auto & msg : msgs) {
if (!msg.content.empty() && !msg.content_parts.empty()) {
throw std::runtime_error("Cannot specify both content and content_parts");
}
json jmsg {
{"role", msg.role},
};
if (!msg.content.empty()) {
jmsg["content"] = msg.content;
} else if (!msg.content_parts.empty()) {
if (concat_typed_text) {
std::string text;
for (const auto & part : msg.content_parts) {
if (part.type != "text") {
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
continue;
}
if (!text.empty()) {
text += '\n';
}
text += part.text;
}
jmsg["content"] = text;
} else {
auto & parts = jmsg["content"] = json::array();
for (const auto & part : msg.content_parts) {
parts.push_back({
{"type", part.type},
{"text", part.text},
});
}
}
} else {
jmsg["content"] = json(); // null
}
if (!msg.reasoning_content.empty()) {
jmsg["reasoning_content"] = msg.reasoning_content;
}
if (!msg.tool_name.empty()) {
jmsg["name"] = msg.tool_name;
}
if (!msg.tool_call_id.empty()) {
jmsg["tool_call_id"] = msg.tool_call_id;
}
if (!msg.tool_calls.empty()) {
auto & tool_calls = jmsg["tool_calls"] = json::array();
for (const auto & tool_call : msg.tool_calls) {
json tc {
{"type", "function"},
{"function", {
{"name", tool_call.name},
{"arguments", tool_call.arguments},
}},
};
if (!tool_call.id.empty()) {
tc["id"] = tool_call.id;
}
tool_calls.push_back(tc);
}
}
messages.push_back(jmsg);
}
return messages;
}
template <>
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
return common_chat_msgs_parse_oaicompat(json::parse(messages));
}
template <>
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
std::vector<common_chat_tool> result;
try {
if (!tools.is_null()) {
if (!tools.is_array()) {
throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
}
for (const auto & tool : tools) {
if (!tool.contains("type")) {
throw std::runtime_error("Missing tool type: " + tool.dump());
}
const auto & type = tool.at("type");
if (!type.is_string() || type != "function") {
throw std::runtime_error("Unsupported tool type: " + tool.dump());
}
if (!tool.contains("function")) {
throw std::runtime_error("Missing tool function: " + tool.dump());
}
const auto & function = tool.at("function");
result.push_back({
/* .name = */ function.at("name"),
/* .description = */ function.at("description"),
/* .parameters = */ function.at("parameters").dump(),
});
}
}
} catch (const std::exception & e) {
throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
}
return result;
}
template <>
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
return common_chat_tools_parse_oaicompat(json::parse(tools));
}
template <>
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
if (tools.empty()) {
return json();
}
auto result = json::array();
for (const auto & tool : tools) {
result.push_back({
{"type", "function"},
{"function", {
{"name", tool.name},
{"description", tool.description},
{"parameters", json::parse(tool.parameters)},
}},
});
}
return result;
}
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
common_chat_msg msg;
msg.role = "user";
msg.content = "test";
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
common_chat_templates_inputs inputs;
inputs.messages = {msg};
common_chat_templates_apply(tmpls.get(), inputs);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}};
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
std::string common_chat_format_single(
const struct common_chat_templates * tmpls,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja) {
common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
std::string fmt_past_msg;
if (!past_msg.empty()) {
inputs.messages = past_msg;
inputs.add_generation_prompt = false;
fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
}
std::ostringstream ss;
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
ss << "\n";
};
// format chat with new_msg
inputs.messages.push_back(new_msg);
inputs.add_generation_prompt = add_ass;
auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
auto add_simple_msg = [&](auto role, auto content) {
common_chat_msg msg;
msg.role = role;
msg.content = content;
inputs.messages.push_back(msg);
};
add_simple_msg("system", "You are a helpful assistant");
add_simple_msg("user", "Hello");
add_simple_msg("assistant", "Hi there");
add_simple_msg("user", "How are you?");
return common_chat_templates_apply(tmpls, inputs).prompt;
}
#define CHATML_TEMPLATE_SRC \
"{%- for message in messages -%}\n" \
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
"{%- endfor -%}\n" \
"{%- if add_generation_prompt -%}\n" \
" {{- '<|im_start|>assistant\n' -}}\n" \
"{%- endif -%}"
void common_chat_templates_free(struct common_chat_templates * tmpls) {
delete tmpls;
}
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
return tmpls->has_explicit_template;
}
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
if (variant != nullptr) {
if (strcmp(variant, "tool_use") == 0) {
if (tmpls->template_tool_use) {
return tmpls->template_tool_use->source().c_str();
}
return nullptr;
} else {
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
}
}
return tmpls->template_default->source().c_str();
}
common_chat_templates_ptr common_chat_templates_init(
const struct llama_model * model,
const std::string & chat_template_override,
const std::string & bos_token_override,
const std::string & eos_token_override)
{
std::string default_template_src;
std::string template_tool_use_src;
bool has_explicit_template = !chat_template_override.empty();
if (chat_template_override.empty()) {
GGML_ASSERT(model != nullptr);
const auto * str = llama_model_chat_template(model, /* name */ nullptr);
if (str) {
default_template_src = str;
has_explicit_template = true;
}
str = llama_model_chat_template(model, /* name */ "tool_use");
if (str) {
template_tool_use_src = str;
has_explicit_template = true;
}
} else {
default_template_src = chat_template_override;
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!template_tool_use_src.empty()) {
default_template_src = template_tool_use_src;
} else {
default_template_src = CHATML_TEMPLATE_SRC;
}
}
std::string token_bos = bos_token_override;
std::string token_eos = eos_token_override;
if (model) {
const auto * vocab = llama_model_get_vocab(model);
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
if (token == LLAMA_TOKEN_NULL) {
if (default_template_src.find(jinja_variable_name) != std::string::npos
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
}
return std::string();
}
return common_token_to_piece(vocab, token, true);
};
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
}
common_chat_templates_ptr tmpls(new common_chat_templates());
tmpls->has_explicit_template = has_explicit_template;
try {
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
}
if (!template_tool_use_src.empty()) {
try {
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
}
}
return tmpls;
}
std::string common_chat_format_name(common_chat_format format) { std::string common_chat_format_name(common_chat_format format) {
switch (format) { switch (format) {
@ -38,22 +463,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
json_error_locator() : position(0), found_error(false) {} json_error_locator() : position(0), found_error(false) {}
bool parse_error(std::size_t position, const std::string &, const json::exception &) override { bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
this->position = position - 1; this->position = position - 1;
this->found_error = true; this->found_error = true;
return false; return false;
} }
bool null() override { return true; } bool null() override { return true; } // NOLINT
bool boolean(bool) override { return true; } bool boolean(bool) override { return true; } // NOLINT
bool number_integer(number_integer_t) override { return true; } bool number_integer(number_integer_t) override { return true; } // NOLINT
bool number_unsigned(number_unsigned_t) override { return true; } bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
bool number_float(number_float_t, const string_t &) override { return true; } bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
bool string(string_t &) override { return true; } bool string(string_t &) override { return true; } // NOLINT
bool binary(binary_t &) override { return true; } bool binary(binary_t &) override { return true; } // NOLINT
bool start_object(std::size_t) override { return true; } bool start_object(std::size_t) override { return true; } // NOLINT
bool key(string_t &) override { return true; } bool key(string_t &) override { return true; } // NOLINT
bool end_object() override { return true; } bool end_object() override { return true; }
bool start_array(std::size_t) override { return true; } bool start_array(std::size_t) override { return true; } // NOLINT
bool end_array() override { return true; } bool end_array() override { return true; }
}; };
json_error_locator err_loc; json_error_locator err_loc;
@ -187,13 +612,20 @@ static std::string apply(
// tmpl_inputs.now = std::chrono::system_clock::now(); // tmpl_inputs.now = std::chrono::system_clock::now();
minja::chat_template_options tmpl_opts; minja::chat_template_options tmpl_opts;
tmpl_opts.use_bos_token = false; // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
tmpl_opts.use_eos_token = false; // instead of using `chat_template_options.use_bos_token = false`, since these tokens
// may be needed inside the template / between messages too.
return tmpl.apply(tmpl_inputs, tmpl_opts); auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
if (string_starts_with(result, tmpl.bos_token())) {
result = result.substr(tmpl.bos_token().size());
}
if (string_ends_with(result, tmpl.eos_token())) {
result = result.substr(0, result.size() - tmpl.eos_token().size());
}
return result;
} }
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
auto tool_call_schemas = json::array(); auto tool_call_schemas = json::array();
@ -247,7 +679,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
{"required", json::array({"tool_call"})}, {"required", json::array({"tool_call"})},
}; };
const auto schema = const auto schema =
inputs.tool_choice != "required" inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
? json { ? json {
{"anyOf", json::array({ {"anyOf", json::array({
tool_call, tool_call,
@ -303,9 +735,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
return result; return result;
} }
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
@ -348,9 +780,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
} }
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
@ -455,10 +887,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
const auto & parameters_required = parameters.at("required"); const auto & parameters_required = parameters.at("required");
for (const auto & prop : expected_properties) { for (const auto & prop : expected_properties) {
if (!parameters_properties.contains(prop)) { if (!parameters_properties.contains(prop)) {
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
} }
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) { if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
} }
} }
if (parameters_properties.size() != expected_properties.size()) { if (parameters_properties.size() != expected_properties.size()) {
@ -466,18 +898,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame
} }
} }
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) { static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
auto builtin_tools = json::array(); auto builtin_tools = json::array();
common_chat_params data; common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
if (name == "wolfram_alpha") { if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
expect_tool_parameters(name, parameters, {"query"});
} else if (name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
expect_tool_parameters(name, parameters, {"query"}); expect_tool_parameters(name, parameters, {"query"});
} else if (name == "python" || name == "code_interpreter") { } else if (name == "python" || name == "code_interpreter") {
@ -489,7 +919,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
std::vector<std::string> kvs; std::vector<std::string> kvs;
for (const auto & [key, value] : parameters.at("properties").items()) { for (const auto & [key, value] : parameters.at("properties").items()) {
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
} }
tool_rules.push_back( tool_rules.push_back(
@ -560,34 +990,33 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
auto arg_value_str = raw_args.substr(it_eq + 1); auto arg_value_str = raw_args.substr(it_eq + 1);
auto arg_value = json::parse(arg_value_str); auto arg_value = json::parse(arg_value_str);
return { common_chat_msg msg;
/* .role = */ "assistant", msg.role = "assistant";
/* .content = */ match.prefix().str(), msg.content = match.prefix().str();
/* .tool_calls = */ { msg.tool_calls.push_back({
{ /* .name = */ name,
/* .name = */ match[1], /* .arguments = */ (json {
/* .arguments = */ (json { {arg_name, arg_value},
{arg_name, arg_value}, }).dump(),
}).dump(), /* .id = */ "",
/* .id = */ "", });
}, return msg;
},
};
} }
} }
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
} }
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
if (inputs.tools.is_array() && !inputs.tools.empty()) { if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null(); data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function"); const auto & function = tool.at("function");
std::string name = function.at("name"); std::string name = function.at("name");
auto parameters = function.at("parameters"); auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
auto args_rule = builder.add_schema(name + "-args", parameters); auto args_rule = builder.add_schema(name + "-args", parameters);
tool_rules.push_back(builder.add_rule(name + "-call", tool_rules.push_back(builder.add_rule(name + "-call",
"\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n" "\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n"
@ -666,15 +1095,15 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input,
return msg; return msg;
} }
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
fprintf(stderr, "%s\n", __func__); LOG_DBG("%s\n", __func__);
common_chat_params data; common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
{"datetime", "Jan 29 2025 13:00:00 GMT"}, {"datetime", "Jan 29 2025 13:00:00 GMT"},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
}); });
if (inputs.tools.is_array() && !inputs.tools.empty()) { if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
@ -712,14 +1141,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
} }
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
common_chat_params data; common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
if (inputs.tools.is_array() && !inputs.tools.empty()) { if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules; std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules; std::vector<std::string> subsequent_tool_rules;
@ -727,6 +1156,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
const auto & function = tool.at("function"); const auto & function = tool.at("function");
std::string name = function.at("name"); std::string name = function.at("name");
auto parameters = function.at("parameters"); auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
auto args_rule = builder.add_schema(name + "-args", parameters); auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
@ -795,14 +1225,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
} }
} }
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_params data; common_chat_params data;
json tools = inputs.tools.is_null() ? inputs.tools : json::array(); json tools = inputs.tools.is_null() ? inputs.tools : json::array();
std::string python_code_argument_name; std::string python_code_argument_name;
auto has_raw_python = false; auto has_raw_python = false;
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
@ -814,7 +1244,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
throw std::runtime_error("Missing type in python tool"); throw std::runtime_error("Missing type in python tool");
} }
has_raw_python = true; has_raw_python = true;
auto type = parameters.at("type"); const auto & type = parameters.at("type");
if (type == "object") { if (type == "object") {
auto properties = parameters.at("properties"); auto properties = parameters.at("properties");
for (auto it = properties.begin(); it != properties.end(); ++it) { for (auto it = properties.begin(); it != properties.end(); ++it) {
@ -854,17 +1284,15 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
std::smatch match; std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) { if (std::regex_search(input, match, python_tag_regex)) {
auto code = match[1].str(); auto code = match[1].str();
return { common_chat_msg msg;
/* .role = */ "assistant", msg.role = "assistant";
/* .content = */ match.prefix().str(), msg.content = match.prefix().str();
/* .tool_calls = */ { msg.tool_calls.push_back({
{ /* .name = */ "python",
/* .name = */ "python", /* .arguments = */ (json {{"code", code}}).dump(),
/* .arguments = */ (json {{"code", code}}).dump(), /* .id = */ "",
/* .id = */ "", });
}, return msg;
}
};
} }
static std::regex function_regex(R"(<function=(\w+)>)"); static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)"); static std::regex close_regex(R"(</function>)");
@ -872,10 +1300,10 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
} }
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)* // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
@ -908,20 +1336,18 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)"); std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)"); std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
common_chat_msg msg;
msg.role = "assistant";
auto end = input.end(); auto end = input.end();
std::sregex_iterator rend; std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern); std::sregex_iterator rit(input.begin(), end, start_pattern);
if (rit == rend) { if (rit == rend) {
return { msg.content = input;
/* .role = */ "assistant", return msg;
/* .content = */ input,
/* .tool_calls = */ {},
};
} }
common_chat_msg result; msg.content = rit->prefix();
result.role = "assistant";
result.content = rit->prefix();
auto it = rit->suffix().first; auto it = rit->suffix().first;
while (it != end) { while (it != end) {
@ -930,7 +1356,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
throw std::runtime_error("Failed to parse json tool call"); throw std::runtime_error("Failed to parse json tool call");
} }
const auto & arguments = call.at("arguments"); const auto & arguments = call.at("arguments");
result.tool_calls.push_back({ msg.tool_calls.push_back({
call.at("name"), call.at("name"),
arguments.dump(), arguments.dump(),
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), // arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
@ -947,17 +1373,17 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
break; break;
} }
} }
return result; return msg;
} catch (const std::exception & e) { } catch (const std::exception & e) {
return { LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
/* .role = */ "assistant", common_chat_msg msg;
/* .content = */ input, msg.role = "assistant";
/* .tool_calls = */ {}, msg.content = input;
}; return msg;
} }
} }
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
@ -973,12 +1399,35 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
return data; return data;
} }
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_templates_apply_jinja(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
templates_params params;
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
? *tmpls->template_tool_use
: *tmpls->template_default;
const auto & src = tmpl.source(); const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps(); const auto & caps = tmpl.original_caps();
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
params.add_generation_prompt = inputs.add_generation_prompt;
params.extract_reasoning = inputs.extract_reasoning;
params.tool_choice = inputs.tool_choice;
params.grammar = inputs.grammar;
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
if (inputs.tools.is_array()) { if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
if (inputs.tool_choice != "none" && !inputs.grammar.empty()) { LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
params.parallel_tool_calls = false;
} else {
params.parallel_tool_calls = inputs.parallel_tool_calls;
}
if (params.tools.is_array()) {
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools"); throw std::runtime_error("Cannot specify grammar with tools");
} }
if (caps.supports_tool_calls && !caps.supports_tools) { if (caps.supports_tool_calls && !caps.supports_tools) {
@ -987,68 +1436,135 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
} }
// DeepSeek R1: use handler in all cases except json schema (thinking / tools). // DeepSeek R1: use handler in all cases except json schema (thinking / tools).
if (src.find("<tool▁calls▁begin>") != std::string::npos && inputs.json_schema.is_null()) { if (src.find("<tool▁calls▁begin>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_deepseek_r1(tmpl, inputs); return common_chat_params_init_deepseek_r1(tmpl, params);
} }
// Command R7B: : use handler in all cases except json schema (thinking / tools). // Command R7B: : use handler in all cases except json schema (thinking / tools).
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) { if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_command_r7b(tmpl, inputs); return common_chat_params_init_command_r7b(tmpl, params);
} }
// Use generic handler when mixing tools + JSON schema. // Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below. // TODO: support that mix in handlers below.
if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) { if ((params.tools.is_array() && params.json_schema.is_object())) {
return common_chat_params_init_generic(tmpl, inputs); return common_chat_params_init_generic(tmpl, params);
} }
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases. // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
if (src.find(">>>all") != std::string::npos) { if (src.find(">>>all") != std::string::npos) {
return common_chat_params_init_functionary_v3_2(tmpl, inputs); return common_chat_params_init_functionary_v3_2(tmpl, params);
} }
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases. // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
if (src.find(" functools[") != std::string::npos) { if (src.find(" functools[") != std::string::npos) {
return common_chat_params_init_firefunction_v2(tmpl, inputs); return common_chat_params_init_firefunction_v2(tmpl, params);
} }
// Plain handler (no tools) // Plain handler (no tools)
if (inputs.tools.is_null() || inputs.tool_choice == "none") { if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, inputs); return common_chat_params_init_without_tools(tmpl, params);
} }
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos) { if (src.find("<tool_call>") != std::string::npos) {
return common_chat_params_init_hermes_2_pro(tmpl, inputs); return common_chat_params_init_hermes_2_pro(tmpl, params);
} }
// Functionary v3.1 (w/ tools) // Functionary v3.1 (w/ tools)
if (src.find("<|start_header_id|>") != std::string::npos if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) { && src.find("<function=") != std::string::npos) {
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs); return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
} }
// Llama 3.1, 3.2, 3.3 (w/ tools) // Llama 3.1, 3.2, 3.3 (w/ tools)
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) { if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools); return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
} }
// Mistral Nemo (w/ tools) // Mistral Nemo (w/ tools)
if (src.find("[TOOL_CALLS]") != std::string::npos) { if (src.find("[TOOL_CALLS]") != std::string::npos) {
return common_chat_params_init_mistral_nemo(tmpl, inputs); return common_chat_params_init_mistral_nemo(tmpl, params);
} }
// Generic fallback // Generic fallback
return common_chat_params_init_generic(tmpl, inputs); return common_chat_params_init_generic(tmpl, params);
}
// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
static common_chat_params common_chat_templates_apply_legacy(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
int alloc_size = 0;
std::vector<llama_chat_message> chat;
std::vector<std::string> contents;
for (const auto & msg : inputs.messages) {
auto content = msg.content;
for (const auto & part : msg.content_parts) {
if (part.type != "text") {
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
continue;
}
if (!content.empty()) {
content += "\n";;
}
content += part.text;
}
contents.emplace_back(std::move(content));
}
for (size_t i = 0; i < contents.size(); ++i) {
const auto & msg = inputs.messages[i];
const auto & content = contents[i];
chat.push_back({msg.role.c_str(), content.c_str()});
alloc_size += (msg.role.size() + content.size()) * 1.25;
}
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
const auto & src = tmpls->template_default->source();
int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
}
common_chat_params params;
params.prompt = std::string(buf.data(), res);
if (!inputs.json_schema.empty()) {
params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
} else {
params.grammar = inputs.grammar;
}
return params;
}
common_chat_params common_chat_templates_apply(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
GGML_ASSERT(tmpls != nullptr);
return inputs.use_jinja
? common_chat_templates_apply_jinja(tmpls, inputs)
: common_chat_templates_apply_legacy(tmpls, inputs);
} }
static common_chat_msg common_chat_parse_content_only(const std::string & input) { static common_chat_msg common_chat_parse_content_only(const std::string & input) {
return { common_chat_msg msg;
/* .role = */ "assistant", msg.role = "assistant";
/* .content = */ input, msg.content = input;
/* .tool_calls = */ {}, return msg;
};
} }
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) { common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {

134
common/chat.h Normal file
View File

@ -0,0 +1,134 @@
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
#pragma once
#include "common.h"
#include <string>
#include <vector>
struct common_chat_templates;
struct common_chat_tool_call {
std::string name;
std::string arguments;
std::string id;
};
struct common_chat_msg_content_part {
std::string type;
std::string text;
};
struct common_chat_msg {
std::string role;
std::string content;
std::vector<common_chat_msg_content_part> content_parts = {};
std::vector<common_chat_tool_call> tool_calls = {};
std::string reasoning_content;
std::string tool_name;
std::string tool_call_id;
};
struct common_chat_tool {
std::string name;
std::string description;
std::string parameters;
};
enum common_chat_tool_choice {
COMMON_CHAT_TOOL_CHOICE_AUTO,
COMMON_CHAT_TOOL_CHOICE_REQUIRED,
COMMON_CHAT_TOOL_CHOICE_NONE,
};
enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
struct common_chat_templates_inputs {
std::vector<common_chat_msg> messages;
std::string grammar;
std::string json_schema;
bool add_generation_prompt = true;
bool use_jinja = true;
// Parameters below only supported when use_jinja is true
std::vector<common_chat_tool> tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
bool extract_reasoning = true;
};
struct common_chat_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::string prompt;
std::string grammar;
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
void common_chat_templates_free(struct common_chat_templates * tmpls);
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
common_chat_templates_ptr common_chat_templates_init(
const struct llama_model * model,
const std::string & chat_template_override,
const std::string & bos_token_override = "",
const std::string & eos_token_override = "");
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
struct common_chat_params common_chat_templates_apply(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(
const struct common_chat_templates * tmpls,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja);
// Returns an example of formatted chat
std::string common_chat_format_example(
const struct common_chat_templates * tmpls,
bool use_jinja);
std::string common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
// Parses a JSON array of messages in OpenAI's chat completion API format.
// T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
// T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);

View File

@ -1,55 +0,0 @@
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
#pragma once
#include "common.h"
#include <json.hpp>
#include <optional>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
struct common_chat_inputs {
json messages;
json tools;
json tool_choice;
json json_schema;
bool parallel_tool_calls;
bool stream;
std::string grammar;
bool add_generation_prompt = true;
bool extract_reasoning = true;
};
enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
struct common_chat_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
json prompt;
std::string grammar;
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
std::string common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);

View File

@ -12,8 +12,6 @@
#include "json.hpp" #include "json.hpp"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "llama.h" #include "llama.h"
#include "chat.hpp"
#include "chat-template.hpp"
#include <algorithm> #include <algorithm>
#include <cinttypes> #include <cinttypes>
@ -1768,174 +1766,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
return text; return text;
} }
//
// Chat template utils
//
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
common_chat_inputs inputs;
inputs.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
common_chat_params_init(chat_template, inputs);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}};
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
std::string common_chat_apply_template(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & msgs,
bool add_ass,
bool use_jinja) {
if (use_jinja) {
auto messages = json::array();
for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}});
}
common_chat_inputs inputs;
inputs.messages = messages;
inputs.add_generation_prompt = add_ass;
return common_chat_params_init(tmpl, inputs).prompt;
}
int alloc_size = 0;
std::vector<llama_chat_message> chat;
for (const auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
}
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}
std::string formatted_chat(buf.data(), res);
return formatted_chat;
}
std::string common_chat_format_single(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja) {
std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
std::vector<common_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
ss << "\n";
};
// format chat with new_msg
chat_new.push_back(new_msg);
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
std::vector<common_chat_msg> msgs = {
{"system", "You are a helpful assistant", {}},
{"user", "Hello", {}},
{"assistant", "Hi there", {}},
{"user", "How are you?", {}},
};
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
}
#define CHATML_TEMPLATE_SRC \
"{%- for message in messages -%}\n" \
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
"{%- endfor -%}\n" \
"{%- if add_generation_prompt -%}\n" \
" {{- '<|im_start|>assistant\n' -}}\n" \
"{%- endif -%}"
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
{
std::string default_template_src;
std::string template_tool_use_src;
bool has_explicit_template = !chat_template_override.empty();
if (chat_template_override.empty()) {
auto str = llama_model_chat_template(model, /* name */ nullptr);
if (str) {
default_template_src = str;
has_explicit_template = true;
}
str = llama_model_chat_template(model, /* name */ "tool_use");
if (str) {
template_tool_use_src = str;
has_explicit_template = true;
}
} else {
default_template_src = chat_template_override;
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!template_tool_use_src.empty()) {
default_template_src = template_tool_use_src;
} else {
default_template_src = CHATML_TEMPLATE_SRC;
}
}
auto vocab = llama_model_get_vocab(model);
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
if (token == LLAMA_TOKEN_NULL) {
if (default_template_src.find(jinja_variable_name) != std::string::npos
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
}
return std::string();
} else {
return common_token_to_piece(vocab, token, true);
}
};
auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
try {
return {
has_explicit_template,
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
template_tool_use_src.empty()
? nullptr
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
};
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what());
return {
has_explicit_template,
std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
nullptr,
};
}
}
// //
// KV cache utils // KV cache utils
// //

View File

@ -616,62 +616,6 @@ std::string common_detokenize(
const std::vector<llama_token> & tokens, const std::vector<llama_token> & tokens,
bool special = true); bool special = true);
//
// Chat template utils
//
struct common_tool_call {
std::string name;
std::string arguments;
std::string id;
};
// same with llama_chat_message, but uses std::string
struct common_chat_msg {
std::string role;
std::string content;
std::vector<common_tool_call> tool_calls;
std::string reasoning_content = "";
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
namespace minja {
class chat_template;
}
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
};
// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error
std::string common_chat_apply_template(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & chat,
bool add_ass,
bool use_jinja);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja);
// Returns an example of formatted chat
std::string common_chat_format_example(
const common_chat_template & tmpl, bool use_jinja);
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
// //
// KV cache utils // KV cache utils
// //

View File

@ -4,7 +4,7 @@
#include "log.h" #include "log.h"
#include "sampling.h" #include "sampling.h"
#include "llama.h" #include "llama.h"
#include "chat-template.hpp" #include "chat.h"
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
@ -158,7 +158,7 @@ int main(int argc, char ** argv) {
} }
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
auto chat_templates = common_chat_templates_from_model(model, params.chat_template); auto chat_templates = common_chat_templates_init(model, params.chat_template);
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
@ -201,7 +201,7 @@ int main(int argc, char ** argv) {
} }
// auto enable conversation mode if chat template is available // auto enable conversation mode if chat template is available
const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default; const bool has_chat_template = common_chat_templates_was_explicit(chat_templates.get());
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) { if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
if (has_chat_template) { if (has_chat_template) {
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__); LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
// print chat template example in conversation mode // print chat template example in conversation mode
if (params.conversation_mode) { if (params.conversation_mode) {
if (params.enable_chat_template) { if (params.enable_chat_template) {
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str()); LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str());
} else { } else {
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
} }
@ -264,9 +264,11 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
common_chat_msg new_msg{role, content, {}}; common_chat_msg new_msg;
auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja); new_msg.role = role;
chat_msgs.push_back({role, content, {}}); new_msg.content = content;
auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja);
chat_msgs.push_back(new_msg);
LOG_DBG("formatted: '%s'\n", formatted.c_str()); LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted; return formatted;
}; };
@ -755,11 +757,14 @@ int main(int argc, char ** argv) {
// check for reverse prompt using special tokens // check for reverse prompt using special tokens
llama_token last_token = common_sampler_last(smpl); llama_token last_token = common_sampler_last(smpl);
if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) { for (auto token : antiprompt_token) {
if (params.interactive) { if (token == last_token) {
is_interacting = true; if (params.interactive) {
is_interacting = true;
}
is_antiprompt = true;
break;
} }
is_antiprompt = true;
} }
if (is_antiprompt) { if (is_antiprompt) {

View File

@ -24,7 +24,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "chat-template.hpp" #include "chat.h"
#include "common.h" #include "common.h"
#include "json.hpp" #include "json.hpp"
#include "linenoise.cpp/linenoise.h" #include "linenoise.cpp/linenoise.h"
@ -557,7 +557,7 @@ class LlamaData {
llama_model_ptr model; llama_model_ptr model;
llama_sampler_ptr sampler; llama_sampler_ptr sampler;
llama_context_ptr context; llama_context_ptr context;
std::vector<llama_chat_message> messages; std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
std::list<std::string> msg_strs; std::list<std::string> msg_strs;
std::vector<char> fmtted; std::vector<char> fmtted;
@ -834,44 +834,23 @@ static void add_message(const char * role, const std::string & text, LlamaData &
} }
// Function to apply the chat template and resize `formatted` if needed // Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) {
if (use_jinja) { common_chat_templates_inputs inputs;
json messages = json::array(); for (const auto & msg : llama_data.messages) {
for (const auto & msg : llama_data.messages) { common_chat_msg cmsg;
messages.push_back({ cmsg.role = msg.role;
{"role", msg.role}, cmsg.content = msg.content;
{"content", msg.content}, inputs.messages.push_back(cmsg);
});
}
try {
minja::chat_template_inputs tmpl_inputs;
tmpl_inputs.messages = messages;
tmpl_inputs.add_generation_prompt = append;
minja::chat_template_options tmpl_opts;
tmpl_opts.use_bos_token = false;
tmpl_opts.use_eos_token = false;
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
} catch (const std::exception & e) {
printe("failed to render the chat template: %s\n", e.what());
return -1;
}
}
int result = llama_chat_apply_template(
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result);
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size());
} }
inputs.add_generation_prompt = append;
inputs.use_jinja = use_jinja;
return result; auto chat_params = common_chat_templates_apply(tmpls, inputs);
// TODO: use other params for tool calls.
auto result = chat_params.prompt;
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
} }
// Function to tokenize the prompt // Function to tokenize the prompt
@ -1015,8 +994,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
} }
// Helper function to apply the chat template and handle errors // Helper function to apply the chat template and handle errors
static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja);
if (new_len < 0) { if (new_len < 0) {
printe("failed to apply the chat template\n"); printe("failed to apply the chat template\n");
return -1; return -1;
@ -1078,8 +1057,7 @@ static int get_user_input(std::string & user_input, const std::string & user) {
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) { static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
int prev_len = 0; int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), ""); auto chat_templates = common_chat_templates_init(llama_data.model.get(), "");
GGML_ASSERT(chat_templates.template_default);
static const bool stdout_a_terminal = is_stdout_a_terminal(); static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) { while (true) {
// Get user input // Get user input
@ -1090,7 +1068,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
add_message("user", user.empty() ? user_input : user, llama_data); add_message("user", user.empty() ? user_input : user, llama_data);
int new_len; int new_len;
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) { if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
return 1; return 1;
} }
@ -1105,7 +1083,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
} }
add_message("assistant", response, llama_data); add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) { if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
return 1; return 1;
} }
} }

View File

@ -329,9 +329,6 @@ struct server_task {
} }
// process "json_schema" and "grammar" // process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
}
if (data.contains("json_schema") && !data.contains("grammar")) { if (data.contains("json_schema") && !data.contains("grammar")) {
try { try {
auto schema = json_value(data, "json_schema", json::object()); auto schema = json_value(data, "json_schema", json::object());
@ -1807,7 +1804,7 @@ struct server_context {
// Necessary similarity of prompt for slot selection // Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f; float slot_prompt_similarity = 0.0f;
common_chat_templates chat_templates; common_chat_templates_ptr chat_templates;
~server_context() { ~server_context() {
// Clear any sampling context // Clear any sampling context
@ -1891,45 +1888,17 @@ struct server_context {
llama_init_dft.context.reset(); llama_init_dft.context.reset();
} }
if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) { chat_templates = common_chat_templates_init(model, params_base.chat_template);
try {
common_chat_format_example(chat_templates.get(), params.use_jinja);
} catch (const std::exception & e) {
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_from_model(model, "chatml"); chat_templates = common_chat_templates_init(model, "chatml");
} else {
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
} }
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
return true; return true;
} }
bool validate_builtin_chat_template(bool use_jinja) const {
llama_chat_message chat[] = {{"user", "test"}};
if (use_jinja) {
auto templates = common_chat_templates_from_model(model, "");
common_chat_inputs inputs;
inputs.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
GGML_ASSERT(templates.template_default);
try {
common_chat_params_init(*templates.template_default, inputs);
if (templates.template_tool_use) {
common_chat_params_init(*templates.template_tool_use, inputs);
}
return true;
} catch (const std::exception & e) {
SRV_ERR("failed to apply template: %s\n", e.what());
return false;
}
} else {
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0;
}
}
void init() { void init() {
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
@ -3822,13 +3791,15 @@ int main(int argc, char ** argv) {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel }, { "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model }, { "model_path", ctx_server.params_base.model },
{ "chat_template", ctx_server.chat_templates.template_default->source() }, { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
{ "bos_token", ctx_server.chat_templates.template_default->bos_token() }, { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
{ "eos_token", ctx_server.chat_templates.template_default->eos_token() }, { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
{ "build_info", build_info }, { "build_info", build_info },
}; };
if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) { if (ctx_server.params_base.use_jinja) {
data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source(); if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
data["chat_template_tool_use"] = tool_use_src;
}
} }
res_ok(res, data); res_ok(res, data);
@ -4063,7 +4034,7 @@ int main(int argc, char ** argv) {
} }
auto body = json::parse(req.body); auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
return handle_completions_impl( return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_COMPLETION,
@ -4076,7 +4047,7 @@ int main(int argc, char ** argv) {
// same with handle_chat_completions, but without inference part // same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) { const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body); auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
}; };
@ -4493,8 +4464,8 @@ int main(int argc, char ** argv) {
// print sample chat example to make it clear which template is used // print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
ctx_server.chat_templates.template_default->source().c_str(), common_chat_templates_source(ctx_server.chat_templates.get()),
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
ctx_server.process_single_task(task); ctx_server.process_single_task(task);

View File

@ -21,6 +21,8 @@ def create_server():
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
] ]
) )
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
@ -44,7 +46,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
assert res.body["usage"]["completion_tokens"] == n_predicted assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0] choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"] assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"]) assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
assert choice["finish_reason"] == finish_reason assert choice["finish_reason"] == finish_reason
@ -169,6 +171,47 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
assert "error" in res.body assert "error" in res.body
@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
(False, {"const": "42"}, 6, "\"42\""),
(True, {"const": "42"}, 6, "\"42\""),
])
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"json_schema": json_schema,
})
assert res.status_code == 200, f'Expected 200, got {res.status_code}'
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
(False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
(True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
])
def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "user", "content": "Does not matter what I say, does it?"},
],
"grammar": grammar,
})
assert res.status_code == 200, res.body
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
@pytest.mark.parametrize("messages", [ @pytest.mark.parametrize("messages", [
None, None,
"string", "string",

View File

@ -356,12 +356,12 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), (None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value) # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
]) ])
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None): def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server global server
@ -401,7 +401,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
{ {
"role": "tool", "role": "tool",
"name": "calculate", "name": "calculate",
"content": 0.55644242476, "content": "0.55644242476",
"tool_call_id": "call_6789" "tool_call_id": "call_6789"
} }
], ],
@ -444,7 +444,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
(128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), (1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'none', "<think>\n?I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), (1024, 'none', "^I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), (1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
]) ])

View File

@ -12,9 +12,7 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT: // Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
#include "minja.hpp" #include "chat.h"
#include "chat.hpp"
#include "chat-template.hpp"
#include <random> #include <random>
#include <sstream> #include <sstream>
@ -347,41 +345,6 @@ static llama_tokens format_infill(
return embd_inp; return embd_inp;
} }
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
std::vector<common_chat_msg> chat;
for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i];
std::string role = json_value(curr_msg, "role", std::string(""));
std::string content;
if (curr_msg.contains("content")) {
if (curr_msg["content"].is_string()) {
content = curr_msg["content"].get<std::string>();
} else if (curr_msg["content"].is_array()) {
for (const auto & part : curr_msg["content"]) {
if (part.contains("text")) {
content += "\n" + part["text"].get<std::string>();
}
}
} else {
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
} else {
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
chat.push_back({role, content, /* tool_calls= */ {}});
}
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
return formatted_chat;
}
// //
// base64 utils (TODO: move to common in the future) // base64 utils (TODO: move to common in the future)
// //
@ -579,12 +542,9 @@ static json oaicompat_completion_params_parse(
const json & body, /* openai api json semantics */ const json & body, /* openai api json semantics */
bool use_jinja, bool use_jinja,
common_reasoning_format reasoning_format, common_reasoning_format reasoning_format,
const common_chat_templates & chat_templates) const struct common_chat_templates * tmpls)
{ {
json llama_params; json llama_params;
const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use
? *chat_templates.template_tool_use
: *chat_templates.template_default;
auto tools = json_value(body, "tools", json()); auto tools = json_value(body, "tools", json());
auto stream = json_value(body, "stream", false); auto stream = json_value(body, "stream", false);
@ -610,62 +570,58 @@ static json oaicompat_completion_params_parse(
llama_params["stop"] = json_value(body, "stop", json::array()); llama_params["stop"] = json_value(body, "stop", json::array());
} }
auto json_schema = json_value(body, "json_schema", json());
auto grammar = json_value(body, "grammar", std::string());
if (!json_schema.is_null() && !grammar.empty()) {
throw std::runtime_error("Cannot use both json_schema and grammar");
}
// Handle "response_format" field // Handle "response_format" field
if (body.contains("response_format")) { if (body.contains("response_format")) {
json response_format = json_value(body, "response_format", json::object()); json response_format = json_value(body, "response_format", json::object());
std::string response_type = json_value(response_format, "type", std::string()); std::string response_type = json_value(response_format, "type", std::string());
if (response_type == "json_object") { if (response_type == "json_object") {
llama_params["json_schema"] = json_value(response_format, "schema", json::object()); json_schema = json_value(response_format, "schema", json::object());
} else if (response_type == "json_schema") { } else if (response_type == "json_schema") {
json json_schema = json_value(response_format, "json_schema", json::object()); json json_schema = json_value(response_format, "json_schema", json::object());
llama_params["json_schema"] = json_value(json_schema, "schema", json::object()); json_schema = json_value(json_schema, "schema", json::object());
} else if (!response_type.empty() && response_type != "text") { } else if (!response_type.empty() && response_type != "text") {
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
} }
} }
// Apply chat template to the list of messages common_chat_templates_inputs inputs;
if (use_jinja) { inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages"));
auto tool_choice = json_value(body, "tool_choice", std::string("auto")); inputs.tools = common_chat_tools_parse_oaicompat(tools);
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") { inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
throw std::runtime_error("Invalid tool_choice: " + tool_choice); inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
} inputs.grammar = grammar;
if (tool_choice != "none" && llama_params.contains("grammar")) { inputs.add_generation_prompt = true;
throw std::runtime_error("Cannot use custom grammar constraints with tools."); inputs.use_jinja = use_jinja;
} inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
common_chat_inputs inputs; inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
inputs.messages = body.at("messages"); throw std::runtime_error("Cannot use custom grammar constraints with tools.");
inputs.tools = tools; }
inputs.tool_choice = tool_choice;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
inputs.parallel_tool_calls = false;
}
inputs.stream = stream;
// TODO: support mixing schema w/ tools beyond generic format.
inputs.json_schema = json_value(llama_params, "json_schema", json());
auto chat_params = common_chat_params_init(tmpl, inputs);
llama_params["chat_format"] = static_cast<int>(chat_params.format); // Apply chat template to the list of messages
llama_params["prompt"] = chat_params.prompt; auto chat_params = common_chat_templates_apply(tmpls, inputs);
llama_params["grammar"] = chat_params.grammar;
llama_params["grammar_lazy"] = chat_params.grammar_lazy; llama_params["chat_format"] = static_cast<int>(chat_params.format);
auto grammar_triggers = json::array(); llama_params["prompt"] = chat_params.prompt;
for (const auto & trigger : chat_params.grammar_triggers) { llama_params["grammar"] = chat_params.grammar;
grammar_triggers.push_back({ llama_params["grammar_lazy"] = chat_params.grammar_lazy;
{"word", trigger.word}, auto grammar_triggers = json::array();
{"at_start", trigger.at_start}, for (const auto & trigger : chat_params.grammar_triggers) {
}); grammar_triggers.push_back({
} {"word", trigger.word},
llama_params["grammar_triggers"] = grammar_triggers; {"at_start", trigger.at_start},
llama_params["preserved_tokens"] = chat_params.preserved_tokens; });
for (const auto & stop : chat_params.additional_stops) { }
llama_params["stop"].push_back(stop); llama_params["grammar_triggers"] = grammar_triggers;
} llama_params["preserved_tokens"] = chat_params.preserved_tokens;
} else { for (const auto & stop : chat_params.additional_stops) {
llama_params["prompt"] = format_chat(tmpl, body.at("messages")); llama_params["stop"].push_back(stop);
} }
// Handle "n" field // Handle "n" field

View File

@ -1,13 +1,14 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <regex>
#undef NDEBUG #undef NDEBUG
#include <cassert> #include <cassert>
#include "llama.h" #include "llama.h"
#include "common.h" #include "common.h"
#include "chat-template.hpp" #include "chat.h"
static std::string normalize_newlines(const std::string & s) { static std::string normalize_newlines(const std::string & s) {
#ifdef _WIN32 #ifdef _WIN32
@ -18,6 +19,13 @@ static std::string normalize_newlines(const std::string & s) {
#endif #endif
} }
static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
common_chat_msg msg;
msg.role = role;
msg.content = content;
return msg;
}
int main(void) { int main(void) {
std::vector<llama_chat_message> conversation { std::vector<llama_chat_message> conversation {
{"system", "You are a helpful assistant"}, {"system", "You are a helpful assistant"},
@ -50,7 +58,7 @@ int main(void) {
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]", /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
/* .expected_output_jinja= */ "", /* .expected_output_jinja= */ "",
/* .bos_token= */ "", /* .bos_token= */ "<s>",
/* .eos_token= */ "</s>", /* .eos_token= */ "</s>",
}, },
{ {
@ -72,8 +80,8 @@ int main(void) {
{ {
/* .name= */ "mlabonne/AlphaMonarch-7B", /* .name= */ "mlabonne/AlphaMonarch-7B",
/* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n", /* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
/* .expected_output_jinja= */ "<s>system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n", /* .expected_output_jinja= */ "",
/* .bos_token= */ "<s>", /* .bos_token= */ "<s>",
/* .eos_token= */ "</s>", /* .eos_token= */ "</s>",
}, },
@ -87,7 +95,7 @@ int main(void) {
/* .name= */ "OrionStarAI/Orion-14B-Chat", /* .name= */ "OrionStarAI/Orion-14B-Chat",
/* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
/* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>", /* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
/* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>", /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: ",
/* .bos_token= */ "", /* .bos_token= */ "",
/* .eos_token= */ "</s>", /* .eos_token= */ "</s>",
}, },
@ -304,12 +312,9 @@ int main(void) {
} }
} }
json messages = json::array(); std::vector<common_chat_msg> messages;
for (const auto & msg : conversation) { for (const auto & msg : conversation) {
messages.push_back({ messages.push_back(simple_msg(msg.role, msg.content));
{"role", msg.role},
{"content", msg.content},
});
} }
for (const auto & test_case : test_cases) { for (const auto & test_case : test_cases) {
if (!test_case.supported_with_jinja) { if (!test_case.supported_with_jinja) {
@ -317,8 +322,13 @@ int main(void) {
} }
printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str()); printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
try { try {
minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token); auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token);
auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt)); common_chat_templates_inputs inputs;
inputs.use_jinja = true;
inputs.messages = messages;
inputs.add_generation_prompt = add_generation_prompt;
auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;
output = normalize_newlines(output);
auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja); auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
if (output != expected_output) { if (output != expected_output) {
printf("Expected:\n%s\n", expected_output.c_str()); printf("Expected:\n%s\n", expected_output.c_str());
@ -336,11 +346,11 @@ int main(void) {
// test llama_chat_format_single for system message // test llama_chat_format_single for system message
printf("\n\n=== llama_chat_format_single (system message) ===\n\n"); printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
std::vector<common_chat_msg> chat2; std::vector<common_chat_msg> chat2;
common_chat_msg sys_msg{"system", "You are a helpful assistant", {}}; auto sys_msg = simple_msg("system", "You are a helpful assistant");
auto fmt_sys = [&](std::string tmpl_str) { auto fmt_sys = [&](std::string tmpl_str) {
minja::chat_template tmpl(tmpl_str, "", ""); auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false);
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n"); printf("-------------------------\n");
return output; return output;
@ -360,14 +370,14 @@ int main(void) {
// test llama_chat_format_single for user message // test llama_chat_format_single for user message
printf("\n\n=== llama_chat_format_single (user message) ===\n\n"); printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
chat2.push_back({"system", "You are a helpful assistant", {}}); chat2.push_back(simple_msg("system", "You are a helpful assistant"));
chat2.push_back({"user", "Hello", {}}); chat2.push_back(simple_msg("user", "Hello"));
chat2.push_back({"assistant", "I am assistant", {}}); chat2.push_back(simple_msg("assistant", "I am assistant"));
common_chat_msg new_msg{"user", "How are you", {}}; auto new_msg = simple_msg("user", "How are you");
auto fmt_single = [&](std::string tmpl_str) { auto fmt_single = [&](const std::string & tmpl_str) {
minja::chat_template tmpl(tmpl_str, "", ""); auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str());
auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false);
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n"); printf("-------------------------\n");
return output; return output;

File diff suppressed because it is too large Load Diff