tool-call: refactor common chat / tool-call api (+ tests / fixes) (#11900)

* tool-call refactoring: moved common_chat_* to chat.h, common_chat_templates_init return a unique_ptr to opaque type

* addressed clang-tidy lints in [test-]chat.*

* rm minja deps from util & common & move it to common/minja/

* add name & tool_call_id to common_chat_msg

* add common_chat_tool

* added json <-> tools, msgs conversions to chat.h

* fix double bos/eos jinja avoidance hack (was preventing inner bos/eos tokens)

* fix deepseek r1 slow test (no longer <think> opening w/ new template)

* allow empty tools w/ auto + grammar

* fix & test server grammar & json_schema params w/ & w/o --jinja
This commit is contained in:
Olivier Chafik
2025-02-18 18:03:23 +00:00
committed by GitHub
parent 63ac128563
commit 63e489c025
18 changed files with 1385 additions and 993 deletions

View File

@ -329,9 +329,6 @@ struct server_task {
}
// process "json_schema" and "grammar"
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
}
if (data.contains("json_schema") && !data.contains("grammar")) {
try {
auto schema = json_value(data, "json_schema", json::object());
@ -1807,7 +1804,7 @@ struct server_context {
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
common_chat_templates chat_templates;
common_chat_templates_ptr chat_templates;
~server_context() {
// Clear any sampling context
@ -1891,45 +1888,17 @@ struct server_context {
llama_init_dft.context.reset();
}
if (params_base.chat_template.empty() && !validate_builtin_chat_template(params.use_jinja)) {
chat_templates = common_chat_templates_init(model, params_base.chat_template);
try {
common_chat_format_example(chat_templates.get(), params.use_jinja);
} catch (const std::exception & e) {
SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
chat_templates = common_chat_templates_from_model(model, "chatml");
} else {
chat_templates = common_chat_templates_from_model(model, params_base.chat_template);
chat_templates = common_chat_templates_init(model, "chatml");
}
GGML_ASSERT(chat_templates.template_default.get() != nullptr);
return true;
}
bool validate_builtin_chat_template(bool use_jinja) const {
llama_chat_message chat[] = {{"user", "test"}};
if (use_jinja) {
auto templates = common_chat_templates_from_model(model, "");
common_chat_inputs inputs;
inputs.messages = json::array({{
{"role", "user"},
{"content", "test"},
}});
GGML_ASSERT(templates.template_default);
try {
common_chat_params_init(*templates.template_default, inputs);
if (templates.template_tool_use) {
common_chat_params_init(*templates.template_tool_use, inputs);
}
return true;
} catch (const std::exception & e) {
SRV_ERR("failed to apply template: %s\n", e.what());
return false;
}
} else {
const char * tmpl = llama_model_chat_template(model, /* name */ nullptr);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0;
}
}
void init() {
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
@ -3822,13 +3791,15 @@ int main(int argc, char ** argv) {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model },
{ "chat_template", ctx_server.chat_templates.template_default->source() },
{ "bos_token", ctx_server.chat_templates.template_default->bos_token() },
{ "eos_token", ctx_server.chat_templates.template_default->eos_token() },
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
{ "build_info", build_info },
};
if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) {
data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source();
if (ctx_server.params_base.use_jinja) {
if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
data["chat_template_tool_use"] = tool_use_src;
}
}
res_ok(res, data);
@ -4063,7 +4034,7 @@ int main(int argc, char ** argv) {
}
auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
@ -4076,7 +4047,7 @@ int main(int argc, char ** argv) {
// same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
};
@ -4493,8 +4464,8 @@ int main(int argc, char ** argv) {
// print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
ctx_server.chat_templates.template_default->source().c_str(),
common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str());
common_chat_templates_source(ctx_server.chat_templates.get()),
common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
ctx_server.process_single_task(task);