server: streaming of tool calls and thoughts when --jinja is on (#12379)

* add common_json w/ support for truncated json healing

* add common_chat_msg_diff

* partial common_chat_parse

* refactor parser w/ optionals

* server: wire chat diffs in stream mode

* fix trigger of thinking models (must happen after thoughts are closed)

* fix functionary v3.2 raw python!

* rename: common_chat_syntax (now contains format)

* rm common_regex.at_start

* don't return empty <think></think>

* accommodate yet another deepseek r1 distill fantasy syntax (`<|tool▁calls|>`)

* fix QwQ 32B tool call parsing after thoughts (hermes2)

* better logs for grammar triggers

* consume spaces after parse_json_tool_calls

* fix required tool calls w/ thinking models that have pre-opened thinking tags

* fix thinking model's initial trigger + test qwq's template

* run most test_tool_call tests in stream + non-stream modes

* make functionary v3.2 parsing more strict (differentiate first match from others)

* send final diff from server, to close off raw python arguments

* support partial content streaming in Generic mode

* tool-call: allow content prelude before hermes2 tool calls (for Qwen2.5)

* Update function-calling.md

* Update tool_bench.py

* chat-parser: remove input from exception (llm output may contain PII)

---------

Co-authored-by: ochafik <ochafik@google.com>
Co-authored-by: Olivier Chafik <ochafik@users.noreply.github.com>
This commit is contained in:
Olivier Chafik
2025-05-25 01:48:08 +01:00
committed by GitHub
parent a2d02d5793
commit f5cd27b71d
23 changed files with 3245 additions and 1091 deletions

View File

@ -60,12 +60,16 @@ add_library(${TARGET} STATIC
base64.hpp base64.hpp
chat.cpp chat.cpp
chat.h chat.h
chat-parser.cpp
chat-parser.h
common.cpp common.cpp
common.h common.h
console.cpp console.cpp
console.h console.h
json-schema-to-grammar.cpp json-schema-to-grammar.cpp
json.hpp json.hpp
json-partial.h
json-partial.cpp
llguidance.cpp llguidance.cpp
log.cpp log.cpp
log.h log.h

376
common/chat-parser.cpp Normal file
View File

@ -0,0 +1,376 @@
#include "chat-parser.h"
#include "common.h"
#include "log.h"
#include "regex-partial.h"
#include <optional>
#include <stdexcept>
#include <string>
#include <vector>
using json = nlohmann::ordered_json;
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
: input_(input), is_partial_(is_partial), syntax_(syntax)
{
result_.role = "assistant";
while (true) {
std::string id = std::to_string(std::rand());
if (input.find(id) == std::string::npos) {
healing_marker_ = id;
break;
}
}
}
std::string common_chat_msg_parser::str(const common_string_range & rng) const {
GGML_ASSERT(rng.begin <= rng.end);
return input_.substr(rng.begin, rng.end - rng.begin);
}
void common_chat_msg_parser::add_content(const std::string &content) {
result_.content += content;
}
void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) {
result_.reasoning_content += reasoning_content;
}
bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) {
if (name.empty()) {
return false;
}
common_chat_tool_call tool_call;
tool_call.name = name;
tool_call.arguments = arguments;
tool_call.id = id;
// LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
result_.tool_calls.emplace_back(tool_call);
return true;
}
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
return add_tool_call(name, id, arguments);
}
bool common_chat_msg_parser::add_tool_calls(const json & arr) {
for (const auto & item : arr) {
if (!add_tool_call(item)) {
return false;
}
}
return true;
}
void common_chat_msg_parser::finish() {
if (!is_partial_ && pos_ != input_.size()) {
throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
}
}
bool common_chat_msg_parser::consume_spaces() {
const auto length = input_.size();
auto consumed = false;
while (pos_ < length && std::isspace(input_[pos_])) {
++pos_;
consumed = true;
}
return consumed;
}
bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
auto pos = pos_;
for (auto i = 0u; i < literal.size(); ++i) {
if (pos >= input_.size()) {
return false;
}
if (input_[pos] != literal[i]) {
return false;
}
++pos;
}
pos_ = pos;
return true;
}
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
auto idx = input_.find(literal, pos_);
if (idx != std::string::npos) {
find_regex_result res;
res.prelude = input_.substr(pos_, idx - pos_);
auto end = idx + literal.size();
res.groups.emplace_back(common_string_range{idx, end});
move_to(end);
return res;
}
if (is_partial_) {
idx = string_find_partial_stop(input_, literal);
if (idx != std::string::npos && idx >= pos_) {
find_regex_result res;
res.prelude = input_.substr(pos_, idx - pos_);
auto end = input_.size();
res.groups.emplace_back(common_string_range{idx, end});
move_to(end);
return res;
}
}
return std::nullopt;
}
void common_chat_msg_parser::consume_literal(const std::string & literal) {
if (!try_consume_literal(literal)) {
throw common_chat_msg_partial_exception(literal);
}
}
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
auto stripped_reasoning = string_strip(reasoning);
if (stripped_reasoning.empty()) {
return;
}
if (syntax_.reasoning_in_content) {
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "<think>" : start_think);
add_content(stripped_reasoning);
if (closed) {
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
}
} else {
add_reasoning_content(stripped_reasoning);
}
};
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
if (auto res = try_find_literal(end_think)) {
handle_reasoning(res->prelude, /* closed */ true);
consume_spaces();
return true;
}
auto rest = consume_rest();
if (!rest.empty()) {
handle_reasoning(rest, /* closed */ !is_partial());
}
if (!syntax_.thinking_forced_open) {
throw common_chat_msg_partial_exception(end_think);
}
return true;
}
}
return false;
}
std::string common_chat_msg_parser::consume_rest() {
auto rest = input_.substr(pos_);
pos_ = input_.size();
return rest;
}
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
return std::nullopt;
}
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
if (is_partial()) {
throw common_chat_msg_partial_exception(regex.str());
}
return std::nullopt;
}
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
pos_ = m.groups[0].end;
return find_regex_result{prelude, m.groups};
}
common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
if (auto result = try_consume_regex(regex)) {
return *result;
}
throw common_chat_msg_partial_exception(regex.str());
}
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
auto m = regex.search(input_, pos_);
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
return std::nullopt;
}
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
if (is_partial()) {
throw common_chat_msg_partial_exception(regex.str());
}
return std::nullopt;
}
if (m.groups[0].begin != pos_) {
// Didn't match at the current position.
return std::nullopt;
}
pos_ = m.groups[0].end;
return find_regex_result {
/* .prelude = */ "",
m.groups,
};
}
std::optional<common_json> common_chat_msg_parser::try_consume_json() {
auto it = input_.cbegin() + pos_;
const auto end = input_.cend();
common_json result;
if (!common_json_parse(it, end, healing_marker_, result)) {
return std::nullopt;
}
pos_ = std::distance(input_.cbegin(), it);
if (result.healing_marker.marker.empty()) {
// No healing marker, just return the parsed json
return result;
}
if (!is_partial()) {
throw common_chat_msg_partial_exception("JSON");
}
return result;
}
common_json common_chat_msg_parser::consume_json() {
if (auto result = try_consume_json()) {
return *result;
}
throw common_chat_msg_partial_exception("JSON");
}
common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths,
const std::vector<std::vector<std::string>> & content_paths
) {
if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) {
return *result;
}
throw common_chat_msg_partial_exception("JSON");
}
std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parser::try_consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths,
const std::vector<std::vector<std::string>> & content_paths
) {
auto partial = try_consume_json();
if (!partial) {
return std::nullopt;
}
auto is_arguments_path = [&](const std::vector<std::string> & path) {
return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end();
};
auto is_content_path = [&](const std::vector<std::string> & path) {
return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end();
};
if (partial->healing_marker.marker.empty()) {
if (args_paths.empty()) {
// No arguments to dump, and JSON was parsed fully.
return consume_json_result {
partial->json,
/* .is_partial = */ false,
};
}
if (is_arguments_path({})) {
// Entire JSON is the arguments and was parsed fully.
return consume_json_result {
partial->json.dump(),
/* .is_partial = */ false,
};
}
}
LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
auto found_healing_marker = false;
std::vector<std::string> path;
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
if (is_arguments_path(path)) {
auto arguments = j.dump();
if (is_partial() && !partial->healing_marker.marker.empty()) {
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
if (idx != std::string::npos) {
arguments.resize(idx);
found_healing_marker = true;
}
if (arguments == "\"") {
// This happens because of completing `:"$magic` after `"arguments"`
arguments = "";
}
}
return arguments;
}
if (is_content_path(path)) {
if (!j.is_string()) {
throw std::runtime_error("Content path must be a string");
}
std::string str = j;
auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string
if (idx != std::string::npos) {
str.resize(idx);
found_healing_marker = true;
}
return str;
}
if (j.is_object()) {
auto obj = json::object();
for (const auto & p : j.items()) {
const auto & key = p.key();
const auto & value = p.value();
const std::string key_str = key; // NOLINT
auto idx = key_str.find(healing_marker_);
if (idx != std::string::npos) {
found_healing_marker = true;
break;
}
path.push_back(key_str);
if (value.is_string()) {
const std::string value_str = value;
if (value_str.find(healing_marker_) != std::string::npos) {
found_healing_marker = true;
if (is_content_path(path)) {
if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) {
// The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair.
obj[key] = remove_unsupported_healings_and_dump_args(value);
}
}
break;
}
obj[key] = value;
} else {
obj[key] = remove_unsupported_healings_and_dump_args(value);
}
path.pop_back();
}
return obj;
}
if (j.is_array()) {
auto arr = json::array();
for (const auto & value : j) {
if (value.is_string()) {
std::string str = value;
auto idx = str.find(healing_marker_);
if (idx != std::string::npos) {
// Don't heal array values that aren't in the arguments.
found_healing_marker = true;
break;
}
}
arr.push_back(remove_unsupported_healings_and_dump_args(value));
}
return arr;
}
return j;
};
auto cleaned = remove_unsupported_healings_and_dump_args(partial->json);
LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
return consume_json_result {
cleaned,
/* .is_partial = */ found_healing_marker,
};
}

116
common/chat-parser.h Normal file
View File

@ -0,0 +1,116 @@
#pragma once
#include "chat.h"
#include "json-partial.h"
#include "json.hpp"
#include "regex-partial.h"
#include <optional>
#include <string>
#include <vector>
class common_chat_msg_partial_exception : public std::runtime_error {
public:
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
};
class common_chat_msg_parser {
std::string input_;
bool is_partial_;
common_chat_syntax syntax_;
std::string healing_marker_;
size_t pos_ = 0;
common_chat_msg result_;
public:
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
const std::string & input() const { return input_; }
size_t pos() const { return pos_; }
const std::string & healing_marker() const { return healing_marker_; }
const bool & is_partial() const { return is_partial_; }
const common_chat_msg & result() const { return result_; }
void move_to(size_t pos) {
if (pos > input_.size()) {
throw std::runtime_error("Invalid position!");
}
pos_ = pos;
}
void move_back(size_t n) {
if (pos_ < n) {
throw std::runtime_error("Can't move back that far!");
}
pos_ -= n;
}
// Get the substring of the input at the given range
std::string str(const common_string_range & rng) const;
// Appends to the result.content field
void add_content(const std::string & content);
// Appends to the result.reasoning_content field
void add_reasoning_content(const std::string & reasoning_content);
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
bool add_tool_call(const nlohmann::ordered_json & tool_call);
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
bool add_tool_calls(const nlohmann::ordered_json & arr);
void finish();
bool consume_spaces();
void consume_literal(const std::string & literal);
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
std::string consume_rest();
struct find_regex_result {
std::string prelude;
std::vector<common_string_range> groups;
};
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
bool try_consume_literal(const std::string & literal);
std::optional<find_regex_result> try_find_literal(const std::string & literal);
find_regex_result consume_regex(const common_regex & regex);
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
std::optional<common_json> try_consume_json();
common_json consume_json();
struct consume_json_result {
nlohmann::ordered_json value;
bool is_partial;
};
/*
Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
- with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
- with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
*/
consume_json_result consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths = {},
const std::vector<std::vector<std::string>> & content_paths = {}
);
std::optional<consume_json_result> try_consume_json_with_dumped_args(
const std::vector<std::vector<std::string>> & args_paths = {},
const std::vector<std::vector<std::string>> & content_paths = {}
);
};

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@
#pragma once #pragma once
#include "common.h" #include "common.h"
#include <functional>
#include <chrono> #include <chrono>
#include <string> #include <string>
#include <vector> #include <vector>
@ -13,11 +14,19 @@ struct common_chat_tool_call {
std::string name; std::string name;
std::string arguments; std::string arguments;
std::string id; std::string id;
bool operator==(const common_chat_tool_call & other) const {
return name == other.name && arguments == other.arguments && id == other.id;
}
}; };
struct common_chat_msg_content_part { struct common_chat_msg_content_part {
std::string type; std::string type;
std::string text; std::string text;
bool operator==(const common_chat_msg_content_part & other) const {
return type == other.type && text == other.text;
}
}; };
struct common_chat_msg { struct common_chat_msg {
@ -28,6 +37,51 @@ struct common_chat_msg {
std::string reasoning_content; std::string reasoning_content;
std::string tool_name; std::string tool_name;
std::string tool_call_id; std::string tool_call_id;
template <class T> T to_json_oaicompat() const;
bool empty() const {
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
}
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
for (auto i = 0u; i < tool_calls.size(); i++) {
if (ids_cache.size() <= i) {
auto id = tool_calls[i].id;
if (id.empty()) {
id = gen_tool_call_id();
}
ids_cache.push_back(id);
}
tool_calls[i].id = ids_cache[i];
}
}
bool operator==(const common_chat_msg & other) const {
return role == other.role
&& content == other.content
&& content_parts == other.content_parts
&& tool_calls == other.tool_calls
&& reasoning_content == other.reasoning_content
&& tool_name == other.tool_name
&& tool_call_id == other.tool_call_id;
}
bool operator!=(const common_chat_msg & other) const {
return !(*this == other);
}
};
struct common_chat_msg_diff {
// std::string reasoning_content_delta;
std::string content_delta;
size_t tool_call_index = std::string::npos;
common_chat_tool_call tool_call_delta;
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
bool operator==(const common_chat_msg_diff & other) const {
return content_delta == other.content_delta
&& tool_call_index == other.tool_call_index
&& tool_call_delta == other.tool_call_delta;
}
}; };
struct common_chat_tool { struct common_chat_tool {
@ -49,14 +103,11 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_LLAMA_3_X, COMMON_CHAT_FORMAT_LLAMA_3_X,
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
COMMON_CHAT_FORMAT_DEEPSEEK_R1, COMMON_CHAT_FORMAT_DEEPSEEK_R1,
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_FIREFUNCTION_V2, COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO, COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COMMAND_R7B, COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
}; };
@ -71,7 +122,7 @@ struct common_chat_templates_inputs {
std::vector<common_chat_tool> tools; std::vector<common_chat_tool> tools;
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
bool parallel_tool_calls = false; bool parallel_tool_calls = false;
bool extract_reasoning = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
}; };
@ -80,11 +131,20 @@ struct common_chat_params {
std::string prompt; std::string prompt;
std::string grammar; std::string grammar;
bool grammar_lazy = false; bool grammar_lazy = false;
bool thinking_forced_open = false;
std::vector<common_grammar_trigger> grammar_triggers; std::vector<common_grammar_trigger> grammar_triggers;
std::vector<std::string> preserved_tokens; std::vector<std::string> preserved_tokens;
std::vector<std::string> additional_stops; std::vector<std::string> additional_stops;
}; };
struct common_chat_syntax {
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
bool reasoning_in_content = false;
bool thinking_forced_open = false;
};
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
@ -122,7 +182,7 @@ std::string common_chat_format_example(
bool use_jinja); bool use_jinja);
std::string common_chat_format_name(common_chat_format format); std::string common_chat_format_name(common_chat_format format);
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format); common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
@ -135,3 +195,5 @@ template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common
// T can be std::string containing JSON or nlohmann::ordered_json // T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools); template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools); template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);

View File

@ -115,7 +115,7 @@ enum common_grammar_trigger_type {
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN, COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
COMMON_GRAMMAR_TRIGGER_TYPE_WORD, COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
}; };
struct common_grammar_trigger { struct common_grammar_trigger {

255
common/json-partial.cpp Normal file
View File

@ -0,0 +1,255 @@
#include <json-partial.h>
#include "ggml.h"
#include "log.h"
#include <string>
#include <json.hpp>
using json = nlohmann::ordered_json;
enum common_json_stack_element_type {
COMMON_JSON_STACK_ELEMENT_OBJECT,
COMMON_JSON_STACK_ELEMENT_KEY,
COMMON_JSON_STACK_ELEMENT_ARRAY,
};
struct common_json_stack_element {
common_json_stack_element_type type;
std::string key;
};
bool common_json_parse(
const std::string & input,
const std::string & healing_marker,
common_json & out)
{
std::string::const_iterator it = input.begin();
const auto end = input.end();
return common_json_parse(it, end, healing_marker, out);
}
bool common_json_parse(
std::string::const_iterator & it,
const std::string::const_iterator & end,
const std::string & healing_marker,
common_json & out)
{
// // https://json.nlohmann.me/features/parsing/sax_interface/
struct json_error_locator : public nlohmann::json_sax<json> {
std::size_t position;
bool found_error;
std::string last_token;
std::string exception_message;
std::vector<common_json_stack_element> stack;
json_error_locator() : position(0), found_error(false) {}
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
this->position = position - 1;
this->found_error = true;
this->last_token = last_token;
this->exception_message = ex.what();
return false;
}
void close_value() {
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
stack.pop_back();
}
}
bool null() override { // NOLINT
close_value();
return true;
}
bool boolean(bool) override { // NOLINT
close_value();
return true;
}
bool number_integer(number_integer_t) override { // NOLINT
close_value();
return true;
}
bool number_unsigned(number_unsigned_t) override { // NOLINT
close_value();
return true;
}
bool number_float(number_float_t, const string_t &) override { // NOLINT
close_value();
return true;
}
bool string(string_t &) override { // NOLINT
close_value();
return true;
}
bool binary(binary_t &) override { // NOLINT
close_value();
return true;
}
bool start_object(std::size_t) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
return true;
}
bool end_object() override {
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
stack.pop_back();
close_value();
return true;
}
bool key(string_t & key) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
return true;
}
bool start_array(std::size_t) override { // NOLINT
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
return true;
}
bool end_array() override {
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
stack.pop_back();
close_value();
return true;
}
};
json_error_locator err_loc;
auto start = it;
json::sax_parse(it, end, &err_loc);
if (err_loc.found_error) {
it = start;
auto temptative_end = it + err_loc.position;
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
auto input = std::string(it, temptative_end);
try {
out.json = json::parse(input);
// out.json = json::parse(it, temptative_end);
it = temptative_end;
return true;
} catch (const std::exception & ex) {
// No, needs healing.
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
}
auto can_parse = [](const std::string & str) {
try {
auto _ = json::parse(str); // NOLINT
return true;
} catch (const std::exception &) {
return false;
}
};
if (!healing_marker.empty() && !err_loc.stack.empty()) {
std::string str(it, temptative_end);
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
if (last_non_sp_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
}
auto last_non_sp_char = str[last_non_sp_pos];
// Used to detect stops on a number, which may not be complete.
auto was_maybe_number = [&]() {
if (!str.empty() && std::isspace(str.back())) {
return false;
}
return std::isdigit(last_non_sp_char) ||
last_non_sp_char == '.' ||
last_non_sp_char == 'e' ||
last_non_sp_char == 'E' ||
last_non_sp_char == '-';
};
std::string closing;
for (size_t i = err_loc.stack.size(); i > 0; i--) {
auto & el = err_loc.stack[i - 1];
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
closing += "}";
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
closing += "]";
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
throw std::runtime_error("Unexpected stack element type");
}
}
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
// We're inside an object value
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
// Was about to create an object value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
} else if (can_parse(str + ": 1" + closing)) {
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
// Was about to create an object
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + "\"" + closing)) {
// Was inside an object value string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an object value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else {
// find last :
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
}
// Cutting back to opening : for object value
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
// Was about to create an array value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
} else if (can_parse(str + "\"" + closing)) {
// Was inside an array value string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
// Was inside an array value string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
// Had just finished a value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
} else {
auto last_pos = str.find_last_of("[,");
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
}
// Cutting back to last [ or , for array value
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
// Was about to create an object key+value
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
// Was about to create an object key+value
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
} else if (can_parse(str + "\": 1" + closing)) {
// Was inside an object key string
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
// Was inside an object key string after an escape
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
} else {
auto last_pos = str.find_last_of(':');
if (last_pos == std::string::npos) {
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
}
// fprintf(stderr, "Cutting back to last : for object key+value\n");
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
}
} else {
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
}
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
out.json = json::parse(str);
it = temptative_end;
return true;
}
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
// fprintf(stderr, "Closing: TODO\n");
return false;
}
out.json = json::parse(it, end);
it = end;
return true;
}

37
common/json-partial.h Normal file
View File

@ -0,0 +1,37 @@
#pragma once
#include <json.hpp>
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
struct common_healing_marker {
// Raw marker.
std::string marker;
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
std::string json_dump_marker;
};
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
struct common_json {
nlohmann::ordered_json json;
common_healing_marker healing_marker;
};
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
//
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
//
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
bool common_json_parse(
const std::string & input,
const std::string & healing_marker,
common_json & out);
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
bool common_json_parse(
std::string::const_iterator & it,
const std::string::const_iterator & end,
const std::string & healing_marker,
common_json & out);

View File

@ -161,7 +161,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE #endif // LLAMA_USE_LLGUIDANCE
} else { } else {
std::vector<std::string> patterns_at_start; std::vector<std::string> trigger_patterns;
std::vector<std::string> patterns_anywhere; std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens; std::vector<llama_token> trigger_tokens;
for (const auto & trigger : params.grammar_triggers) { for (const auto & trigger : params.grammar_triggers) {
@ -173,10 +173,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
break; break;
} }
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
{ {
const auto & pattern = trigger.value; patterns_anywhere.push_back(trigger.value);
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern); break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
trigger_patterns.push_back(trigger.value);
break; break;
} }
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
@ -190,10 +193,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
} }
} }
std::vector<std::string> trigger_patterns;
if (!patterns_at_start.empty()) {
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
}
if (!patterns_anywhere.empty()) { if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*"); trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
} }

View File

@ -325,36 +325,65 @@ To get the official template from original HuggingFace repos, you can use [scrip
> [!TIP] > [!TIP]
> If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills) > If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills)
> [!CAUTION]
> Beware of extreme KV quantizations (e.g. `-ctk q4_0`), they can substantially degrade the model's tool calling performance.
Test in CLI (or with any library / software that can use OpenAI-compatible API backends): Test in CLI (or with any library / software that can use OpenAI-compatible API backends):
```bash ```bash
curl http://localhost:8080/v1/chat/completions -d '{ curl http://localhost:8080/v1/chat/completions -d '{
"model": "gpt-3.5-turbo", "model": "gpt-3.5-turbo",
"tools": [ "tools": [
{ {
"type":"function", "type":"function",
"function":{ "function":{
"name":"python", "name":"python",
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters":{ "parameters":{
"type":"object", "type":"object",
"properties":{ "properties":{
"code":{ "code":{
"type":"string", "type":"string",
"description":"The code to run in the ipython interpreter." "description":"The code to run in the ipython interpreter."
}
},
"required":["code"]
} }
},
"required":["code"]
} }
} }
} ],
], "messages": [
"messages": [ {
{ "role": "user",
"role": "user", "content": "Print a hello world message with python."
"content": "Print a hello world message with python." }
} ]
] }'
curl http://localhost:8080/v1/chat/completions -d '{
"model": "gpt-3.5-turbo",
"messages": [
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
{"role": "user", "content": "What is the weather in Istanbul?"}
],
"tools": [{
"type":"function",
"function":{
"name":"get_current_weather",
"description":"Get the current weather in a given location",
"parameters":{
"type":"object",
"properties":{
"location":{
"type":"string",
"description":"The city and country/state, e.g. `San Francisco, CA`, or `Paris, France`"
}
},
"required":["location"]
}
}
}]
}' }'
``` ```

View File

@ -0,0 +1,62 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0]['role'] == 'system' %}
{{- messages[0]['content'] }}
{%- else %}
{{- '' }}
{%- endif %}
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0]['role'] == 'system' %}
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" and not message.tool_calls %}
{%- set content = message.content %}
{%- if not loop.last %}
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set content = message.content %}
{%- if not loop.last %}
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{{- '<|im_start|>' + message.role }}
{%- if message.content %}
{{- '\n' + content }}
{%- endif %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '\n<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{{- tool_call.arguments | tojson }}
{{- '}\n</tool_call>' }}
{%- endfor %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- message.content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n<think>\n' }}
{%- endif %}

View File

@ -19,4 +19,5 @@ These templates can be updated with the following commands:
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja ./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja ./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja ./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja
``` ```

View File

@ -12,6 +12,7 @@
export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp} export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b
@ -205,6 +206,7 @@ def run(
model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None, model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None, hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None, chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None,
ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None, ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None, llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10, n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
@ -229,6 +231,12 @@ def run(
# n_ctx = 8192 # n_ctx = 8192
n_ctx = 2048 n_ctx = 2048
if model is None:
if hf is not None:
model = hf.split("/")[-1]
elif ollama is not None:
model = ollama
assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite" assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
with output.open('a' if append else 'w') as output_file: with output.open('a' if append else 'w') as output_file:
@ -320,6 +328,7 @@ def run(
server.model_hf_repo = hf server.model_hf_repo = hf
server.model_hf_file = None server.model_hf_file = None
server.chat_template = chat_template server.chat_template = chat_template
server.chat_template_file = chat_template_file
server.server_path = server_path server.server_path = server_path
if port is not None: if port is not None:
server.server_port = port server.server_port = port
@ -335,6 +344,7 @@ def run(
temp=t, temp=t,
output_kwargs=dict( output_kwargs=dict(
chat_template=chat_template, chat_template=chat_template,
chat_template_file=chat_template_file,
), ),
request_kwargs=dict( request_kwargs=dict(
ignore_chat_grammar=ignore_chat_grammar, ignore_chat_grammar=ignore_chat_grammar,
@ -355,6 +365,7 @@ def run(
temp=t, temp=t,
output_kwargs=dict( output_kwargs=dict(
chat_template=None, chat_template=None,
chat_template_file=None,
), ),
request_kwargs=dict( request_kwargs=dict(
model=ollama, model=ollama,

View File

@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
for (const auto & trigger_pattern : grammar.trigger_patterns) { for (const auto & trigger_pattern : grammar.trigger_patterns) {
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
grammar.awaiting_trigger = false; grammar.awaiting_trigger = false;
// get from the first match to the end of the string // get from the first matched capturing group to the end of the string
auto constrained_str = grammar.trigger_buffer.substr(match.position(1)); size_t start = std::string::npos;
for (auto i = 1u; i < match.size(); i++) {
if (match.length(i) > 0) {
start = match.position(i);
break;
}
}
if (start == std::string::npos) {
start = match.position(0);
}
auto constrained_str = grammar.trigger_buffer.substr(start);
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); // std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
grammar.trigger_buffer.clear(); grammar.trigger_buffer.clear();
llama_grammar_accept_str(grammar, constrained_str); llama_grammar_accept_str(grammar, constrained_str);

View File

@ -142,8 +142,10 @@ if (NOT WIN32)
# llama_build_and_test(test-double-float.cpp) # SLOW # llama_build_and_test(test-double-float.cpp) # SLOW
endif() endif()
llama_build_and_test(test-log.cpp) llama_build_and_test(test-chat-parser.cpp)
llama_build_and_test(test-chat-template.cpp) llama_build_and_test(test-chat-template.cpp)
llama_build_and_test(test-json-partial.cpp)
llama_build_and_test(test-log.cpp)
llama_build_and_test(test-regex-partial.cpp) llama_build_and_test(test-regex-partial.cpp)
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135) # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)

355
tests/test-chat-parser.cpp Normal file
View File

@ -0,0 +1,355 @@
// Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
//
// Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
// e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
//
// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
//
#include <exception>
#include <iostream>
#include <json.hpp>
#include <string>
#include "chat-parser.h"
#include "common.h"
#include "log.h"
#include "regex-partial.h"
using json = nlohmann::ordered_json;
template <class T>
static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << "Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
static void assert_equals(const char * expected, const std::string & actual) {
return assert_equals<std::string>(expected, actual);
}
static void assert_throws(const std::function<void()> & fn, const std::string & expected_exception_pattern = "") {
try {
fn();
} catch (const std::exception & e) {
if (expected_exception_pattern.empty()) {
return;
}
std::regex expected_exception_regex(expected_exception_pattern);
std::string actual_message = e.what();
if (std::regex_search(actual_message, expected_exception_regex)) {
return;
}
throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
throw std::runtime_error("Exception of unexpected type: " + std::string(e.what()));
}
throw std::runtime_error("Exception was expected but not thrown");
}
static void test_reasoning() {
{
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
assert_equals("Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ false,
});
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ false,
/* .thinking_forced_open = */ true,
});
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
assert_equals("Ergo sum", builder.consume_rest());
}
{
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
/* .reasoning_in_content = */ true,
/* .thinking_forced_open = */ true,
});
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
assert_equals("<think>Cogito</think>", builder.result().content);
assert_equals("Ergo sum", builder.consume_rest());
}
}
static void test_regex() {
auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
common_chat_msg_parser builder(input, /* is_partial= */ false, {});
assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
};
test_throws("Hello, world!", "abc", "^abc$");
test_throws("Hello, world!", "e", "^e$");
{
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
builder.consume_regex(common_regex("Hello"));
assert_equals(", world!", builder.consume_rest());
}
{
// When in non partial mode, we can say whether the regex was consumed or not.
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
}
{
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?"));
assert_equals(true, res.has_value());
// Verify captures
assert_equals<size_t>(2, res->groups.size());
assert_equals("Hell", builder.str(res->groups[0]));
assert_equals("el", builder.str(res->groups[1]));
// Verify position is after the match
assert_equals<size_t>(4, builder.pos());
assert_equals("o,", builder.consume_rest());
}
{
// But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
assert_throws([&]() {
builder.try_consume_regex(common_regex("Hello, world!"));
}, "^Hello, world!$");
}
// Now regardless of the mode, we can tell these aren't a match.
for (const auto is_partial : {false, true}) {
common_chat_msg_parser builder("Hello,", is_partial, {});
assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
}
for (const auto is_partial : {false, true}) {
common_chat_msg_parser builder("Hello,", is_partial, {});
assert_equals(false, builder.try_consume_literal("Oh"));
}
}
const std::vector<std::string> barely_healable_jsons = {
"{",
"{\"",
"{\"\\",
"{\"n",
"{\"name\"",
"{\"name\":",
"{\"name\":\"",
"{\"name\":\"\\",
"{\"name\":\"python",
"{\"name\":\"python\\",
"{\",",
"{\":",
"{\"[",
"{\"]",
"{\"{",
"{\"}",
"{\"1",
"{\"name\":\",",
"{\"name\":\":",
"{\"name\":\"[",
"{\"name\":\"]",
"{\"name\":\"{",
"{\"name\":\"}",
"{\"name\":\"1",
};
static void test(const std::string & input, bool is_partial, const std::vector<std::vector<std::string>> & args_paths, const std::vector<std::vector<std::string>> & content_paths, const std::string & expected) {
common_chat_msg_parser builder(input, is_partial, {});
auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths);
assert_equals(true, js.has_value());
assert_equals(is_partial, js->is_partial);
assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get<std::string>() : js->value.dump());
}
static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) {
common_chat_msg_parser builder(input, parse_as_partial, {});
auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {});
assert_equals(true, js.has_value());
assert_equals(is_partial, js->is_partial);
assert_equals(expected, js->value.dump());
}
static void test_json_with_dumped_args_no_args() {
// Normal JSON, nothing to heal, nothing to dump
test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}");
// Full json is args
test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}");
// If the arguments are further down, don't heal partial content.
for (const auto & src : barely_healable_jsons) {
test(src, true, {{"arguments"}}, {}, "{}");
}
// But heal content that isn't partial.
test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}");
}
static void test_json_with_dumped_args() {
// Partial content.
test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}");
test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}");
test("{\"content\": ", true, {}, {{"content"}}, "{}");
// If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted).
test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python");
for (const auto & src : barely_healable_jsons) {
test(src, true, {{}}, {}, src);
}
// Full JSON w/ args
for (auto parse_as_partial : {true, false}) {
test_with_args(
R"({"name": "python", "args": {"arg1": 1}})",
R"({"name":"python","args":"{\"arg1\":1}"})",
parse_as_partial,
/* is_partial= */ false
);
}
// Partial JSON w/ partial args
test_with_args(
R"({"foo": "bar", "args": {")",
R"({"foo":"bar","args":"{\""})"
);
// Partial args broken in object key
test_with_args(
R"({"foo": "bar", "args": {"ar)",
R"({"foo":"bar","args":"{\"ar"})"
);
// Partial args broken after object key
test_with_args(
R"({"foo": "bar", "args": {"arg1")",
R"({"foo":"bar","args":"{\"arg1\""})"
);
// Partial args broken before object value
test_with_args(
R"({"foo": "bar", "args": {"arg1":)",
R"({"foo":"bar","args":"{\"arg1\":"})"
);
// Partial args broken before object value (space)
test_with_args(
R"({"foo": "bar", "args": {"arg1": )",
R"({"foo":"bar","args":"{\"arg1\":"})"
);
// Partial args broken in object value that may not be complete (int)
test_with_args(
R"({"foo": "bar", "args": {"arg1": 1)",
R"({"foo":"bar","args":"{\"arg1\":"})"
);
// Partial args broken in object value that is complete (int)
test_with_args(
R"({"foo": "bar", "args": {"arg1": 1 )",
R"({"foo":"bar","args":"{\"arg1\":1"})"
);
// Partial args broken in object value that is incomplete (string)
test_with_args(
R"({"foo": "bar", "args": {"arg1": ")",
R"({"foo":"bar","args":"{\"arg1\":\""})"
);
// Partial args broken in object value that is complete (string)
test_with_args(
R"({"foo": "bar", "args": {"arg1": "1")",
R"({"foo":"bar","args":"{\"arg1\":\"1\""})"
);
// Partial args broken on array opening
test_with_args(
R"({"foo": "bar", "args": [)",
R"({"foo":"bar","args":"["})"
);
// Partial args broken on array value that is incomplete (int)
test_with_args(
R"({"foo": "bar", "args": [1)",
R"({"foo":"bar","args":"["})"
);
// Partial args broken on array value that is complete (int)
test_with_args(
R"({"foo": "bar", "args": [1 )",
R"({"foo":"bar","args":"[1"})"
);
// Partial args broken on array value that is complete (string)
test_with_args(
R"({"foo": "bar", "args": ["1")",
R"({"foo":"bar","args":"[\"1\""})"
);
// Partial args broken after array value
test_with_args(
R"({"foo": "bar", "args": [1,)",
R"({"foo":"bar","args":"[1,"})"
);
// Partial args broken on nested array
test_with_args(
R"({"foo": "bar", "args": {"arg1": [)",
R"({"foo":"bar","args":"{\"arg1\":["})"
);
}
static void test_positions() {
{
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
assert_equals<size_t>(0, builder.pos());
assert_throws([&]() { builder.move_to(100); });
assert_equals<size_t>(0, builder.pos());
assert_throws([&]() { builder.move_back(1); });
assert_equals<size_t>(0, builder.pos());
builder.move_to(8);
assert_equals<size_t>(8, builder.pos());
builder.move_back(1);
assert_equals<size_t>(7, builder.pos());
assert_equals("world!", builder.consume_rest());
builder.move_to(0);
assert_equals<size_t>(0, builder.pos());
assert_throws([&]() { builder.finish(); });
assert_equals<size_t>(0, builder.pos());
builder.move_to(builder.input().size());
builder.finish();
}
{
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {});
builder.move_to(builder.input().size());
assert_equals<size_t>(builder.input().size(), builder.pos());
builder.finish();
}
}
int main() {
test_positions();
test_json_with_dumped_args_no_args();
test_json_with_dumped_args();
test_reasoning();
test_regex();
std::cout << "All tests passed!\n";
return 0;
}

File diff suppressed because it is too large Load Diff

237
tests/test-json-partial.cpp Normal file
View File

@ -0,0 +1,237 @@
#include "common.h"
#include "json-partial.h"
#include <exception>
#include <iostream>
#include <stdexcept>
template <class T> static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << "Actual: " << actual << std::endl;
std::cerr << std::flush;
throw std::runtime_error("Test failed");
}
}
static void test_json_healing() {
auto parse = [](const std::string & str) {
std::cerr << "# Parsing: " << str << '\n';
std::string::const_iterator it = str.begin();
const auto end = str.end();
common_json out;
std::string healing_marker = "$llama.cpp.json$";
if (common_json_parse(it, end, healing_marker, out)) {
auto dump = out.json.dump();
std::cerr << "Parsed: " << dump << '\n';
std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
std::string result;
if (!out.healing_marker.json_dump_marker.empty()) {
auto i = dump.find(out.healing_marker.json_dump_marker);
if (i == std::string::npos) {
throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
}
result = dump.substr(0, i);
} else {
result = dump;
}
std::cerr << "Result: " << result << '\n';
if (string_starts_with(str, result)) {
std::cerr << "Failure!\n";
}
// return dump;
} else {
throw std::runtime_error("Failed to parse: " + str);
}
};
auto parse_all = [&](const std::string & str) {
for (size_t i = 1; i < str.size(); i++) {
parse(str.substr(0, i));
}
};
parse_all("{\"a\": \"b\"}");
parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
parse_all("[{\"a\": \"b\"}]");
auto test = [&](const std::vector<std::string> & inputs, const std::string & expected, const std::string & expected_marker) {
for (const auto & input : inputs) {
common_json out;
assert_equals(true, common_json_parse(input, "$foo", out));
assert_equals<std::string>(expected, out.json.dump());
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
}
};
// No healing needed:
test(
{
R"([{"a":"b"}, "y"])",
},
R"([{"a":"b"},"y"])",
""
);
// Partial literals can't be healed:
test(
{
R"([1)",
R"([tru)",
R"([n)",
R"([nul)",
R"([23.2)",
},
R"(["$foo"])",
R"("$foo)"
);
test(
{
R"({"a": 1)",
R"({"a": tru)",
R"({"a": n)",
R"({"a": nul)",
R"({"a": 23.2)",
},
R"({"a":"$foo"})",
R"("$foo)"
);
test(
{
R"({)",
},
R"({"$foo":1})",
R"("$foo)"
);
test(
{
R"([)",
},
R"(["$foo"])",
R"("$foo)"
);
// Healing right after a full literal
test(
{
R"(1 )",
},
R"(1)",
""
);
test(
{
R"(true)",
R"(true )",
},
R"(true)",
""
);
test(
{
R"(null)",
R"(null )",
},
R"(null)",
""
);
test(
{
R"([1 )",
},
R"([1,"$foo"])",
R"(,"$foo)"
);
test(
{
R"([{})",
R"([{} )",
},
R"([{},"$foo"])",
R"(,"$foo)"
);
test(
{
R"([true)",
},
// TODO: detect the true/false/null literal was complete
R"(["$foo"])",
R"("$foo)"
);
test(
{
R"([true )",
},
R"([true,"$foo"])",
R"(,"$foo)"
);
test(
{
R"([true,)",
},
R"([true,"$foo"])",
R"("$foo)"
);
// Test nesting
test(
{
R"([{"a": [{"b": [{)",
},
R"([{"a":[{"b":[{"$foo":1}]}]}])",
R"("$foo)"
);
test(
{
R"([{"a": [{"b": [)",
},
R"([{"a":[{"b":["$foo"]}]}])",
R"("$foo)"
);
test(
{
R"([{"a": "b"})",
R"([{"a": "b"} )",
},
R"([{"a":"b"},"$foo"])",
R"(,"$foo)"
);
test(
{
R"([{"a": "b"},)",
R"([{"a": "b"}, )",
},
R"([{"a":"b"},"$foo"])",
R"("$foo)"
);
test(
{
R"({ "code)",
},
R"({"code$foo":1})",
R"($foo)"
);
test(
{
R"({ "code\)",
},
R"({"code\\$foo":1})",
R"(\$foo)"
);
test(
{
R"({ "code")",
},
R"({"code":"$foo"})",
R"(:"$foo)"
);
test(
{
R"({ "key")",
},
R"({"key":"$foo"})",
R"(:"$foo)"
);
}
int main() {
test_json_healing();
std::cerr << "All tests passed.\n";
return 0;
}

View File

