Merge branch 'master' into gg/llama-kv-cache

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-20 14:26:43 +02:00
34 changed files with 2103 additions and 1218 deletions

1
.gitignore vendored
View File

@ -98,6 +98,7 @@ examples/server/*.css.hpp
examples/server/*.html.hpp examples/server/*.html.hpp
examples/server/*.js.hpp examples/server/*.js.hpp
examples/server/*.mjs.hpp examples/server/*.mjs.hpp
examples/server/*.gz.hpp
!build_64.sh !build_64.sh
!examples/*.bat !examples/*.bat
!examples/*/*.kts !examples/*/*.kts

View File

@ -1,5 +1,6 @@
# Pull requests (for contributors) # Pull requests (for contributors)
- llama.cpp uses the ggml tensor library for model evaluation. If you are unfamiliar with ggml, consider taking a look at the [examples in the ggml repository](https://github.com/ggml-org/ggml/tree/master/examples/). [simple](https://github.com/ggml-org/ggml/tree/master/examples/simple) shows the bare minimum for using ggml. [gpt-2](https://github.com/ggml-org/ggml/tree/master/examples/gpt-2) has minimal implementations for language model inference using GPT-2. [mnist](https://github.com/ggml-org/ggml/tree/master/examples/mnist) demonstrates how to train and evaluate a simple image classifier
- Test your changes: - Test your changes:
- Execute [the full CI locally on your machine](ci/README.md) before publishing - Execute [the full CI locally on your machine](ci/README.md) before publishing
- Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`) - Verify that the perplexity and the performance are not affected negatively by your changes (use `llama-perplexity` and `llama-bench`)

View File

@ -1364,7 +1364,7 @@ llama-server: \
examples/server/index.html.hpp \ examples/server/index.html.hpp \
examples/server/loading.html.hpp \ examples/server/loading.html.hpp \
common/chat.cpp \ common/chat.cpp \
common/chat.hpp \ common/chat.h \
common/chat-template.hpp \ common/chat-template.hpp \
common/json.hpp \ common/json.hpp \
common/minja.hpp \ common/minja.hpp \

View File

@ -57,8 +57,7 @@ add_library(${TARGET} STATIC
arg.h arg.h
base64.hpp base64.hpp
chat.cpp chat.cpp
chat.hpp chat.h
chat-template.hpp
common.cpp common.cpp
common.h common.h
console.cpp console.cpp
@ -68,7 +67,8 @@ add_library(${TARGET} STATIC
llguidance.cpp llguidance.cpp
log.cpp log.cpp
log.h log.h
minja.hpp minja/chat-template.hpp
minja/minja.hpp
ngram-cache.cpp ngram-cache.cpp
ngram-cache.h ngram-cache.h
sampling.cpp sampling.cpp

View File

@ -2,6 +2,7 @@
#include "log.h" #include "log.h"
#include "sampling.h" #include "sampling.h"
#include "chat.h"
#include <algorithm> #include <algorithm>
#include <climits> #include <climits>
@ -2501,5 +2502,53 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--fim-qwen-1.5b-default"},
string_format("use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet)"),
[](common_params & params) {
params.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF";
params.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf";
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--fim-qwen-3b-default"},
string_format("use default Qwen 2.5 Coder 3B (note: can download weights from the internet)"),
[](common_params & params) {
params.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF";
params.hf_file = "qwen2.5-coder-3b-q8_0.gguf";
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--fim-qwen-7b-default"},
string_format("use default Qwen 2.5 Coder 7B (note: can download weights from the internet)"),
[](common_params & params) {
params.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
params.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024;
params.n_batch = 1024;
params.n_ctx = 0;
params.n_cache_reuse = 256;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
return ctx_arg; return ctx_arg;
} }

View File

@ -1,8 +1,433 @@
#include "chat.hpp" #include "chat.h"
#include "chat-template.hpp"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "log.h" #include "log.h"
#include "minja.hpp" #include "minja/chat-template.hpp"
#include "minja/minja.hpp"
#include <optional>
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
};
struct templates_params {
json messages;
json tools;
common_chat_tool_choice tool_choice;
json json_schema;
bool parallel_tool_calls;
bool stream;
std::string grammar;
bool add_generation_prompt = true;
bool extract_reasoning = true;
};
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) {
if (tool_choice == "auto") {
return COMMON_CHAT_TOOL_CHOICE_AUTO;
}
if (tool_choice == "none") {
return COMMON_CHAT_TOOL_CHOICE_NONE;
}
if (tool_choice == "required") {
return COMMON_CHAT_TOOL_CHOICE_REQUIRED;
}
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
}
template <>
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
std::vector<common_chat_msg> msgs;
try {
if (!messages.is_array()) {
throw std::runtime_error("Expected 'messages' to be an array, got " + messages.dump());
}
for (const auto & message : messages) {
if (!message.is_object()) {
throw std::runtime_error("Expected 'message' to be an object, got " + message.dump());
}
common_chat_msg msg;
if (!message.contains("role")) {
throw std::runtime_error("Missing 'role' in message: " + message.dump());
}
msg.role = message.at("role");
if (message.contains("content")) {
const auto & content = message.at("content");
if (content.is_string()) {
msg.content = content;
} else if (content.is_array()) {
for (const auto & part : content) {
if (!part.contains("type")) {
throw std::runtime_error("Missing content part type: " + part.dump());
}
const auto & type = part.at("type");
if (type != "text") {
throw std::runtime_error("Unsupported content part type: " + type.dump());
}
common_chat_msg_content_part msg_part;
msg_part.type = type;
msg_part.text = part.at("text");
msg.content_parts.push_back(msg_part);
}
} else if (!content.is_null()) {
throw std::runtime_error("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
} else {
throw std::runtime_error("Expected 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
if (message.contains("reasoning_content")) {
msg.reasoning_content = message.at("reasoning_content");
}
if (message.contains("name")) {
msg.tool_name = message.at("name");
}
if (message.contains("tool_call_id")) {
msg.tool_call_id = message.at("tool_call_id");
}
if (message.contains("tool_calls")) {
for (const auto & tool_call : message.at("tool_calls")) {
common_chat_tool_call tc;
if (!tool_call.contains("type")) {
throw std::runtime_error("Missing tool call type: " + tool_call.dump());
}
const auto & type = tool_call.at("type");
if (type != "function") {
throw std::runtime_error("Unsupported tool call type: " + tool_call.dump());
}
if (!tool_call.contains("function")) {
throw std::runtime_error("Missing tool call function: " + tool_call.dump());
}
const auto & fc = tool_call.at("function");
if (!fc.contains("name")) {
throw std::runtime_error("Missing tool call name: " + tool_call.dump());
}
tc.name = fc.at("name");
tc.arguments = fc.at("arguments");
if (tool_call.contains("id")) {
tc.id = tool_call.at("id");
}
msg.tool_calls.push_back(tc);
}
}
msgs.push_back(msg);
}
} catch (const std::exception & e) {
throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2));
}
return msgs;
}
template <>
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
json messages = json::array();
for (const auto & msg : msgs) {
if (!msg.content.empty() && !msg.content_parts.empty()) {
throw std::runtime_error("Cannot specify both content and content_parts");
}
json jmsg {
{"role", msg.role},
};
if (!msg.content.empty()) {
jmsg["content"] = msg.content;
} else if (!msg.content_parts.empty()) {
if (concat_typed_text) {
std::string text;
for (const auto & part : msg.content_parts) {
if (part.type != "text") {
LOG_WRN("Ignoring content part type: %s\n", part.type.c_str());
continue;
}
if (!text.empty()) {
text += '\n';
}
text += part.text;
}
jmsg["content"] = text;
} else {
auto & parts = jmsg["content"] = json::array();
for (const auto & part : msg.content_parts) {
parts.push_back({
{"type", part.type},
{"text", part.text},
});
}
}
} else {
jmsg["content"] = json(); // null
}
if (!msg.reasoning_content.empty()) {
jmsg["reasoning_content"] = msg.reasoning_content;
}
if (!msg.tool_name.empty()) {
jmsg["name"] = msg.tool_name;
}
if (!msg.tool_call_id.empty()) {
jmsg["tool_call_id"] = msg.tool_call_id;
}
if (!msg.tool_calls.empty()) {
auto & tool_calls = jmsg["tool_calls"] = json::array();
for (const auto & tool_call : msg.tool_calls) {
json tc {
{"type", "function"},
{"function", {
{"name", tool_call.name},
{"arguments", tool_call.arguments},
}},
};
if (!tool_call.id.empty()) {
tc["id"] = tool_call.id;
}
tool_calls.push_back(tc);
}
}
messages.push_back(jmsg);
}
return messages;
}
template <>
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const std::string & messages) {
return common_chat_msgs_parse_oaicompat(json::parse(messages));
}
template <>
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
std::vector<common_chat_tool> result;
try {
if (!tools.is_null()) {
if (!tools.is_array()) {
throw std::runtime_error("Expected 'tools' to be an array, got " + tools.dump());
}
for (const auto & tool : tools) {
if (!tool.contains("type")) {
throw std::runtime_error("Missing tool type: " + tool.dump());
}
const auto & type = tool.at("type");
if (!type.is_string() || type != "function") {
throw std::runtime_error("Unsupported tool type: " + tool.dump());
}
if (!tool.contains("function")) {
throw std::runtime_error("Missing tool function: " + tool.dump());
}
const auto & function = tool.at("function");
result.push_back({
/* .name = */ function.at("name"),
/* .description = */ function.at("description"),
/* .parameters = */ function.at("parameters").dump(),
});
}
}
} catch (const std::exception & e) {
throw std::runtime_error("Failed to parse tools: " + std::string(e.what()) + "; tools = " + tools.dump(2));
}
return result;
}
template <>
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const std::string & tools) {
return common_chat_tools_parse_oaicompat(json::parse(tools));
}
template <>
json common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools) {
if (tools.empty()) {
return json();
}
auto result = json::array();
for (const auto & tool : tools) {
result.push_back({
{"type", "function"},
{"function", {
{"name", tool.name},
{"description", tool.description},
{"parameters", json::parse(tool.parameters)},
}},
});
}
return result;
}
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
common_chat_msg msg;
msg.role = "user";
msg.content = "test";
auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl);
common_chat_templates_inputs inputs;
inputs.messages = {msg};
common_chat_templates_apply(tmpls.get(), inputs);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}};
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
std::string common_chat_format_single(
const struct common_chat_templates * tmpls,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja) {
common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
std::string fmt_past_msg;
if (!past_msg.empty()) {
inputs.messages = past_msg;
inputs.add_generation_prompt = false;
fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt;
}
std::ostringstream ss;
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
ss << "\n";
};
// format chat with new_msg
inputs.messages.push_back(new_msg);
inputs.add_generation_prompt = add_ass;
auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt;
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja) {
common_chat_templates_inputs inputs;
inputs.use_jinja = use_jinja;
auto add_simple_msg = [&](auto role, auto content) {
common_chat_msg msg;
msg.role = role;
msg.content = content;
inputs.messages.push_back(msg);
};
add_simple_msg("system", "You are a helpful assistant");
add_simple_msg("user", "Hello");
add_simple_msg("assistant", "Hi there");
add_simple_msg("user", "How are you?");
return common_chat_templates_apply(tmpls, inputs).prompt;
}
#define CHATML_TEMPLATE_SRC \
"{%- for message in messages -%}\n" \
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
"{%- endfor -%}\n" \
"{%- if add_generation_prompt -%}\n" \
" {{- '<|im_start|>assistant\n' -}}\n" \
"{%- endif -%}"
void common_chat_templates_free(struct common_chat_templates * tmpls) {
delete tmpls;
}
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls) {
return tmpls->has_explicit_template;
}
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant) {
if (variant != nullptr) {
if (strcmp(variant, "tool_use") == 0) {
if (tmpls->template_tool_use) {
return tmpls->template_tool_use->source().c_str();
}
return nullptr;
} else {
LOG_DBG("%s: unknown template variant: %s\n", __func__, variant);
}
}
return tmpls->template_default->source().c_str();
}
common_chat_templates_ptr common_chat_templates_init(
const struct llama_model * model,
const std::string & chat_template_override,
const std::string & bos_token_override,
const std::string & eos_token_override)
{
std::string default_template_src;
std::string template_tool_use_src;
bool has_explicit_template = !chat_template_override.empty();
if (chat_template_override.empty()) {
GGML_ASSERT(model != nullptr);
const auto * str = llama_model_chat_template(model, /* name */ nullptr);
if (str) {
default_template_src = str;
has_explicit_template = true;
}
str = llama_model_chat_template(model, /* name */ "tool_use");
if (str) {
template_tool_use_src = str;
has_explicit_template = true;
}
} else {
default_template_src = chat_template_override;
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!template_tool_use_src.empty()) {
default_template_src = template_tool_use_src;
} else {
default_template_src = CHATML_TEMPLATE_SRC;
}
}
std::string token_bos = bos_token_override;
std::string token_eos = eos_token_override;
if (model) {
const auto * vocab = llama_model_get_vocab(model);
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
if (token == LLAMA_TOKEN_NULL) {
if (default_template_src.find(jinja_variable_name) != std::string::npos
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name);
}
return std::string();
}
return common_token_to_piece(vocab, token, true);
};
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
}
common_chat_templates_ptr tmpls(new common_chat_templates());
tmpls->has_explicit_template = has_explicit_template;
try {
tmpls->template_default = std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse chat template (defaulting to chatml): %s \n", __func__, e.what());
tmpls->template_default = std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos);
}
if (!template_tool_use_src.empty()) {
try {
tmpls->template_tool_use = std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos);
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse tool use chat template (ignoring it): %s\n", __func__, e.what());
}
}
return tmpls;
}
std::string common_chat_format_name(common_chat_format format) { std::string common_chat_format_name(common_chat_format format) {
switch (format) { switch (format) {
@ -38,22 +463,22 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons
json_error_locator() : position(0), found_error(false) {} json_error_locator() : position(0), found_error(false) {}
bool parse_error(std::size_t position, const std::string &, const json::exception &) override { bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT
this->position = position - 1; this->position = position - 1;
this->found_error = true; this->found_error = true;
return false; return false;
} }
bool null() override { return true; } bool null() override { return true; } // NOLINT
bool boolean(bool) override { return true; } bool boolean(bool) override { return true; } // NOLINT
bool number_integer(number_integer_t) override { return true; } bool number_integer(number_integer_t) override { return true; } // NOLINT
bool number_unsigned(number_unsigned_t) override { return true; } bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT
bool number_float(number_float_t, const string_t &) override { return true; } bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT
bool string(string_t &) override { return true; } bool string(string_t &) override { return true; } // NOLINT
bool binary(binary_t &) override { return true; } bool binary(binary_t &) override { return true; } // NOLINT
bool start_object(std::size_t) override { return true; } bool start_object(std::size_t) override { return true; } // NOLINT
bool key(string_t &) override { return true; } bool key(string_t &) override { return true; } // NOLINT
bool end_object() override { return true; } bool end_object() override { return true; }
bool start_array(std::size_t) override { return true; } bool start_array(std::size_t) override { return true; } // NOLINT
bool end_array() override { return true; } bool end_array() override { return true; }
}; };
json_error_locator err_loc; json_error_locator err_loc;
@ -187,13 +612,20 @@ static std::string apply(
// tmpl_inputs.now = std::chrono::system_clock::now(); // tmpl_inputs.now = std::chrono::system_clock::now();
minja::chat_template_options tmpl_opts; minja::chat_template_options tmpl_opts;
tmpl_opts.use_bos_token = false; // To avoid double BOS / EOS tokens, we're manually removing begining / trailing tokens
tmpl_opts.use_eos_token = false; // instead of using `chat_template_options.use_bos_token = false`, since these tokens
// may be needed inside the template / between messages too.
return tmpl.apply(tmpl_inputs, tmpl_opts); auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
if (string_starts_with(result, tmpl.bos_token())) {
result = result.substr(tmpl.bos_token().size());
}
if (string_ends_with(result, tmpl.eos_token())) {
result = result.substr(0, result.size() - tmpl.eos_token().size());
}
return result;
} }
static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
auto tool_call_schemas = json::array(); auto tool_call_schemas = json::array();
@ -247,7 +679,7 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp
{"required", json::array({"tool_call"})}, {"required", json::array({"tool_call"})},
}; };
const auto schema = const auto schema =
inputs.tool_choice != "required" inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED
? json { ? json {
{"anyOf", json::array({ {"anyOf", json::array({
tool_call, tool_call,
@ -303,9 +735,9 @@ static common_chat_msg common_chat_parse_generic(const std::string & input) {
return result; return result;
} }
static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
@ -348,9 +780,9 @@ static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input)
return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]");
} }
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
@ -455,10 +887,10 @@ static void expect_tool_parameters(const std::string & name, const json & parame
const auto & parameters_required = parameters.at("required"); const auto & parameters_required = parameters.at("required");
for (const auto & prop : expected_properties) { for (const auto & prop : expected_properties) {
if (!parameters_properties.contains(prop)) { if (!parameters_properties.contains(prop)) {
throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT
} }
if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) { if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) {
throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT
} }
} }
if (parameters_properties.size() != expected_properties.size()) { if (parameters_properties.size() != expected_properties.size()) {
@ -466,18 +898,16 @@ static void expect_tool_parameters(const std::string & name, const json & parame
} }
} }
static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct common_chat_inputs & inputs, bool allow_python_tag_builtin_tools) { static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) {
auto builtin_tools = json::array(); auto builtin_tools = json::array();
common_chat_params data; common_chat_params data;
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { auto handle_builtin_tool = [&](const std::string & name, const json & parameters) {
if (name == "wolfram_alpha") { if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py
expect_tool_parameters(name, parameters, {"query"});
} else if (name == "web_search" || name == "brave_search") {
// https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py
expect_tool_parameters(name, parameters, {"query"}); expect_tool_parameters(name, parameters, {"query"});
} else if (name == "python" || name == "code_interpreter") { } else if (name == "python" || name == "code_interpreter") {
@ -489,7 +919,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com
std::vector<std::string> kvs; std::vector<std::string> kvs;
for (const auto & [key, value] : parameters.at("properties").items()) { for (const auto & [key, value] : parameters.at("properties").items()) {
kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT
} }
tool_rules.push_back( tool_rules.push_back(
@ -560,34 +990,33 @@ static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bo
auto arg_value_str = raw_args.substr(it_eq + 1); auto arg_value_str = raw_args.substr(it_eq + 1);
auto arg_value = json::parse(arg_value_str); auto arg_value = json::parse(arg_value_str);
return { common_chat_msg msg;
/* .role = */ "assistant", msg.role = "assistant";
/* .content = */ match.prefix().str(), msg.content = match.prefix().str();
/* .tool_calls = */ { msg.tool_calls.push_back({
{ /* .name = */ name,
/* .name = */ match[1], /* .arguments = */ (json {
/* .arguments = */ (json { {arg_name, arg_value},
{arg_name, arg_value}, }).dump(),
}).dump(), /* .id = */ "",
/* .id = */ "", });
}, return msg;
},
};
} }
} }
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
} }
static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
if (inputs.tools.is_array() && !inputs.tools.empty()) { if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required" && inputs.json_schema.is_null(); data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function"); const auto & function = tool.at("function");
std::string name = function.at("name"); std::string name = function.at("name");
auto parameters = function.at("parameters"); auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
auto args_rule = builder.add_schema(name + "-args", parameters); auto args_rule = builder.add_schema(name + "-args", parameters);
tool_rules.push_back(builder.add_rule(name + "-call", tool_rules.push_back(builder.add_rule(name + "-call",
"\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n" "\"<tool▁call▁begin>function<tool▁sep>" + name + "\\n"
@ -666,15 +1095,15 @@ static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input,
return msg; return msg;
} }
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
fprintf(stderr, "%s\n", __func__); LOG_DBG("%s\n", __func__);
common_chat_params data; common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, {
{"datetime", "Jan 29 2025 13:00:00 GMT"}, {"datetime", "Jan 29 2025 13:00:00 GMT"},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
}); });
if (inputs.tools.is_array() && !inputs.tools.empty()) { if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array(); auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
@ -712,14 +1141,14 @@ static common_chat_msg common_chat_parse_firefunction_v2(const std::string & inp
return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1);
} }
static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) {
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
common_chat_params data; common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2;
if (inputs.tools.is_array() && !inputs.tools.empty()) { if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> first_tool_rules; std::vector<std::string> first_tool_rules;
std::vector<std::string> subsequent_tool_rules; std::vector<std::string> subsequent_tool_rules;
@ -727,6 +1156,7 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_
const auto & function = tool.at("function"); const auto & function = tool.at("function");
std::string name = function.at("name"); std::string name = function.at("name");
auto parameters = function.at("parameters"); auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
auto args_rule = builder.add_schema(name + "-args", parameters); auto args_rule = builder.add_schema(name + "-args", parameters);
first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule));
subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule));
@ -795,14 +1225,14 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in
} }
} }
static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
common_chat_params data; common_chat_params data;
json tools = inputs.tools.is_null() ? inputs.tools : json::array(); json tools = inputs.tools.is_null() ? inputs.tools : json::array();
std::string python_code_argument_name; std::string python_code_argument_name;
auto has_raw_python = false; auto has_raw_python = false;
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
@ -814,7 +1244,7 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con
throw std::runtime_error("Missing type in python tool"); throw std::runtime_error("Missing type in python tool");
} }
has_raw_python = true; has_raw_python = true;
auto type = parameters.at("type"); const auto & type = parameters.at("type");
if (type == "object") { if (type == "object") {
auto properties = parameters.at("properties"); auto properties = parameters.at("properties");
for (auto it = properties.begin(); it != properties.end(); ++it) { for (auto it = properties.begin(); it != properties.end(); ++it) {
@ -854,17 +1284,15 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
std::smatch match; std::smatch match;
if (std::regex_search(input, match, python_tag_regex)) { if (std::regex_search(input, match, python_tag_regex)) {
auto code = match[1].str(); auto code = match[1].str();
return { common_chat_msg msg;
/* .role = */ "assistant", msg.role = "assistant";
/* .content = */ match.prefix().str(), msg.content = match.prefix().str();
/* .tool_calls = */ { msg.tool_calls.push_back({
{ /* .name = */ "python",
/* .name = */ "python", /* .arguments = */ (json {{"code", code}}).dump(),
/* .arguments = */ (json {{"code", code}}).dump(), /* .id = */ "",
/* .id = */ "", });
}, return msg;
}
};
} }
static std::regex function_regex(R"(<function=(\w+)>)"); static std::regex function_regex(R"(<function=(\w+)>)");
static std::regex close_regex(R"(</function>)"); static std::regex close_regex(R"(</function>)");
@ -872,10 +1300,10 @@ static common_chat_msg common_chat_parse_functionary_v3_1_llama_3_1(const std::s
return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex);
} }
static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)* // (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
data.grammar_lazy = inputs.tool_choice != "required"; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules; std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) { foreach_function(inputs.tools, [&](const json & tool) {
@ -908,20 +1336,18 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)"); std::regex middle_pattern(R"([\n\s]*</tool_call>[\n\s]*<tool_call>)");
std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)"); std::regex end_pattern(R"([\n\s]*</tool_call>[\n\s]*$)");
common_chat_msg msg;
msg.role = "assistant";
auto end = input.end(); auto end = input.end();
std::sregex_iterator rend; std::sregex_iterator rend;
std::sregex_iterator rit(input.begin(), end, start_pattern); std::sregex_iterator rit(input.begin(), end, start_pattern);
if (rit == rend) { if (rit == rend) {
return { msg.content = input;
/* .role = */ "assistant", return msg;
/* .content = */ input,
/* .tool_calls = */ {},
};
} }
common_chat_msg result; msg.content = rit->prefix();
result.role = "assistant";
result.content = rit->prefix();
auto it = rit->suffix().first; auto it = rit->suffix().first;
while (it != end) { while (it != end) {
@ -930,7 +1356,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
throw std::runtime_error("Failed to parse json tool call"); throw std::runtime_error("Failed to parse json tool call");
} }
const auto & arguments = call.at("arguments"); const auto & arguments = call.at("arguments");
result.tool_calls.push_back({ msg.tool_calls.push_back({
call.at("name"), call.at("name"),
arguments.dump(), arguments.dump(),
// arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), // arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
@ -947,17 +1373,17 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input)
break; break;
} }
} }
return result; return msg;
} catch (const std::exception & e) { } catch (const std::exception & e) {
return { LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what());
/* .role = */ "assistant", common_chat_msg msg;
/* .content = */ input, msg.role = "assistant";
/* .tool_calls = */ {}, msg.content = input;
}; return msg;
} }
} }
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt);
data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
@ -973,12 +1399,35 @@ static common_chat_params common_chat_params_init_without_tools(const common_cha
return data; return data;
} }
common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & inputs) { static common_chat_params common_chat_templates_apply_jinja(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
templates_params params;
params.tools = common_chat_tools_to_json_oaicompat<json>(inputs.tools);
const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use
? *tmpls->template_tool_use
: *tmpls->template_default;
const auto & src = tmpl.source(); const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps(); const auto & caps = tmpl.original_caps();
params.messages = common_chat_msgs_to_json_oaicompat<json>(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
params.add_generation_prompt = inputs.add_generation_prompt;
params.extract_reasoning = inputs.extract_reasoning;
params.tool_choice = inputs.tool_choice;
params.grammar = inputs.grammar;
if (!inputs.json_schema.empty()) {
params.json_schema = json::parse(inputs.json_schema);
}
if (inputs.tools.is_array()) { if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
if (inputs.tool_choice != "none" && !inputs.grammar.empty()) { LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
params.parallel_tool_calls = false;
} else {
params.parallel_tool_calls = inputs.parallel_tool_calls;
}
if (params.tools.is_array()) {
if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) {
throw std::runtime_error("Cannot specify grammar with tools"); throw std::runtime_error("Cannot specify grammar with tools");
} }
if (caps.supports_tool_calls && !caps.supports_tools) { if (caps.supports_tool_calls && !caps.supports_tools) {
@ -987,68 +1436,135 @@ common_chat_params common_chat_params_init(const common_chat_template & tmpl, co
} }
// DeepSeek R1: use handler in all cases except json schema (thinking / tools). // DeepSeek R1: use handler in all cases except json schema (thinking / tools).
if (src.find("<tool▁calls▁begin>") != std::string::npos && inputs.json_schema.is_null()) { if (src.find("<tool▁calls▁begin>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_deepseek_r1(tmpl, inputs); return common_chat_params_init_deepseek_r1(tmpl, params);
} }
// Command R7B: : use handler in all cases except json schema (thinking / tools). // Command R7B: : use handler in all cases except json schema (thinking / tools).
if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && inputs.json_schema.is_null()) { if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_command_r7b(tmpl, inputs); return common_chat_params_init_command_r7b(tmpl, params);
} }
// Use generic handler when mixing tools + JSON schema. // Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below. // TODO: support that mix in handlers below.
if ((!inputs.tools.is_array() && inputs.json_schema.is_object())) { if ((params.tools.is_array() && params.json_schema.is_object())) {
return common_chat_params_init_generic(tmpl, inputs); return common_chat_params_init_generic(tmpl, params);
} }
// Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases. // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases.
if (src.find(">>>all") != std::string::npos) { if (src.find(">>>all") != std::string::npos) {
return common_chat_params_init_functionary_v3_2(tmpl, inputs); return common_chat_params_init_functionary_v3_2(tmpl, params);
} }
// Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases. // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases.
if (src.find(" functools[") != std::string::npos) { if (src.find(" functools[") != std::string::npos) {
return common_chat_params_init_firefunction_v2(tmpl, inputs); return common_chat_params_init_firefunction_v2(tmpl, params);
} }
// Plain handler (no tools) // Plain handler (no tools)
if (inputs.tools.is_null() || inputs.tool_choice == "none") { if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, inputs); return common_chat_params_init_without_tools(tmpl, params);
} }
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos) { if (src.find("<tool_call>") != std::string::npos) {
return common_chat_params_init_hermes_2_pro(tmpl, inputs); return common_chat_params_init_hermes_2_pro(tmpl, params);
} }
// Functionary v3.1 (w/ tools) // Functionary v3.1 (w/ tools)
if (src.find("<|start_header_id|>") != std::string::npos if (src.find("<|start_header_id|>") != std::string::npos
&& src.find("<function=") != std::string::npos) { && src.find("<function=") != std::string::npos) {
return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, inputs); return common_chat_params_init_functionary_v3_1_llama_3_1(tmpl, params);
} }
// Llama 3.1, 3.2, 3.3 (w/ tools) // Llama 3.1, 3.2, 3.3 (w/ tools)
if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) { if (src.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos;
return common_chat_params_init_llama_3_1_tool_calls(tmpl, inputs, allow_python_tag_builtin_tools); return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools);
} }
// Mistral Nemo (w/ tools) // Mistral Nemo (w/ tools)
if (src.find("[TOOL_CALLS]") != std::string::npos) { if (src.find("[TOOL_CALLS]") != std::string::npos) {
return common_chat_params_init_mistral_nemo(tmpl, inputs); return common_chat_params_init_mistral_nemo(tmpl, params);
} }
// Generic fallback // Generic fallback
return common_chat_params_init_generic(tmpl, inputs); return common_chat_params_init_generic(tmpl, params);
}
// Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template.
static common_chat_params common_chat_templates_apply_legacy(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
int alloc_size = 0;
std::vector<llama_chat_message> chat;
std::vector<std::string> contents;
for (const auto & msg : inputs.messages) {
auto content = msg.content;
for (const auto & part : msg.content_parts) {
if (part.type != "text") {
LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str());
continue;
}
if (!content.empty()) {
content += "\n";;
}
content += part.text;
}
contents.emplace_back(std::move(content));
}
for (size_t i = 0; i < contents.size(); ++i) {
const auto & msg = inputs.messages[i];
const auto & content = contents[i];
chat.push_back({msg.role.c_str(), content.c_str()});
alloc_size += (msg.role.size() + content.size()) * 1.25;
}
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
const auto & src = tmpls->template_default->source();
int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size());
}
common_chat_params params;
params.prompt = std::string(buf.data(), res);
if (!inputs.json_schema.empty()) {
params.grammar = json_schema_to_grammar(json::parse(inputs.json_schema));
} else {
params.grammar = inputs.grammar;
}
return params;
}
common_chat_params common_chat_templates_apply(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs)
{
GGML_ASSERT(tmpls != nullptr);
return inputs.use_jinja
? common_chat_templates_apply_jinja(tmpls, inputs)
: common_chat_templates_apply_legacy(tmpls, inputs);
} }
static common_chat_msg common_chat_parse_content_only(const std::string & input) { static common_chat_msg common_chat_parse_content_only(const std::string & input) {
return { common_chat_msg msg;
/* .role = */ "assistant", msg.role = "assistant";
/* .content = */ input, msg.content = input;
/* .tool_calls = */ {}, return msg;
};
} }
common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) { common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) {

134
common/chat.h Normal file
View File

@ -0,0 +1,134 @@
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
#pragma once
#include "common.h"
#include <string>
#include <vector>
struct common_chat_templates;
struct common_chat_tool_call {
std::string name;
std::string arguments;
std::string id;
};
struct common_chat_msg_content_part {
std::string type;
std::string text;
};
struct common_chat_msg {
std::string role;
std::string content;
std::vector<common_chat_msg_content_part> content_parts = {};
std::vector<common_chat_tool_call> tool_calls = {};
std::string reasoning_content;
std::string tool_name;
std::string tool_call_id;
};
struct common_chat_tool {
std::string name;
std::string description;
std::string parameters;
};
enum common_chat_tool_choice {
COMMON_CHAT_TOOL_CHOICE_AUTO,
COMMON_CHAT_TOOL_CHOICE_REQUIRED,
COMMON_CHAT_TOOL_CHOICE_NONE,
};
enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
struct common_chat_templates_inputs {
std::vector<common_chat_msg> messages;
std::string grammar;
std::string json_schema;
bool add_generation_prompt = true;
bool use_jinja = true;
// Parameters below only supported when use_jinja is true
std::vector<common_chat_tool> tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false;
bool extract_reasoning = true;
};
struct common_chat_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::string prompt;
std::string grammar;
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
void common_chat_templates_free(struct common_chat_templates * tmpls);
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
common_chat_templates_ptr common_chat_templates_init(
const struct llama_model * model,
const std::string & chat_template_override,
const std::string & bos_token_override = "",
const std::string & eos_token_override = "");
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
struct common_chat_params common_chat_templates_apply(
const struct common_chat_templates * tmpls,
const struct common_chat_templates_inputs & inputs);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(
const struct common_chat_templates * tmpls,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja);
// Returns an example of formatted chat
std::string common_chat_format_example(
const struct common_chat_templates * tmpls,
bool use_jinja);
std::string common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
// Parses a JSON array of messages in OpenAI's chat completion API format.
// T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
// T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);

View File

@ -1,55 +0,0 @@
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
#pragma once
#include "common.h"
#include <json.hpp>
#include <optional>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
struct common_chat_inputs {
json messages;
json tools;
json tool_choice;
json json_schema;
bool parallel_tool_calls;
bool stream;
std::string grammar;
bool add_generation_prompt = true;
bool extract_reasoning = true;
};
enum common_chat_format {
COMMON_CHAT_FORMAT_CONTENT_ONLY,
COMMON_CHAT_FORMAT_GENERIC,
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
struct common_chat_params {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
json prompt;
std::string grammar;
bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops;
};
struct common_chat_params common_chat_params_init(const common_chat_template & tmpl, const struct common_chat_inputs & params);
std::string common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);

View File

@ -12,8 +12,6 @@
#include "json.hpp" #include "json.hpp"
#include "json-schema-to-grammar.h" #include "json-schema-to-grammar.h"
#include "llama.h" #include "llama.h"
#include "chat.hpp"
#include "chat-template.hpp"
#include <algorithm> #include <algorithm>
#include <cinttypes> #include <cinttypes>
@ -1768,174 +1766,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
return text; return text;
} }
//
// Chat template utils
//
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) {
if (use_jinja) {
try {
auto chat_template = common_chat_template(tmpl, "<s>", "</s>");
common_chat_inputs inputs;
inputs.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
common_chat_params_init(chat_template, inputs);
return true;
} catch (const std::exception & e) {
LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what());
return false;
}
}
llama_chat_message chat[] = {{"user", "test"}};
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0;
}
std::string common_chat_apply_template(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & msgs,
bool add_ass,
bool use_jinja) {
if (use_jinja) {
auto messages = json::array();
for (const auto & msg : msgs) {
messages.push_back({{"role", msg.role}, {"content", msg.content}});
}
common_chat_inputs inputs;
inputs.messages = messages;
inputs.add_generation_prompt = add_ass;
return common_chat_params_init(tmpl, inputs).prompt;
}
int alloc_size = 0;
std::vector<llama_chat_message> chat;
for (const auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
}
std::vector<char> buf(alloc_size);
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
}
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size());
}
std::string formatted_chat(buf.data(), res);
return formatted_chat;
}
std::string common_chat_format_single(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja) {
std::ostringstream ss;
auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja);
std::vector<common_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
ss << "\n";
};
// format chat with new_msg
chat_new.push_back(new_msg);
auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja);
// get the diff part
ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
}
std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) {
std::vector<common_chat_msg> msgs = {
{"system", "You are a helpful assistant", {}},
{"user", "Hello", {}},
{"assistant", "Hi there", {}},
{"user", "How are you?", {}},
};
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
}
#define CHATML_TEMPLATE_SRC \
"{%- for message in messages -%}\n" \
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
"{%- endfor -%}\n" \
"{%- if add_generation_prompt -%}\n" \
" {{- '<|im_start|>assistant\n' -}}\n" \
"{%- endif -%}"
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
{
std::string default_template_src;
std::string template_tool_use_src;
bool has_explicit_template = !chat_template_override.empty();
if (chat_template_override.empty()) {
auto str = llama_model_chat_template(model, /* name */ nullptr);
if (str) {
default_template_src = str;
has_explicit_template = true;
}
str = llama_model_chat_template(model, /* name */ "tool_use");
if (str) {
template_tool_use_src = str;
has_explicit_template = true;
}
} else {
default_template_src = chat_template_override;
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!template_tool_use_src.empty()) {
default_template_src = template_tool_use_src;
} else {
default_template_src = CHATML_TEMPLATE_SRC;
}
}
auto vocab = llama_model_get_vocab(model);
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
if (token == LLAMA_TOKEN_NULL) {
if (default_template_src.find(jinja_variable_name) != std::string::npos
|| template_tool_use_src.find(jinja_variable_name) != std::string::npos) {
LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name);
}
return std::string();
} else {
return common_token_to_piece(vocab, token, true);
}
};
auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
try {
return {
has_explicit_template,
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
template_tool_use_src.empty()
? nullptr
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
};
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what());
return {
has_explicit_template,
std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
nullptr,
};
}
}
// //
// KV cache utils // KV cache utils
// //

View File

@ -178,10 +178,10 @@ struct common_params_speculative {
int32_t n_ctx = 0; // draft context size int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
float p_split = 0.1f; // speculative decoding split probability float p_split = 0.1f; // speculative decoding split probability
float p_min = 0.9f; // minimum speculative decoding probability (greedy) float p_min = 0.75f; // minimum speculative decoding probability (greedy)
struct cpu_params cpuparams; struct cpu_params cpuparams;
struct cpu_params cpuparams_batch; struct cpu_params cpuparams_batch;
@ -616,62 +616,6 @@ std::string common_detokenize(
const std::vector<llama_token> & tokens, const std::vector<llama_token> & tokens,
bool special = true); bool special = true);
//
// Chat template utils
//
struct common_tool_call {
std::string name;
std::string arguments;
std::string id;
};
// same with llama_chat_message, but uses std::string
struct common_chat_msg {
std::string role;
std::string content;
std::vector<common_tool_call> tool_calls;
std::string reasoning_content = "";
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
namespace minja {
class chat_template;
}
typedef minja::chat_template common_chat_template;
struct common_chat_templates {
bool has_explicit_template; // Model had builtin template or template overridde was specified.
std::unique_ptr<common_chat_template> template_default; // always set (defaults to chatml)
std::unique_ptr<common_chat_template> template_tool_use;
};
// CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error
std::string common_chat_apply_template(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & chat,
bool add_ass,
bool use_jinja);
// Format single message, while taking into account the position of that message in chat history
std::string common_chat_format_single(
const common_chat_template & tmpl,
const std::vector<common_chat_msg> & past_msg,
const common_chat_msg & new_msg,
bool add_ass,
bool use_jinja);
// Returns an example of formatted chat
std::string common_chat_format_example(
const common_chat_template & tmpl, bool use_jinja);
common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override);
// //
// KV cache utils // KV cache utils
// //

View File

@ -252,11 +252,6 @@ llama_tokens common_speculative_gen_draft(
// add drafted token for each sequence // add drafted token for each sequence
const llama_token id = cur_p->data[0].id; const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
break;
}
common_sampler_accept(smpl, id, true); common_sampler_accept(smpl, id, true);
result.push_back(id); result.push_back(id);
@ -265,6 +260,11 @@ llama_tokens common_speculative_gen_draft(
break; break;
} }
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
break;
}
common_batch_add(batch, id, n_past + i + 1, { 0 }, true); common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
// evaluate the drafted tokens on the draft model // evaluate the drafted tokens on the draft model

View File

@ -9,7 +9,7 @@ struct common_speculative_params {
int n_draft = 16; // max drafted tokens int n_draft = 16; // max drafted tokens
int n_reuse = 256; int n_reuse = 256;
float p_min = 0.9f; // min probability required to accept a token in the draft float p_min = 0.75f; // min probability required to accept a token in the draft
}; };
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);

View File

@ -4,7 +4,7 @@
#include "log.h" #include "log.h"
#include "sampling.h" #include "sampling.h"
#include "llama.h" #include "llama.h"
#include "chat-template.hpp" #include "chat.h"
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
@ -158,7 +158,7 @@ int main(int argc, char ** argv) {
} }
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
auto chat_templates = common_chat_templates_from_model(model, params.chat_template); auto chat_templates = common_chat_templates_init(model, params.chat_template);
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
@ -201,7 +201,7 @@ int main(int argc, char ** argv) {
} }
// auto enable conversation mode if chat template is available // auto enable conversation mode if chat template is available
const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default; const bool has_chat_template = common_chat_templates_was_explicit(chat_templates.get());
if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) { if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) {
if (has_chat_template) { if (has_chat_template) {
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__); LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
@ -219,7 +219,7 @@ int main(int argc, char ** argv) {
// print chat template example in conversation mode // print chat template example in conversation mode
if (params.conversation_mode) { if (params.conversation_mode) {
if (params.enable_chat_template) { if (params.enable_chat_template) {
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str()); LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str());
} else { } else {
LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
} }
@ -264,9 +264,11 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
common_chat_msg new_msg{role, content, {}}; common_chat_msg new_msg;
auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja); new_msg.role = role;
chat_msgs.push_back({role, content, {}}); new_msg.content = content;
auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja);
chat_msgs.push_back(new_msg);
LOG_DBG("formatted: '%s'\n", formatted.c_str()); LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted; return formatted;
}; };
@ -755,11 +757,14 @@ int main(int argc, char ** argv) {
// check for reverse prompt using special tokens // check for reverse prompt using special tokens
llama_token last_token = common_sampler_last(smpl); llama_token last_token = common_sampler_last(smpl);
if (std::find(antiprompt_token.begin(), antiprompt_token.end(), last_token) != antiprompt_token.end()) { for (auto token : antiprompt_token) {
if (params.interactive) { if (token == last_token) {
is_interacting = true; if (params.interactive) {
is_interacting = true;
}
is_antiprompt = true;
break;
} }
is_antiprompt = true;
} }
if (is_antiprompt) { if (is_antiprompt) {

View File

@ -24,7 +24,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "chat-template.hpp" #include "chat.h"
#include "common.h" #include "common.h"
#include "json.hpp" #include "json.hpp"
#include "linenoise.cpp/linenoise.h" #include "linenoise.cpp/linenoise.h"
@ -113,6 +113,7 @@ class Opt {
llama_context_params ctx_params; llama_context_params ctx_params;
llama_model_params model_params; llama_model_params model_params;
std::string model_; std::string model_;
std::string chat_template_file;
std::string user; std::string user;
bool use_jinja = false; bool use_jinja = false;
int context_size = -1, ngl = -1; int context_size = -1, ngl = -1;
@ -148,6 +149,16 @@ class Opt {
return 0; return 0;
} }
int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) {
if (i + 1 >= argc) {
return 1;
}
option_value = argv[++i];
return 0;
}
int parse(int argc, const char ** argv) { int parse(int argc, const char ** argv) {
bool options_parsing = true; bool options_parsing = true;
for (int i = 1, positional_args_i = 0; i < argc; ++i) { for (int i = 1, positional_args_i = 0; i < argc; ++i) {
@ -169,6 +180,11 @@ class Opt {
verbose = true; verbose = true;
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) { } else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
use_jinja = true; use_jinja = true;
} else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0){
if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) {
return 1;
}
use_jinja = true;
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) { } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
help = true; help = true;
return 0; return 0;
@ -207,6 +223,11 @@ class Opt {
"Options:\n" "Options:\n"
" -c, --context-size <value>\n" " -c, --context-size <value>\n"
" Context size (default: %d)\n" " Context size (default: %d)\n"
" --chat-template-file <path>\n"
" Path to the file containing the chat template to use with the model.\n"
" Only supports jinja templates and implicitly sets the --jinja flag.\n"
" --jinja\n"
" Use jinja templating for the chat template of the model\n"
" -n, -ngl, --ngl <value>\n" " -n, -ngl, --ngl <value>\n"
" Number of GPU layers (default: %d)\n" " Number of GPU layers (default: %d)\n"
" --temp <value>\n" " --temp <value>\n"
@ -261,13 +282,12 @@ static int get_terminal_width() {
#endif #endif
} }
#ifdef LLAMA_USE_CURL
class File { class File {
public: public:
FILE * file = nullptr; FILE * file = nullptr;
FILE * open(const std::string & filename, const char * mode) { FILE * open(const std::string & filename, const char * mode) {
file = fopen(filename.c_str(), mode); file = ggml_fopen(filename.c_str(), mode);
return file; return file;
} }
@ -303,6 +323,28 @@ class File {
return 0; return 0;
} }
std::string read_all(const std::string & filename){
open(filename, "r");
lock();
if (!file) {
printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
return "";
}
fseek(file, 0, SEEK_END);
size_t size = ftell(file);
fseek(file, 0, SEEK_SET);
std::string out;
out.resize(size);
size_t read_size = fread(&out[0], 1, size, file);
if (read_size != size) {
printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
return "";
}
return out;
}
~File() { ~File() {
if (fd >= 0) { if (fd >= 0) {
# ifdef _WIN32 # ifdef _WIN32
@ -327,6 +369,7 @@ class File {
# endif # endif
}; };
#ifdef LLAMA_USE_CURL
class HttpClient { class HttpClient {
public: public:
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file, int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
@ -557,7 +600,7 @@ class LlamaData {
llama_model_ptr model; llama_model_ptr model;
llama_sampler_ptr sampler; llama_sampler_ptr sampler;
llama_context_ptr context; llama_context_ptr context;
std::vector<llama_chat_message> messages; std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
std::list<std::string> msg_strs; std::list<std::string> msg_strs;
std::vector<char> fmtted; std::vector<char> fmtted;
@ -834,44 +877,23 @@ static void add_message(const char * role, const std::string & text, LlamaData &
} }
// Function to apply the chat template and resize `formatted` if needed // Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) {
if (use_jinja) { common_chat_templates_inputs inputs;
json messages = json::array(); for (const auto & msg : llama_data.messages) {
for (const auto & msg : llama_data.messages) { common_chat_msg cmsg;
messages.push_back({ cmsg.role = msg.role;
{"role", msg.role}, cmsg.content = msg.content;
{"content", msg.content}, inputs.messages.push_back(cmsg);
});
}
try {
minja::chat_template_inputs tmpl_inputs;
tmpl_inputs.messages = messages;
tmpl_inputs.add_generation_prompt = append;
minja::chat_template_options tmpl_opts;
tmpl_opts.use_bos_token = false;
tmpl_opts.use_eos_token = false;
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
} catch (const std::exception & e) {
printe("failed to render the chat template: %s\n", e.what());
return -1;
}
}
int result = llama_chat_apply_template(
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result);
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size());
} }
inputs.add_generation_prompt = append;
inputs.use_jinja = use_jinja;
return result; auto chat_params = common_chat_templates_apply(tmpls, inputs);
// TODO: use other params for tool calls.
auto result = chat_params.prompt;
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
} }
// Function to tokenize the prompt // Function to tokenize the prompt
@ -1015,8 +1037,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
} }
// Helper function to apply the chat template and handle errors // Helper function to apply the chat template and handle errors
static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja);
if (new_len < 0) { if (new_len < 0) {
printe("failed to apply the chat template\n"); printe("failed to apply the chat template\n");
return -1; return -1;
@ -1074,12 +1096,33 @@ static int get_user_input(std::string & user_input, const std::string & user) {
return 0; return 0;
} }
// Reads a chat template file to be used
static std::string read_chat_template_file(const std::string & chat_template_file) {
if(chat_template_file.empty()){
return "";
}
File file;
std::string chat_template = "";
chat_template = file.read_all(chat_template_file);
if(chat_template.empty()){
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
return "";
}
return chat_template;
}
// Main chat loop function // Main chat loop function
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) { static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
int prev_len = 0; int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
GGML_ASSERT(chat_templates.template_default); std::string chat_template = "";
if(!chat_template_file.empty()){
chat_template = read_chat_template_file(chat_template_file);
}
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
static const bool stdout_a_terminal = is_stdout_a_terminal(); static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) { while (true) {
// Get user input // Get user input
@ -1090,7 +1133,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
add_message("user", user.empty() ? user_input : user, llama_data); add_message("user", user.empty() ? user_input : user, llama_data);
int new_len; int new_len;
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) { if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
return 1; return 1;
} }
@ -1105,7 +1148,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
} }
add_message("assistant", response, llama_data); add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) { if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
return 1; return 1;
} }
} }
@ -1165,7 +1208,7 @@ int main(int argc, const char ** argv) {
return 1; return 1;
} }
if (chat_loop(llama_data, opt.user, opt.use_jinja)) { if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
return 1; return 1;
} }

Binary file not shown.

View File

@ -274,7 +274,7 @@ struct server_task {
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
params.speculative.n_min = std::max(params.speculative.n_min, 2); params.speculative.n_min = std::max(params.speculative.n_min, 0);
params.speculative.n_max = std::max(params.speculative.n_max, 0); params.speculative.n_max = std::max(params.speculative.n_max, 0);
// Use OpenAI API logprobs only if n_probs wasn't provided // Use OpenAI API logprobs only if n_probs wasn't provided
@ -329,9 +329,6 @@ struct server_task {
} }
// process "json_schema" and "grammar" // process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
}
if (data.contains("json_schema") && !data.contains("grammar")) { if (data.contains("json_schema") && !data.contains("grammar")) {
try { try {
auto schema = json_value(data, "json_schema", json::object()); auto schema = json_value(data, "json_schema", json::object());
@ -1807,7 +1804,7 @@ struct server_context {
// Necessary similarity of prompt for slot selection // Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f; float slot_prompt_similarity = 0.0f;
common_chat_templates chat_templates; common_chat_templates_ptr chat_templates;
~server_context() { ~server_context() {
// Clear any sampling context // Clear any sampling context
@ -1891,45 +1888,17 @@ struct server_context {
llama_init_dft.context.reset(); llama_init_dft.context.reset();
} }
if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) { chat_templates = common_chat_templates_init(model, params_base.chat_template);
try {
common_chat_format_example(chat_templates.get(), params.use_jinja);
} catch (const std::exception & e) {
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_from_model(model, "chatml"); chat_templates = common_chat_templates_init(model, "chatml");
} else {
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
} }
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
return true; return true;
} }
bool validate_builtin_chat_template(bool use_jinja) const {
llama_chat_message chat[] = {{"user", "test"}};
if (use_jinja) {
auto templates = common_chat_templates_from_model(model, "");
common_chat_inputs inputs;
inputs.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
GGML_ASSERT(templates.template_default);
try {
common_chat_params_init(*templates.template_default, inputs);
if (templates.template_tool_use) {
common_chat_params_init(*templates.template_tool_use, inputs);
}
return true;
} catch (const std::exception & e) {
SRV_ERR("failed to apply template: %s\n", e.what());
return false;
}
} else {
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0;
}
}
void init() { void init() {
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel; const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
@ -3822,13 +3791,15 @@ int main(int argc, char ** argv) {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel }, { "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model }, { "model_path", ctx_server.params_base.model },
{ "chat_template", ctx_server.chat_templates.template_default->source() }, { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
{ "bos_token", ctx_server.chat_templates.template_default->bos_token() }, { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
{ "eos_token", ctx_server.chat_templates.template_default->eos_token() }, { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
{ "build_info", build_info }, { "build_info", build_info },
}; };
if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) { if (ctx_server.params_base.use_jinja) {
data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source(); if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
data["chat_template_tool_use"] = tool_use_src;
}
} }
res_ok(res, data); res_ok(res, data);
@ -4063,7 +4034,7 @@ int main(int argc, char ** argv) {
} }
auto body = json::parse(req.body); auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
return handle_completions_impl( return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION, SERVER_TASK_TYPE_COMPLETION,
@ -4076,7 +4047,7 @@ int main(int argc, char ** argv) {
// same with handle_chat_completions, but without inference part // same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) { const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body); auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates); json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
}; };
@ -4263,6 +4234,11 @@ int main(int argc, char ** argv) {
// return; // return;
//} //}
// if true, use TEI API format, otherwise use Jina API format
// Jina: https://jina.ai/reranker/
// TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
bool is_tei_format = body.contains("texts");
json query; json query;
if (body.count("query") == 1) { if (body.count("query") == 1) {
query = body.at("query"); query = body.at("query");
@ -4275,7 +4251,8 @@ int main(int argc, char ** argv) {
return; return;
} }
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>()); std::vector<std::string> documents = json_value(body, "documents",
json_value(body, "texts", std::vector<std::string>()));
if (documents.empty()) { if (documents.empty()) {
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
return; return;
@ -4320,7 +4297,12 @@ int main(int argc, char ** argv) {
} }
// write JSON response // write JSON response
json root = format_response_rerank(body, responses); json root = format_response_rerank(
body,
responses,
is_tei_format,
documents);
res_ok(res, root); res_ok(res, root);
}; };
@ -4482,8 +4464,8 @@ int main(int argc, char ** argv) {
// print sample chat example to make it clear which template is used // print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
ctx_server.chat_templates.template_default->source().c_str(), common_chat_templates_source(ctx_server.chat_templates.get()),
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) { ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
ctx_server.process_single_task(task); ctx_server.process_single_task(task);

View File

@ -48,7 +48,7 @@ DEBUG=1 ./tests.sh -s -v -x
To run all the tests in a file: To run all the tests in a file:
```shell ```shell
./tests.sh unit/test_chat_completion.py.py -v -x ./tests.sh unit/test_chat_completion.py -v -x
``` ```
To run a single test: To run a single test:

View File

@ -21,6 +21,8 @@ def create_server():
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
] ]
) )
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
@ -44,7 +46,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
assert res.body["usage"]["completion_tokens"] == n_predicted assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0] choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"] assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"]) assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
assert choice["finish_reason"] == finish_reason assert choice["finish_reason"] == finish_reason
@ -169,6 +171,47 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int
assert "error" in res.body assert "error" in res.body
@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
(False, {"const": "42"}, 6, "\"42\""),
(True, {"const": "42"}, 6, "\"42\""),
])
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"json_schema": json_schema,
})
assert res.status_code == 200, f'Expected 200, got {res.status_code}'
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
(False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
(True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
])
def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "user", "content": "Does not matter what I say, does it?"},
],
"grammar": grammar,
})
assert res.status_code == 200, res.body
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
@pytest.mark.parametrize("messages", [ @pytest.mark.parametrize("messages", [
None, None,
"string", "string",

View File

@ -10,17 +10,20 @@ def create_server():
server = ServerPreset.jina_reranker_tiny() server = ServerPreset.jina_reranker_tiny()
TEST_DOCUMENTS = [
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
]
def test_rerank(): def test_rerank():
global server global server
server.start() server.start()
res = server.make_request("POST", "/rerank", data={ res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is", "query": "Machine learning is",
"documents": [ "documents": TEST_DOCUMENTS,
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
]
}) })
assert res.status_code == 200 assert res.status_code == 200
assert len(res.body["results"]) == 4 assert len(res.body["results"]) == 4
@ -38,6 +41,29 @@ def test_rerank():
assert least_relevant["index"] == 3 assert least_relevant["index"] == 3
def test_rerank_tei_format():
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"texts": TEST_DOCUMENTS,
})
assert res.status_code == 200
assert len(res.body) == 4
most_relevant = res.body[0]
least_relevant = res.body[0]
for doc in res.body:
if doc["score"] > most_relevant["score"]:
most_relevant = doc
if doc["score"] < least_relevant["score"]:
least_relevant = doc
assert most_relevant["score"] > least_relevant["score"]
assert most_relevant["index"] == 2
assert least_relevant["index"] == 3
@pytest.mark.parametrize("documents", [ @pytest.mark.parametrize("documents", [
[], [],
None, None,

View File

@ -356,12 +356,12 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), (None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
(None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("^> 0.56$", 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value) # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
("^The y-coordinate [\\s\\S]*?\\*\\*0.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
("[\\s\\S]*?\\*\\*0\\.5\\*\\*", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
]) ])
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None): def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server global server
@ -401,7 +401,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
{ {
"role": "tool", "role": "tool",
"name": "calculate", "name": "calculate",
"content": 0.55644242476, "content": "0.55644242476",
"tool_call_id": "call_6789" "tool_call_id": "call_6789"
} }
], ],
@ -444,7 +444,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
(128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (128, None, "^The sum of 102 and 7 is 109.*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), (1024, 'deepseek', "To find the sum of.*", "I need to calculate the sum of 102 and 7.*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'none', "<think>\n?I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), (1024, 'none', "^I need[\\s\\S]*?</think>\n?To find.*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), (1024, 'deepseek', "To find the sum of.*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
]) ])

View File

@ -12,9 +12,7 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT: // Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
#include "minja.hpp" #include "chat.h"
#include "chat.hpp"
#include "chat-template.hpp"
#include <random> #include <random>
#include <sstream> #include <sstream>
@ -347,41 +345,6 @@ static llama_tokens format_infill(
return embd_inp; return embd_inp;
} }
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const common_chat_template & tmpl, const std::vector<json> & messages) {
std::vector<common_chat_msg> chat;
for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i];
std::string role = json_value(curr_msg, "role", std::string(""));
std::string content;
if (curr_msg.contains("content")) {
if (curr_msg["content"].is_string()) {
content = curr_msg["content"].get<std::string>();
} else if (curr_msg["content"].is_array()) {
for (const auto & part : curr_msg["content"]) {
if (part.contains("text")) {
content += "\n" + part["text"].get<std::string>();
}
}
} else {
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
} else {
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggml-org/llama.cpp/issues/8367)");
}
chat.push_back({role, content, /* tool_calls= */ {}});
}
const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false);
LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
return formatted_chat;
}
// //
// base64 utils (TODO: move to common in the future) // base64 utils (TODO: move to common in the future)
// //
@ -579,12 +542,9 @@ static json oaicompat_completion_params_parse(
const json & body, /* openai api json semantics */ const json & body, /* openai api json semantics */
bool use_jinja, bool use_jinja,
common_reasoning_format reasoning_format, common_reasoning_format reasoning_format,
const common_chat_templates & chat_templates) const struct common_chat_templates * tmpls)
{ {
json llama_params; json llama_params;
const auto & tmpl = body.contains("tools") && chat_templates.template_tool_use
? *chat_templates.template_tool_use
: *chat_templates.template_default;
auto tools = json_value(body, "tools", json()); auto tools = json_value(body, "tools", json());
auto stream = json_value(body, "stream", false); auto stream = json_value(body, "stream", false);
@ -610,62 +570,58 @@ static json oaicompat_completion_params_parse(
llama_params["stop"] = json_value(body, "stop", json::array()); llama_params["stop"] = json_value(body, "stop", json::array());
} }
auto json_schema = json_value(body, "json_schema", json());
auto grammar = json_value(body, "grammar", std::string());
if (!json_schema.is_null() && !grammar.empty()) {
throw std::runtime_error("Cannot use both json_schema and grammar");
}
// Handle "response_format" field // Handle "response_format" field
if (body.contains("response_format")) { if (body.contains("response_format")) {
json response_format = json_value(body, "response_format", json::object()); json response_format = json_value(body, "response_format", json::object());
std::string response_type = json_value(response_format, "type", std::string()); std::string response_type = json_value(response_format, "type", std::string());
if (response_type == "json_object") { if (response_type == "json_object") {
llama_params["json_schema"] = json_value(response_format, "schema", json::object()); json_schema = json_value(response_format, "schema", json::object());
} else if (response_type == "json_schema") { } else if (response_type == "json_schema") {
json json_schema = json_value(response_format, "json_schema", json::object()); json json_schema = json_value(response_format, "json_schema", json::object());
llama_params["json_schema"] = json_value(json_schema, "schema", json::object()); json_schema = json_value(json_schema, "schema", json::object());
} else if (!response_type.empty() && response_type != "text") { } else if (!response_type.empty() && response_type != "text") {
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
} }
} }
// Apply chat template to the list of messages common_chat_templates_inputs inputs;
if (use_jinja) { inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages"));
auto tool_choice = json_value(body, "tool_choice", std::string("auto")); inputs.tools = common_chat_tools_parse_oaicompat(tools);
if (tool_choice != "none" && tool_choice != "auto" && tool_choice != "required") { inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
throw std::runtime_error("Invalid tool_choice: " + tool_choice); inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
} inputs.grammar = grammar;
if (tool_choice != "none" && llama_params.contains("grammar")) { inputs.add_generation_prompt = true;
throw std::runtime_error("Cannot use custom grammar constraints with tools."); inputs.use_jinja = use_jinja;
} inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
common_chat_inputs inputs; inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE; if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
inputs.messages = body.at("messages"); throw std::runtime_error("Cannot use custom grammar constraints with tools.");
inputs.tools = tools; }
inputs.tool_choice = tool_choice;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) {
LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n");
inputs.parallel_tool_calls = false;
}
inputs.stream = stream;
// TODO: support mixing schema w/ tools beyond generic format.
inputs.json_schema = json_value(llama_params, "json_schema", json());
auto chat_params = common_chat_params_init(tmpl, inputs);
llama_params["chat_format"] = static_cast<int>(chat_params.format); // Apply chat template to the list of messages
llama_params["prompt"] = chat_params.prompt; auto chat_params = common_chat_templates_apply(tmpls, inputs);
llama_params["grammar"] = chat_params.grammar;
llama_params["grammar_lazy"] = chat_params.grammar_lazy; llama_params["chat_format"] = static_cast<int>(chat_params.format);
auto grammar_triggers = json::array(); llama_params["prompt"] = chat_params.prompt;
for (const auto & trigger : chat_params.grammar_triggers) { llama_params["grammar"] = chat_params.grammar;
grammar_triggers.push_back({ llama_params["grammar_lazy"] = chat_params.grammar_lazy;
{"word", trigger.word}, auto grammar_triggers = json::array();
{"at_start", trigger.at_start}, for (const auto & trigger : chat_params.grammar_triggers) {
}); grammar_triggers.push_back({
} {"word", trigger.word},
llama_params["grammar_triggers"] = grammar_triggers; {"at_start", trigger.at_start},
llama_params["preserved_tokens"] = chat_params.preserved_tokens; });
for (const auto & stop : chat_params.additional_stops) { }
llama_params["stop"].push_back(stop); llama_params["grammar_triggers"] = grammar_triggers;
} llama_params["preserved_tokens"] = chat_params.preserved_tokens;
} else { for (const auto & stop : chat_params.additional_stops) {
llama_params["prompt"] = format_chat(tmpl, body.at("messages")); llama_params["stop"].push_back(stop);
} }
// Handle "n" field // Handle "n" field
@ -737,29 +693,51 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
return res; return res;
} }
static json format_response_rerank(const json & request, const json & ranks) { static json format_response_rerank(
json data = json::array(); const json & request,
int32_t n_tokens = 0; const json & ranks,
int i = 0; bool is_tei_format,
for (const auto & rank : ranks) { std::vector<std::string> & texts) {
data.push_back(json{ json res;
{"index", i++}, if (is_tei_format) {
{"relevance_score", json_value(rank, "score", 0.0)}, // TEI response format
}); res = json::array();
bool return_text = json_value(request, "return_text", false);
for (const auto & rank : ranks) {
int index = json_value(rank, "index", 0);
json elem = json{
{"index", index},
{"score", json_value(rank, "score", 0.0)},
};
if (return_text) {
elem["text"] = std::move(texts[index]);
}
res.push_back(elem);
}
} else {
// Jina response format
json results = json::array();
int32_t n_tokens = 0;
for (const auto & rank : ranks) {
results.push_back(json{
{"index", json_value(rank, "index", 0)},
{"relevance_score", json_value(rank, "score", 0.0)},
});
n_tokens += json_value(rank, "tokens_evaluated", 0); n_tokens += json_value(rank, "tokens_evaluated", 0);
}
res = json{
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json{
{"prompt_tokens", n_tokens},
{"total_tokens", n_tokens}
}},
{"results", results}
};
} }
json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json {
{"prompt_tokens", n_tokens},
{"total_tokens", n_tokens}
}},
{"results", data}
};
return res; return res;
} }

View File

@ -159,6 +159,35 @@ export default function ChatMessage({
</div> </div>
</details> </details>
)} )}
{msg.extra && msg.extra.length > 0 && (
<details
className={classNames({
'collapse collapse-arrow mb-4 bg-base-200': true,
'bg-opacity-10': msg.role !== 'assistant',
})}
>
<summary className="collapse-title">
Extra content
</summary>
<div className="collapse-content">
{msg.extra.map(
(extra, i) =>
extra.type === 'textFile' ? (
<div key={extra.name}>
<b>{extra.name}</b>
<pre>{extra.content}</pre>
</div>
) : extra.type === 'context' ? (
<div key={i}>
<pre>{extra.content}</pre>
</div>
) : null // TODO: support other extra types
)}
</div>
</details>
)}
<MarkdownDisplay <MarkdownDisplay
content={content} content={content}
isGenerating={isPending} isGenerating={isPending}

View File

@ -1,10 +1,11 @@
import { useEffect, useMemo, useState } from 'react'; import { useEffect, useMemo, useRef, useState } from 'react';
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context'; import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
import ChatMessage from './ChatMessage'; import ChatMessage from './ChatMessage';
import { CanvasType, Message, PendingMessage } from '../utils/types'; import { CanvasType, Message, PendingMessage } from '../utils/types';
import { classNames, throttle } from '../utils/misc'; import { classNames, throttle } from '../utils/misc';
import CanvasPyInterpreter from './CanvasPyInterpreter'; import CanvasPyInterpreter from './CanvasPyInterpreter';
import StorageUtils from '../utils/storage'; import StorageUtils from '../utils/storage';
import { useVSCodeContext } from '../utils/llama-vscode';
/** /**
* A message display is a message node with additional information for rendering. * A message display is a message node with additional information for rendering.
@ -81,6 +82,14 @@ export default function ChatScreen() {
replaceMessageAndGenerate, replaceMessageAndGenerate,
} = useAppContext(); } = useAppContext();
const [inputMsg, setInputMsg] = useState(''); const [inputMsg, setInputMsg] = useState('');
const inputRef = useRef<HTMLTextAreaElement>(null);
const { extraContext, clearExtraContext } = useVSCodeContext(
inputRef,
setInputMsg
);
// TODO: improve this when we have "upload file" feature
const currExtra: Message['extra'] = extraContext ? [extraContext] : undefined;
// keep track of leaf node for rendering // keep track of leaf node for rendering
const [currNodeId, setCurrNodeId] = useState<number>(-1); const [currNodeId, setCurrNodeId] = useState<number>(-1);
@ -115,10 +124,20 @@ export default function ChatScreen() {
setCurrNodeId(-1); setCurrNodeId(-1);
// get the last message node // get the last message node
const lastMsgNodeId = messages.at(-1)?.msg.id ?? null; const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
if (!(await sendMessage(currConvId, lastMsgNodeId, inputMsg, onChunk))) { if (
!(await sendMessage(
currConvId,
lastMsgNodeId,
inputMsg,
currExtra,
onChunk
))
) {
// restore the input message if failed // restore the input message if failed
setInputMsg(lastInpMsg); setInputMsg(lastInpMsg);
} }
// OK
clearExtraContext();
}; };
const handleEditMessage = async (msg: Message, content: string) => { const handleEditMessage = async (msg: Message, content: string) => {
@ -129,6 +148,7 @@ export default function ChatScreen() {
viewingChat.conv.id, viewingChat.conv.id,
msg.parent, msg.parent,
content, content,
msg.extra,
onChunk onChunk
); );
setCurrNodeId(-1); setCurrNodeId(-1);
@ -143,6 +163,7 @@ export default function ChatScreen() {
viewingChat.conv.id, viewingChat.conv.id,
msg.parent, msg.parent,
null, null,
msg.extra,
onChunk onChunk
); );
setCurrNodeId(-1); setCurrNodeId(-1);
@ -203,6 +224,7 @@ export default function ChatScreen() {
<textarea <textarea
className="textarea textarea-bordered w-full" className="textarea textarea-bordered w-full"
placeholder="Type a message (Shift+Enter to add a new line)" placeholder="Type a message (Shift+Enter to add a new line)"
ref={inputRef}
value={inputMsg} value={inputMsg}
onChange={(e) => setInputMsg(e.target.value)} onChange={(e) => setInputMsg(e.target.value)}
onKeyDown={(e) => { onKeyDown={(e) => {

View File

@ -25,6 +25,7 @@ interface AppContextValue {
convId: string | null, convId: string | null,
leafNodeId: Message['id'] | null, leafNodeId: Message['id'] | null,
content: string, content: string,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk onChunk: CallbackGeneratedChunk
) => Promise<boolean>; ) => Promise<boolean>;
stopGenerating: (convId: string) => void; stopGenerating: (convId: string) => void;
@ -32,6 +33,7 @@ interface AppContextValue {
convId: string, convId: string,
parentNodeId: Message['id'], // the parent node of the message to be replaced parentNodeId: Message['id'], // the parent node of the message to be replaced
content: string | null, content: string | null,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk onChunk: CallbackGeneratedChunk
) => Promise<void>; ) => Promise<void>;
@ -274,6 +276,7 @@ export const AppContextProvider = ({
convId: string | null, convId: string | null,
leafNodeId: Message['id'] | null, leafNodeId: Message['id'] | null,
content: string, content: string,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk onChunk: CallbackGeneratedChunk
): Promise<boolean> => { ): Promise<boolean> => {
if (isGenerating(convId ?? '') || content.trim().length === 0) return false; if (isGenerating(convId ?? '') || content.trim().length === 0) return false;
@ -298,6 +301,7 @@ export const AppContextProvider = ({
convId, convId,
role: 'user', role: 'user',
content, content,
extra,
parent: leafNodeId, parent: leafNodeId,
children: [], children: [],
}, },
@ -324,6 +328,7 @@ export const AppContextProvider = ({
convId: string, convId: string,
parentNodeId: Message['id'], // the parent node of the message to be replaced parentNodeId: Message['id'], // the parent node of the message to be replaced
content: string | null, content: string | null,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk onChunk: CallbackGeneratedChunk
) => { ) => {
if (isGenerating(convId)) return; if (isGenerating(convId)) return;
@ -339,6 +344,7 @@ export const AppContextProvider = ({
convId, convId,
role: 'user', role: 'user',
content, content,
extra,
parent: parentNodeId, parent: parentNodeId,
children: [], children: [],
}, },

View File

@ -0,0 +1,62 @@
import { useEffect, useState } from 'react';
import { MessageExtraContext } from './types';
// Extra context when using llama.cpp WebUI from llama-vscode, inside an iframe
// Ref: https://github.com/ggml-org/llama.cpp/pull/11940
interface SetTextEvData {
text: string;
context: string;
}
/**
* To test it:
* window.postMessage({ command: 'setText', text: 'Spot the syntax error', context: 'def test()\n return 123' }, '*');
*/
export const useVSCodeContext = (
inputRef: React.RefObject<HTMLTextAreaElement>,
setInputMsg: (text: string) => void
) => {
const [extraContext, setExtraContext] = useState<MessageExtraContext | null>(
null
);
// Accept setText message from a parent window and set inputMsg and extraContext
useEffect(() => {
const handleMessage = (event: MessageEvent) => {
if (event.data?.command === 'setText') {
const data: SetTextEvData = event.data;
setInputMsg(data?.text);
if (data?.context && data.context.length > 0) {
setExtraContext({
type: 'context',
content: data.context,
});
}
inputRef.current?.focus();
}
};
window.addEventListener('message', handleMessage);
return () => window.removeEventListener('message', handleMessage);
}, []);
// Add a keydown listener that sends the "escapePressed" message to the parent window
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === 'Escape') {
window.parent.postMessage({ command: 'escapePressed' }, '*');
}
};
window.addEventListener('keydown', handleKeyDown);
return () => window.removeEventListener('keydown', handleKeyDown);
}, []);
return {
extraContext,
// call once the user message is sent, to clear the extra context
clearExtraContext: () => setExtraContext(null),
};
};

View File

@ -53,12 +53,23 @@ export const copyStr = (textToCopy: string) => {
/** /**
* filter out redundant fields upon sending to API * filter out redundant fields upon sending to API
* also format extra into text
*/ */
export function normalizeMsgsForAPI(messages: Readonly<Message[]>) { export function normalizeMsgsForAPI(messages: Readonly<Message[]>) {
return messages.map((msg) => { return messages.map((msg) => {
let newContent = '';
for (const extra of msg.extra ?? []) {
if (extra.type === 'context') {
newContent += `${extra.content}\n\n`;
}
}
newContent += msg.content;
return { return {
role: msg.role, role: msg.role,
content: msg.content, content: newContent,
}; };
}) as APIMessage[]; }) as APIMessage[];
} }

View File

@ -42,11 +42,25 @@ export interface Message {
role: 'user' | 'assistant' | 'system'; role: 'user' | 'assistant' | 'system';
content: string; content: string;
timings?: TimingReport; timings?: TimingReport;
extra?: MessageExtra[];
// node based system for branching // node based system for branching
parent: Message['id']; parent: Message['id'];
children: Message['id'][]; children: Message['id'][];
} }
type MessageExtra = MessageExtraTextFile | MessageExtraContext; // TODO: will add more in the future
export interface MessageExtraTextFile {
type: 'textFile';
name: string;
content: string;
}
export interface MessageExtraContext {
type: 'context';
content: string;
}
export type APIMessage = Pick<Message, 'role' | 'content'>; export type APIMessage = Pick<Message, 'role' | 'content'>;
export interface Conversation { export interface Conversation {

View File

@ -5112,7 +5112,182 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
const int nb = n / QK_K; const int nb = n / QK_K;
#ifdef __ARM_NEON #if defined(__ARM_FEATURE_SVE)
uint32_t utmp[4];
const int8_t m32 = 32;
const int vector_length = svcntb()*8;
const svuint8_t m3b_sv = svdup_n_u8(0x3);
const svint32_t vzero_sv = svdup_n_s32(0);
const svuint8_t m0_sv = svdup_n_u8(1);
const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
svbool_t pred_s32 = svnot_b_z (svptrue_b32(), svptrue_pat_b32(SV_VL4));
float sum = 0;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const uint8_t * restrict q3_sv = x[i].qs;
const uint8_t * restrict qh_sv = x[i].hmask;
const int8_t * restrict q8_sv = y[i].qs;
// Set up scales
uint32_t * aux = &x[i].scales;
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
int8_t * scale = (int8_t *)utmp;
for (int j = 0; j < 16; ++j) scale[j] -= m32;
switch (vector_length) {
case 128:
{
svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);
svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);
svuint8_t q3h_sv;
svint32_t sumi1_1 = svdup_n_s32(0);
svint8_t q3bytes_sv;
for (int j = 0; j < QK_K/128; ++j) {
const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
scale += 4;
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);
q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
if (j == 0) {
qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);
qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);
}
scale += 4;
}
sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
} break;
case 256:
case 512:
{
svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);
svuint8_t q3h_sv;
svint32_t sumi1_1 = svdup_n_s32(0);
svint8_t q3bytes_sv;
for (int j = 0; j < QK_K/128; ++j) {
const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;
svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
scale += 4;
q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);
q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
if (j == 0) {
qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);
}
scale += 4;
}
sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
} break;
default:
assert(false && "Unsupported vector length");
break;
}
}
*s = sum;
#elif __ARM_NEON
uint32_t aux[3]; uint32_t aux[3];
uint32_t utmp[4]; uint32_t utmp[4];

View File

@ -21,7 +21,7 @@ def get_chat_template(model_id, variant=None):
# Use huggingface_hub library if available. # Use huggingface_hub library if available.
# Allows access to gated models if the user has access and ran `huggingface-cli login`. # Allows access to gated models if the user has access and ran `huggingface-cli login`.
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json"), encoding="utf-8") as f:
config_str = f.read() config_str = f.read()
except ImportError: except ImportError:
import requests import requests

View File

