tool-call: refactor common chat / tool-call api (+ tests / fixes) (#11900)

* tool-call refactoring: moved common_chat_* to chat.h, common_chat_templates_init return a unique_ptr to opaque type

* addressed clang-tidy lints in [test-]chat.*

* rm minja deps from util & common & move it to common/minja/

* add name & tool_call_id to common_chat_msg

* add common_chat_tool

* added json <-> tools, msgs conversions to chat.h

* fix double bos/eos jinja avoidance hack (was preventing inner bos/eos tokens)

* fix deepseek r1 slow test (no longer <think> opening w/ new template)

* allow empty tools w/ auto + grammar

* fix & test server grammar & json_schema params w/ & w/o --jinja
This commit is contained in:
Olivier Chafik
2025-02-18 18:03:23 +00:00
committed by GitHub
parent 63ac128563
commit 63e489c025
18 changed files with 1385 additions and 993 deletions

View File

@ -24,7 +24,7 @@
#include <string>
#include <vector>
#include "chat-template.hpp"
#include "chat.h"
#include "common.h"
#include "json.hpp"
#include "linenoise.cpp/linenoise.h"
@ -557,7 +557,7 @@ class LlamaData {
llama_model_ptr model;
llama_sampler_ptr sampler;
llama_context_ptr context;
std::vector<llama_chat_message> messages;
std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
std::list<std::string> msg_strs;
std::vector<char> fmtted;
@ -834,44 +834,23 @@ static void add_message(const char * role, const std::string & text, LlamaData &
}
// Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
if (use_jinja) {
json messages = json::array();
for (const auto & msg : llama_data.messages) {
messages.push_back({
{"role", msg.role},
{"content", msg.content},
});
}
try {
minja::chat_template_inputs tmpl_inputs;
tmpl_inputs.messages = messages;
tmpl_inputs.add_generation_prompt = append;
minja::chat_template_options tmpl_opts;
tmpl_opts.use_bos_token = false;
tmpl_opts.use_eos_token = false;
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
} catch (const std::exception & e) {
printe("failed to render the chat template: %s\n", e.what());
return -1;
}
}
int result = llama_chat_apply_template(
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result);
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size());
static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) {
common_chat_templates_inputs inputs;
for (const auto & msg : llama_data.messages) {
common_chat_msg cmsg;
cmsg.role = msg.role;
cmsg.content = msg.content;
inputs.messages.push_back(cmsg);
}
inputs.add_generation_prompt = append;
inputs.use_jinja = use_jinja;
return result;
auto chat_params = common_chat_templates_apply(tmpls, inputs);
// TODO: use other params for tool calls.
auto result = chat_params.prompt;
llama_data.fmtted.resize(result.size() + 1);
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
return result.size();
}
// Function to tokenize the prompt
@ -1015,8 +994,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
}
// Helper function to apply the chat template and handle errors
static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja);
if (new_len < 0) {
printe("failed to apply the chat template\n");
return -1;
@ -1078,8 +1057,7 @@ static int get_user_input(std::string & user_input, const std::string & user) {
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
int prev_len = 0;
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
GGML_ASSERT(chat_templates.template_default);
auto chat_templates = common_chat_templates_init(llama_data.model.get(), "");
static const bool stdout_a_terminal = is_stdout_a_terminal();
while (true) {
// Get user input
@ -1090,7 +1068,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
add_message("user", user.empty() ? user_input : user, llama_data);
int new_len;
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
return 1;
}
@ -1105,7 +1083,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
}
add_message("assistant", response, llama_data);
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
return 1;
}
}