mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 03:55:20 +00:00
tool-call: refactor common chat / tool-call api (+ tests / fixes) (#11900)
* tool-call refactoring: moved common_chat_* to chat.h, common_chat_templates_init return a unique_ptr to opaque type * addressed clang-tidy lints in [test-]chat.* * rm minja deps from util & common & move it to common/minja/ * add name & tool_call_id to common_chat_msg * add common_chat_tool * added json <-> tools, msgs conversions to chat.h * fix double bos/eos jinja avoidance hack (was preventing inner bos/eos tokens) * fix deepseek r1 slow test (no longer <think> opening w/ new template) * allow empty tools w/ auto + grammar * fix & test server grammar & json_schema params w/ & w/o --jinja
This commit is contained in:
2
Makefile
2
Makefile
@ -1364,7 +1364,7 @@ llama-server: \
|
|||||||
examples/server/index.html.hpp \
|
examples/server/index.html.hpp \
|
||||||
examples/server/loading.html.hpp \
|
examples/server/loading.html.hpp \
|
||||||
common/chat.cpp \
|
common/chat.cpp \
|
||||||
common/chat.hpp \
|
common/chat.h \
|
||||||
common/chat-template.hpp \
|
common/chat-template.hpp \
|
||||||
common/json.hpp \
|
common/json.hpp \
|
||||||
common/minja.hpp \
|
common/minja.hpp \
|
||||||
|
@ -57,8 +57,7 @@ add_library(${TARGET} STATIC
|
|||||||
arg.h
|
arg.h
|
||||||
base64.hpp
|
base64.hpp
|
||||||
chat.cpp
|
chat.cpp
|
||||||
chat.hpp
|
chat.h
|
||||||
chat-template.hpp
|
|
||||||
common.cpp
|
common.cpp
|
||||||
common.h
|
common.h
|
||||||
console.cpp
|
console.cpp
|
||||||
@ -68,7 +67,8 @@ add_library(${TARGET} STATIC
|
|||||||
llguidance.cpp
|
llguidance.cpp
|
||||||
log.cpp
|
log.cpp
|
||||||
log.h
|
log.h
|
||||||
minja.hpp
|
minja/chat-template.hpp
|
||||||
|
minja/minja.hpp
|
||||||
ngram-cache.cpp
|
ngram-cache.cpp
|
||||||
ngram-cache.h
|
ngram-cache.h
|
||||||
sampling.cpp
|
sampling.cpp
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
|
#include "chat.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <climits>
|
#include <climits>
|
||||||
|
716
common/chat.cpp
716
common/chat.cpp
@ -1,8 +1,433 @@
|
|||||||
#include "chat.hpp"
|
#include "chat.h"
|
||||||
#include "chat-template.hpp"
|
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "minja.hpp"
|
#include "minja/chat-template.hpp"
|
||||||
|
#include "minja/minja.hpp"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
|
typedef minja::chat_template common_chat_template;
|
||||||
|
|
||||||
|
struct common_chat_templates {
|
||||||
|
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
||||||
|
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
||||||
|
std::unique_ptr<common_chat_template> template_tool_use;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct templates_params {
|
||||||
|
json messages;
|
||||||
|
json tools;
|
||||||
|
common_chat_tool_choice tool_choice;
|
||||||
|
json json_schema;
|
||||||
|
bool parallel_tool_calls;
|
||||||
|
bool stream;
|
||||||
|
std::string grammar;
|
||||||
|
bool add_generation_prompt = true;
|
||||||
|
bool extract_reasoning = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
|
||||||
|
if (tool_choice == "auto") {
|
||||||
|
return COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||||
|
}
|
||||||
|
if (tool_choice == "none") {
|
||||||
|
return COMMON_CHAT_TOOL_CHOICE_NONE;
|
||||||
|
}
|
||||||
|
if (tool_choice == "required") {
|
||||||
|
return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
|
||||||
|
std::vector<common_chat_msg> msgs;
|
||||||
|
|
||||||
|
try {
|
||||||
|
|
||||||
|
if (!messages.is_array()) {
|
||||||
|
throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const auto & message : messages) {
|
||||||
|
if (!message.is_object()) {
|
||||||
|
throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_msg msg;
|
||||||
|
if (!message.contains("role")) {
|
||||||
|
throw std::runtime_error("Missing 'role' in message: " + message.dump());
|
||||||
|
}
|
||||||
|
msg.role = message.at("role");
|
||||||
|
|
||||||
|
if (message.contains("content")) {
|
||||||
|
const auto & content = message.at("content");
|
||||||
|
if (content.is_string()) {
|
||||||
|
msg.content = content;
|
||||||
|
} else if (content.is_array()) {
|
||||||
|
for (const auto & part : content) {
|
||||||
|
if (!part.contains("type")) {
|
||||||
|
throw std::runtime_error("Missing content part type: " + part.dump());
|
||||||
|
}
|
||||||
|
const auto & type = part.at("type");
|
||||||
|
if (type != "text") {
|
||||||
|
throw std::runtime_error("Unsupported content part type: " + type.dump());
|
||||||
|
}
|
||||||
|
common_chat_msg_content_part msg_part;
|
||||||
|
msg_part.type = type;
|
||||||
|
msg_part.text = part.at("text");
|
||||||
|
msg.content_parts.push_back(msg_part);
|
||||||
|
}
|
||||||
|
} else if (!content.is_null()) {
|
||||||
|
throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
||||||
|
}
|
||||||
|
if (message.contains("reasoning_content")) {
|
||||||
|
msg.reasoning_content = message.at("reasoning_content");
|
||||||
|
}
|
||||||
|
if (message.contains("name")) {
|
||||||
|
msg.tool_name = message.at("name");
|
||||||
|
}
|
||||||
|
if (message.contains("tool_call_id")) {
|
||||||
|
msg.tool_call_id = message.at("tool_call_id");
|
||||||
|
}
|
||||||
|
if (message.contains("tool_calls")) {
|
||||||
|
for (const auto & tool_call : message.at("tool_calls")) {
|
||||||
|
common_chat_tool_call tc;
|
||||||
|
if (!tool_call.contains("type")) {
|
||||||
|
throw std::runtime_error("Missing tool call type: " + tool_call.dump());
|
||||||
|
}
|
||||||
|
const auto & type = tool_call.at("type");
|
||||||
|
if (type != "function") {
|
||||||
|
throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
|
||||||
|
}
|
||||||
|
if (!tool_call.contains("function")) {
|
||||||
|
throw std::runtime_error("Missing tool call function: " + tool_call.dump());
|
||||||
|
}
|
||||||
|
const auto & fc = tool_call.at("function");
|
||||||
|
if (!fc.contains("name")) {
|
||||||
|
throw std::runtime_error("Missing tool call name: " + tool_call.dump());
|
||||||
|
}
|
||||||
|
tc.name = fc.at("name");
|
||||||
|
tc.arguments = fc.at("arguments");
|
||||||
|
if (tool_call.contains("id")) {
|
||||||
|
tc.id = tool_call.at("id");
|
||||||
|
}
|
||||||
|
msg.tool_calls.push_back(tc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs.push_back(msg);
|
||||||
|
}
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
return msgs;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
|
||||||
|
json messages = json::array();
|
||||||
|
for (const auto & msg : msgs) {
|
||||||
|
if (!msg.content.empty() && !msg.content_parts.empty()) {
|
||||||
|
throw std::runtime_error("Cannot specify both content and content_parts");
|
||||||
|
}
|
||||||
|
json jmsg {
|
||||||
|
{"role", msg.role},
|
||||||
|
};
|
||||||
|
if (!msg.content.empty()) {
|
||||||
|
jmsg["content"] = msg.content;
|
||||||
|
} else if (!msg.content_parts.empty()) {
|
||||||
|
if (concat_typed_text) {
|
||||||
|
std::string text;
|
||||||
|
for (const auto & part : msg.content_parts) {
|
||||||
|
if (part.type != "text") {
|
||||||
|
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!text.empty()) {
|
||||||
|
text += '\n';
|
||||||
|
}
|
||||||
|
text += part.text;
|
||||||
|
}
|
||||||
|
jmsg["content"] = text;
|
||||||
|
} else {
|
||||||
|
auto & parts = jmsg["content"] = json::array();
|
||||||
|
for (const auto & part : msg.content_parts) {
|
||||||
|
parts.push_back({
|
||||||
|
{"type", part.type},
|
||||||
|
{"text", part.text},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
jmsg["content"] = json(); // null
|
||||||
|
}
|
||||||
|
if (!msg.reasoning_content.empty()) {
|
||||||
|
jmsg["reasoning_content"] = msg.reasoning_content;
|
||||||
|
}
|
||||||
|
if (!msg.tool_name.empty()) {
|
||||||
|
jmsg["name"] = msg.tool_name;
|
||||||
|
}
|
||||||
|
if (!msg.tool_call_id.empty()) {
|
||||||
|
jmsg["tool_call_id"] = msg.tool_call_id;
|
||||||
|
}
|
||||||
|
if (!msg.tool_calls.empty()) {
|
||||||
|
auto & tool_calls = jmsg["tool_calls"] = json::array();
|
||||||
|
for (const auto & tool_call : msg.tool_calls) {
|
||||||
|
json tc {
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"name", tool_call.name},
|
||||||
|
{"arguments", tool_call.arguments},
|
||||||
|
}},
|
||||||
|
};
|
||||||
|
if (!tool_call.id.empty()) {
|
||||||
|
tc["id"] = tool_call.id;
|
||||||
|
}
|
||||||
|
tool_calls.push_back(tc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages.push_back(jmsg);
|
||||||
|
}
|
||||||
|
return messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
|
||||||
|
return common_chat_msgs_parse_oaicompat(json::parse(messages));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
|
||||||
|
std::vector<common_chat_tool> result;
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (!tools.is_null()) {
|
||||||
|
if (!tools.is_array()) {
|
||||||
|
throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
|
||||||
|
}
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
if (!tool.contains("type")) {
|
||||||
|
throw std::runtime_error("Missing tool type: " + tool.dump());
|
||||||
|
}
|
||||||
|
const auto & type = tool.at("type");
|
||||||
|
if (!type.is_string() || type != "function") {
|
||||||
|
throw std::runtime_error("Unsupported tool type: " + tool.dump());
|
||||||
|
}
|
||||||
|
if (!tool.contains("function")) {
|
||||||
|
throw std::runtime_error("Missing tool function: " + tool.dump());
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto & function = tool.at("function");
|
||||||
|
result.push_back({
|
||||||
|
/* .name = */ function.at("name"),
|
||||||
|
/* .description = */ function.at("description"),
|
||||||
|
/* .parameters = */ function.at("parameters").dump(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
|
||||||
|
return common_chat_tools_parse_oaicompat(json::parse(tools));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
|
||||||
|
if (tools.empty()) {
|
||||||
|
return json();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = json::array();
|
||||||
|
for (const auto & tool : tools) {
|
||||||
|
result.push_back({
|
||||||
|
{"type", "function"},
|
||||||
|
{"function", {
|
||||||
|
{"name", tool.name},
|
||||||
|
{"description", tool.description},
|
||||||
|
{"parameters", json::parse(tool.parameters)},
|
||||||
|
}},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
||||||
|
if (use_jinja) {
|
||||||
|
try {
|
||||||
|
common_chat_msg msg;
|
||||||
|
msg.role = "user";
|
||||||
|
msg.content = "test";
|
||||||
|
|
||||||
|
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
|
||||||
|
|
||||||
|
common_chat_templates_inputs inputs;
|
||||||
|
inputs.messages = {msg};
|
||||||
|
|
||||||
|
common_chat_templates_apply(tmpls.get(), inputs);
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llama_chat_message chat[] = {{"user", "test"}};
|
||||||
|
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||||
|
return res >= 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string common_chat_format_single(
|
||||||
|
const struct common_chat_templates * tmpls,
|
||||||
|
const std::vector<common_chat_msg> & past_msg,
|
||||||
|
const common_chat_msg & new_msg,
|
||||||
|
bool add_ass,
|
||||||
|
bool use_jinja) {
|
||||||
|
|
||||||
|
common_chat_templates_inputs inputs;
|
||||||
|
inputs.use_jinja = use_jinja;
|
||||||
|
|
||||||
|
std::string fmt_past_msg;
|
||||||
|
if (!past_msg.empty()) {
|
||||||
|
inputs.messages = past_msg;
|
||||||
|
inputs.add_generation_prompt = false;
|
||||||
|
fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
||||||
|
}
|
||||||
|
std::ostringstream ss;
|
||||||
|
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
||||||
|
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
||||||
|
ss << "\n";
|
||||||
|
};
|
||||||
|
// format chat with new_msg
|
||||||
|
inputs.messages.push_back(new_msg);
|
||||||
|
inputs.add_generation_prompt = add_ass;
|
||||||
|
auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
|
||||||
|
// get the diff part
|
||||||
|
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
|
||||||
|
common_chat_templates_inputs inputs;
|
||||||
|
inputs.use_jinja = use_jinja;
|
||||||
|
auto add_simple_msg = [&](auto role, auto content) {
|
||||||
|
common_chat_msg msg;
|
||||||
|
msg.role = role;
|
||||||
|
msg.content = content;
|
||||||
|
inputs.messages.push_back(msg);
|
||||||
|
};
|
||||||
|
add_simple_msg("system", "You are a helpful assistant");
|
||||||
|
add_simple_msg("user", "Hello");
|
||||||
|
add_simple_msg("assistant", "Hi there");
|
||||||
|
add_simple_msg("user", "How are you?");
|
||||||
|
return common_chat_templates_apply(tmpls, inputs).prompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CHATML_TEMPLATE_SRC \
|
||||||
|
"{%- for message in messages -%}\n" \
|
||||||
|
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
|
||||||
|
"{%- endfor -%}\n" \
|
||||||
|
"{%- if add_generation_prompt -%}\n" \
|
||||||
|
" {{- '<|im_start|>assistant\n' -}}\n" \
|
||||||
|
"{%- endif -%}"
|
||||||
|
|
||||||
|
void common_chat_templates_free(struct common_chat_templates * tmpls) {
|
||||||
|
delete tmpls;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
|
||||||
|
return tmpls->has_explicit_template;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
|
||||||
|
if (variant != nullptr) {
|
||||||
|
if (strcmp(variant, "tool_use") == 0) {
|
||||||
|
if (tmpls->template_tool_use) {
|
||||||
|
return tmpls->template_tool_use->source().c_str();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
} else {
|
||||||
|
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tmpls->template_default->source().c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_templates_ptr common_chat_templates_init(
|
||||||
|
const struct llama_model * model,
|
||||||
|
const std::string & chat_template_override,
|
||||||
|
const std::string & bos_token_override,
|
||||||
|
const std::string & eos_token_override)
|
||||||
|
{
|
||||||
|
std::string default_template_src;
|
||||||
|
std::string template_tool_use_src;
|
||||||
|
|
||||||
|
bool has_explicit_template = !chat_template_override.empty();
|
||||||
|
if (chat_template_override.empty()) {
|
||||||
|
GGML_ASSERT(model != nullptr);
|
||||||
|
const auto * str = llama_model_chat_template(model, /* name */ nullptr);
|
||||||
|
if (str) {
|
||||||
|
default_template_src = str;
|
||||||
|
has_explicit_template = true;
|
||||||
|
}
|
||||||
|
str = llama_model_chat_template(model, /* name */ "tool_use");
|
||||||
|
if (str) {
|
||||||
|
template_tool_use_src = str;
|
||||||
|
has_explicit_template = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
default_template_src = chat_template_override;
|
||||||
|
}
|
||||||
|
if (default_template_src.empty() || default_template_src == "chatml") {
|
||||||
|
if (!template_tool_use_src.empty()) {
|
||||||
|
default_template_src = template_tool_use_src;
|
||||||
|
} else {
|
||||||
|
default_template_src = CHATML_TEMPLATE_SRC;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::string token_bos = bos_token_override;
|
||||||
|
std::string token_eos = eos_token_override;
|
||||||
|
if (model) {
|
||||||
|
const auto * vocab = llama_model_get_vocab(model);
|
||||||
|
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
||||||
|
if (token == LLAMA_TOKEN_NULL) {
|
||||||
|
if (default_template_src.find(jinja_variable_name) != std::string::npos
|
||||||
|
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
||||||
|
LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
|
||||||
|
}
|
||||||
|
return std::string();
|
||||||
|
}
|
||||||
|
return common_token_to_piece(vocab, token, true);
|
||||||
|
};
|
||||||
|
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
||||||
|
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
||||||
|
}
|
||||||
|
common_chat_templates_ptr tmpls(new common_chat_templates());
|
||||||
|
tmpls->has_explicit_template = has_explicit_template;
|
||||||
|
try {
|
||||||
|
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
|
||||||
|
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
|
||||||
|
}
|
||||||
|
if (!template_tool_use_src.empty()) {
|
||||||
|
try {
|
||||||
|
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tmpls;
|
||||||
|
}
|
||||||
|
|
||||||
std::string common_chat_format_name(common_chat_format format) {
|
std::string common_chat_format_name(common_chat_format format) {
|
||||||
switch (format) {
|
switch (format) {
|
||||||
@ -38,22 +463,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
|
|||||||
|
|
||||||
json_error_locator() : position(0), found_error(false) {}
|
json_error_locator() : position(0), found_error(false) {}
|
||||||
|
|
||||||
bool parse_error(std::size_t position, const std::string &, const json::exception &) override {
|
bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
|
||||||
this->position = position - 1;
|
this->position = position - 1;
|
||||||
this->found_error = true;
|
this->found_error = true;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool null() override { return true; }
|
bool null() override { return true; } // NOLINT
|
||||||
bool boolean(bool) override { return true; }
|
bool boolean(bool) override { return true; } // NOLINT
|
||||||
bool number_integer(number_integer_t) override { return true; }
|
bool number_integer(number_integer_t) override { return true; } // NOLINT
|
||||||
bool number_unsigned(number_unsigned_t) override { return true; }
|
bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
|
||||||
bool number_float(number_float_t, const string_t &) override { return true; }
|
bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
|
||||||
bool string(string_t &) override { return true; }
|
bool string(string_t &) override { return true; } // NOLINT
|
||||||
bool binary(binary_t &) override { return true; }
|
bool binary(binary_t &) override { return true; } // NOLINT
|
||||||
bool start_object(std::size_t) override { return true; }
|
bool start_object(std::size_t) override { return true; } // NOLINT
|
||||||
bool key(string_t &) override { return true; }
|
bool key(string_t &) override { return true; } // NOLINT
|
||||||
bool end_object() override { return true; }
|
bool end_object() override { return true; }
|
||||||
bool start_array(std::size_t) override { return true; }
|
bool start_array(std::size_t) override { return true; } // NOLINT
|
||||||
bool end_array() override { return true; }
|
bool end_array() override { return true; }
|
||||||
};
|
};
|
||||||
json_error_locator err_loc;
|
json_error_locator err_loc;
|
||||||
@ -187,13 +612,20 @@ static std::string apply(
|
|||||||
// tmpl_inputs.now = std::chrono::system_clock::now();
|
// tmpl_inputs.now = std::chrono::system_clock::now();
|
||||||
|
|
||||||
minja::chat_template_options tmpl_opts;
|
minja::chat_template_options tmpl_opts;
|
||||||
tmpl_opts.use_bos_token = false;
|
// To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
|
||||||
tmpl_opts.use_eos_token = false;
|
// instead of using `chat_template_options.use_bos_token = false`, since these tokens
|
||||||
|
// may be needed inside the template / between messages too.
|
||||||
return tmpl.apply(tmpl_inputs, tmpl_opts);
|
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||||
|
if (string_starts_with(result, tmpl.bos_token())) {
|
||||||
|
result = result.substr(tmpl.bos_token().size());
|
||||||
|
}
|
||||||
|
if (string_ends_with(result, tmpl.eos_token())) {
|
||||||
|
result = result.substr(0, result.size() - tmpl.eos_token().size());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
|
|
||||||
auto tool_call_schemas = json::array();
|
auto tool_call_schemas = json::array();
|
||||||
@ -247,7 +679,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
|
|||||||
{"required", json::array({"tool_call"})},
|
{"required", json::array({"tool_call"})},
|
||||||
};
|
};
|
||||||
const auto schema =
|
const auto schema =
|
||||||
inputs.tool_choice != "required"
|
inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
|
||||||
? json {
|
? json {
|
||||||
{"anyOf", json::array({
|
{"anyOf", json::array({
|
||||||
tool_call,
|
tool_call,
|
||||||
@ -303,9 +735,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
auto schemas = json::array();
|
auto schemas = json::array();
|
||||||
foreach_function(inputs.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
@ -348,9 +780,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
|
|||||||
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
auto schemas = json::array();
|
auto schemas = json::array();
|
||||||
foreach_function(inputs.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
@ -455,10 +887,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
|||||||
const auto & parameters_required = parameters.at("required");
|
const auto & parameters_required = parameters.at("required");
|
||||||
for (const auto & prop : expected_properties) {
|
for (const auto & prop : expected_properties) {
|
||||||
if (!parameters_properties.contains(prop)) {
|
if (!parameters_properties.contains(prop)) {
|
||||||
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop);
|
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
|
||||||
}
|
}
|
||||||
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
|
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
|
||||||
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop);
|
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (parameters_properties.size() != expected_properties.size()) {
|
if (parameters_properties.size() != expected_properties.size()) {
|
||||||
@ -466,18 +898,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) {
|
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
|
||||||
auto builtin_tools = json::array();
|
auto builtin_tools = json::array();
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
|
|
||||||
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
|
||||||
if (name == "wolfram_alpha") {
|
if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
|
||||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
|
||||||
expect_tool_parameters(name, parameters, {"query"});
|
|
||||||
} else if (name == "web_search" || name == "brave_search") {
|
|
||||||
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
|
||||||
expect_tool_parameters(name, parameters, {"query"});
|
expect_tool_parameters(name, parameters, {"query"});
|
||||||
} else if (name == "python" || name == "code_interpreter") {
|
} else if (name == "python" || name == "code_interpreter") {
|
||||||
@ -489,7 +919,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
|
|||||||
|
|
||||||
std::vector<std::string> kvs;
|
std::vector<std::string> kvs;
|
||||||
for (const auto & [key, value] : parameters.at("properties").items()) {
|
for (const auto & [key, value] : parameters.at("properties").items()) {
|
||||||
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value));
|
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
|
||||||
}
|
}
|
||||||
|
|
||||||
tool_rules.push_back(
|
tool_rules.push_back(
|
||||||
@ -560,34 +990,33 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
|
|||||||
auto arg_value_str = raw_args.substr(it_eq + 1);
|
auto arg_value_str = raw_args.substr(it_eq + 1);
|
||||||
auto arg_value = json::parse(arg_value_str);
|
auto arg_value = json::parse(arg_value_str);
|
||||||
|
|
||||||
return {
|
common_chat_msg msg;
|
||||||
/* .role = */ "assistant",
|
msg.role = "assistant";
|
||||||
/* .content = */ match.prefix().str(),
|
msg.content = match.prefix().str();
|
||||||
/* .tool_calls = */ {
|
msg.tool_calls.push_back({
|
||||||
{
|
/* .name = */ name,
|
||||||
/* .name = */ match[1],
|
|
||||||
/* .arguments = */ (json {
|
/* .arguments = */ (json {
|
||||||
{arg_name, arg_value},
|
{arg_name, arg_value},
|
||||||
}).dump(),
|
}).dump(),
|
||||||
/* .id = */ "",
|
/* .id = */ "",
|
||||||
},
|
});
|
||||||
},
|
return msg;
|
||||||
};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||||
data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null();
|
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
foreach_function(inputs.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
const auto & function = tool.at("function");
|
const auto & function = tool.at("function");
|
||||||
std::string name = function.at("name");
|
std::string name = function.at("name");
|
||||||
auto parameters = function.at("parameters");
|
auto parameters = function.at("parameters");
|
||||||
|
builder.resolve_refs(parameters);
|
||||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||||
tool_rules.push_back(builder.add_rule(name + "-call",
|
tool_rules.push_back(builder.add_rule(name + "-call",
|
||||||
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
|
"\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n"
|
||||||
@ -666,15 +1095,15 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input,
|
|||||||
return msg;
|
return msg;
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
fprintf(stderr, "%s\n", __func__);
|
LOG_DBG("%s\n", __func__);
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
|
||||||
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
{"datetime", "Jan 29 2025 13:00:00 GMT"},
|
||||||
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
|
||||||
});
|
});
|
||||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
auto schemas = json::array();
|
auto schemas = json::array();
|
||||||
foreach_function(inputs.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
@ -712,14 +1141,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp
|
|||||||
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
|
||||||
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
if (inputs.tools.is_array() && !inputs.tools.empty()) {
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> first_tool_rules;
|
std::vector<std::string> first_tool_rules;
|
||||||
std::vector<std::string> subsequent_tool_rules;
|
std::vector<std::string> subsequent_tool_rules;
|
||||||
@ -727,6 +1156,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
|
|||||||
const auto & function = tool.at("function");
|
const auto & function = tool.at("function");
|
||||||
std::string name = function.at("name");
|
std::string name = function.at("name");
|
||||||
auto parameters = function.at("parameters");
|
auto parameters = function.at("parameters");
|
||||||
|
builder.resolve_refs(parameters);
|
||||||
auto args_rule = builder.add_schema(name + "-args", parameters);
|
auto args_rule = builder.add_schema(name + "-args", parameters);
|
||||||
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
|
||||||
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
|
||||||
@ -795,14 +1225,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
json tools = inputs.tools.is_null() ? inputs.tools : json::array();
|
||||||
std::string python_code_argument_name;
|
std::string python_code_argument_name;
|
||||||
auto has_raw_python = false;
|
auto has_raw_python = false;
|
||||||
|
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
foreach_function(inputs.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
@ -814,7 +1244,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
|
|||||||
throw std::runtime_error("Missing type in python tool");
|
throw std::runtime_error("Missing type in python tool");
|
||||||
}
|
}
|
||||||
has_raw_python = true;
|
has_raw_python = true;
|
||||||
auto type = parameters.at("type");
|
const auto & type = parameters.at("type");
|
||||||
if (type == "object") {
|
if (type == "object") {
|
||||||
auto properties = parameters.at("properties");
|
auto properties = parameters.at("properties");
|
||||||
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
for (auto it = properties.begin(); it != properties.end(); ++it) {
|
||||||
@ -854,17 +1284,15 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
|
|||||||
std::smatch match;
|
std::smatch match;
|
||||||
if (std::regex_search(input, match, python_tag_regex)) {
|
if (std::regex_search(input, match, python_tag_regex)) {
|
||||||
auto code = match[1].str();
|
auto code = match[1].str();
|
||||||
return {
|
common_chat_msg msg;
|
||||||
/* .role = */ "assistant",
|
msg.role = "assistant";
|
||||||
/* .content = */ match.prefix().str(),
|
msg.content = match.prefix().str();
|
||||||
/* .tool_calls = */ {
|
msg.tool_calls.push_back({
|
||||||
{
|
|
||||||
/* .name = */ "python",
|
/* .name = */ "python",
|
||||||
/* .arguments = */ (json {{"code", code}}).dump(),
|
/* .arguments = */ (json {{"code", code}}).dump(),
|
||||||
/* .id = */ "",
|
/* .id = */ "",
|
||||||
},
|
});
|
||||||
}
|
return msg;
|
||||||
};
|
|
||||||
}
|
}
|
||||||
static std::regex function_regex(R"(<function=(\w+)>)");
|
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||||
static std::regex close_regex(R"(</function>)");
|
static std::regex close_regex(R"(</function>)");
|
||||||
@ -872,10 +1300,10 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
|
|||||||
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||||
data.grammar_lazy = inputs.tool_choice != "required";
|
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||||
std::vector<std::string> tool_rules;
|
std::vector<std::string> tool_rules;
|
||||||
foreach_function(inputs.tools, [&](const json & tool) {
|
foreach_function(inputs.tools, [&](const json & tool) {
|
||||||
@ -908,20 +1336,18 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
|
|||||||
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
|
||||||
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
|
||||||
|
|
||||||
|
common_chat_msg msg;
|
||||||
|
msg.role = "assistant";
|
||||||
|
|
||||||
auto end = input.end();
|
auto end = input.end();
|
||||||
std::sregex_iterator rend;
|
std::sregex_iterator rend;
|
||||||
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
std::sregex_iterator rit(input.begin(), end, start_pattern);
|
||||||
if (rit == rend) {
|
if (rit == rend) {
|
||||||
return {
|
msg.content = input;
|
||||||
/* .role = */ "assistant",
|
return msg;
|
||||||
/* .content = */ input,
|
|
||||||
/* .tool_calls = */ {},
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
common_chat_msg result;
|
msg.content = rit->prefix();
|
||||||
result.role = "assistant";
|
|
||||||
result.content = rit->prefix();
|
|
||||||
|
|
||||||
auto it = rit->suffix().first;
|
auto it = rit->suffix().first;
|
||||||
while (it != end) {
|
while (it != end) {
|
||||||
@ -930,7 +1356,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
|
|||||||
throw std::runtime_error("Failed to parse json tool call");
|
throw std::runtime_error("Failed to parse json tool call");
|
||||||
}
|
}
|
||||||
const auto & arguments = call.at("arguments");
|
const auto & arguments = call.at("arguments");
|
||||||
result.tool_calls.push_back({
|
msg.tool_calls.push_back({
|
||||||
call.at("name"),
|
call.at("name"),
|
||||||
arguments.dump(),
|
arguments.dump(),
|
||||||
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
|
||||||
@ -947,17 +1373,17 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return msg;
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
return {
|
LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
|
||||||
/* .role = */ "assistant",
|
common_chat_msg msg;
|
||||||
/* .content = */ input,
|
msg.role = "assistant";
|
||||||
/* .tool_calls = */ {},
|
msg.content = input;
|
||||||
};
|
return msg;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||||
common_chat_params data;
|
common_chat_params data;
|
||||||
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
|
||||||
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
@ -973,12 +1399,35 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
|
|||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) {
|
static common_chat_params common_chat_templates_apply_jinja(
|
||||||
|
const struct common_chat_templates * tmpls,
|
||||||
|
const struct common_chat_templates_inputs & inputs)
|
||||||
|
{
|
||||||
|
templates_params params;
|
||||||
|
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
|
||||||
|
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
|
||||||
|
? *tmpls->template_tool_use
|
||||||
|
: *tmpls->template_default;
|
||||||
const auto & src = tmpl.source();
|
const auto & src = tmpl.source();
|
||||||
const auto & caps = tmpl.original_caps();
|
const auto & caps = tmpl.original_caps();
|
||||||
|
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
|
||||||
|
params.add_generation_prompt = inputs.add_generation_prompt;
|
||||||
|
params.extract_reasoning = inputs.extract_reasoning;
|
||||||
|
params.tool_choice = inputs.tool_choice;
|
||||||
|
params.grammar = inputs.grammar;
|
||||||
|
if (!inputs.json_schema.empty()) {
|
||||||
|
params.json_schema = json::parse(inputs.json_schema);
|
||||||
|
}
|
||||||
|
|
||||||
if (inputs.tools.is_array()) {
|
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
||||||
if (inputs.tool_choice != "none" && !inputs.grammar.empty()) {
|
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
||||||
|
params.parallel_tool_calls = false;
|
||||||
|
} else {
|
||||||
|
params.parallel_tool_calls = inputs.parallel_tool_calls;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.tools.is_array()) {
|
||||||
|
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
|
||||||
throw std::runtime_error("Cannot specify grammar with tools");
|
throw std::runtime_error("Cannot specify grammar with tools");
|
||||||
}
|
}
|
||||||
if (caps.supports_tool_calls && !caps.supports_tools) {
|
if (caps.supports_tool_calls && !caps.supports_tools) {
|
||||||
@ -987,68 +1436,135 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
|
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
|
||||||
if (src.find("<|tool▁calls▁begin|>") != std::string::npos && inputs.json_schema.is_null()) {
|
if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) {
|
||||||
return common_chat_params_init_deepseek_r1(tmpl, inputs);
|
return common_chat_params_init_deepseek_r1(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Command R7B: : use handler in all cases except json schema (thinking / tools).
|
// Command R7B: : use handler in all cases except json schema (thinking / tools).
|
||||||
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) {
|
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
|
||||||
return common_chat_params_init_command_r7b(tmpl, inputs);
|
return common_chat_params_init_command_r7b(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use generic handler when mixing tools + JSON schema.
|
// Use generic handler when mixing tools + JSON schema.
|
||||||
// TODO: support that mix in handlers below.
|
// TODO: support that mix in handlers below.
|
||||||
if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) {
|
if ((params.tools.is_array() && params.json_schema.is_object())) {
|
||||||
return common_chat_params_init_generic(tmpl, inputs);
|
return common_chat_params_init_generic(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
|
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
|
||||||
if (src.find(">>>all") != std::string::npos) {
|
if (src.find(">>>all") != std::string::npos) {
|
||||||
return common_chat_params_init_functionary_v3_2(tmpl, inputs);
|
return common_chat_params_init_functionary_v3_2(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
|
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
|
||||||
if (src.find(" functools[") != std::string::npos) {
|
if (src.find(" functools[") != std::string::npos) {
|
||||||
return common_chat_params_init_firefunction_v2(tmpl, inputs);
|
return common_chat_params_init_firefunction_v2(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Plain handler (no tools)
|
// Plain handler (no tools)
|
||||||
if (inputs.tools.is_null() || inputs.tool_choice == "none") {
|
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
|
||||||
return common_chat_params_init_without_tools(tmpl, inputs);
|
return common_chat_params_init_without_tools(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
||||||
if (src.find("<tool_call>") != std::string::npos) {
|
if (src.find("<tool_call>") != std::string::npos) {
|
||||||
return common_chat_params_init_hermes_2_pro(tmpl, inputs);
|
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Functionary v3.1 (w/ tools)
|
// Functionary v3.1 (w/ tools)
|
||||||
if (src.find("<|start_header_id|>") != std::string::npos
|
if (src.find("<|start_header_id|>") != std::string::npos
|
||||||
&& src.find("<function=") != std::string::npos) {
|
&& src.find("<function=") != std::string::npos) {
|
||||||
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs);
|
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
// Llama 3.1, 3.2, 3.3 (w/ tools)
|
||||||
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||||
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
|
||||||
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools);
|
return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mistral Nemo (w/ tools)
|
// Mistral Nemo (w/ tools)
|
||||||
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
if (src.find("[TOOL_CALLS]") != std::string::npos) {
|
||||||
return common_chat_params_init_mistral_nemo(tmpl, inputs);
|
return common_chat_params_init_mistral_nemo(tmpl, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generic fallback
|
// Generic fallback
|
||||||
return common_chat_params_init_generic(tmpl, inputs);
|
return common_chat_params_init_generic(tmpl, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
|
||||||
|
static common_chat_params common_chat_templates_apply_legacy(
|
||||||
|
const struct common_chat_templates * tmpls,
|
||||||
|
const struct common_chat_templates_inputs & inputs)
|
||||||
|
{
|
||||||
|
int alloc_size = 0;
|
||||||
|
std::vector<llama_chat_message> chat;
|
||||||
|
std::vector<std::string> contents;
|
||||||
|
for (const auto & msg : inputs.messages) {
|
||||||
|
auto content = msg.content;
|
||||||
|
for (const auto & part : msg.content_parts) {
|
||||||
|
if (part.type != "text") {
|
||||||
|
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!content.empty()) {
|
||||||
|
content += "\n";;
|
||||||
|
}
|
||||||
|
content += part.text;
|
||||||
|
}
|
||||||
|
contents.emplace_back(std::move(content));
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < contents.size(); ++i) {
|
||||||
|
const auto & msg = inputs.messages[i];
|
||||||
|
const auto & content = contents[i];
|
||||||
|
chat.push_back({msg.role.c_str(), content.c_str()});
|
||||||
|
alloc_size += (msg.role.size() + content.size()) * 1.25;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<char> buf(alloc_size);
|
||||||
|
|
||||||
|
// run the first time to get the total output length
|
||||||
|
const auto & src = tmpls->template_default->source();
|
||||||
|
int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
||||||
|
|
||||||
|
// error: chat template is not supported
|
||||||
|
if (res < 0) {
|
||||||
|
// if the custom "tmpl" is not supported, we throw an error
|
||||||
|
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
||||||
|
throw std::runtime_error("this custom template is not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
// if it turns out that our buffer is too small, we resize it
|
||||||
|
if ((size_t) res > buf.size()) {
|
||||||
|
buf.resize(res);
|
||||||
|
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_params params;
|
||||||
|
params.prompt = std::string(buf.data(), res);
|
||||||
|
if (!inputs.json_schema.empty()) {
|
||||||
|
params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
|
||||||
|
} else {
|
||||||
|
params.grammar = inputs.grammar;
|
||||||
|
}
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_params common_chat_templates_apply(
|
||||||
|
const struct common_chat_templates * tmpls,
|
||||||
|
const struct common_chat_templates_inputs & inputs)
|
||||||
|
{
|
||||||
|
GGML_ASSERT(tmpls != nullptr);
|
||||||
|
return inputs.use_jinja
|
||||||
|
? common_chat_templates_apply_jinja(tmpls, inputs)
|
||||||
|
: common_chat_templates_apply_legacy(tmpls, inputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
|
static common_chat_msg common_chat_parse_content_only(const std::string & input) {
|
||||||
return {
|
common_chat_msg msg;
|
||||||
/* .role = */ "assistant",
|
msg.role = "assistant";
|
||||||
/* .content = */ input,
|
msg.content = input;
|
||||||
/* .tool_calls = */ {},
|
return msg;
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
|
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {
|
||||||
|
134
common/chat.h
Normal file
134
common/chat.h
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
struct common_chat_templates;
|
||||||
|
|
||||||
|
struct common_chat_tool_call {
|
||||||
|
std::string name;
|
||||||
|
std::string arguments;
|
||||||
|
std::string id;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_msg_content_part {
|
||||||
|
std::string type;
|
||||||
|
std::string text;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_msg {
|
||||||
|
std::string role;
|
||||||
|
std::string content;
|
||||||
|
std::vector<common_chat_msg_content_part> content_parts = {};
|
||||||
|
std::vector<common_chat_tool_call> tool_calls = {};
|
||||||
|
std::string reasoning_content;
|
||||||
|
std::string tool_name;
|
||||||
|
std::string tool_call_id;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_tool {
|
||||||
|
std::string name;
|
||||||
|
std::string description;
|
||||||
|
std::string parameters;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum common_chat_tool_choice {
|
||||||
|
COMMON_CHAT_TOOL_CHOICE_AUTO,
|
||||||
|
COMMON_CHAT_TOOL_CHOICE_REQUIRED,
|
||||||
|
COMMON_CHAT_TOOL_CHOICE_NONE,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum common_chat_format {
|
||||||
|
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||||
|
COMMON_CHAT_FORMAT_GENERIC,
|
||||||
|
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
||||||
|
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||||
|
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||||
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||||
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
|
||||||
|
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||||
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||||
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||||
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||||
|
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||||
|
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
|
||||||
|
|
||||||
|
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_templates_inputs {
|
||||||
|
std::vector<common_chat_msg> messages;
|
||||||
|
std::string grammar;
|
||||||
|
std::string json_schema;
|
||||||
|
bool add_generation_prompt = true;
|
||||||
|
bool use_jinja = true;
|
||||||
|
// Parameters below only supported when use_jinja is true
|
||||||
|
std::vector<common_chat_tool> tools;
|
||||||
|
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||||
|
bool parallel_tool_calls = false;
|
||||||
|
bool extract_reasoning = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_params {
|
||||||
|
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
|
std::string prompt;
|
||||||
|
std::string grammar;
|
||||||
|
bool grammar_lazy = false;
|
||||||
|
std::vector<common_grammar_trigger> grammar_triggers;
|
||||||
|
std::vector<std::string> preserved_tokens;
|
||||||
|
std::vector<std::string> additional_stops;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||||
|
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||||
|
|
||||||
|
void common_chat_templates_free(struct common_chat_templates * tmpls);
|
||||||
|
|
||||||
|
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
|
||||||
|
|
||||||
|
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
|
||||||
|
|
||||||
|
common_chat_templates_ptr common_chat_templates_init(
|
||||||
|
const struct llama_model * model,
|
||||||
|
const std::string & chat_template_override,
|
||||||
|
const std::string & bos_token_override = "",
|
||||||
|
const std::string & eos_token_override = "");
|
||||||
|
|
||||||
|
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
||||||
|
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
|
||||||
|
|
||||||
|
|
||||||
|
struct common_chat_params common_chat_templates_apply(
|
||||||
|
const struct common_chat_templates * tmpls,
|
||||||
|
const struct common_chat_templates_inputs & inputs);
|
||||||
|
|
||||||
|
// Format single message, while taking into account the position of that message in chat history
|
||||||
|
std::string common_chat_format_single(
|
||||||
|
const struct common_chat_templates * tmpls,
|
||||||
|
const std::vector<common_chat_msg> & past_msg,
|
||||||
|
const common_chat_msg & new_msg,
|
||||||
|
bool add_ass,
|
||||||
|
bool use_jinja);
|
||||||
|
|
||||||
|
// Returns an example of formatted chat
|
||||||
|
std::string common_chat_format_example(
|
||||||
|
const struct common_chat_templates * tmpls,
|
||||||
|
bool use_jinja);
|
||||||
|
|
||||||
|
std::string common_chat_format_name(common_chat_format format);
|
||||||
|
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
|
||||||
|
|
||||||
|
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||||
|
|
||||||
|
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
||||||
|
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||||
|
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
|
||||||
|
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
||||||
|
|
||||||
|
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
|
||||||
|
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||||
|
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
||||||
|
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
@ -1,55 +0,0 @@
|
|||||||
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "common.h"
|
|
||||||
#include <json.hpp>
|
|
||||||
#include <optional>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
|
||||||
|
|
||||||
struct common_chat_inputs {
|
|
||||||
json messages;
|
|
||||||
json tools;
|
|
||||||
json tool_choice;
|
|
||||||
json json_schema;
|
|
||||||
bool parallel_tool_calls;
|
|
||||||
bool stream;
|
|
||||||
std::string grammar;
|
|
||||||
bool add_generation_prompt = true;
|
|
||||||
bool extract_reasoning = true;
|
|
||||||
};
|
|
||||||
|
|
||||||
enum common_chat_format {
|
|
||||||
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
|
||||||
COMMON_CHAT_FORMAT_GENERIC,
|
|
||||||
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
|
||||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
|
||||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
|
||||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
|
||||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
|
|
||||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
|
||||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
|
||||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
|
||||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
||||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
|
||||||
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
|
|
||||||
|
|
||||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_chat_params {
|
|
||||||
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
||||||
json prompt;
|
|
||||||
std::string grammar;
|
|
||||||
bool grammar_lazy = false;
|
|
||||||
std::vector<common_grammar_trigger> grammar_triggers;
|
|
||||||
std::vector<std::string> preserved_tokens;
|
|
||||||
std::vector<std::string> additional_stops;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
|
|
||||||
std::string common_chat_format_name(common_chat_format format);
|
|
||||||
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
|
|
@ -12,8 +12,6 @@
|
|||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "json-schema-to-grammar.h"
|
#include "json-schema-to-grammar.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "chat.hpp"
|
|
||||||
#include "chat-template.hpp"
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
@ -1768,174 +1766,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
|
|||||||
return text;
|
return text;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// Chat template utils
|
|
||||||
//
|
|
||||||
|
|
||||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
|
|
||||||
if (use_jinja) {
|
|
||||||
try {
|
|
||||||
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
|
|
||||||
common_chat_inputs inputs;
|
|
||||||
inputs.messages = json::array({{
|
|
||||||
{"role", "user"},
|
|
||||||
{"content", "test"},
|
|
||||||
}});
|
|
||||||
common_chat_params_init(chat_template, inputs);
|
|
||||||
return true;
|
|
||||||
} catch (const std::exception & e) {
|
|
||||||
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
|
||||||
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
|
|
||||||
return res >= 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string common_chat_apply_template(
|
|
||||||
const common_chat_template & tmpl,
|
|
||||||
const std::vector<common_chat_msg> & msgs,
|
|
||||||
bool add_ass,
|
|
||||||
bool use_jinja) {
|
|
||||||
if (use_jinja) {
|
|
||||||
auto messages = json::array();
|
|
||||||
for (const auto & msg : msgs) {
|
|
||||||
messages.push_back({{"role", msg.role}, {"content", msg.content}});
|
|
||||||
}
|
|
||||||
common_chat_inputs inputs;
|
|
||||||
inputs.messages = messages;
|
|
||||||
inputs.add_generation_prompt = add_ass;
|
|
||||||
return common_chat_params_init(tmpl, inputs).prompt;
|
|
||||||
}
|
|
||||||
|
|
||||||
int alloc_size = 0;
|
|
||||||
std::vector<llama_chat_message> chat;
|
|
||||||
for (const auto & msg : msgs) {
|
|
||||||
chat.push_back({msg.role.c_str(), msg.content.c_str()});
|
|
||||||
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<char> buf(alloc_size);
|
|
||||||
|
|
||||||
// run the first time to get the total output length
|
|
||||||
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
|
||||||
|
|
||||||
// error: chat template is not supported
|
|
||||||
if (res < 0) {
|
|
||||||
// if the custom "tmpl" is not supported, we throw an error
|
|
||||||
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
|
|
||||||
throw std::runtime_error("this custom template is not supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
// if it turns out that our buffer is too small, we resize it
|
|
||||||
if ((size_t) res > buf.size()) {
|
|
||||||
buf.resize(res);
|
|
||||||
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string formatted_chat(buf.data(), res);
|
|
||||||
return formatted_chat;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string common_chat_format_single(
|
|
||||||
const common_chat_template & tmpl,
|
|
||||||
const std::vector<common_chat_msg> & past_msg,
|
|
||||||
const common_chat_msg & new_msg,
|
|
||||||
bool add_ass,
|
|
||||||
bool use_jinja) {
|
|
||||||
std::ostringstream ss;
|
|
||||||
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
|
|
||||||
std::vector<common_chat_msg> chat_new(past_msg);
|
|
||||||
// if the past_msg ends with a newline, we must preserve it in the formatted version
|
|
||||||
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
|
|
||||||
ss << "\n";
|
|
||||||
};
|
|
||||||
// format chat with new_msg
|
|
||||||
chat_new.push_back(new_msg);
|
|
||||||
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
|
|
||||||
// get the diff part
|
|
||||||
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
|
|
||||||
std::vector<common_chat_msg> msgs = {
|
|
||||||
{"system", "You are a helpful assistant", {}},
|
|
||||||
{"user", "Hello", {}},
|
|
||||||
{"assistant", "Hi there", {}},
|
|
||||||
{"user", "How are you?", {}},
|
|
||||||
};
|
|
||||||
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define CHATML_TEMPLATE_SRC \
|
|
||||||
"{%- for message in messages -%}\n" \
|
|
||||||
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
|
|
||||||
"{%- endfor -%}\n" \
|
|
||||||
"{%- if add_generation_prompt -%}\n" \
|
|
||||||
" {{- '<|im_start|>assistant\n' -}}\n" \
|
|
||||||
"{%- endif -%}"
|
|
||||||
|
|
||||||
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
|
|
||||||
{
|
|
||||||
std::string default_template_src;
|
|
||||||
std::string template_tool_use_src;
|
|
||||||
|
|
||||||
bool has_explicit_template = !chat_template_override.empty();
|
|
||||||
if (chat_template_override.empty()) {
|
|
||||||
auto str = llama_model_chat_template(model, /* name */ nullptr);
|
|
||||||
if (str) {
|
|
||||||
default_template_src = str;
|
|
||||||
has_explicit_template = true;
|
|
||||||
}
|
|
||||||
str = llama_model_chat_template(model, /* name */ "tool_use");
|
|
||||||
if (str) {
|
|
||||||
template_tool_use_src = str;
|
|
||||||
has_explicit_template = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
default_template_src = chat_template_override;
|
|
||||||
}
|
|
||||||
if (default_template_src.empty() || default_template_src == "chatml") {
|
|
||||||
if (!template_tool_use_src.empty()) {
|
|
||||||
default_template_src = template_tool_use_src;
|
|
||||||
} else {
|
|
||||||
default_template_src = CHATML_TEMPLATE_SRC;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto vocab = llama_model_get_vocab(model);
|
|
||||||
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
|
|
||||||
if (token == LLAMA_TOKEN_NULL) {
|
|
||||||
if (default_template_src.find(jinja_variable_name) != std::string::npos
|
|
||||||
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
|
|
||||||
LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
|
|
||||||
}
|
|
||||||
return std::string();
|
|
||||||
} else {
|
|
||||||
return common_token_to_piece(vocab, token, true);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
|
|
||||||
auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
|
|
||||||
try {
|
|
||||||
return {
|
|
||||||
has_explicit_template,
|
|
||||||
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
|
|
||||||
template_tool_use_src.empty()
|
|
||||||
? nullptr
|
|
||||||
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
|
|
||||||
};
|
|
||||||
} catch (const std::exception & e) {
|
|
||||||
LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what());
|
|
||||||
return {
|
|
||||||
has_explicit_template,
|
|
||||||
std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
|
|
||||||
nullptr,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache utils
|
// KV cache utils
|
||||||
//
|
//
|
||||||
|
@ -616,62 +616,6 @@ std::string common_detokenize(
|
|||||||
const std::vector<llama_token> & tokens,
|
const std::vector<llama_token> & tokens,
|
||||||
bool special = true);
|
bool special = true);
|
||||||
|
|
||||||
//
|
|
||||||
// Chat template utils
|
|
||||||
//
|
|
||||||
|
|
||||||
struct common_tool_call {
|
|
||||||
std::string name;
|
|
||||||
std::string arguments;
|
|
||||||
std::string id;
|
|
||||||
};
|
|
||||||
|
|
||||||
// same with llama_chat_message, but uses std::string
|
|
||||||
struct common_chat_msg {
|
|
||||||
std::string role;
|
|
||||||
std::string content;
|
|
||||||
std::vector<common_tool_call> tool_calls;
|
|
||||||
std::string reasoning_content = "";
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
|
||||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
|
||||||
|
|
||||||
namespace minja {
|
|
||||||
class chat_template;
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef minja::chat_template common_chat_template;
|
|
||||||
|
|
||||||
struct common_chat_templates {
|
|
||||||
bool has_explicit_template; // Model had builtin template or template overridde was specified.
|
|
||||||
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
|
|
||||||
std::unique_ptr<common_chat_template> template_tool_use;
|
|
||||||
};
|
|
||||||
|
|
||||||
// CPP wrapper for llama_chat_apply_template
|
|
||||||
// If the built-in template is not supported, we default to chatml
|
|
||||||
// If the custom "tmpl" is not supported, we throw an error
|
|
||||||
std::string common_chat_apply_template(
|
|
||||||
const common_chat_template & tmpl,
|
|
||||||
const std::vector<common_chat_msg> & chat,
|
|
||||||
bool add_ass,
|
|
||||||
bool use_jinja);
|
|
||||||
|
|
||||||
// Format single message, while taking into account the position of that message in chat history
|
|
||||||
std::string common_chat_format_single(
|
|
||||||
const common_chat_template & tmpl,
|
|
||||||
const std::vector<common_chat_msg> & past_msg,
|
|
||||||
const common_chat_msg & new_msg,
|
|
||||||
bool add_ass,
|
|
||||||
bool use_jinja);
|
|
||||||
|
|
||||||
// Returns an example of formatted chat
|
|
||||||
std::string common_chat_format_example(
|
|
||||||
const common_chat_template & tmpl, bool use_jinja);
|
|
||||||
|
|
||||||
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// KV cache utils
|
// KV cache utils
|
||||||
//
|
//
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "chat-template.hpp"
|
#include "chat.h"
|
||||||
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
@ -158,7 +158,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
auto chat_templates = common_chat_templates_from_model(model, params.chat_template);
|
auto chat_templates = common_chat_templates_init(model, params.chat_template);
|
||||||
|
|
||||||
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
|
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
|
||||||
|
|
||||||
@ -201,7 +201,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// auto enable conversation mode if chat template is available
|
// auto enable conversation mode if chat template is available
|
||||||
const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default;
|
const bool has_chat_template = common_chat_templates_was_explicit(chat_templates.get());
|
||||||
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
|
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
|
||||||
if (has_chat_template) {
|
if (has_chat_template) {
|
||||||
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
|
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
|
||||||
@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
|
|||||||
// print chat template example in conversation mode
|
// print chat template example in conversation mode
|
||||||
if (params.conversation_mode) {
|
if (params.conversation_mode) {
|
||||||
if (params.enable_chat_template) {
|
if (params.enable_chat_template) {
|
||||||
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str());
|
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str());
|
||||||
} else {
|
} else {
|
||||||
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
|
||||||
}
|
}
|
||||||
@ -264,9 +264,11 @@ int main(int argc, char ** argv) {
|
|||||||
std::vector<llama_token> embd_inp;
|
std::vector<llama_token> embd_inp;
|
||||||
|
|
||||||
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
|
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
|
||||||
common_chat_msg new_msg{role, content, {}};
|
common_chat_msg new_msg;
|
||||||
auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja);
|
new_msg.role = role;
|
||||||
chat_msgs.push_back({role, content, {}});
|
new_msg.content = content;
|
||||||
|
auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja);
|
||||||
|
chat_msgs.push_back(new_msg);
|
||||||
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
LOG_DBG("formatted: '%s'\n", formatted.c_str());
|
||||||
return formatted;
|
return formatted;
|
||||||
};
|
};
|
||||||
@ -755,11 +757,14 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// check for reverse prompt using special tokens
|
// check for reverse prompt using special tokens
|
||||||
llama_token last_token = common_sampler_last(smpl);
|
llama_token last_token = common_sampler_last(smpl);
|
||||||
if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) {
|
for (auto token : antiprompt_token) {
|
||||||
|
if (token == last_token) {
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
is_interacting = true;
|
is_interacting = true;
|
||||||
}
|
}
|
||||||
is_antiprompt = true;
|
is_antiprompt = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_antiprompt) {
|
if (is_antiprompt) {
|
||||||
|
@ -24,7 +24,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "chat-template.hpp"
|
#include "chat.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "linenoise.cpp/linenoise.h"
|
#include "linenoise.cpp/linenoise.h"
|
||||||
@ -557,7 +557,7 @@ class LlamaData {
|
|||||||
llama_model_ptr model;
|
llama_model_ptr model;
|
||||||
llama_sampler_ptr sampler;
|
llama_sampler_ptr sampler;
|
||||||
llama_context_ptr context;
|
llama_context_ptr context;
|
||||||
std::vector<llama_chat_message> messages;
|
std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
|
||||||
std::list<std::string> msg_strs;
|
std::list<std::string> msg_strs;
|
||||||
std::vector<char> fmtted;
|
std::vector<char> fmtted;
|
||||||
|
|
||||||
@ -834,44 +834,23 @@ static void add_message(const char * role, const std::string & text, LlamaData &
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Function to apply the chat template and resize `formatted` if needed
|
// Function to apply the chat template and resize `formatted` if needed
|
||||||
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
|
static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) {
|
||||||
if (use_jinja) {
|
common_chat_templates_inputs inputs;
|
||||||
json messages = json::array();
|
|
||||||
for (const auto & msg : llama_data.messages) {
|
for (const auto & msg : llama_data.messages) {
|
||||||
messages.push_back({
|
common_chat_msg cmsg;
|
||||||
{"role", msg.role},
|
cmsg.role = msg.role;
|
||||||
{"content", msg.content},
|
cmsg.content = msg.content;
|
||||||
});
|
inputs.messages.push_back(cmsg);
|
||||||
}
|
}
|
||||||
try {
|
inputs.add_generation_prompt = append;
|
||||||
minja::chat_template_inputs tmpl_inputs;
|
inputs.use_jinja = use_jinja;
|
||||||
tmpl_inputs.messages = messages;
|
|
||||||
tmpl_inputs.add_generation_prompt = append;
|
|
||||||
|
|
||||||
minja::chat_template_options tmpl_opts;
|
auto chat_params = common_chat_templates_apply(tmpls, inputs);
|
||||||
tmpl_opts.use_bos_token = false;
|
// TODO: use other params for tool calls.
|
||||||
tmpl_opts.use_eos_token = false;
|
auto result = chat_params.prompt;
|
||||||
|
|
||||||
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
|
||||||
llama_data.fmtted.resize(result.size() + 1);
|
llama_data.fmtted.resize(result.size() + 1);
|
||||||
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
||||||
return result.size();
|
return result.size();
|
||||||
} catch (const std::exception & e) {
|
|
||||||
printe("failed to render the chat template: %s\n", e.what());
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
int result = llama_chat_apply_template(
|
|
||||||
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
|
|
||||||
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
|
|
||||||
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
|
|
||||||
llama_data.fmtted.resize(result);
|
|
||||||
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
|
|
||||||
llama_data.messages.size(), append, llama_data.fmtted.data(),
|
|
||||||
llama_data.fmtted.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to tokenize the prompt
|
// Function to tokenize the prompt
|
||||||
@ -1015,8 +994,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to apply the chat template and handle errors
|
// Helper function to apply the chat template and handle errors
|
||||||
static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
|
static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
|
||||||
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
|
const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja);
|
||||||
if (new_len < 0) {
|
if (new_len < 0) {
|
||||||
printe("failed to apply the chat template\n");
|
printe("failed to apply the chat template\n");
|
||||||
return -1;
|
return -1;
|
||||||
@ -1078,8 +1057,7 @@ static int get_user_input(std::string & user_input, const std::string & user) {
|
|||||||
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
|
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
|
||||||
int prev_len = 0;
|
int prev_len = 0;
|
||||||
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
||||||
auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
|
auto chat_templates = common_chat_templates_init(llama_data.model.get(), "");
|
||||||
GGML_ASSERT(chat_templates.template_default);
|
|
||||||
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
||||||
while (true) {
|
while (true) {
|
||||||
// Get user input
|
// Get user input
|
||||||
@ -1090,7 +1068,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
|
|||||||
|
|
||||||
add_message("user", user.empty() ? user_input : user, llama_data);
|
add_message("user", user.empty() ? user_input : user, llama_data);
|
||||||
int new_len;
|
int new_len;
|
||||||
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
|
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1105,7 +1083,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
|
|||||||
}
|
}
|
||||||
|
|
||||||
add_message("assistant", response, llama_data);
|
add_message("assistant", response, llama_data);
|
||||||
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
|
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -329,9 +329,6 @@ struct server_task {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// process "json_schema" and "grammar"
|
// process "json_schema" and "grammar"
|
||||||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
|
||||||
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
|
||||||
}
|
|
||||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||||
try {
|
try {
|
||||||
auto schema = json_value(data, "json_schema", json::object());
|
auto schema = json_value(data, "json_schema", json::object());
|
||||||
@ -1807,7 +1804,7 @@ struct server_context {
|
|||||||
// Necessary similarity of prompt for slot selection
|
// Necessary similarity of prompt for slot selection
|
||||||
float slot_prompt_similarity = 0.0f;
|
float slot_prompt_similarity = 0.0f;
|
||||||
|
|
||||||
common_chat_templates chat_templates;
|
common_chat_templates_ptr chat_templates;
|
||||||
|
|
||||||
~server_context() {
|
~server_context() {
|
||||||
// Clear any sampling context
|
// Clear any sampling context
|
||||||
@ -1891,43 +1888,15 @@ struct server_context {
|
|||||||
llama_init_dft.context.reset();
|
llama_init_dft.context.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
|
chat_templates = common_chat_templates_init(model, params_base.chat_template);
|
||||||
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
|
||||||
chat_templates = common_chat_templates_from_model(model, "chatml");
|
|
||||||
} else {
|
|
||||||
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
|
|
||||||
}
|
|
||||||
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool validate_builtin_chat_template(bool use_jinja) const {
|
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
|
||||||
|
|
||||||
if (use_jinja) {
|
|
||||||
auto templates = common_chat_templates_from_model(model, "");
|
|
||||||
common_chat_inputs inputs;
|
|
||||||
inputs.messages = json::array({{
|
|
||||||
{"role", "user"},
|
|
||||||
{"content", "test"},
|
|
||||||
}});
|
|
||||||
GGML_ASSERT(templates.template_default);
|
|
||||||
try {
|
try {
|
||||||
common_chat_params_init(*templates.template_default, inputs);
|
common_chat_format_example(chat_templates.get(), params.use_jinja);
|
||||||
if (templates.template_tool_use) {
|
|
||||||
common_chat_params_init(*templates.template_tool_use, inputs);
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
SRV_ERR("failed to apply template: %s\n", e.what());
|
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||||
return false;
|
chat_templates = common_chat_templates_init(model, "chatml");
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
|
|
||||||
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
|
|
||||||
return chat_res > 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void init() {
|
void init() {
|
||||||
@ -3822,13 +3791,15 @@ int main(int argc, char ** argv) {
|
|||||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||||
{ "total_slots", ctx_server.params_base.n_parallel },
|
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||||
{ "model_path", ctx_server.params_base.model },
|
{ "model_path", ctx_server.params_base.model },
|
||||||
{ "chat_template", ctx_server.chat_templates.template_default->source() },
|
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
|
||||||
{ "bos_token", ctx_server.chat_templates.template_default->bos_token() },
|
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
|
||||||
{ "eos_token", ctx_server.chat_templates.template_default->eos_token() },
|
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
|
||||||
{ "build_info", build_info },
|
{ "build_info", build_info },
|
||||||
};
|
};
|
||||||
if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
|
if (ctx_server.params_base.use_jinja) {
|
||||||
data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
|
if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
|
||||||
|
data["chat_template_tool_use"] = tool_use_src;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
res_ok(res, data);
|
res_ok(res, data);
|
||||||
@ -4063,7 +4034,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto body = json::parse(req.body);
|
auto body = json::parse(req.body);
|
||||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
|
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
|
||||||
|
|
||||||
return handle_completions_impl(
|
return handle_completions_impl(
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
@ -4076,7 +4047,7 @@ int main(int argc, char ** argv) {
|
|||||||
// same with handle_chat_completions, but without inference part
|
// same with handle_chat_completions, but without inference part
|
||||||
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_apply_template = [&ctx_server, ¶ms, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||||
auto body = json::parse(req.body);
|
auto body = json::parse(req.body);
|
||||||
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
|
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
|
||||||
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
|
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -4493,8 +4464,8 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// print sample chat example to make it clear which template is used
|
// print sample chat example to make it clear which template is used
|
||||||
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
||||||
ctx_server.chat_templates.template_default->source().c_str(),
|
common_chat_templates_source(ctx_server.chat_templates.get()),
|
||||||
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
|
common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
|
||||||
|
|
||||||
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
|
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
|
||||||
ctx_server.process_single_task(task);
|
ctx_server.process_single_task(task);
|
||||||
|
@ -21,6 +21,8 @@ def create_server():
|
|||||||
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
|
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
|
||||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
|
||||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
|
||||||
|
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
|
||||||
|
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
|
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
|
||||||
@ -44,7 +46,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
|||||||
assert res.body["usage"]["completion_tokens"] == n_predicted
|
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||||
choice = res.body["choices"][0]
|
choice = res.body["choices"][0]
|
||||||
assert "assistant" == choice["message"]["role"]
|
assert "assistant" == choice["message"]["role"]
|
||||||
assert match_regex(re_content, choice["message"]["content"])
|
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
|
||||||
assert choice["finish_reason"] == finish_reason
|
assert choice["finish_reason"] == finish_reason
|
||||||
|
|
||||||
|
|
||||||
@ -169,6 +171,47 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
|
|||||||
assert "error" in res.body
|
assert "error" in res.body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
|
||||||
|
(False, {"const": "42"}, 6, "\"42\""),
|
||||||
|
(True, {"const": "42"}, 6, "\"42\""),
|
||||||
|
])
|
||||||
|
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
|
||||||
|
global server
|
||||||
|
server.jinja = jinja
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
"max_tokens": n_predicted,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a coding assistant."},
|
||||||
|
{"role": "user", "content": "Write an example"},
|
||||||
|
],
|
||||||
|
"json_schema": json_schema,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200, f'Expected 200, got {res.status_code}'
|
||||||
|
choice = res.body["choices"][0]
|
||||||
|
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
|
||||||
|
(False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
|
||||||
|
(True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
|
||||||
|
])
|
||||||
|
def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
|
||||||
|
global server
|
||||||
|
server.jinja = jinja
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
"max_tokens": n_predicted,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Does not matter what I say, does it?"},
|
||||||
|
],
|
||||||
|
"grammar": grammar,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200, res.body
|
||||||
|
choice = res.body["choices"][0]
|
||||||
|
assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("messages", [
|
@pytest.mark.parametrize("messages", [
|
||||||
None,
|
None,
|
||||||
"string",
|
"string",
|
||||||
|
@ -356,12 +356,12 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
|
|||||||
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||||
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||||
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||||
("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||||
|
|
||||||
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
|
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
|
||||||
("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||||
])
|
])
|
||||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
||||||
global server
|
global server
|
||||||
@ -401,7 +401,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
|
|||||||
{
|
{
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"name": "calculate",
|
"name": "calculate",
|
||||||
"content": 0.55644242476,
|
"content": "0.55644242476",
|
||||||
"tool_call_id": "call_6789"
|
"tool_call_id": "call_6789"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -444,7 +444,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
|
|||||||
(128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
(128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||||
|
|
||||||
(1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
(1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
(1024, 'none', "<think>\n?I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
(1024, 'none', "^I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
|
|
||||||
(1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
(1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||||
])
|
])
|
||||||
|
@ -12,9 +12,7 @@
|
|||||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||||
#define JSON_ASSERT GGML_ASSERT
|
#define JSON_ASSERT GGML_ASSERT
|
||||||
#include "json.hpp"
|
#include "json.hpp"
|
||||||
#include "minja.hpp"
|
#include "chat.h"
|
||||||
#include "chat.hpp"
|
|
||||||
#include "chat-template.hpp"
|
|
||||||
|
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
@ -347,41 +345,6 @@ static llama_tokens format_infill(
|
|||||||
return embd_inp;
|
return embd_inp;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
|
||||||
inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
|
|
||||||
std::vector<common_chat_msg> chat;
|
|
||||||
|
|
||||||
for (size_t i = 0; i < messages.size(); ++i) {
|
|
||||||
const auto & curr_msg = messages[i];
|
|
||||||
|
|
||||||
std::string role = json_value(curr_msg, "role", std::string(""));
|
|
||||||
|
|
||||||
std::string content;
|
|
||||||
if (curr_msg.contains("content")) {
|
|
||||||
if (curr_msg["content"].is_string()) {
|
|
||||||
content = curr_msg["content"].get<std::string>();
|
|
||||||
} else if (curr_msg["content"].is_array()) {
|
|
||||||
for (const auto & part : curr_msg["content"]) {
|
|
||||||
if (part.contains("text")) {
|
|
||||||
content += "\n" + part["text"].get<std::string>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
|
|
||||||
}
|
|
||||||
|
|
||||||
chat.push_back({role, content, /* tool_calls= */ {}});
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
|
|
||||||
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
|
|
||||||
|
|
||||||
return formatted_chat;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// base64 utils (TODO: move to common in the future)
|
// base64 utils (TODO: move to common in the future)
|
||||||
//
|
//
|
||||||
@ -579,12 +542,9 @@ static json oaicompat_completion_params_parse(
|
|||||||
const json & body, /* openai api json semantics */
|
const json & body, /* openai api json semantics */
|
||||||
bool use_jinja,
|
bool use_jinja,
|
||||||
common_reasoning_format reasoning_format,
|
common_reasoning_format reasoning_format,
|
||||||
const common_chat_templates & chat_templates)
|
const struct common_chat_templates * tmpls)
|
||||||
{
|
{
|
||||||
json llama_params;
|
json llama_params;
|
||||||
const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use
|
|
||||||
? *chat_templates.template_tool_use
|
|
||||||
: *chat_templates.template_default;
|
|
||||||
|
|
||||||
auto tools = json_value(body, "tools", json());
|
auto tools = json_value(body, "tools", json());
|
||||||
auto stream = json_value(body, "stream", false);
|
auto stream = json_value(body, "stream", false);
|
||||||
@ -610,43 +570,42 @@ static json oaicompat_completion_params_parse(
|
|||||||
llama_params["stop"] = json_value(body, "stop", json::array());
|
llama_params["stop"] = json_value(body, "stop", json::array());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto json_schema = json_value(body, "json_schema", json());
|
||||||
|
auto grammar = json_value(body, "grammar", std::string());
|
||||||
|
if (!json_schema.is_null() && !grammar.empty()) {
|
||||||
|
throw std::runtime_error("Cannot use both json_schema and grammar");
|
||||||
|
}
|
||||||
|
|
||||||
// Handle "response_format" field
|
// Handle "response_format" field
|
||||||
if (body.contains("response_format")) {
|
if (body.contains("response_format")) {
|
||||||
json response_format = json_value(body, "response_format", json::object());
|
json response_format = json_value(body, "response_format", json::object());
|
||||||
std::string response_type = json_value(response_format, "type", std::string());
|
std::string response_type = json_value(response_format, "type", std::string());
|
||||||
if (response_type == "json_object") {
|
if (response_type == "json_object") {
|
||||||
llama_params["json_schema"] = json_value(response_format, "schema", json::object());
|
json_schema = json_value(response_format, "schema", json::object());
|
||||||
} else if (response_type == "json_schema") {
|
} else if (response_type == "json_schema") {
|
||||||
json json_schema = json_value(response_format, "json_schema", json::object());
|
json json_schema = json_value(response_format, "json_schema", json::object());
|
||||||
llama_params["json_schema"] = json_value(json_schema, "schema", json::object());
|
json_schema = json_value(json_schema, "schema", json::object());
|
||||||
} else if (!response_type.empty() && response_type != "text") {
|
} else if (!response_type.empty() && response_type != "text") {
|
||||||
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
|
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply chat template to the list of messages
|
common_chat_templates_inputs inputs;
|
||||||
if (use_jinja) {
|
inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages"));
|
||||||
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
|
inputs.tools = common_chat_tools_parse_oaicompat(tools);
|
||||||
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") {
|
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
|
||||||
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
|
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
|
||||||
}
|
inputs.grammar = grammar;
|
||||||
if (tool_choice != "none" && llama_params.contains("grammar")) {
|
inputs.add_generation_prompt = true;
|
||||||
|
inputs.use_jinja = use_jinja;
|
||||||
|
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||||
|
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
||||||
|
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
|
||||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||||
}
|
}
|
||||||
common_chat_inputs inputs;
|
|
||||||
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
// Apply chat template to the list of messages
|
||||||
inputs.messages = body.at("messages");
|
auto chat_params = common_chat_templates_apply(tmpls, inputs);
|
||||||
inputs.tools = tools;
|
|
||||||
inputs.tool_choice = tool_choice;
|
|
||||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
|
||||||
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
|
|
||||||
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
|
|
||||||
inputs.parallel_tool_calls = false;
|
|
||||||
}
|
|
||||||
inputs.stream = stream;
|
|
||||||
// TODO: support mixing schema w/ tools beyond generic format.
|
|
||||||
inputs.json_schema = json_value(llama_params, "json_schema", json());
|
|
||||||
auto chat_params = common_chat_params_init(tmpl, inputs);
|
|
||||||
|
|
||||||
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
llama_params["chat_format"] = static_cast<int>(chat_params.format);
|
||||||
llama_params["prompt"] = chat_params.prompt;
|
llama_params["prompt"] = chat_params.prompt;
|
||||||
@ -664,9 +623,6 @@ static json oaicompat_completion_params_parse(
|
|||||||
for (const auto & stop : chat_params.additional_stops) {
|
for (const auto & stop : chat_params.additional_stops) {
|
||||||
llama_params["stop"].push_back(stop);
|
llama_params["stop"].push_back(stop);
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
llama_params["prompt"] = format_chat(tmpl, body.at("messages"));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle "n" field
|
// Handle "n" field
|
||||||
int n_choices = json_value(body, "n", 1);
|
int n_choices = json_value(body, "n", 1);
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
#undef NDEBUG
|
#undef NDEBUG
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "chat-template.hpp"
|
#include "chat.h"
|
||||||
|
|
||||||
static std::string normalize_newlines(const std::string & s) {
|
static std::string normalize_newlines(const std::string & s) {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
@ -18,6 +19,13 @@ static std::string normalize_newlines(const std::string & s) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
|
||||||
|
common_chat_msg msg;
|
||||||
|
msg.role = role;
|
||||||
|
msg.content = content;
|
||||||
|
return msg;
|
||||||
|
}
|
||||||
|
|
||||||
int main(void) {
|
int main(void) {
|
||||||
std::vector<llama_chat_message> conversation {
|
std::vector<llama_chat_message> conversation {
|
||||||
{"system", "You are a helpful assistant"},
|
{"system", "You are a helpful assistant"},
|
||||||
@ -50,7 +58,7 @@ int main(void) {
|
|||||||
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
|
||||||
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
|
||||||
/* .expected_output_jinja= */ "",
|
/* .expected_output_jinja= */ "",
|
||||||
/* .bos_token= */ "",
|
/* .bos_token= */ "<s>",
|
||||||
/* .eos_token= */ "</s>",
|
/* .eos_token= */ "</s>",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -73,7 +81,7 @@ int main(void) {
|
|||||||
/* .name= */ "mlabonne/AlphaMonarch-7B",
|
/* .name= */ "mlabonne/AlphaMonarch-7B",
|
||||||
/* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
|
/* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
|
||||||
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
|
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
|
||||||
/* .expected_output_jinja= */ "<s>system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
|
/* .expected_output_jinja= */ "",
|
||||||
/* .bos_token= */ "<s>",
|
/* .bos_token= */ "<s>",
|
||||||
/* .eos_token= */ "</s>",
|
/* .eos_token= */ "</s>",
|
||||||
},
|
},
|
||||||
@ -87,7 +95,7 @@ int main(void) {
|
|||||||
/* .name= */ "OrionStarAI/Orion-14B-Chat",
|
/* .name= */ "OrionStarAI/Orion-14B-Chat",
|
||||||
/* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
|
/* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
|
||||||
/* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
|
/* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
|
||||||
/* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
|
/* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: ",
|
||||||
/* .bos_token= */ "",
|
/* .bos_token= */ "",
|
||||||
/* .eos_token= */ "</s>",
|
/* .eos_token= */ "</s>",
|
||||||
},
|
},
|
||||||
@ -304,12 +312,9 @@ int main(void) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
json messages = json::array();
|
std::vector<common_chat_msg> messages;
|
||||||
for (const auto & msg : conversation) {
|
for (const auto & msg : conversation) {
|
||||||
messages.push_back({
|
messages.push_back(simple_msg(msg.role, msg.content));
|
||||||
{"role", msg.role},
|
|
||||||
{"content", msg.content},
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
for (const auto & test_case : test_cases) {
|
for (const auto & test_case : test_cases) {
|
||||||
if (!test_case.supported_with_jinja) {
|
if (!test_case.supported_with_jinja) {
|
||||||
@ -317,8 +322,13 @@ int main(void) {
|
|||||||
}
|
}
|
||||||
printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
|
printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
|
||||||
try {
|
try {
|
||||||
minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token);
|
auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token);
|
||||||
auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt));
|
common_chat_templates_inputs inputs;
|
||||||
|
inputs.use_jinja = true;
|
||||||
|
inputs.messages = messages;
|
||||||
|
inputs.add_generation_prompt = add_generation_prompt;
|
||||||
|
auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;
|
||||||
|
output = normalize_newlines(output);
|
||||||
auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
|
auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
|
||||||
if (output != expected_output) {
|
if (output != expected_output) {
|
||||||
printf("Expected:\n%s\n", expected_output.c_str());
|
printf("Expected:\n%s\n", expected_output.c_str());
|
||||||
@ -336,11 +346,11 @@ int main(void) {
|
|||||||
// test llama_chat_format_single for system message
|
// test llama_chat_format_single for system message
|
||||||
printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
|
printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
|
||||||
std::vector<common_chat_msg> chat2;
|
std::vector<common_chat_msg> chat2;
|
||||||
common_chat_msg sys_msg{"system", "You are a helpful assistant", {}};
|
auto sys_msg = simple_msg("system", "You are a helpful assistant");
|
||||||
|
|
||||||
auto fmt_sys = [&](std::string tmpl_str) {
|
auto fmt_sys = [&](std::string tmpl_str) {
|
||||||
minja::chat_template tmpl(tmpl_str, "", "");
|
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
|
||||||
auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false);
|
auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false);
|
||||||
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
||||||
printf("-------------------------\n");
|
printf("-------------------------\n");
|
||||||
return output;
|
return output;
|
||||||
@ -360,14 +370,14 @@ int main(void) {
|
|||||||
|
|
||||||
// test llama_chat_format_single for user message
|
// test llama_chat_format_single for user message
|
||||||
printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
|
printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
|
||||||
chat2.push_back({"system", "You are a helpful assistant", {}});
|
chat2.push_back(simple_msg("system", "You are a helpful assistant"));
|
||||||
chat2.push_back({"user", "Hello", {}});
|
chat2.push_back(simple_msg("user", "Hello"));
|
||||||
chat2.push_back({"assistant", "I am assistant", {}});
|
chat2.push_back(simple_msg("assistant", "I am assistant"));
|
||||||
common_chat_msg new_msg{"user", "How are you", {}};
|
auto new_msg = simple_msg("user", "How are you");
|
||||||
|
|
||||||
auto fmt_single = [&](std::string tmpl_str) {
|
auto fmt_single = [&](const std::string & tmpl_str) {
|
||||||
minja::chat_template tmpl(tmpl_str, "", "");
|
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str());
|
||||||
auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false);
|
auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false);
|
||||||
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
|
||||||
printf("-------------------------\n");
|
printf("-------------------------\n");
|
||||||
return output;
|
return output;
|
||||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user