@ -345,194 +345,194 @@ const char * llama_grammar_parser::parse_sequence(
size_t last_sym_start = rule.size(); size_t last_sym_start = rule.size();
const char * pos = src; const char * pos = src;
auto handle_repetitions = [&](int min_times, int max_times) { auto handle_repetitions = [&](int min_times, int max_times) {
if (last_sym_start == rule.size()) { if (last_sym_start == rule.size()) {
throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
} }
// apply transformation to previous symbol (last_sym_start to end) according to // apply transformation to previous symbol (last_sym_start to end) according to
// the following rewrite rules: // the following rewrite rules:
// S{m,n} --> S S S (m times) S'(n-m) // S{m,n} --> S S S (m times) S'(n-m)
// S'(x) ::= S S'(x-1) | // S'(x) ::= S S'(x-1) |
// (... n-m definitions of these S' rules ...) // (... n-m definitions of these S' rules ...)
// S'(1) ::= S | // S'(1) ::= S |
// S{m,} --> S S S (m times) S' // S{m,} --> S S S (m times) S'
// S' ::= S S' | // S' ::= S S' |
// S* --> S{0,} // S* --> S{0,}
// --> S' ::= S S' | // --> S' ::= S S' |
// S+ --> S{1,} // S+ --> S{1,}
// --> S S' // --> S S'
// S' ::= S S' | // S' ::= S S' |
// S? --> S{0,1} // S? --> S{0,1}
// --> S' // --> S'
// S' ::= S | // S' ::= S |
llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end());
if (min_times == 0) { if (min_times == 0) {
rule.resize(last_sym_start); rule.resize(last_sym_start);
} else { } else {
// Repeat the previous elements (min_times - 1) times // Repeat the previous elements (min_times - 1) times
for (int i = 1; i < min_times; i++) { for (int i = 1; i < min_times; i++) {
rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
}
}
uint32_t last_rec_rule_id = 0;
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
llama_grammar_rule rec_rule(prev_rule);
for (int i = 0; i < n_opt; i++) {
rec_rule.resize(prev_rule.size());
uint32_t rec_rule_id = generate_symbol_id( rule_name);
if (i > 0 || max_times < 0) {
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
}
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule( rec_rule_id, rec_rule);
last_rec_rule_id = rec_rule_id;
}
if (n_opt > 0) {
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
}
};
while (*pos) {
if (*pos == '"') { // literal string
pos++;
last_sym_start = rule.size();
while (*pos != '"') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s)
pos++;
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
if (*pos == '^') {
pos++;
start_type = LLAMA_GRETYPE_CHAR_NOT;
}
last_sym_start = rule.size();
while (*pos != ']') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum llama_gretype type = last_sym_start < rule.size()
? LLAMA_GRETYPE_CHAR_ALT
: start_type;
rule.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
if (!pos[1]) {
throw std::runtime_error("unexpected end of input");
}
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
}
}
pos = parse_space(pos + 1, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(rule_name);
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
last_sym_start = rule.size();
// output reference to synthesized rule
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '.') { // any char
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1);
} else if (*pos == '+') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(1, -1);
} else if (*pos == '?') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, 1);
} else if (*pos == '{') {
pos = parse_space(pos + 1, is_nested);
if (!is_digit_char(*pos)) {
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
int min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
int max_times = -1;
if (*pos == '}') {
max_times = min_times;
pos = parse_space(pos + 1, is_nested);
} else if (*pos == ',') {
pos = parse_space(pos + 1, is_nested);
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}
if (*pos != '}') {
throw std::runtime_error(std::string("expecting '}' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else {
throw std::runtime_error(std::string("expecting ',' at ") + pos);
}
handle_repetitions(min_times, max_times);
} else {
break;
} }
} }
return pos;
uint32_t last_rec_rule_id = 0;
auto n_opt = max_times < 0 ? 1 : max_times - min_times;
llama_grammar_rule rec_rule(prev_rule);
for (int i = 0; i < n_opt; i++) {
rec_rule.resize(prev_rule.size());
uint32_t rec_rule_id = generate_symbol_id( rule_name);
if (i > 0 || max_times < 0) {
rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
}
rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
rec_rule.push_back({LLAMA_GRETYPE_END, 0});
add_rule( rec_rule_id, rec_rule);
last_rec_rule_id = rec_rule_id;
}
if (n_opt > 0) {
rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id});
}
};
while (*pos) {
if (*pos == '"') { // literal string
pos++;
last_sym_start = rule.size();
while (*pos != '"') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '[') { // char range(s)
pos++;
enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
if (*pos == '^') {
pos++;
start_type = LLAMA_GRETYPE_CHAR_NOT;
}
last_sym_start = rule.size();
while (*pos != ']') {
if (!*pos) {
throw std::runtime_error("unexpected end of input");
}
auto char_pair = parse_char(pos);
pos = char_pair.second;
enum llama_gretype type = last_sym_start < rule.size()
? LLAMA_GRETYPE_CHAR_ALT
: start_type;
rule.push_back({type, char_pair.first});
if (pos[0] == '-' && pos[1] != ']') {
if (!pos[1]) {
throw std::runtime_error("unexpected end of input");
}
auto endchar_pair = parse_char(pos + 1);
pos = endchar_pair.second;
rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
}
}
pos = parse_space(pos + 1, is_nested);
} else if (is_word_char(*pos)) { // rule reference
const char * name_end = parse_name(pos);
uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
pos = parse_space(name_end, is_nested);
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
} else if (*pos == '(') { // grouping
// parse nested alternates into synthesized rule
pos = parse_space(pos + 1, true);
uint32_t sub_rule_id = generate_symbol_id(rule_name);
pos = parse_alternates(pos, rule_name, sub_rule_id, true);
last_sym_start = rule.size();
// output reference to synthesized rule
rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
if (*pos != ')') {
throw std::runtime_error(std::string("expecting ')' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '.') { // any char
last_sym_start = rule.size();
rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0});
pos = parse_space(pos + 1, is_nested);
} else if (*pos == '*') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, -1);
} else if (*pos == '+') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(1, -1);
} else if (*pos == '?') {
pos = parse_space(pos + 1, is_nested);
handle_repetitions(0, 1);
} else if (*pos == '{') {
pos = parse_space(pos + 1, is_nested);
if (!is_digit_char(*pos)) {
throw std::runtime_error(std::string("expecting an int at ") + pos);
}
const char * int_end = parse_int(pos);
int min_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
int max_times = -1;
if (*pos == '}') {
max_times = min_times;
pos = parse_space(pos + 1, is_nested);
} else if (*pos == ',') {
pos = parse_space(pos + 1, is_nested);
if (is_digit_char(*pos)) {
const char * int_end = parse_int(pos);
max_times = std::stoul(std::string(pos, int_end - pos));
pos = parse_space(int_end, is_nested);
}
if (*pos != '}') {
throw std::runtime_error(std::string("expecting '}' at ") + pos);
}
pos = parse_space(pos + 1, is_nested);
} else {
throw std::runtime_error(std::string("expecting ',' at ") + pos);
}
handle_repetitions(min_times, max_times);
} else {
break;
}
} }
return pos;
}
const char * llama_grammar_parser::parse_rule(const char * src) { const char * llama_grammar_parser::parse_rule(const char * src) {
const char * name_end = parse_name(src); const char * name_end = parse_name(src);
const char * pos = parse_space(name_end, false); const char * pos = parse_space(name_end, false);
size_t name_len = name_end - src; size_t name_len = name_end - src;
uint32_t rule_id = get_symbol_id(src, name_len); uint32_t rule_id = get_symbol_id(src, name_len);
const std::string name(src, name_len); const std::string name(src, name_len);
if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
throw std::runtime_error(std::string("expecting ::= at ") + pos); throw std::runtime_error(std::string("expecting ::= at ") + pos);
}
pos = parse_space(pos + 3, true);
pos = parse_alternates(pos, name, rule_id, false);
if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
return parse_space(pos, true);
} }
pos = parse_space(pos + 3, true);
pos = parse_alternates(pos, name, rule_id, false);
if (*pos == '\r') {
pos += pos[1] == '\n' ? 2 : 1;
} else if (*pos == '\n') {
pos++;
} else if (*pos) {
throw std::runtime_error(std::string("expecting newline or end at ") + pos);
}
return parse_space(pos, true);
}
bool llama_grammar_parser::parse(const char * src) { bool llama_grammar_parser::parse(const char * src) {
try { try {

View File

@ -1,13 +1,14 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <regex>
#undef NDEBUG #undef NDEBUG
#include <cassert> #include <cassert>
#include "llama.h" #include "llama.h"
#include "common.h" #include "common.h"
#include "chat-template.hpp" #include "chat.h"
static std::string normalize_newlines(const std::string & s) { static std::string normalize_newlines(const std::string & s) {
#ifdef _WIN32 #ifdef _WIN32
@ -18,6 +19,13 @@ static std::string normalize_newlines(const std::string & s) {
#endif #endif
} }
static common_chat_msg simple_msg(const std::string & role, const std::string & content) {
common_chat_msg msg;
msg.role = role;
msg.content = content;
return msg;
}
int main(void) { int main(void) {
std::vector<llama_chat_message> conversation { std::vector<llama_chat_message> conversation {
{"system", "You are a helpful assistant"}, {"system", "You are a helpful assistant"},
@ -50,7 +58,7 @@ int main(void) {
/* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
/* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]", /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there</s>[INST] Who are you [/INST] I am an assistant </s>[INST] Another question [/INST]",
/* .expected_output_jinja= */ "", /* .expected_output_jinja= */ "",
/* .bos_token= */ "", /* .bos_token= */ "<s>",
/* .eos_token= */ "</s>", /* .eos_token= */ "</s>",
}, },
{ {
@ -72,8 +80,8 @@ int main(void) {
{ {
/* .name= */ "mlabonne/AlphaMonarch-7B", /* .name= */ "mlabonne/AlphaMonarch-7B",
/* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}",
/* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n", /* .expected_output= */ "system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n",
/* .expected_output_jinja= */ "<s>system\nYou are a helpful assistant</s>\n<s>user\nHello</s>\n<s>assistant\nHi there</s>\n<s>user\nWho are you</s>\n<s>assistant\n I am an assistant </s>\n<s>user\nAnother question</s>\n<s>assistant\n", /* .expected_output_jinja= */ "",
/* .bos_token= */ "<s>", /* .bos_token= */ "<s>",
/* .eos_token= */ "</s>", /* .eos_token= */ "</s>",
}, },
@ -87,7 +95,7 @@ int main(void) {
/* .name= */ "OrionStarAI/Orion-14B-Chat", /* .name= */ "OrionStarAI/Orion-14B-Chat",
/* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}",
/* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>", /* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>",
/* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: </s>", /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: </s>Hi there</s>Human: Who are you\n\nAssistant: </s> I am an assistant </s>Human: Another question\n\nAssistant: ",
/* .bos_token= */ "", /* .bos_token= */ "",
/* .eos_token= */ "</s>", /* .eos_token= */ "</s>",
}, },
@ -304,12 +312,9 @@ int main(void) {
} }
} }
json messages = json::array(); std::vector<common_chat_msg> messages;
for (const auto & msg : conversation) { for (const auto & msg : conversation) {
messages.push_back({ messages.push_back(simple_msg(msg.role, msg.content));
{"role", msg.role},
{"content", msg.content},
});
} }
for (const auto & test_case : test_cases) { for (const auto & test_case : test_cases) {
if (!test_case.supported_with_jinja) { if (!test_case.supported_with_jinja) {
@ -317,8 +322,13 @@ int main(void) {
} }
printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str()); printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str());
try { try {
minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token); auto tmpls = common_chat_templates_init(/* model= */ nullptr, test_case.template_str.c_str(), test_case.bos_token, test_case.eos_token);
auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt)); common_chat_templates_inputs inputs;
inputs.use_jinja = true;
inputs.messages = messages;
inputs.add_generation_prompt = add_generation_prompt;
auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt;
output = normalize_newlines(output);
auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja); auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja);
if (output != expected_output) { if (output != expected_output) {
printf("Expected:\n%s\n", expected_output.c_str()); printf("Expected:\n%s\n", expected_output.c_str());
@ -336,11 +346,11 @@ int main(void) {
// test llama_chat_format_single for system message // test llama_chat_format_single for system message
printf("\n\n=== llama_chat_format_single (system message) ===\n\n"); printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
std::vector<common_chat_msg> chat2; std::vector<common_chat_msg> chat2;
common_chat_msg sys_msg{"system", "You are a helpful assistant", {}}; auto sys_msg = simple_msg("system", "You are a helpful assistant");
auto fmt_sys = [&](std::string tmpl_str) { auto fmt_sys = [&](std::string tmpl_str) {
minja::chat_template tmpl(tmpl_str, "", ""); auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str);
auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false);
printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n"); printf("-------------------------\n");
return output; return output;
@ -360,14 +370,14 @@ int main(void) {
// test llama_chat_format_single for user message // test llama_chat_format_single for user message
printf("\n\n=== llama_chat_format_single (user message) ===\n\n"); printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
chat2.push_back({"system", "You are a helpful assistant", {}}); chat2.push_back(simple_msg("system", "You are a helpful assistant"));
chat2.push_back({"user", "Hello", {}}); chat2.push_back(simple_msg("user", "Hello"));
chat2.push_back({"assistant", "I am assistant", {}}); chat2.push_back(simple_msg("assistant", "I am assistant"));
common_chat_msg new_msg{"user", "How are you", {}}; auto new_msg = simple_msg("user", "How are you");
auto fmt_single = [&](std::string tmpl_str) { auto fmt_single = [&](const std::string & tmpl_str) {
minja::chat_template tmpl(tmpl_str, "", ""); auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str());
auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false);
printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str());
printf("-------------------------\n"); printf("-------------------------\n");
return output; return output;

File diff suppressed because it is too large Load Diff