@ -1,3 +1,4 @@
#include "chat.h"
#include "utils.hpp" #include "utils.hpp"
#include "arg.h" #include "arg.h"
@ -114,11 +115,11 @@ struct slot_params {
struct common_params_speculative speculative; struct common_params_speculative speculative;
// OAI-compat fields // OAI-compat fields
bool verbose = false; bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; common_chat_syntax oaicompat_chat_syntax;
json to_json() const { json to_json() const {
std::vector<std::string> samplers; std::vector<std::string> samplers;
@ -176,7 +177,10 @@ struct slot_params {
{"grammar_lazy", sampling.grammar_lazy}, {"grammar_lazy", sampling.grammar_lazy},
{"grammar_triggers", grammar_triggers}, {"grammar_triggers", grammar_triggers},
{"preserved_tokens", sampling.preserved_tokens}, {"preserved_tokens", sampling.preserved_tokens},
{"chat_format", common_chat_format_name(oaicompat_chat_format)}, {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
{"reasoning_format", (oaicompat_chat_syntax.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "deepseek" : "none")},
{"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
{"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
{"samplers", samplers}, {"samplers", samplers},
{"speculative.n_max", speculative.n_max}, {"speculative.n_max", speculative.n_max},
{"speculative.n_min", speculative.n_min}, {"speculative.n_min", speculative.n_min},
@ -352,11 +356,14 @@ struct server_task {
{ {
auto it = data.find("chat_format"); auto it = data.find("chat_format");
if (it != data.end()) { if (it != data.end()) {
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>()); params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str()); SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format).c_str());
} else { } else {
params.oaicompat_chat_format = defaults.oaicompat_chat_format; params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
} }
params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
params.oaicompat_chat_syntax.reasoning_in_content = params.stream;
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
} }
{ {
@ -396,7 +403,14 @@ struct server_task {
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
} }
} else { } else {
params.sampling.grammar_triggers.push_back(std::move(ct.value)); if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
} else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
} else {
throw std::runtime_error("Unknown grammar trigger type");
}
params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
} }
} }
} }
@ -639,11 +653,12 @@ struct server_task_result_cmpl_final : server_task_result {
slot_params generation_params; slot_params generation_params;
// OAI-compat fields // OAI-compat fields
bool verbose = false; bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; common_chat_msg oaicompat_msg;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
virtual int get_index() override { virtual int get_index() override {
return index; return index;
@ -738,47 +753,20 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat_chat() { json to_json_oaicompat_chat() {
std::string finish_reason = "length"; std::string finish_reason = "length";
common_chat_msg msg; common_chat_msg msg;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { if (!oaicompat_msg.empty()) {
SRV_DBG("Parsing chat message: %s\n", content.c_str()); msg = oaicompat_msg;
msg = common_chat_parse(content, oaicompat_chat_format);
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
} else { } else {
msg.role = "assistant";
msg.content = content; msg.content = content;
} }
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
json message { finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
{"role", "assistant"},
};
if (!msg.reasoning_content.empty()) {
message["reasoning_content"] = msg.reasoning_content;
}
if (msg.content.empty() && !msg.tool_calls.empty()) {
message["content"] = json();
} else {
message["content"] = msg.content;
}
if (!msg.tool_calls.empty()) {
auto tool_calls = json::array();
for (const auto & tc : msg.tool_calls) {
tool_calls.push_back({
{"type", "function"},
{"function", {
{"name", tc.name},
{"arguments", tc.arguments},
}},
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
// We only generate a random id for the ones that don't generate one by themselves
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
{"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
});
}
message["tool_calls"] = tool_calls;
} }
json choice { json choice {
{"finish_reason", finish_reason}, {"finish_reason", finish_reason},
{"index", 0}, {"index", 0},
{"message", message}, {"message", msg.to_json_oaicompat<json>()},
}; };
if (!stream && probs_output.size() > 0) { if (!stream && probs_output.size() > 0) {
@ -818,17 +806,35 @@ struct server_task_result_cmpl_final : server_task_result {
std::time_t t = std::time(0); std::time_t t = std::time(0);
std::string finish_reason = "length"; std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop"; finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
} }
json choice = json { json deltas = json::array();
{"finish_reason", finish_reason}, for (const auto & diff : oaicompat_msg_diffs) {
{"index", 0}, deltas.push_back({
{"delta", json::object()} {"choices", json::array({
}; json {
{"finish_reason", nullptr},
{"index", 0},
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
},
})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
});
}
json ret = json { deltas.push_back({
{"choices", json::array({choice})}, {"choices", json::array({
json {
{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()},
},
})},
{"created", t}, {"created", t},
{"id", oaicompat_cmpl_id}, {"id", oaicompat_cmpl_id},
{"model", oaicompat_model}, {"model", oaicompat_model},
@ -839,18 +845,18 @@ struct server_task_result_cmpl_final : server_task_result {
{"prompt_tokens", n_prompt_tokens}, {"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens}, {"total_tokens", n_decoded + n_prompt_tokens},
}}, }},
}; });
if (timings.prompt_n >= 0) { if (timings.prompt_n >= 0) {
ret.push_back({"timings", timings.to_json()}); deltas.back().push_back({"timings", timings.to_json()});
} }
// extra fields for debugging purposes // extra fields for debugging purposes
if (verbose) { if (verbose && !deltas.empty()) {
ret["__verbose"] = to_json_non_oaicompat(); deltas.front()["__verbose"] = to_json_non_oaicompat();
} }
return ret; return deltas;
} }
}; };
@ -868,10 +874,11 @@ struct server_task_result_cmpl_partial : server_task_result {
result_timings timings; result_timings timings;
// OAI-compat fields // OAI-compat fields
bool verbose = false; bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model; std::string oaicompat_model;
std::string oaicompat_cmpl_id; std::string oaicompat_cmpl_id;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
virtual int get_index() override { virtual int get_index() override {
return index; return index;
@ -955,84 +962,50 @@ struct server_task_result_cmpl_partial : server_task_result {
std::time_t t = std::time(0); std::time_t t = std::time(0);
json choices; json choices;
if (first) { std::vector<json> deltas;
if (content.empty()) { auto add_delta = [&](const json & delta) {
choices = json::array({json{{"finish_reason", nullptr}, deltas.push_back({
{"index", 0}, {"choices", json::array({
{"delta", json{{"role", "assistant"}}}}}); json {
} else { {"finish_reason", nullptr},
// We have to send this as two updates to conform to openai behavior {"index", 0},
// initial_ret is the role message for stream=True {"delta", delta},
json initial_ret = json{{"choices", json::array({json{ },
{"finish_reason", nullptr}, })},
{"index", 0}, {"created", t},
{"delta", json{ {"id", oaicompat_cmpl_id},
{"role", "assistant"}, {"model", oaicompat_model},
{"content", ""} {"system_fingerprint", build_info},
}}}})}, {"object", "chat.completion.chunk"},
{"created", t}, });
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}};
json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json {
{"content", content}}}
}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}};
if (prob_output.probs.size() > 0) {
second_ret["choices"][0]["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
if (timings.prompt_n >= 0) {
second_ret.push_back({"timings", timings.to_json()});
}
return std::vector<json>({initial_ret, second_ret});
}
} else {
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json {
{"content", content},
}},
}});
}
GGML_ASSERT(choices.size() >= 1);
if (prob_output.probs.size() > 0) {
choices[0]["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
json ret = json {
{"choices", choices},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}
}; };
// We have to send an initial update to conform to openai behavior
if (timings.prompt_n >= 0) { if (first) {
ret.push_back({"timings", timings.to_json()}); add_delta({
{"role", "assistant"},
{"content", nullptr},
});
} }
return std::vector<json>({ret}); for (const auto & diff : oaicompat_msg_diffs) {
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
}
if (!deltas.empty()) {
GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
if (prob_output.probs.size() > 0) {
deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json {
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
if (timings.prompt_n >= 0) {
deltas[deltas.size() - 1].push_back({"timings", timings.to_json()});
}
}
return deltas;
} }
}; };
@ -1293,6 +1266,7 @@ struct server_slot {
std::string generated_text; std::string generated_text;
llama_tokens generated_tokens; llama_tokens generated_tokens;
common_chat_msg chat_msg;
server_tokens cache_tokens; server_tokens cache_tokens;
@ -1313,6 +1287,7 @@ struct server_slot {
llama_token sampled; llama_token sampled;
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
std::vector<std::string> generated_tool_call_ids;
// stats // stats
size_t n_sent_text = 0; // number of sent text character size_t n_sent_text = 0; // number of sent text character
@ -1342,9 +1317,13 @@ struct server_slot {
n_past = 0; n_past = 0;
n_sent_text = 0; n_sent_text = 0;
task_type = SERVER_TASK_TYPE_COMPLETION; task_type = SERVER_TASK_TYPE_COMPLETION;
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
generated_tokens.clear(); generated_tokens.clear();
generated_token_probs.clear(); generated_token_probs.clear();
chat_msg = {};
json_schema = json();
generated_tool_call_ids.clear();
// clear speculative decoding stats // clear speculative decoding stats
n_draft_total = 0; n_draft_total = 0;
@ -1424,6 +1403,21 @@ struct server_slot {
return timings; return timings;
} }
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
auto previous_msg = chat_msg;
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
auto new_msg = common_chat_parse(
generated_text,
/* is_partial= */ stop != STOP_TYPE_EOS,
params.oaicompat_chat_syntax);
if (!new_msg.empty()) {
new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
chat_msg = new_msg;
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
}
return chat_msg;
}
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
size_t stop_pos = std::string::npos; size_t stop_pos = std::string::npos;
@ -2475,10 +2469,12 @@ struct server_context {
res->n_prompt_tokens = slot.n_prompt_tokens; res->n_prompt_tokens = slot.n_prompt_tokens;
res->post_sampling_probs = slot.params.post_sampling_probs; res->post_sampling_probs = slot.params.post_sampling_probs;
res->verbose = slot.params.verbose; res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat; res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output // populate res.probs_output
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
@ -2499,7 +2495,7 @@ struct server_context {
res->id_slot = slot.id; res->id_slot = slot.id;
res->index = slot.index; res->index = slot.index;
res->content = std::move(slot.generated_text); res->content = slot.generated_text;
res->tokens = std::move(slot.generated_tokens); res->tokens = std::move(slot.generated_tokens);
res->timings = slot.get_timings(); res->timings = slot.get_timings();
res->prompt = slot.prompt_tokens.detokenize(ctx, true); res->prompt = slot.prompt_tokens.detokenize(ctx, true);
@ -2519,7 +2515,8 @@ struct server_context {
res->oaicompat = slot.params.oaicompat; res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model; res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
res->oaicompat_chat_format = slot.params.oaicompat_chat_format; res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
// populate res.probs_output // populate res.probs_output
if (slot.params.sampling.n_probs > 0) { if (slot.params.sampling.n_probs > 0) {
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) { if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {

View File

@ -75,7 +75,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
choice = data["choices"][0] choice = data["choices"][0]
if i == 0: if i == 0:
# Check first role message for stream=True # Check first role message for stream=True
assert choice["delta"]["content"] == "" assert choice["delta"]["content"] is None
assert choice["delta"]["role"] == "assistant" assert choice["delta"]["role"] == "assistant"
else: else:
assert "role" not in choice["delta"] assert "role" not in choice["delta"]
@ -92,7 +92,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
assert choice["finish_reason"] == finish_reason assert choice["finish_reason"] == finish_reason
else: else:
assert choice["finish_reason"] is None assert choice["finish_reason"] is None
content += choice["delta"]["content"] content += choice["delta"]["content"] or ''
def test_chat_completion_with_openai_library(): def test_chat_completion_with_openai_library():
@ -251,8 +251,9 @@ def test_chat_completion_with_timings_per_token():
for i, data in enumerate(res): for i, data in enumerate(res):
if i == 0: if i == 0:
# Check first role message for stream=True # Check first role message for stream=True
assert data["choices"][0]["delta"]["content"] == "" assert data["choices"][0]["delta"]["content"] is None
assert data["choices"][0]["delta"]["role"] == "assistant" assert data["choices"][0]["delta"]["role"] == "assistant"
assert "timings" not in data, f'First event should not have timings: {data}'
else: else:
assert "role" not in data["choices"][0]["delta"] assert "role" not in data["choices"][0]["delta"]
assert "timings" in data assert "timings" in data
@ -311,7 +312,7 @@ def test_logprobs_stream():
choice = data.choices[0] choice = data.choices[0]
if i == 0: if i == 0:
# Check first role message for stream=True # Check first role message for stream=True
assert choice.delta.content == "" assert choice.delta.content is None
assert choice.delta.role == "assistant" assert choice.delta.role == "assistant"
else: else:
assert choice.delta.role is None assert choice.delta.role is None

View File

@ -8,6 +8,7 @@ path = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(path)) sys.path.insert(0, str(path))
from utils import * from utils import *
from enum import Enum
server: ServerProcess server: ServerProcess
@ -20,7 +21,11 @@ def create_server():
server = ServerPreset.tinyllama2() server = ServerPreset.tinyllama2()
server.model_alias = "tinyllama-2-tool-call" server.model_alias = "tinyllama-2-tool-call"
server.server_port = 8081 server.server_port = 8081
server.n_slots = 1
class CompletionMode(Enum):
NORMAL = "normal"
STREAMED = "streamed"
TEST_TOOL = { TEST_TOOL = {
"type":"function", "type":"function",
@ -73,9 +78,8 @@ WEATHER_TOOL = {
} }
} }
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs): def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={ body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict, "max_tokens": n_predict,
"messages": [ "messages": [
{"role": "system", "content": "You are a coding assistant."}, {"role": "system", "content": "You are a coding assistant."},
@ -86,13 +90,13 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
"parallel_tool_calls": False, "parallel_tool_calls": False,
**kwargs, **kwargs,
}) })
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" # assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0] choice = body["choices"][0]
tool_calls = choice["message"].get("tool_calls") tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0] tool_call = tool_calls[0]
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}' # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
assert expected_function_name == tool_call["function"]["name"] assert expected_function_name == tool_call["function"]["name"]
actual_arguments = tool_call["function"]["arguments"] actual_arguments = tool_call["function"]["arguments"]
@ -102,12 +106,16 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,tool,argument_key", [ @pytest.mark.parametrize("template_name,tool,argument_key", [
("google-gemma-2-2b-it", TEST_TOOL, "success"), ("google-gemma-2-2b-it", TEST_TOOL, "success"),
("google-gemma-2-2b-it", TEST_TOOL, "success"),
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"), ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
]) ])
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None): def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
global server global server
n_predict = 1024 n_predict = 1024
# server = ServerPreset.stories15m_moe() # server = ServerPreset.stories15m_moe()
@ -115,31 +123,43 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
server.n_predict = n_predict server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja' server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START) server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0) do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,tool,argument_key", [ @pytest.mark.parametrize("template_name,tool,argument_key", [
("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), ("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"), ("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"), ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"), # Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own.
# ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"), ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"), ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"), ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"), ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"), ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"), ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"), ("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "codeFalse), True),
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"), # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
]) ])
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None): def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
global server global server
n_predict = 512 n_predict = 512
# server = ServerPreset.stories15m_moe() # server = ServerPreset.stories15m_moe()
@ -147,10 +167,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
server.n_predict = n_predict server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja' server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START) server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict) do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [ @pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), (TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
@ -184,9 +205,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), # (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), # (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), # (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)), (TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)), (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
@ -203,10 +224,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
(TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
]) ])
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server global server
n_predict = 512 n_predict = 512
server.n_slots = 1
server.jinja = True server.jinja = True
server.n_ctx = 8192 server.n_ctx = 8192
server.n_predict = n_predict server.n_predict = n_predict
@ -219,7 +239,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
elif isinstance(template_override, str): elif isinstance(template_override, str):
server.chat_template = template_override server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START) server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/v1/chat/completions", data={ body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict, "max_tokens": n_predict,
"messages": [ "messages": [
{"role": "system", "content": "You are a coding assistant."}, {"role": "system", "content": "You are a coding assistant."},
@ -228,12 +248,12 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
"tool_choice": "required", "tool_choice": "required",
"tools": [tool], "tools": [tool],
"parallel_tool_calls": False, "parallel_tool_calls": False,
"stream": stream == CompletionMode.STREAMED,
"temperature": 0.0, "temperature": 0.0,
"top_k": 1, "top_k": 1,
"top_p": 1.0, "top_p": 1.0,
}, timeout=TIMEOUT_HTTP_REQUEST) }, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = body["choices"][0]
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls") tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0] tool_call = tool_calls[0]
@ -248,7 +268,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs): def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={ body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict, "max_tokens": n_predict,
"messages": [ "messages": [
{"role": "system", "content": "You are a coding assistant."}, {"role": "system", "content": "You are a coding assistant."},
@ -258,26 +278,27 @@ def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int,
"tool_choice": tool_choice, "tool_choice": tool_choice,
**kwargs, **kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST) }, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = body["choices"][0]
choice = res.body["choices"][0]
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
("meta-llama-Llama-3.3-70B-Instruct", 128, [], None), ("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None), ("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'), ("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
]) ])
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
global server global server
server.jinja = True
server.n_predict = n_predict server.n_predict = n_predict
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja' server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START) server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice) do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ @pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
("meetkai-functionary-medium-v3.2", 256, [], None), ("meetkai-functionary-medium-v3.2", 256, [], None),
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None), ("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
@ -289,16 +310,17 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None), ("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'), ("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
]) ])
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
global server global server
server.jinja = True
server.n_predict = n_predict server.n_predict = n_predict
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja' server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START) server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice) do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("hf_repo,template_override", [ @pytest.mark.parametrize("hf_repo,template_override", [
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
@ -321,11 +343,11 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), # ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), # ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), # ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), # ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
@ -339,10 +361,9 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
]) ])
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None): def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server global server
n_predict = 512 n_predict = 512
server.n_slots = 1
server.jinja = True server.jinja = True
server.n_ctx = 8192 server.n_ctx = 8192
server.n_predict = n_predict server.n_predict = n_predict
@ -355,11 +376,11 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
elif isinstance(template_override, str): elif isinstance(template_override, str):
server.chat_template = template_override server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START) server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_weather(server, max_tokens=n_predict) do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
def do_test_weather(server: ServerProcess, **kwargs): def do_test_weather(server: ServerProcess, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={ body = server.make_any_request("POST", "/v1/chat/completions", data={
"messages": [ "messages": [
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."}, {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
{"role": "user", "content": "What is the weather in Istanbul?"}, {"role": "user", "content": "What is the weather in Istanbul?"},
@ -367,14 +388,13 @@ def do_test_weather(server: ServerProcess, **kwargs):
"tools": [WEATHER_TOOL], "tools": [WEATHER_TOOL],
**kwargs, **kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST) }, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = body["choices"][0]
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls") tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0] tool_call = tool_calls[0]
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}' assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}' # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
actual_arguments = json.loads(tool_call["function"]["arguments"]) actual_arguments = json.loads(tool_call["function"]["arguments"])
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
location = actual_arguments["location"] location = actual_arguments["location"]
@ -383,6 +403,7 @@ def do_test_weather(server: ServerProcess, **kwargs):
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [ @pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None), (None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
@ -400,9 +421,8 @@ def do_test_weather(server: ServerProcess, **kwargs):
# (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), # (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
]) ])
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None): def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server global server
server.n_slots = 1
server.jinja = True server.jinja = True
server.n_ctx = 8192 * 2 server.n_ctx = 8192 * 2
server.n_predict = n_predict server.n_predict = n_predict
@ -415,11 +435,11 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
elif isinstance(template_override, str): elif isinstance(template_override, str):
server.chat_template = template_override server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START) server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_calc_result(server, result_override, n_predict) do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs): def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={ body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict, "max_tokens": n_predict,
"messages": [ "messages": [
{"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."}, {"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
@ -466,8 +486,7 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
], ],
**kwargs, **kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST) }, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = body["choices"][0]
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls") tool_calls = choice["message"].get("tool_calls")
assert tool_calls is None, f'Expected no tool call in {choice["message"]}' assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
content = choice["message"].get("content") content = choice["message"].get("content")
@ -480,18 +499,18 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [ @pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [
(128, 'deepseek', "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (128, 'deepseek', CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(128, None, "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (128, None, CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, 'deepseek', CompletionMode.NORMAL, "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "I need to calculate the sum of 102 and 7[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), (1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>I need to calculate [\\s\\S]*?</think>To find the sum of [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'none', "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), (1024, 'deepseek', CompletionMode.NORMAL, "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>First, I [\\s\\S]*?</think>To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), # (1024, 'none', CompletionMode.NORMAL, None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None),
]) ])
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server global server
server.n_slots = 1
server.reasoning_format = reasoning_format server.reasoning_format = reasoning_format
server.jinja = True server.jinja = True
server.n_ctx = 8192 * 2 server.n_ctx = 8192 * 2
@ -505,14 +524,14 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
elif isinstance(template_override, str): elif isinstance(template_override, str):
server.chat_template = template_override server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START) server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/v1/chat/completions", data={ body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict, "max_tokens": n_predict,
"messages": [ "messages": [
{"role": "user", "content": "What's the sum of 102 and 7?"}, {"role": "user", "content": "What's the sum of 102 and 7?"},
] ],
"stream": stream == CompletionMode.STREAMED,
}, timeout=TIMEOUT_HTTP_REQUEST) }, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = body["choices"][0]
choice = res.body["choices"][0]
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
content = choice["message"].get("content") content = choice["message"].get("content")
@ -529,6 +548,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
@pytest.mark.parametrize("hf_repo,template_override", [ @pytest.mark.parametrize("hf_repo,template_override", [
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
@ -562,10 +582,9 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"), ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
]) ])
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None): def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
global server global server
n_predict = 512 # High because of DeepSeek R1 n_predict = 512 # High because of DeepSeek R1
server.n_slots = 1
server.jinja = True server.jinja = True
server.n_ctx = 8192 server.n_ctx = 8192
server.n_predict = n_predict server.n_predict = n_predict
@ -579,11 +598,11 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
server.chat_template = template_override server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START) server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_hello_world(server, max_tokens=n_predict) do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
def do_test_hello_world(server: ServerProcess, **kwargs): def do_test_hello_world(server: ServerProcess, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={ body = server.make_any_request("POST", "/v1/chat/completions", data={
"messages": [ "messages": [
{"role": "system", "content": "You are a tool-calling agent."}, {"role": "system", "content": "You are a tool-calling agent."},
{"role": "user", "content": "say hello world with python"}, {"role": "user", "content": "say hello world with python"},
@ -591,16 +610,15 @@ def do_test_hello_world(server: ServerProcess, **kwargs):
"tools": [PYTHON_TOOL], "tools": [PYTHON_TOOL],
**kwargs, **kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST) }, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = body["choices"][0]
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls") tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0] tool_call = tool_calls[0]
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}' # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
actual_arguments = json.loads(tool_call["function"]["arguments"]) actual_arguments = json.loads(tool_call["function"]["arguments"])
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}" assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
code = actual_arguments["code"] code = actual_arguments["code"]
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}' assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}'

View File

@ -294,6 +294,77 @@ class ServerProcess:
print("Partial response from server", json.dumps(data, indent=2)) print("Partial response from server", json.dumps(data, indent=2))
yield data yield data
def make_any_request(
self,
method: str,
path: str,
data: dict | None = None,
headers: dict | None = None,
timeout: float | None = None,
) -> dict:
stream = data.get('stream', False)
if stream:
content: list[str] = []
tool_calls: list[dict] = []
finish_reason: Optional[str] = None
content_parts = 0
tool_call_parts = 0
arguments_parts = 0
for chunk in self.make_stream_request(method, path, data, headers):
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
choice = chunk['choices'][0]
if choice['delta'].get('content') is not None:
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
content.append(choice['delta']['content'])
content_parts += 1
if choice['delta'].get('finish_reason') is not None:
finish_reason = choice['delta']['finish_reason']
for tc in choice['delta'].get('tool_calls', []):
if 'function' not in tc:
raise ValueError(f"Expected function type, got {tc['type']}")
if tc['index'] >= len(tool_calls):
tool_calls.append(dict(
id="",
type="function",
function=dict(
name="",
arguments="",
)
))
tool_call = tool_calls[tc['index']]
if tc.get('id') is not None:
tool_call['id'] = tc['id']
fct = tc['function']
if fct.get('name') is not None:
tool_call['function']['name'] = fct['name']
if fct.get('arguments') is not None:
assert len(fct['arguments']) > 0, f'Expected non empty arguments delta!'
tool_call['function']['arguments'] += fct['arguments']
print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
result = dict(
choices=[
dict(
index=0,
finish_reason=finish_reason,
message=dict(
role='assistant',
content=''.join(content) if content else None,
tool_calls=tool_calls if tool_calls else None,
),
)
],
)
print("Final response from server", json.dumps(result, indent=2))
return result
else:
response = self.make_request(method, path, data, headers, timeout=timeout)
assert response.status_code == 200, f"Server returned error: {response.status_code}"
return response.body
server_instances: Set[ServerProcess] = set() server_instances: Set[ServerProcess] = set()

View File

@ -474,26 +474,6 @@ static std::string gen_tool_call_id() {
// other common utils // other common utils
// //
static bool ends_with(const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}
static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
if (!text.empty() && !stop.empty()) {
const char text_last_char = text.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const std::string current_partial = stop.substr(0, char_index + 1);
if (ends_with(text, current_partial)) {
return text.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
// TODO: reuse llama_detokenize // TODO: reuse llama_detokenize
template <class Iter> template <class Iter>
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
@ -599,19 +579,16 @@ static json oaicompat_chat_params_parse(
json llama_params; json llama_params;
auto tools = json_value(body, "tools", json()); auto tools = json_value(body, "tools", json());
auto has_tools = tools.is_array() && !tools.empty();
auto stream = json_value(body, "stream", false); auto stream = json_value(body, "stream", false);
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
if (tools.is_array() && !tools.empty()) { if (!opt.use_jinja) {
if (stream) { if (has_tools) {
throw std::runtime_error("Cannot use tools with stream");
}
if (!opt.use_jinja) {
throw std::runtime_error("tools param requires --jinja flag"); throw std::runtime_error("tools param requires --jinja flag");
} }
} if (tool_choice != "auto") {
if (!opt.use_jinja) { throw std::runtime_error("tool_choice param requires --jinja flag");
if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
throw std::runtime_error("Unsupported param: tool_choice");
} }
} }
@ -749,14 +726,12 @@ static json oaicompat_chat_params_parse(
common_chat_templates_inputs inputs; common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(messages); inputs.messages = common_chat_msgs_parse_oaicompat(messages);
inputs.tools = common_chat_tools_parse_oaicompat(tools); inputs.tools = common_chat_tools_parse_oaicompat(tools);
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto"))); inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
inputs.grammar = grammar; inputs.grammar = grammar;
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
inputs.use_jinja = opt.use_jinja; inputs.use_jinja = opt.use_jinja;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
inputs.extract_reasoning = opt.reasoning_format != COMMON_REASONING_FORMAT_NONE; inputs.reasoning_format = opt.reasoning_format;
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) { if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools."); throw std::runtime_error("Cannot use custom grammar constraints with tools.");
} }
@ -774,7 +749,8 @@ static json oaicompat_chat_params_parse(
throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list."); throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
} }
inputs.extract_reasoning = false; /* TODO: test this properly */
inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
inputs.add_generation_prompt = true; inputs.add_generation_prompt = true;
} }
@ -799,6 +775,7 @@ static json oaicompat_chat_params_parse(
} }
llama_params["grammar_triggers"] = grammar_triggers; llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens; llama_params["preserved_tokens"] = chat_params.preserved_tokens;
llama_params["thinking_forced_open"] = chat_params.thinking_forced_open;
for (const auto & stop : chat_params.additional_stops) { for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop); llama_params["stop"].push_back(stop);
} }
@ -812,6 +789,9 @@ static json oaicompat_chat_params_parse(
// Handle "logprobs" field // Handle "logprobs" field
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
if (json_value(body, "logprobs", false)) { if (json_value(body, "logprobs", false)) {
if (has_tools && stream) {
throw std::runtime_error("logprobs is not supported with tools + stream");
}
llama_params["n_probs"] = json_value(body, "top_logprobs", 20); llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
throw std::runtime_error("top_logprobs requires logprobs to be set to true"); throw std::runtime_error("top_logprobs requires logprobs to be set to true");