server: streaming of tool calls and thoughts when --jinja is on (#12379)

* add common_json w/ support for truncated json healing

* add common_chat_msg_diff

* partial common_chat_parse

* refactor parser w/ optionals

* server: wire chat diffs in stream mode

* fix trigger of thinking models (must happen after thoughts are closed)

* fix functionary v3.2 raw python!

* rename: common_chat_syntax (now contains format)

* rm common_regex.at_start

* don't return empty <think></think>

* accommodate yet another deepseek r1 distill fantasy syntax (`<|tool▁calls|>`)

* fix QwQ 32B tool call parsing after thoughts (hermes2)

* better logs for grammar triggers

* consume spaces after parse_json_tool_calls

* fix required tool calls w/ thinking models that have pre-opened thinking tags

* fix thinking model's initial trigger + test qwq's template

* run most test_tool_call tests in stream + non-stream modes

* make functionary v3.2 parsing more strict (differentiate first match from others)

* send final diff from server, to close off raw python arguments

* support partial content streaming in Generic mode

* tool-call: allow content prelude before hermes2 tool calls (for Qwen2.5)

* Update function-calling.md

* Update tool_bench.py

* chat-parser: remove input from exception (llm output may contain PII)

---------

Co-authored-by: ochafik <ochafik@google.com>
Co-authored-by: Olivier Chafik <ochafik@users.noreply.github.com>
This commit is contained in:
Olivier Chafik
2025-05-25 01:48:08 +01:00
committed by GitHub
parent a2d02d5793
commit f5cd27b71d
23 changed files with 3245 additions and 1091 deletions

View File

@ -1,3 +1,4 @@
#include "chat.h"
#include "utils.hpp"
#include "arg.h"
@ -114,11 +115,11 @@ struct slot_params {
struct common_params_speculative speculative;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
json to_json() const {
std::vector<std::string> samplers;
@ -176,7 +177,10 @@ struct slot_params {
{"grammar_lazy", sampling.grammar_lazy},
{"grammar_triggers", grammar_triggers},
{"preserved_tokens", sampling.preserved_tokens},
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
{"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
{"reasoning_format", (oaicompat_chat_syntax.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "deepseek" : "none")},
{"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
{"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
{"samplers", samplers},
{"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min},
@ -352,11 +356,14 @@ struct server_task {
{
auto it = data.find("chat_format");
if (it != data.end()) {
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format).c_str());
} else {
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
}
params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
params.oaicompat_chat_syntax.reasoning_in_content = params.stream;
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
}
{
@ -396,7 +403,14 @@ struct server_task {
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
}
} else {
params.sampling.grammar_triggers.push_back(std::move(ct.value));
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
} else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
} else {
throw std::runtime_error("Unknown grammar trigger type");
}
params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
}
}
}
@ -639,11 +653,12 @@ struct server_task_result_cmpl_final : server_task_result {
slot_params generation_params;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
virtual int get_index() override {
return index;
@ -738,47 +753,20 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat_chat() {
std::string finish_reason = "length";
common_chat_msg msg;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
SRV_DBG("Parsing chat message: %s\n", content.c_str());
msg = common_chat_parse(content, oaicompat_chat_format);
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
if (!oaicompat_msg.empty()) {
msg = oaicompat_msg;
} else {
msg.role = "assistant";
msg.content = content;
}
json message {
{"role", "assistant"},
};
if (!msg.reasoning_content.empty()) {
message["reasoning_content"] = msg.reasoning_content;
}
if (msg.content.empty() && !msg.tool_calls.empty()) {
message["content"] = json();
} else {
message["content"] = msg.content;
}
if (!msg.tool_calls.empty()) {
auto tool_calls = json::array();
for (const auto & tc : msg.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
{"name", tc.name},
{"arguments", tc.arguments},
}},
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
// We only generate a random id for the ones that don't generate one by themselves
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
{"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
});
}
message["tool_calls"] = tool_calls;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
}
json choice {
{"finish_reason", finish_reason},
{"index", 0},
{"message", message},
{"message", msg.to_json_oaicompat<json>()},
};
if (!stream && probs_output.size() > 0) {
@ -818,17 +806,35 @@ struct server_task_result_cmpl_final : server_task_result {
std::time_t t = std::time(0);
std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
}
json choice = json {
{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}
};
json deltas = json::array();
for (const auto & diff : oaicompat_msg_diffs) {
deltas.push_back({
{"choices", json::array({
json {
{"finish_reason", nullptr},
{"index", 0},
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
},
})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
});
}
json ret = json {
{"choices", json::array({choice})},
deltas.push_back({
{"choices", json::array({
json {
{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()},
},
})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
@ -839,18 +845,18 @@ struct server_task_result_cmpl_final : server_task_result {
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}},
};
});
if (timings.prompt_n >= 0) {
ret.push_back({"timings", timings.to_json()});
deltas.back().push_back({"timings", timings.to_json()});
}
// extra fields for debugging purposes
if (verbose) {
ret["__verbose"] = to_json_non_oaicompat();
if (verbose && !deltas.empty()) {
deltas.front()["__verbose"] = to_json_non_oaicompat();
}
return ret;
return deltas;
}
};
@ -868,10 +874,11 @@ struct server_task_result_cmpl_partial : server_task_result {
result_timings timings;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
virtual int get_index() override {
return index;
@ -955,84 +962,50 @@ struct server_task_result_cmpl_partial : server_task_result {
std::time_t t = std::time(0);
json choices;
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
// initial_ret is the role message for stream=True
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"},
{"content", ""}
}}}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}};
json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json {
{"content", content}}}
}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}};
if (prob_output.probs.size() > 0) {
second_ret["choices"][0]["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
if (timings.prompt_n >= 0) {
second_ret.push_back({"timings", timings.to_json()});
}
return std::vector<json>({initial_ret, second_ret});
}
} else {
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json {
{"content", content},
}},
}});
}
GGML_ASSERT(choices.size() >= 1);
if (prob_output.probs.size() > 0) {
choices[0]["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
json ret = json {
{"choices", choices},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}
std::vector<json> deltas;
auto add_delta = [&](const json & delta) {
deltas.push_back({
{"choices", json::array({
json {
{"finish_reason", nullptr},
{"index", 0},
{"delta", delta},
},
})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
});
};
if (timings.prompt_n >= 0) {
ret.push_back({"timings", timings.to_json()});
// We have to send an initial update to conform to openai behavior
if (first) {
add_delta({
{"role", "assistant"},
{"content", nullptr},
});
}
return std::vector<json>({ret});
for (const auto & diff : oaicompat_msg_diffs) {
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
}
if (!deltas.empty()) {
GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
if (prob_output.probs.size() > 0) {
deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json {
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
if (timings.prompt_n >= 0) {
deltas[deltas.size() - 1].push_back({"timings", timings.to_json()});
}
}
return deltas;
}
};
@ -1293,6 +1266,7 @@ struct server_slot {
std::string generated_text;
llama_tokens generated_tokens;
common_chat_msg chat_msg;
server_tokens cache_tokens;
@ -1313,6 +1287,7 @@ struct server_slot {
llama_token sampled;
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::vector<std::string> generated_tool_call_ids;
// stats
size_t n_sent_text = 0; // number of sent text character
@ -1342,9 +1317,13 @@ struct server_slot {
n_past = 0;
n_sent_text = 0;
task_type = SERVER_TASK_TYPE_COMPLETION;
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
generated_tokens.clear();
generated_token_probs.clear();
chat_msg = {};
json_schema = json();
generated_tool_call_ids.clear();
// clear speculative decoding stats
n_draft_total = 0;
@ -1424,6 +1403,21 @@ struct server_slot {
return timings;
}
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
auto previous_msg = chat_msg;
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
auto new_msg = common_chat_parse(
generated_text,
/* is_partial= */ stop != STOP_TYPE_EOS,
params.oaicompat_chat_syntax);
if (!new_msg.empty()) {
new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
chat_msg = new_msg;
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
}
return chat_msg;
}
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
size_t stop_pos = std::string::npos;
@ -2475,10 +2469,12 @@ struct server_context {
res->n_prompt_tokens = slot.n_prompt_tokens;
res->post_sampling_probs = slot.params.post_sampling_probs;
res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
@ -2499,7 +2495,7 @@ struct server_context {
res->id_slot = slot.id;
res->index = slot.index;
res->content = std::move(slot.generated_text);
res->content = slot.generated_text;
res->tokens = std::move(slot.generated_tokens);
res->timings = slot.get_timings();
res->prompt = slot.prompt_tokens.detokenize(ctx, true);
@ -2519,7 +2515,8 @@ struct server_context {
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output
if (slot.params.sampling.n_probs > 0) {
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {