mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 19:55:04 +00:00
server
: streaming of tool calls and thoughts when --jinja
is on (#12379)
* add common_json w/ support for truncated json healing * add common_chat_msg_diff * partial common_chat_parse * refactor parser w/ optionals * server: wire chat diffs in stream mode * fix trigger of thinking models (must happen after thoughts are closed) * fix functionary v3.2 raw python! * rename: common_chat_syntax (now contains format) * rm common_regex.at_start * don't return empty <think></think> * accommodate yet another deepseek r1 distill fantasy syntax (`<|tool▁calls|>`) * fix QwQ 32B tool call parsing after thoughts (hermes2) * better logs for grammar triggers * consume spaces after parse_json_tool_calls * fix required tool calls w/ thinking models that have pre-opened thinking tags * fix thinking model's initial trigger + test qwq's template * run most test_tool_call tests in stream + non-stream modes * make functionary v3.2 parsing more strict (differentiate first match from others) * send final diff from server, to close off raw python arguments * support partial content streaming in Generic mode * tool-call: allow content prelude before hermes2 tool calls (for Qwen2.5) * Update function-calling.md * Update tool_bench.py * chat-parser: remove input from exception (llm output may contain PII) --------- Co-authored-by: ochafik <ochafik@google.com> Co-authored-by: Olivier Chafik <ochafik@users.noreply.github.com>
This commit is contained in:
@ -60,12 +60,16 @@ add_library(${TARGET} STATIC
|
|||||||
base64.hpp
|
base64.hpp
|
||||||
chat.cpp
|
chat.cpp
|
||||||
chat.h
|
chat.h
|
||||||
|
chat-parser.cpp
|
||||||
|
chat-parser.h
|
||||||
common.cpp
|
common.cpp
|
||||||
common.h
|
common.h
|
||||||
console.cpp
|
console.cpp
|
||||||
console.h
|
console.h
|
||||||
json-schema-to-grammar.cpp
|
json-schema-to-grammar.cpp
|
||||||
json.hpp
|
json.hpp
|
||||||
|
json-partial.h
|
||||||
|
json-partial.cpp
|
||||||
llguidance.cpp
|
llguidance.cpp
|
||||||
log.cpp
|
log.cpp
|
||||||
log.h
|
log.h
|
||||||
|
376
common/chat-parser.cpp
Normal file
376
common/chat-parser.cpp
Normal file
@ -0,0 +1,376 @@
|
|||||||
|
#include "chat-parser.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include "regex-partial.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
|
||||||
|
: input_(input), is_partial_(is_partial), syntax_(syntax)
|
||||||
|
{
|
||||||
|
result_.role = "assistant";
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
std::string id = std::to_string(std::rand());
|
||||||
|
if (input.find(id) == std::string::npos) {
|
||||||
|
healing_marker_ = id;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string common_chat_msg_parser::str(const common_string_range & rng) const {
|
||||||
|
GGML_ASSERT(rng.begin <= rng.end);
|
||||||
|
return input_.substr(rng.begin, rng.end - rng.begin);
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_chat_msg_parser::add_content(const std::string &content) {
|
||||||
|
result_.content += content;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) {
|
||||||
|
result_.reasoning_content += reasoning_content;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) {
|
||||||
|
if (name.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_tool_call tool_call;
|
||||||
|
tool_call.name = name;
|
||||||
|
tool_call.arguments = arguments;
|
||||||
|
tool_call.id = id;
|
||||||
|
|
||||||
|
// LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str());
|
||||||
|
result_.tool_calls.emplace_back(tool_call);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool common_chat_msg_parser::add_tool_call(const json & tool_call) {
|
||||||
|
std::string name = tool_call.contains("name") ? tool_call.at("name") : "";
|
||||||
|
std::string id = tool_call.contains("id") ? tool_call.at("id") : "";
|
||||||
|
std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : "";
|
||||||
|
return add_tool_call(name, id, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_chat_msg_parser::add_tool_calls(const json & arr) {
|
||||||
|
for (const auto & item : arr) {
|
||||||
|
if (!add_tool_call(item)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
void common_chat_msg_parser::finish() {
|
||||||
|
if (!is_partial_ && pos_ != input_.size()) {
|
||||||
|
throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_chat_msg_parser::consume_spaces() {
|
||||||
|
const auto length = input_.size();
|
||||||
|
auto consumed = false;
|
||||||
|
while (pos_ < length && std::isspace(input_[pos_])) {
|
||||||
|
++pos_;
|
||||||
|
consumed = true;
|
||||||
|
}
|
||||||
|
return consumed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
|
||||||
|
auto pos = pos_;
|
||||||
|
for (auto i = 0u; i < literal.size(); ++i) {
|
||||||
|
if (pos >= input_.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (input_[pos] != literal[i]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
++pos;
|
||||||
|
}
|
||||||
|
pos_ = pos;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
|
||||||
|
auto idx = input_.find(literal, pos_);
|
||||||
|
if (idx != std::string::npos) {
|
||||||
|
find_regex_result res;
|
||||||
|
res.prelude = input_.substr(pos_, idx - pos_);
|
||||||
|
auto end = idx + literal.size();
|
||||||
|
res.groups.emplace_back(common_string_range{idx, end});
|
||||||
|
move_to(end);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
if (is_partial_) {
|
||||||
|
idx = string_find_partial_stop(input_, literal);
|
||||||
|
if (idx != std::string::npos && idx >= pos_) {
|
||||||
|
find_regex_result res;
|
||||||
|
res.prelude = input_.substr(pos_, idx - pos_);
|
||||||
|
auto end = input_.size();
|
||||||
|
res.groups.emplace_back(common_string_range{idx, end});
|
||||||
|
move_to(end);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
void common_chat_msg_parser::consume_literal(const std::string & literal) {
|
||||||
|
if (!try_consume_literal(literal)) {
|
||||||
|
throw common_chat_msg_partial_exception(literal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) {
|
||||||
|
auto handle_reasoning = [&](const std::string & reasoning, bool closed) {
|
||||||
|
auto stripped_reasoning = string_strip(reasoning);
|
||||||
|
if (stripped_reasoning.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (syntax_.reasoning_in_content) {
|
||||||
|
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "<think>" : start_think);
|
||||||
|
add_content(stripped_reasoning);
|
||||||
|
if (closed) {
|
||||||
|
add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "</think>" : end_think);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
add_reasoning_content(stripped_reasoning);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
|
||||||
|
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
|
||||||
|
if (auto res = try_find_literal(end_think)) {
|
||||||
|
handle_reasoning(res->prelude, /* closed */ true);
|
||||||
|
consume_spaces();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto rest = consume_rest();
|
||||||
|
if (!rest.empty()) {
|
||||||
|
handle_reasoning(rest, /* closed */ !is_partial());
|
||||||
|
}
|
||||||
|
if (!syntax_.thinking_forced_open) {
|
||||||
|
throw common_chat_msg_partial_exception(end_think);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string common_chat_msg_parser::consume_rest() {
|
||||||
|
auto rest = input_.substr(pos_);
|
||||||
|
pos_ = input_.size();
|
||||||
|
return rest;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback.
|
||||||
|
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from) {
|
||||||
|
auto m = regex.search(input_, from == std::string::npos ? pos_ : from);
|
||||||
|
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
|
||||||
|
if (is_partial()) {
|
||||||
|
throw common_chat_msg_partial_exception(regex.str());
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
auto prelude = input_.substr(pos_, m.groups[0].begin - pos_);
|
||||||
|
pos_ = m.groups[0].end;
|
||||||
|
|
||||||
|
return find_regex_result{prelude, m.groups};
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) {
|
||||||
|
if (auto result = try_consume_regex(regex)) {
|
||||||
|
return *result;
|
||||||
|
}
|
||||||
|
throw common_chat_msg_partial_exception(regex.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_consume_regex(const common_regex & regex) {
|
||||||
|
auto m = regex.search(input_, pos_);
|
||||||
|
if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) {
|
||||||
|
if (is_partial()) {
|
||||||
|
throw common_chat_msg_partial_exception(regex.str());
|
||||||
|
}
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
if (m.groups[0].begin != pos_) {
|
||||||
|
// Didn't match at the current position.
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
pos_ = m.groups[0].end;
|
||||||
|
|
||||||
|
return find_regex_result {
|
||||||
|
/* .prelude = */ "",
|
||||||
|
m.groups,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<common_json> common_chat_msg_parser::try_consume_json() {
|
||||||
|
auto it = input_.cbegin() + pos_;
|
||||||
|
const auto end = input_.cend();
|
||||||
|
common_json result;
|
||||||
|
if (!common_json_parse(it, end, healing_marker_, result)) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
pos_ = std::distance(input_.cbegin(), it);
|
||||||
|
if (result.healing_marker.marker.empty()) {
|
||||||
|
// No healing marker, just return the parsed json
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
if (!is_partial()) {
|
||||||
|
throw common_chat_msg_partial_exception("JSON");
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
common_json common_chat_msg_parser::consume_json() {
|
||||||
|
if (auto result = try_consume_json()) {
|
||||||
|
return *result;
|
||||||
|
}
|
||||||
|
throw common_chat_msg_partial_exception("JSON");
|
||||||
|
}
|
||||||
|
|
||||||
|
common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args(
|
||||||
|
const std::vector<std::vector<std::string>> & args_paths,
|
||||||
|
const std::vector<std::vector<std::string>> & content_paths
|
||||||
|
) {
|
||||||
|
if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) {
|
||||||
|
return *result;
|
||||||
|
}
|
||||||
|
throw common_chat_msg_partial_exception("JSON");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parser::try_consume_json_with_dumped_args(
|
||||||
|
const std::vector<std::vector<std::string>> & args_paths,
|
||||||
|
const std::vector<std::vector<std::string>> & content_paths
|
||||||
|
) {
|
||||||
|
auto partial = try_consume_json();
|
||||||
|
if (!partial) {
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
auto is_arguments_path = [&](const std::vector<std::string> & path) {
|
||||||
|
return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end();
|
||||||
|
};
|
||||||
|
auto is_content_path = [&](const std::vector<std::string> & path) {
|
||||||
|
return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (partial->healing_marker.marker.empty()) {
|
||||||
|
if (args_paths.empty()) {
|
||||||
|
// No arguments to dump, and JSON was parsed fully.
|
||||||
|
return consume_json_result {
|
||||||
|
partial->json,
|
||||||
|
/* .is_partial = */ false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if (is_arguments_path({})) {
|
||||||
|
// Entire JSON is the arguments and was parsed fully.
|
||||||
|
return consume_json_result {
|
||||||
|
partial->json.dump(),
|
||||||
|
/* .is_partial = */ false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
|
||||||
|
|
||||||
|
auto found_healing_marker = false;
|
||||||
|
std::vector<std::string> path;
|
||||||
|
std::function<json(const json &)> remove_unsupported_healings_and_dump_args = [&](const json & j) -> json {
|
||||||
|
if (is_arguments_path(path)) {
|
||||||
|
auto arguments = j.dump();
|
||||||
|
if (is_partial() && !partial->healing_marker.marker.empty()) {
|
||||||
|
auto idx = arguments.find(partial->healing_marker.json_dump_marker);
|
||||||
|
if (idx != std::string::npos) {
|
||||||
|
arguments.resize(idx);
|
||||||
|
found_healing_marker = true;
|
||||||
|
}
|
||||||
|
if (arguments == "\"") {
|
||||||
|
// This happens because of completing `:"$magic` after `"arguments"`
|
||||||
|
arguments = "";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return arguments;
|
||||||
|
}
|
||||||
|
if (is_content_path(path)) {
|
||||||
|
if (!j.is_string()) {
|
||||||
|
throw std::runtime_error("Content path must be a string");
|
||||||
|
}
|
||||||
|
std::string str = j;
|
||||||
|
auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string
|
||||||
|
if (idx != std::string::npos) {
|
||||||
|
str.resize(idx);
|
||||||
|
found_healing_marker = true;
|
||||||
|
}
|
||||||
|
return str;
|
||||||
|
}
|
||||||
|
if (j.is_object()) {
|
||||||
|
auto obj = json::object();
|
||||||
|
for (const auto & p : j.items()) {
|
||||||
|
const auto & key = p.key();
|
||||||
|
const auto & value = p.value();
|
||||||
|
const std::string key_str = key; // NOLINT
|
||||||
|
auto idx = key_str.find(healing_marker_);
|
||||||
|
if (idx != std::string::npos) {
|
||||||
|
found_healing_marker = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
path.push_back(key_str);
|
||||||
|
if (value.is_string()) {
|
||||||
|
const std::string value_str = value;
|
||||||
|
if (value_str.find(healing_marker_) != std::string::npos) {
|
||||||
|
found_healing_marker = true;
|
||||||
|
if (is_content_path(path)) {
|
||||||
|
if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) {
|
||||||
|
// The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair.
|
||||||
|
obj[key] = remove_unsupported_healings_and_dump_args(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
obj[key] = value;
|
||||||
|
} else {
|
||||||
|
obj[key] = remove_unsupported_healings_and_dump_args(value);
|
||||||
|
}
|
||||||
|
path.pop_back();
|
||||||
|
}
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
if (j.is_array()) {
|
||||||
|
auto arr = json::array();
|
||||||
|
for (const auto & value : j) {
|
||||||
|
if (value.is_string()) {
|
||||||
|
std::string str = value;
|
||||||
|
auto idx = str.find(healing_marker_);
|
||||||
|
if (idx != std::string::npos) {
|
||||||
|
// Don't heal array values that aren't in the arguments.
|
||||||
|
found_healing_marker = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
arr.push_back(remove_unsupported_healings_and_dump_args(value));
|
||||||
|
}
|
||||||
|
return arr;
|
||||||
|
}
|
||||||
|
return j;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto cleaned = remove_unsupported_healings_and_dump_args(partial->json);
|
||||||
|
LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str());
|
||||||
|
return consume_json_result {
|
||||||
|
cleaned,
|
||||||
|
/* .is_partial = */ found_healing_marker,
|
||||||
|
};
|
||||||
|
}
|
116
common/chat-parser.h
Normal file
116
common/chat-parser.h
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "chat.h"
|
||||||
|
#include "json-partial.h"
|
||||||
|
#include "json.hpp"
|
||||||
|
#include "regex-partial.h"
|
||||||
|
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
class common_chat_msg_partial_exception : public std::runtime_error {
|
||||||
|
public:
|
||||||
|
common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
class common_chat_msg_parser {
|
||||||
|
std::string input_;
|
||||||
|
bool is_partial_;
|
||||||
|
common_chat_syntax syntax_;
|
||||||
|
std::string healing_marker_;
|
||||||
|
|
||||||
|
size_t pos_ = 0;
|
||||||
|
common_chat_msg result_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||||
|
const std::string & input() const { return input_; }
|
||||||
|
size_t pos() const { return pos_; }
|
||||||
|
const std::string & healing_marker() const { return healing_marker_; }
|
||||||
|
const bool & is_partial() const { return is_partial_; }
|
||||||
|
const common_chat_msg & result() const { return result_; }
|
||||||
|
|
||||||
|
void move_to(size_t pos) {
|
||||||
|
if (pos > input_.size()) {
|
||||||
|
throw std::runtime_error("Invalid position!");
|
||||||
|
}
|
||||||
|
pos_ = pos;
|
||||||
|
}
|
||||||
|
void move_back(size_t n) {
|
||||||
|
if (pos_ < n) {
|
||||||
|
throw std::runtime_error("Can't move back that far!");
|
||||||
|
}
|
||||||
|
pos_ -= n;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the substring of the input at the given range
|
||||||
|
std::string str(const common_string_range & rng) const;
|
||||||
|
|
||||||
|
// Appends to the result.content field
|
||||||
|
void add_content(const std::string & content);
|
||||||
|
|
||||||
|
// Appends to the result.reasoning_content field
|
||||||
|
void add_reasoning_content(const std::string & reasoning_content);
|
||||||
|
|
||||||
|
// Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything.
|
||||||
|
bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments);
|
||||||
|
|
||||||
|
// Adds a tool call using the "name", "id" and "arguments" fields of the json object
|
||||||
|
bool add_tool_call(const nlohmann::ordered_json & tool_call);
|
||||||
|
|
||||||
|
// Adds an array of tool calls using their "name", "id" and "arguments" fields.
|
||||||
|
bool add_tool_calls(const nlohmann::ordered_json & arr);
|
||||||
|
|
||||||
|
void finish();
|
||||||
|
|
||||||
|
bool consume_spaces();
|
||||||
|
|
||||||
|
void consume_literal(const std::string & literal);
|
||||||
|
|
||||||
|
bool try_parse_reasoning(const std::string & start_think, const std::string & end_think);
|
||||||
|
|
||||||
|
std::string consume_rest();
|
||||||
|
|
||||||
|
struct find_regex_result {
|
||||||
|
std::string prelude;
|
||||||
|
std::vector<common_string_range> groups;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos);
|
||||||
|
|
||||||
|
bool try_consume_literal(const std::string & literal);
|
||||||
|
|
||||||
|
std::optional<find_regex_result> try_find_literal(const std::string & literal);
|
||||||
|
|
||||||
|
find_regex_result consume_regex(const common_regex & regex);
|
||||||
|
|
||||||
|
std::optional<find_regex_result> try_consume_regex(const common_regex & regex);
|
||||||
|
|
||||||
|
std::optional<common_json> try_consume_json();
|
||||||
|
common_json consume_json();
|
||||||
|
|
||||||
|
struct consume_json_result {
|
||||||
|
nlohmann::ordered_json value;
|
||||||
|
bool is_partial;
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings.
|
||||||
|
|
||||||
|
By default, object keys can't be truncated, nor can string values (their corresponding key is removed,
|
||||||
|
e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}`
|
||||||
|
|
||||||
|
But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings
|
||||||
|
- with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}`
|
||||||
|
- with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}`
|
||||||
|
*/
|
||||||
|
consume_json_result consume_json_with_dumped_args(
|
||||||
|
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||||
|
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||||
|
);
|
||||||
|
std::optional<consume_json_result> try_consume_json_with_dumped_args(
|
||||||
|
const std::vector<std::vector<std::string>> & args_paths = {},
|
||||||
|
const std::vector<std::vector<std::string>> & content_paths = {}
|
||||||
|
);
|
||||||
|
};
|
1126
common/chat.cpp
1126
common/chat.cpp
File diff suppressed because it is too large
Load Diff
@ -3,6 +3,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
#include <functional>
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -13,11 +14,19 @@ struct common_chat_tool_call {
|
|||||||
std::string name;
|
std::string name;
|
||||||
std::string arguments;
|
std::string arguments;
|
||||||
std::string id;
|
std::string id;
|
||||||
|
|
||||||
|
bool operator==(const common_chat_tool_call & other) const {
|
||||||
|
return name == other.name && arguments == other.arguments && id == other.id;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_chat_msg_content_part {
|
struct common_chat_msg_content_part {
|
||||||
std::string type;
|
std::string type;
|
||||||
std::string text;
|
std::string text;
|
||||||
|
|
||||||
|
bool operator==(const common_chat_msg_content_part & other) const {
|
||||||
|
return type == other.type && text == other.text;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_chat_msg {
|
struct common_chat_msg {
|
||||||
@ -28,6 +37,51 @@ struct common_chat_msg {
|
|||||||
std::string reasoning_content;
|
std::string reasoning_content;
|
||||||
std::string tool_name;
|
std::string tool_name;
|
||||||
std::string tool_call_id;
|
std::string tool_call_id;
|
||||||
|
|
||||||
|
template <class T> T to_json_oaicompat() const;
|
||||||
|
|
||||||
|
bool empty() const {
|
||||||
|
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
||||||
|
}
|
||||||
|
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
||||||
|
for (auto i = 0u; i < tool_calls.size(); i++) {
|
||||||
|
if (ids_cache.size() <= i) {
|
||||||
|
auto id = tool_calls[i].id;
|
||||||
|
if (id.empty()) {
|
||||||
|
id = gen_tool_call_id();
|
||||||
|
}
|
||||||
|
ids_cache.push_back(id);
|
||||||
|
}
|
||||||
|
tool_calls[i].id = ids_cache[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool operator==(const common_chat_msg & other) const {
|
||||||
|
return role == other.role
|
||||||
|
&& content == other.content
|
||||||
|
&& content_parts == other.content_parts
|
||||||
|
&& tool_calls == other.tool_calls
|
||||||
|
&& reasoning_content == other.reasoning_content
|
||||||
|
&& tool_name == other.tool_name
|
||||||
|
&& tool_call_id == other.tool_call_id;
|
||||||
|
}
|
||||||
|
bool operator!=(const common_chat_msg & other) const {
|
||||||
|
return !(*this == other);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_chat_msg_diff {
|
||||||
|
// std::string reasoning_content_delta;
|
||||||
|
std::string content_delta;
|
||||||
|
size_t tool_call_index = std::string::npos;
|
||||||
|
common_chat_tool_call tool_call_delta;
|
||||||
|
|
||||||
|
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
|
||||||
|
|
||||||
|
bool operator==(const common_chat_msg_diff & other) const {
|
||||||
|
return content_delta == other.content_delta
|
||||||
|
&& tool_call_index == other.tool_call_index
|
||||||
|
&& tool_call_delta == other.tool_call_delta;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_chat_tool {
|
struct common_chat_tool {
|
||||||
@ -49,14 +103,11 @@ enum common_chat_format {
|
|||||||
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
||||||
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
||||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
||||||
COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING,
|
|
||||||
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
||||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
||||||
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
||||||
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
||||||
COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING,
|
|
||||||
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
||||||
COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING,
|
|
||||||
|
|
||||||
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
||||||
};
|
};
|
||||||
@ -71,7 +122,7 @@ struct common_chat_templates_inputs {
|
|||||||
std::vector<common_chat_tool> tools;
|
std::vector<common_chat_tool> tools;
|
||||||
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
||||||
bool parallel_tool_calls = false;
|
bool parallel_tool_calls = false;
|
||||||
bool extract_reasoning = true;
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||||
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -80,11 +131,20 @@ struct common_chat_params {
|
|||||||
std::string prompt;
|
std::string prompt;
|
||||||
std::string grammar;
|
std::string grammar;
|
||||||
bool grammar_lazy = false;
|
bool grammar_lazy = false;
|
||||||
|
bool thinking_forced_open = false;
|
||||||
std::vector<common_grammar_trigger> grammar_triggers;
|
std::vector<common_grammar_trigger> grammar_triggers;
|
||||||
std::vector<std::string> preserved_tokens;
|
std::vector<std::string> preserved_tokens;
|
||||||
std::vector<std::string> additional_stops;
|
std::vector<std::string> additional_stops;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct common_chat_syntax {
|
||||||
|
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||||
|
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
||||||
|
bool reasoning_in_content = false;
|
||||||
|
bool thinking_forced_open = false;
|
||||||
|
};
|
||||||
|
|
||||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||||
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
||||||
|
|
||||||
@ -122,7 +182,7 @@ std::string common_chat_format_example(
|
|||||||
bool use_jinja);
|
bool use_jinja);
|
||||||
|
|
||||||
std::string common_chat_format_name(common_chat_format format);
|
std::string common_chat_format_name(common_chat_format format);
|
||||||
common_chat_msg common_chat_parse( const std::string & input, common_chat_format format);
|
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
||||||
|
|
||||||
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
||||||
|
|
||||||
@ -135,3 +195,5 @@ template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common
|
|||||||
// T can be std::string containing JSON or nlohmann::ordered_json
|
// T can be std::string containing JSON or nlohmann::ordered_json
|
||||||
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
||||||
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
||||||
|
|
||||||
|
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|
||||||
|
@ -115,7 +115,7 @@ enum common_grammar_trigger_type {
|
|||||||
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
|
||||||
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
|
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct common_grammar_trigger {
|
struct common_grammar_trigger {
|
||||||
|
255
common/json-partial.cpp
Normal file
255
common/json-partial.cpp
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
#include <json-partial.h>
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include <json.hpp>
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
enum common_json_stack_element_type {
|
||||||
|
COMMON_JSON_STACK_ELEMENT_OBJECT,
|
||||||
|
COMMON_JSON_STACK_ELEMENT_KEY,
|
||||||
|
COMMON_JSON_STACK_ELEMENT_ARRAY,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct common_json_stack_element {
|
||||||
|
common_json_stack_element_type type;
|
||||||
|
std::string key;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool common_json_parse(
|
||||||
|
const std::string & input,
|
||||||
|
const std::string & healing_marker,
|
||||||
|
common_json & out)
|
||||||
|
{
|
||||||
|
std::string::const_iterator it = input.begin();
|
||||||
|
const auto end = input.end();
|
||||||
|
return common_json_parse(it, end, healing_marker, out);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool common_json_parse(
|
||||||
|
std::string::const_iterator & it,
|
||||||
|
const std::string::const_iterator & end,
|
||||||
|
const std::string & healing_marker,
|
||||||
|
common_json & out)
|
||||||
|
{
|
||||||
|
// // https://json.nlohmann.me/features/parsing/sax_interface/
|
||||||
|
struct json_error_locator : public nlohmann::json_sax<json> {
|
||||||
|
std::size_t position;
|
||||||
|
bool found_error;
|
||||||
|
std::string last_token;
|
||||||
|
std::string exception_message;
|
||||||
|
std::vector<common_json_stack_element> stack;
|
||||||
|
|
||||||
|
json_error_locator() : position(0), found_error(false) {}
|
||||||
|
|
||||||
|
bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
|
||||||
|
this->position = position - 1;
|
||||||
|
this->found_error = true;
|
||||||
|
this->last_token = last_token;
|
||||||
|
this->exception_message = ex.what();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
void close_value() {
|
||||||
|
if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
|
||||||
|
stack.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bool null() override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool boolean(bool) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool number_integer(number_integer_t) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool number_unsigned(number_unsigned_t) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool number_float(number_float_t, const string_t &) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool string(string_t &) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool binary(binary_t &) override { // NOLINT
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool start_object(std::size_t) override { // NOLINT
|
||||||
|
stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool end_object() override {
|
||||||
|
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
|
||||||
|
stack.pop_back();
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool key(string_t & key) override { // NOLINT
|
||||||
|
stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool start_array(std::size_t) override { // NOLINT
|
||||||
|
stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool end_array() override {
|
||||||
|
GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
|
||||||
|
stack.pop_back();
|
||||||
|
close_value();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
json_error_locator err_loc;
|
||||||
|
auto start = it;
|
||||||
|
json::sax_parse(it, end, &err_loc);
|
||||||
|
|
||||||
|
if (err_loc.found_error) {
|
||||||
|
it = start;
|
||||||
|
auto temptative_end = it + err_loc.position;
|
||||||
|
// LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
|
||||||
|
|
||||||
|
auto input = std::string(it, temptative_end);
|
||||||
|
try {
|
||||||
|
out.json = json::parse(input);
|
||||||
|
// out.json = json::parse(it, temptative_end);
|
||||||
|
it = temptative_end;
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception & ex) {
|
||||||
|
// No, needs healing.
|
||||||
|
LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
|
||||||
|
}
|
||||||
|
auto can_parse = [](const std::string & str) {
|
||||||
|
try {
|
||||||
|
auto _ = json::parse(str); // NOLINT
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception &) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if (!healing_marker.empty() && !err_loc.stack.empty()) {
|
||||||
|
std::string str(it, temptative_end);
|
||||||
|
auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
|
||||||
|
if (last_non_sp_pos == std::string::npos) {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||||
|
}
|
||||||
|
auto last_non_sp_char = str[last_non_sp_pos];
|
||||||
|
// Used to detect stops on a number, which may not be complete.
|
||||||
|
auto was_maybe_number = [&]() {
|
||||||
|
if (!str.empty() && std::isspace(str.back())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return std::isdigit(last_non_sp_char) ||
|
||||||
|
last_non_sp_char == '.' ||
|
||||||
|
last_non_sp_char == 'e' ||
|
||||||
|
last_non_sp_char == 'E' ||
|
||||||
|
last_non_sp_char == '-';
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string closing;
|
||||||
|
for (size_t i = err_loc.stack.size(); i > 0; i--) {
|
||||||
|
auto & el = err_loc.stack[i - 1];
|
||||||
|
if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||||
|
closing += "}";
|
||||||
|
} else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||||
|
closing += "]";
|
||||||
|
} else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||||
|
throw std::runtime_error("Unexpected stack element type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
|
||||||
|
|
||||||
|
if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
|
||||||
|
// We're inside an object value
|
||||||
|
if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
|
||||||
|
// Was about to create an object value
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
} else if (can_parse(str + ": 1" + closing)) {
|
||||||
|
str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
|
||||||
|
} else if (last_non_sp_char == '{' && can_parse(str + closing)) {
|
||||||
|
// Was about to create an object
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||||
|
} else if (can_parse(str + "\"" + closing)) {
|
||||||
|
// Was inside an object value string
|
||||||
|
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||||
|
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||||
|
// Was inside an object value string after an escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||||
|
} else {
|
||||||
|
// find last :
|
||||||
|
auto last_pos = str.find_last_of(':');
|
||||||
|
if (last_pos == std::string::npos) {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
|
||||||
|
}
|
||||||
|
// Cutting back to opening : for object value
|
||||||
|
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
}
|
||||||
|
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
|
||||||
|
if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
|
||||||
|
// Was about to create an array value
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
} else if (can_parse(str + "\"" + closing)) {
|
||||||
|
// Was inside an array value string
|
||||||
|
str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
|
||||||
|
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
|
||||||
|
// Was inside an array value string after an escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
|
||||||
|
} else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
|
||||||
|
// Had just finished a value
|
||||||
|
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
|
||||||
|
} else {
|
||||||
|
auto last_pos = str.find_last_of("[,");
|
||||||
|
if (last_pos == std::string::npos) {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
|
||||||
|
}
|
||||||
|
// Cutting back to last [ or , for array value
|
||||||
|
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
}
|
||||||
|
} else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
|
||||||
|
if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
|
||||||
|
(last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
|
||||||
|
// Was about to create an object key+value
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
|
||||||
|
} else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
|
||||||
|
// Was about to create an object key+value
|
||||||
|
str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
|
||||||
|
} else if (can_parse(str + "\": 1" + closing)) {
|
||||||
|
// Was inside an object key string
|
||||||
|
str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
|
||||||
|
} else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
|
||||||
|
// Was inside an object key string after an escape
|
||||||
|
str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
|
||||||
|
} else {
|
||||||
|
auto last_pos = str.find_last_of(':');
|
||||||
|
if (last_pos == std::string::npos) {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||||
|
}
|
||||||
|
// fprintf(stderr, "Cutting back to last : for object key+value\n");
|
||||||
|
str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
|
||||||
|
}
|
||||||
|
// fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
|
||||||
|
out.json = json::parse(str);
|
||||||
|
it = temptative_end;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
|
||||||
|
// fprintf(stderr, "Closing: TODO\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
out.json = json::parse(it, end);
|
||||||
|
it = end;
|
||||||
|
return true;
|
||||||
|
}
|
37
common/json-partial.h
Normal file
37
common/json-partial.h
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <json.hpp>
|
||||||
|
|
||||||
|
// Healing marker (empty if the JSON was fully parsed / wasn't healed).
|
||||||
|
struct common_healing_marker {
|
||||||
|
// Raw marker.
|
||||||
|
std::string marker;
|
||||||
|
|
||||||
|
// Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format).
|
||||||
|
std::string json_dump_marker;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string)
|
||||||
|
struct common_json {
|
||||||
|
nlohmann::ordered_json json;
|
||||||
|
|
||||||
|
common_healing_marker healing_marker;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty.
|
||||||
|
//
|
||||||
|
// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON.
|
||||||
|
// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker.
|
||||||
|
// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format).
|
||||||
|
//
|
||||||
|
// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again).
|
||||||
|
bool common_json_parse(
|
||||||
|
const std::string & input,
|
||||||
|
const std::string & healing_marker,
|
||||||
|
common_json & out);
|
||||||
|
|
||||||
|
// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds.
|
||||||
|
bool common_json_parse(
|
||||||
|
std::string::const_iterator & it,
|
||||||
|
const std::string::const_iterator & end,
|
||||||
|
const std::string & healing_marker,
|
||||||
|
common_json & out);
|
@ -161,7 +161,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||||
#endif // LLAMA_USE_LLGUIDANCE
|
#endif // LLAMA_USE_LLGUIDANCE
|
||||||
} else {
|
} else {
|
||||||
std::vector<std::string> patterns_at_start;
|
std::vector<std::string> trigger_patterns;
|
||||||
std::vector<std::string> patterns_anywhere;
|
std::vector<std::string> patterns_anywhere;
|
||||||
std::vector<llama_token> trigger_tokens;
|
std::vector<llama_token> trigger_tokens;
|
||||||
for (const auto & trigger : params.grammar_triggers) {
|
for (const auto & trigger : params.grammar_triggers) {
|
||||||
@ -173,10 +173,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
|
|
||||||
{
|
{
|
||||||
const auto & pattern = trigger.value;
|
patterns_anywhere.push_back(trigger.value);
|
||||||
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
|
break;
|
||||||
|
}
|
||||||
|
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
|
||||||
|
{
|
||||||
|
trigger_patterns.push_back(trigger.value);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
|
||||||
@ -190,10 +193,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> trigger_patterns;
|
|
||||||
if (!patterns_at_start.empty()) {
|
|
||||||
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
|
|
||||||
}
|
|
||||||
if (!patterns_anywhere.empty()) {
|
if (!patterns_anywhere.empty()) {
|
||||||
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
|
||||||
}
|
}
|
||||||
|
@ -325,36 +325,65 @@ To get the official template from original HuggingFace repos, you can use [scrip
|
|||||||
> [!TIP]
|
> [!TIP]
|
||||||
> If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills)
|
> If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills)
|
||||||
|
|
||||||
|
> [!CAUTION]
|
||||||
|
> Beware of extreme KV quantizations (e.g. `-ctk q4_0`), they can substantially degrade the model's tool calling performance.
|
||||||
|
|
||||||
Test in CLI (or with any library / software that can use OpenAI-compatible API backends):
|
Test in CLI (or with any library / software that can use OpenAI-compatible API backends):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl http://localhost:8080/v1/chat/completions -d '{
|
curl http://localhost:8080/v1/chat/completions -d '{
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
"tools": [
|
"tools": [
|
||||||
{
|
{
|
||||||
"type":"function",
|
"type":"function",
|
||||||
"function":{
|
"function":{
|
||||||
"name":"python",
|
"name":"python",
|
||||||
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
"description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
|
||||||
"parameters":{
|
"parameters":{
|
||||||
"type":"object",
|
"type":"object",
|
||||||
"properties":{
|
"properties":{
|
||||||
"code":{
|
"code":{
|
||||||
"type":"string",
|
"type":"string",
|
||||||
"description":"The code to run in the ipython interpreter."
|
"description":"The code to run in the ipython interpreter."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required":["code"]
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"required":["code"]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
],
|
||||||
],
|
"messages": [
|
||||||
"messages": [
|
{
|
||||||
{
|
"role": "user",
|
||||||
"role": "user",
|
"content": "Print a hello world message with python."
|
||||||
"content": "Print a hello world message with python."
|
}
|
||||||
}
|
]
|
||||||
]
|
}'
|
||||||
|
|
||||||
|
|
||||||
|
curl http://localhost:8080/v1/chat/completions -d '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
|
||||||
|
{"role": "user", "content": "What is the weather in Istanbul?"}
|
||||||
|
],
|
||||||
|
"tools": [{
|
||||||
|
"type":"function",
|
||||||
|
"function":{
|
||||||
|
"name":"get_current_weather",
|
||||||
|
"description":"Get the current weather in a given location",
|
||||||
|
"parameters":{
|
||||||
|
"type":"object",
|
||||||
|
"properties":{
|
||||||
|
"location":{
|
||||||
|
"type":"string",
|
||||||
|
"description":"The city and country/state, e.g. `San Francisco, CA`, or `Paris, France`"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required":["location"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
62
models/templates/Qwen-QwQ-32B.jinja
Normal file
62
models/templates/Qwen-QwQ-32B.jinja
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
{%- if tools %}
|
||||||
|
{{- '<|im_start|>system\n' }}
|
||||||
|
{%- if messages[0]['role'] == 'system' %}
|
||||||
|
{{- messages[0]['content'] }}
|
||||||
|
{%- else %}
|
||||||
|
{{- '' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||||
|
{%- for tool in tools %}
|
||||||
|
{{- "\n" }}
|
||||||
|
{{- tool | tojson }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||||
|
{%- else %}
|
||||||
|
{%- if messages[0]['role'] == 'system' %}
|
||||||
|
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
||||||
|
{%- elif message.role == "assistant" and not message.tool_calls %}
|
||||||
|
{%- set content = message.content %}
|
||||||
|
{%- if not loop.last %}
|
||||||
|
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
||||||
|
{%- elif message.role == "assistant" %}
|
||||||
|
{%- set content = message.content %}
|
||||||
|
{%- if not loop.last %}
|
||||||
|
{%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|im_start|>' + message.role }}
|
||||||
|
{%- if message.content %}
|
||||||
|
{{- '\n' + content }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- for tool_call in message.tool_calls %}
|
||||||
|
{%- if tool_call.function is defined %}
|
||||||
|
{%- set tool_call = tool_call.function %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '\n<tool_call>\n{"name": "' }}
|
||||||
|
{{- tool_call.name }}
|
||||||
|
{{- '", "arguments": ' }}
|
||||||
|
{{- tool_call.arguments | tojson }}
|
||||||
|
{{- '}\n</tool_call>' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- elif message.role == "tool" %}
|
||||||
|
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
|
||||||
|
{{- '<|im_start|>user' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '\n<tool_response>\n' }}
|
||||||
|
{{- message.content }}
|
||||||
|
{{- '\n</tool_response>' }}
|
||||||
|
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|im_start|>assistant\n<think>\n' }}
|
||||||
|
{%- endif %}
|
@ -19,4 +19,5 @@ These templates can be updated with the following commands:
|
|||||||
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
|
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
|
||||||
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
|
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
|
||||||
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
|
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
|
||||||
|
./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja
|
||||||
```
|
```
|
@ -12,6 +12,7 @@
|
|||||||
export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
|
export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
|
||||||
export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
|
export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
|
||||||
|
|
||||||
|
./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L
|
||||||
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M
|
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M
|
||||||
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b
|
./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b
|
||||||
|
|
||||||
@ -205,6 +206,7 @@ def run(
|
|||||||
model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
|
model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
|
||||||
hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
|
hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
|
||||||
chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
|
chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
|
||||||
|
chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None,
|
||||||
ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
|
ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
|
||||||
llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
|
llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
|
||||||
n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
|
n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
|
||||||
@ -229,6 +231,12 @@ def run(
|
|||||||
# n_ctx = 8192
|
# n_ctx = 8192
|
||||||
n_ctx = 2048
|
n_ctx = 2048
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
if hf is not None:
|
||||||
|
model = hf.split("/")[-1]
|
||||||
|
elif ollama is not None:
|
||||||
|
model = ollama
|
||||||
|
|
||||||
assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
|
assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
|
||||||
|
|
||||||
with output.open('a' if append else 'w') as output_file:
|
with output.open('a' if append else 'w') as output_file:
|
||||||
@ -320,6 +328,7 @@ def run(
|
|||||||
server.model_hf_repo = hf
|
server.model_hf_repo = hf
|
||||||
server.model_hf_file = None
|
server.model_hf_file = None
|
||||||
server.chat_template = chat_template
|
server.chat_template = chat_template
|
||||||
|
server.chat_template_file = chat_template_file
|
||||||
server.server_path = server_path
|
server.server_path = server_path
|
||||||
if port is not None:
|
if port is not None:
|
||||||
server.server_port = port
|
server.server_port = port
|
||||||
@ -335,6 +344,7 @@ def run(
|
|||||||
temp=t,
|
temp=t,
|
||||||
output_kwargs=dict(
|
output_kwargs=dict(
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
|
chat_template_file=chat_template_file,
|
||||||
),
|
),
|
||||||
request_kwargs=dict(
|
request_kwargs=dict(
|
||||||
ignore_chat_grammar=ignore_chat_grammar,
|
ignore_chat_grammar=ignore_chat_grammar,
|
||||||
@ -355,6 +365,7 @@ def run(
|
|||||||
temp=t,
|
temp=t,
|
||||||
output_kwargs=dict(
|
output_kwargs=dict(
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
|
chat_template_file=None,
|
||||||
),
|
),
|
||||||
request_kwargs=dict(
|
request_kwargs=dict(
|
||||||
model=ollama,
|
model=ollama,
|
||||||
|
@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
|
|||||||
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
for (const auto & trigger_pattern : grammar.trigger_patterns) {
|
||||||
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) {
|
||||||
grammar.awaiting_trigger = false;
|
grammar.awaiting_trigger = false;
|
||||||
// get from the first match to the end of the string
|
// get from the first matched capturing group to the end of the string
|
||||||
auto constrained_str = grammar.trigger_buffer.substr(match.position(1));
|
size_t start = std::string::npos;
|
||||||
|
for (auto i = 1u; i < match.size(); i++) {
|
||||||
|
if (match.length(i) > 0) {
|
||||||
|
start = match.position(i);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (start == std::string::npos) {
|
||||||
|
start = match.position(0);
|
||||||
|
}
|
||||||
|
auto constrained_str = grammar.trigger_buffer.substr(start);
|
||||||
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
// std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
|
||||||
grammar.trigger_buffer.clear();
|
grammar.trigger_buffer.clear();
|
||||||
llama_grammar_accept_str(grammar, constrained_str);
|
llama_grammar_accept_str(grammar, constrained_str);
|
||||||
|
@ -142,8 +142,10 @@ if (NOT WIN32)
|
|||||||
# llama_build_and_test(test-double-float.cpp) # SLOW
|
# llama_build_and_test(test-double-float.cpp) # SLOW
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
llama_build_and_test(test-log.cpp)
|
llama_build_and_test(test-chat-parser.cpp)
|
||||||
llama_build_and_test(test-chat-template.cpp)
|
llama_build_and_test(test-chat-template.cpp)
|
||||||
|
llama_build_and_test(test-json-partial.cpp)
|
||||||
|
llama_build_and_test(test-log.cpp)
|
||||||
llama_build_and_test(test-regex-partial.cpp)
|
llama_build_and_test(test-regex-partial.cpp)
|
||||||
|
|
||||||
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
|
# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
|
||||||
|
355
tests/test-chat-parser.cpp
Normal file
355
tests/test-chat-parser.cpp
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
// Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
|
||||||
|
//
|
||||||
|
// Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
|
||||||
|
// e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
|
||||||
|
//
|
||||||
|
// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
|
||||||
|
//
|
||||||
|
#include <exception>
|
||||||
|
#include <iostream>
|
||||||
|
#include <json.hpp>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "chat-parser.h"
|
||||||
|
#include "common.h"
|
||||||
|
#include "log.h"
|
||||||
|
#include "regex-partial.h"
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
static void assert_equals(const T & expected, const T & actual) {
|
||||||
|
if (expected != actual) {
|
||||||
|
std::cerr << "Expected: " << expected << std::endl;
|
||||||
|
std::cerr << "Actual: " << actual << std::endl;
|
||||||
|
std::cerr << std::flush;
|
||||||
|
throw std::runtime_error("Test failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
static void assert_equals(const char * expected, const std::string & actual) {
|
||||||
|
return assert_equals<std::string>(expected, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void assert_throws(const std::function<void()> & fn, const std::string & expected_exception_pattern = "") {
|
||||||
|
try {
|
||||||
|
fn();
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
if (expected_exception_pattern.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::regex expected_exception_regex(expected_exception_pattern);
|
||||||
|
std::string actual_message = e.what();
|
||||||
|
if (std::regex_search(actual_message, expected_exception_regex)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
|
||||||
|
throw std::runtime_error("Exception of unexpected type: " + std::string(e.what()));
|
||||||
|
}
|
||||||
|
throw std::runtime_error("Exception was expected but not thrown");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_reasoning() {
|
||||||
|
{
|
||||||
|
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||||
|
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||||
|
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||||
|
/* .reasoning_in_content = */ false,
|
||||||
|
/* .thinking_forced_open = */ false,
|
||||||
|
});
|
||||||
|
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||||
|
assert_equals("<tnk>Cogito</tnk>Ergo sum", builder.consume_rest());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
common_chat_msg_parser builder("<tnk>Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||||
|
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||||
|
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||||
|
/* .reasoning_in_content = */ false,
|
||||||
|
/* .thinking_forced_open = */ false,
|
||||||
|
});
|
||||||
|
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||||
|
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
|
||||||
|
assert_equals("Ergo sum", builder.consume_rest());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||||
|
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||||
|
/* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
|
||||||
|
/* .reasoning_in_content = */ false,
|
||||||
|
/* .thinking_forced_open = */ false,
|
||||||
|
});
|
||||||
|
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||||
|
assert_equals("Cogito</tnk>Ergo sum", builder.consume_rest());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||||
|
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||||
|
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||||
|
/* .reasoning_in_content = */ false,
|
||||||
|
/* .thinking_forced_open = */ true,
|
||||||
|
});
|
||||||
|
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||||
|
assert_equals(std::string("Cogito"), builder.result().reasoning_content);
|
||||||
|
assert_equals("Ergo sum", builder.consume_rest());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
common_chat_msg_parser builder("Cogito</tnk>Ergo sum", /* is_partial= */ false, {
|
||||||
|
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
||||||
|
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
|
||||||
|
/* .reasoning_in_content = */ true,
|
||||||
|
/* .thinking_forced_open = */ true,
|
||||||
|
});
|
||||||
|
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
|
||||||
|
assert_equals("<think>Cogito</think>", builder.result().content);
|
||||||
|
assert_equals("Ergo sum", builder.consume_rest());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_regex() {
|
||||||
|
auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
|
||||||
|
common_chat_msg_parser builder(input, /* is_partial= */ false, {});
|
||||||
|
assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
|
||||||
|
};
|
||||||
|
|
||||||
|
test_throws("Hello, world!", "abc", "^abc$");
|
||||||
|
test_throws("Hello, world!", "e", "^e$");
|
||||||
|
|
||||||
|
{
|
||||||
|
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
|
||||||
|
builder.consume_regex(common_regex("Hello"));
|
||||||
|
assert_equals(", world!", builder.consume_rest());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// When in non partial mode, we can say whether the regex was consumed or not.
|
||||||
|
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
|
||||||
|
assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
|
||||||
|
auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?"));
|
||||||
|
assert_equals(true, res.has_value());
|
||||||
|
// Verify captures
|
||||||
|
assert_equals<size_t>(2, res->groups.size());
|
||||||
|
assert_equals("Hell", builder.str(res->groups[0]));
|
||||||
|
assert_equals("el", builder.str(res->groups[1]));
|
||||||
|
// Verify position is after the match
|
||||||
|
assert_equals<size_t>(4, builder.pos());
|
||||||
|
assert_equals("o,", builder.consume_rest());
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
|
||||||
|
common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
|
||||||
|
assert_throws([&]() {
|
||||||
|
builder.try_consume_regex(common_regex("Hello, world!"));
|
||||||
|
}, "^Hello, world!$");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now regardless of the mode, we can tell these aren't a match.
|
||||||
|
for (const auto is_partial : {false, true}) {
|
||||||
|
common_chat_msg_parser builder("Hello,", is_partial, {});
|
||||||
|
assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
|
||||||
|
}
|
||||||
|
for (const auto is_partial : {false, true}) {
|
||||||
|
common_chat_msg_parser builder("Hello,", is_partial, {});
|
||||||
|
assert_equals(false, builder.try_consume_literal("Oh"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<std::string> barely_healable_jsons = {
|
||||||
|
"{",
|
||||||
|
"{\"",
|
||||||
|
"{\"\\",
|
||||||
|
"{\"n",
|
||||||
|
"{\"name\"",
|
||||||
|
"{\"name\":",
|
||||||
|
"{\"name\":\"",
|
||||||
|
"{\"name\":\"\\",
|
||||||
|
"{\"name\":\"python",
|
||||||
|
"{\"name\":\"python\\",
|
||||||
|
"{\",",
|
||||||
|
"{\":",
|
||||||
|
"{\"[",
|
||||||
|
"{\"]",
|
||||||
|
"{\"{",
|
||||||
|
"{\"}",
|
||||||
|
"{\"1",
|
||||||
|
"{\"name\":\",",
|
||||||
|
"{\"name\":\":",
|
||||||
|
"{\"name\":\"[",
|
||||||
|
"{\"name\":\"]",
|
||||||
|
"{\"name\":\"{",
|
||||||
|
"{\"name\":\"}",
|
||||||
|
"{\"name\":\"1",
|
||||||
|
};
|
||||||
|
|
||||||
|
static void test(const std::string & input, bool is_partial, const std::vector<std::vector<std::string>> & args_paths, const std::vector<std::vector<std::string>> & content_paths, const std::string & expected) {
|
||||||
|
common_chat_msg_parser builder(input, is_partial, {});
|
||||||
|
auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths);
|
||||||
|
assert_equals(true, js.has_value());
|
||||||
|
assert_equals(is_partial, js->is_partial);
|
||||||
|
assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get<std::string>() : js->value.dump());
|
||||||
|
}
|
||||||
|
static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) {
|
||||||
|
common_chat_msg_parser builder(input, parse_as_partial, {});
|
||||||
|
auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {});
|
||||||
|
assert_equals(true, js.has_value());
|
||||||
|
assert_equals(is_partial, js->is_partial);
|
||||||
|
assert_equals(expected, js->value.dump());
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_json_with_dumped_args_no_args() {
|
||||||
|
// Normal JSON, nothing to heal, nothing to dump
|
||||||
|
test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}");
|
||||||
|
// Full json is args
|
||||||
|
test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}");
|
||||||
|
|
||||||
|
// If the arguments are further down, don't heal partial content.
|
||||||
|
for (const auto & src : barely_healable_jsons) {
|
||||||
|
test(src, true, {{"arguments"}}, {}, "{}");
|
||||||
|
}
|
||||||
|
// But heal content that isn't partial.
|
||||||
|
test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_json_with_dumped_args() {
|
||||||
|
|
||||||
|
// Partial content.
|
||||||
|
test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}");
|
||||||
|
test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}");
|
||||||
|
test("{\"content\": ", true, {}, {{"content"}}, "{}");
|
||||||
|
|
||||||
|
// If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted).
|
||||||
|
test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python");
|
||||||
|
for (const auto & src : barely_healable_jsons) {
|
||||||
|
test(src, true, {{}}, {}, src);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Full JSON w/ args
|
||||||
|
for (auto parse_as_partial : {true, false}) {
|
||||||
|
test_with_args(
|
||||||
|
R"({"name": "python", "args": {"arg1": 1}})",
|
||||||
|
R"({"name":"python","args":"{\"arg1\":1}"})",
|
||||||
|
parse_as_partial,
|
||||||
|
/* is_partial= */ false
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Partial JSON w/ partial args
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": {")",
|
||||||
|
R"({"foo":"bar","args":"{\""})"
|
||||||
|
);
|
||||||
|
// Partial args broken in object key
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": {"ar)",
|
||||||
|
R"({"foo":"bar","args":"{\"ar"})"
|
||||||
|
);
|
||||||
|
// Partial args broken after object key
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": {"arg1")",
|
||||||
|
R"({"foo":"bar","args":"{\"arg1\""})"
|
||||||
|
);
|
||||||
|
// Partial args broken before object value
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": {"arg1":)",
|
||||||
|
R"({"foo":"bar","args":"{\"arg1\":"})"
|
||||||
|
);
|
||||||
|
// Partial args broken before object value (space)
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": {"arg1": )",
|
||||||
|
R"({"foo":"bar","args":"{\"arg1\":"})"
|
||||||
|
);
|
||||||
|
// Partial args broken in object value that may not be complete (int)
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": {"arg1": 1)",
|
||||||
|
R"({"foo":"bar","args":"{\"arg1\":"})"
|
||||||
|
);
|
||||||
|
// Partial args broken in object value that is complete (int)
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": {"arg1": 1 )",
|
||||||
|
R"({"foo":"bar","args":"{\"arg1\":1"})"
|
||||||
|
);
|
||||||
|
// Partial args broken in object value that is incomplete (string)
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": {"arg1": ")",
|
||||||
|
R"({"foo":"bar","args":"{\"arg1\":\""})"
|
||||||
|
);
|
||||||
|
// Partial args broken in object value that is complete (string)
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": {"arg1": "1")",
|
||||||
|
R"({"foo":"bar","args":"{\"arg1\":\"1\""})"
|
||||||
|
);
|
||||||
|
// Partial args broken on array opening
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": [)",
|
||||||
|
R"({"foo":"bar","args":"["})"
|
||||||
|
);
|
||||||
|
// Partial args broken on array value that is incomplete (int)
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": [1)",
|
||||||
|
R"({"foo":"bar","args":"["})"
|
||||||
|
);
|
||||||
|
// Partial args broken on array value that is complete (int)
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": [1 )",
|
||||||
|
R"({"foo":"bar","args":"[1"})"
|
||||||
|
);
|
||||||
|
// Partial args broken on array value that is complete (string)
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": ["1")",
|
||||||
|
R"({"foo":"bar","args":"[\"1\""})"
|
||||||
|
);
|
||||||
|
// Partial args broken after array value
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": [1,)",
|
||||||
|
R"({"foo":"bar","args":"[1,"})"
|
||||||
|
);
|
||||||
|
// Partial args broken on nested array
|
||||||
|
test_with_args(
|
||||||
|
R"({"foo": "bar", "args": {"arg1": [)",
|
||||||
|
R"({"foo":"bar","args":"{\"arg1\":["})"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_positions() {
|
||||||
|
{
|
||||||
|
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
|
||||||
|
assert_equals<size_t>(0, builder.pos());
|
||||||
|
assert_throws([&]() { builder.move_to(100); });
|
||||||
|
assert_equals<size_t>(0, builder.pos());
|
||||||
|
assert_throws([&]() { builder.move_back(1); });
|
||||||
|
assert_equals<size_t>(0, builder.pos());
|
||||||
|
|
||||||
|
builder.move_to(8);
|
||||||
|
assert_equals<size_t>(8, builder.pos());
|
||||||
|
builder.move_back(1);
|
||||||
|
assert_equals<size_t>(7, builder.pos());
|
||||||
|
assert_equals("world!", builder.consume_rest());
|
||||||
|
|
||||||
|
builder.move_to(0);
|
||||||
|
assert_equals<size_t>(0, builder.pos());
|
||||||
|
|
||||||
|
assert_throws([&]() { builder.finish(); });
|
||||||
|
assert_equals<size_t>(0, builder.pos());
|
||||||
|
|
||||||
|
builder.move_to(builder.input().size());
|
||||||
|
builder.finish();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {});
|
||||||
|
|
||||||
|
builder.move_to(builder.input().size());
|
||||||
|
assert_equals<size_t>(builder.input().size(), builder.pos());
|
||||||
|
builder.finish();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
test_positions();
|
||||||
|
test_json_with_dumped_args_no_args();
|
||||||
|
test_json_with_dumped_args();
|
||||||
|
test_reasoning();
|
||||||
|
test_regex();
|
||||||
|
std::cout << "All tests passed!\n";
|
||||||
|
return 0;
|
||||||
|
}
|
1005
tests/test-chat.cpp
1005
tests/test-chat.cpp
File diff suppressed because it is too large
Load Diff
237
tests/test-json-partial.cpp
Normal file
237
tests/test-json-partial.cpp
Normal file
@ -0,0 +1,237 @@
|
|||||||
|
#include "common.h"
|
||||||
|
#include "json-partial.h"
|
||||||
|
#include <exception>
|
||||||
|
#include <iostream>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
template <class T> static void assert_equals(const T & expected, const T & actual) {
|
||||||
|
if (expected != actual) {
|
||||||
|
std::cerr << "Expected: " << expected << std::endl;
|
||||||
|
std::cerr << "Actual: " << actual << std::endl;
|
||||||
|
std::cerr << std::flush;
|
||||||
|
throw std::runtime_error("Test failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_json_healing() {
|
||||||
|
auto parse = [](const std::string & str) {
|
||||||
|
std::cerr << "# Parsing: " << str << '\n';
|
||||||
|
std::string::const_iterator it = str.begin();
|
||||||
|
const auto end = str.end();
|
||||||
|
common_json out;
|
||||||
|
std::string healing_marker = "$llama.cpp.json$";
|
||||||
|
if (common_json_parse(it, end, healing_marker, out)) {
|
||||||
|
auto dump = out.json.dump();
|
||||||
|
std::cerr << "Parsed: " << dump << '\n';
|
||||||
|
std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
|
||||||
|
std::string result;
|
||||||
|
if (!out.healing_marker.json_dump_marker.empty()) {
|
||||||
|
auto i = dump.find(out.healing_marker.json_dump_marker);
|
||||||
|
if (i == std::string::npos) {
|
||||||
|
throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
|
||||||
|
}
|
||||||
|
result = dump.substr(0, i);
|
||||||
|
} else {
|
||||||
|
result = dump;
|
||||||
|
}
|
||||||
|
std::cerr << "Result: " << result << '\n';
|
||||||
|
if (string_starts_with(str, result)) {
|
||||||
|
std::cerr << "Failure!\n";
|
||||||
|
}
|
||||||
|
// return dump;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Failed to parse: " + str);
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
auto parse_all = [&](const std::string & str) {
|
||||||
|
for (size_t i = 1; i < str.size(); i++) {
|
||||||
|
parse(str.substr(0, i));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
parse_all("{\"a\": \"b\"}");
|
||||||
|
parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
|
||||||
|
|
||||||
|
parse_all("[{\"a\": \"b\"}]");
|
||||||
|
|
||||||
|
auto test = [&](const std::vector<std::string> & inputs, const std::string & expected, const std::string & expected_marker) {
|
||||||
|
for (const auto & input : inputs) {
|
||||||
|
common_json out;
|
||||||
|
assert_equals(true, common_json_parse(input, "$foo", out));
|
||||||
|
assert_equals<std::string>(expected, out.json.dump());
|
||||||
|
assert_equals<std::string>(expected_marker, out.healing_marker.json_dump_marker);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// No healing needed:
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"([{"a":"b"}, "y"])",
|
||||||
|
},
|
||||||
|
R"([{"a":"b"},"y"])",
|
||||||
|
""
|
||||||
|
);
|
||||||
|
// Partial literals can't be healed:
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"([1)",
|
||||||
|
R"([tru)",
|
||||||
|
R"([n)",
|
||||||
|
R"([nul)",
|
||||||
|
R"([23.2)",
|
||||||
|
},
|
||||||
|
R"(["$foo"])",
|
||||||
|
R"("$foo)"
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"({"a": 1)",
|
||||||
|
R"({"a": tru)",
|
||||||
|
R"({"a": n)",
|
||||||
|
R"({"a": nul)",
|
||||||
|
R"({"a": 23.2)",
|
||||||
|
},
|
||||||
|
R"({"a":"$foo"})",
|
||||||
|
R"("$foo)"
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"({)",
|
||||||
|
},
|
||||||
|
R"({"$foo":1})",
|
||||||
|
R"("$foo)"
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"([)",
|
||||||
|
},
|
||||||
|
R"(["$foo"])",
|
||||||
|
R"("$foo)"
|
||||||
|
);
|
||||||
|
// Healing right after a full literal
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"(1 )",
|
||||||
|
},
|
||||||
|
R"(1)",
|
||||||
|
""
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"(true)",
|
||||||
|
R"(true )",
|
||||||
|
},
|
||||||
|
R"(true)",
|
||||||
|
""
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"(null)",
|
||||||
|
R"(null )",
|
||||||
|
},
|
||||||
|
R"(null)",
|
||||||
|
""
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"([1 )",
|
||||||
|
},
|
||||||
|
R"([1,"$foo"])",
|
||||||
|
R"(,"$foo)"
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"([{})",
|
||||||
|
R"([{} )",
|
||||||
|
},
|
||||||
|
R"([{},"$foo"])",
|
||||||
|
R"(,"$foo)"
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"([true)",
|
||||||
|
},
|
||||||
|
// TODO: detect the true/false/null literal was complete
|
||||||
|
R"(["$foo"])",
|
||||||
|
R"("$foo)"
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"([true )",
|
||||||
|
},
|
||||||
|
R"([true,"$foo"])",
|
||||||
|
R"(,"$foo)"
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"([true,)",
|
||||||
|
},
|
||||||
|
R"([true,"$foo"])",
|
||||||
|
R"("$foo)"
|
||||||
|
);
|
||||||
|
// Test nesting
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"([{"a": [{"b": [{)",
|
||||||
|
},
|
||||||
|
R"([{"a":[{"b":[{"$foo":1}]}]}])",
|
||||||
|
R"("$foo)"
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"([{"a": [{"b": [)",
|
||||||
|
},
|
||||||
|
R"([{"a":[{"b":["$foo"]}]}])",
|
||||||
|
R"("$foo)"
|
||||||
|
);
|
||||||
|
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"([{"a": "b"})",
|
||||||
|
R"([{"a": "b"} )",
|
||||||
|
},
|
||||||
|
R"([{"a":"b"},"$foo"])",
|
||||||
|
R"(,"$foo)"
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"([{"a": "b"},)",
|
||||||
|
R"([{"a": "b"}, )",
|
||||||
|
},
|
||||||
|
R"([{"a":"b"},"$foo"])",
|
||||||
|
R"("$foo)"
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"({ "code)",
|
||||||
|
},
|
||||||
|
R"({"code$foo":1})",
|
||||||
|
R"($foo)"
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"({ "code\)",
|
||||||
|
},
|
||||||
|
R"({"code\\$foo":1})",
|
||||||
|
R"(\$foo)"
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"({ "code")",
|
||||||
|
},
|
||||||
|
R"({"code":"$foo"})",
|
||||||
|
R"(:"$foo)"
|
||||||
|
);
|
||||||
|
test(
|
||||||
|
{
|
||||||
|
R"({ "key")",
|
||||||
|
},
|
||||||
|
R"({"key":"$foo"})",
|
||||||
|
R"(:"$foo)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
test_json_healing();
|
||||||
|
std::cerr << "All tests passed.\n";
|
||||||
|
return 0;
|
||||||
|
}
|
@ -1,3 +1,4 @@
|
|||||||
|
#include "chat.h"
|
||||||
#include "utils.hpp"
|
#include "utils.hpp"
|
||||||
|
|
||||||
#include "arg.h"
|
#include "arg.h"
|
||||||
@ -114,11 +115,11 @@ struct slot_params {
|
|||||||
struct common_params_speculative speculative;
|
struct common_params_speculative speculative;
|
||||||
|
|
||||||
// OAI-compat fields
|
// OAI-compat fields
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
common_chat_syntax oaicompat_chat_syntax;
|
||||||
|
|
||||||
json to_json() const {
|
json to_json() const {
|
||||||
std::vector<std::string> samplers;
|
std::vector<std::string> samplers;
|
||||||
@ -176,7 +177,10 @@ struct slot_params {
|
|||||||
{"grammar_lazy", sampling.grammar_lazy},
|
{"grammar_lazy", sampling.grammar_lazy},
|
||||||
{"grammar_triggers", grammar_triggers},
|
{"grammar_triggers", grammar_triggers},
|
||||||
{"preserved_tokens", sampling.preserved_tokens},
|
{"preserved_tokens", sampling.preserved_tokens},
|
||||||
{"chat_format", common_chat_format_name(oaicompat_chat_format)},
|
{"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)},
|
||||||
|
{"reasoning_format", (oaicompat_chat_syntax.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "deepseek" : "none")},
|
||||||
|
{"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content},
|
||||||
|
{"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open},
|
||||||
{"samplers", samplers},
|
{"samplers", samplers},
|
||||||
{"speculative.n_max", speculative.n_max},
|
{"speculative.n_max", speculative.n_max},
|
||||||
{"speculative.n_min", speculative.n_min},
|
{"speculative.n_min", speculative.n_min},
|
||||||
@ -352,11 +356,14 @@ struct server_task {
|
|||||||
{
|
{
|
||||||
auto it = data.find("chat_format");
|
auto it = data.find("chat_format");
|
||||||
if (it != data.end()) {
|
if (it != data.end()) {
|
||||||
params.oaicompat_chat_format = static_cast<common_chat_format>(it->get<int>());
|
params.oaicompat_chat_syntax.format = static_cast<common_chat_format>(it->get<int>());
|
||||||
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_format).c_str());
|
SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format).c_str());
|
||||||
} else {
|
} else {
|
||||||
params.oaicompat_chat_format = defaults.oaicompat_chat_format;
|
params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format;
|
||||||
}
|
}
|
||||||
|
params.oaicompat_chat_syntax.reasoning_format = params_base.reasoning_format;
|
||||||
|
params.oaicompat_chat_syntax.reasoning_in_content = params.stream;
|
||||||
|
params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -396,7 +403,14 @@ struct server_task {
|
|||||||
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
|
params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
params.sampling.grammar_triggers.push_back(std::move(ct.value));
|
if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) {
|
||||||
|
SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str());
|
||||||
|
} else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) {
|
||||||
|
SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str());
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Unknown grammar trigger type");
|
||||||
|
}
|
||||||
|
params.sampling.grammar_triggers.emplace_back(std::move(ct.value));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -639,11 +653,12 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||||||
slot_params generation_params;
|
slot_params generation_params;
|
||||||
|
|
||||||
// OAI-compat fields
|
// OAI-compat fields
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
common_chat_format oaicompat_chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
common_chat_msg oaicompat_msg;
|
||||||
|
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
@ -738,47 +753,20 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||||||
json to_json_oaicompat_chat() {
|
json to_json_oaicompat_chat() {
|
||||||
std::string finish_reason = "length";
|
std::string finish_reason = "length";
|
||||||
common_chat_msg msg;
|
common_chat_msg msg;
|
||||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
if (!oaicompat_msg.empty()) {
|
||||||
SRV_DBG("Parsing chat message: %s\n", content.c_str());
|
msg = oaicompat_msg;
|
||||||
msg = common_chat_parse(content, oaicompat_chat_format);
|
|
||||||
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
|
||||||
} else {
|
} else {
|
||||||
|
msg.role = "assistant";
|
||||||
msg.content = content;
|
msg.content = content;
|
||||||
}
|
}
|
||||||
|
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||||
json message {
|
finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||||
{"role", "assistant"},
|
|
||||||
};
|
|
||||||
if (!msg.reasoning_content.empty()) {
|
|
||||||
message["reasoning_content"] = msg.reasoning_content;
|
|
||||||
}
|
|
||||||
if (msg.content.empty() && !msg.tool_calls.empty()) {
|
|
||||||
message["content"] = json();
|
|
||||||
} else {
|
|
||||||
message["content"] = msg.content;
|
|
||||||
}
|
|
||||||
if (!msg.tool_calls.empty()) {
|
|
||||||
auto tool_calls = json::array();
|
|
||||||
for (const auto & tc : msg.tool_calls) {
|
|
||||||
tool_calls.push_back({
|
|
||||||
{"type", "function"},
|
|
||||||
{"function", {
|
|
||||||
{"name", tc.name},
|
|
||||||
{"arguments", tc.arguments},
|
|
||||||
}},
|
|
||||||
// Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo).
|
|
||||||
// We only generate a random id for the ones that don't generate one by themselves
|
|
||||||
// (they also won't get to see it as their template likely doesn't use it, so it's all for the client)
|
|
||||||
{"id", tc.id.empty() ? gen_tool_call_id() : tc.id},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
message["tool_calls"] = tool_calls;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
json choice {
|
json choice {
|
||||||
{"finish_reason", finish_reason},
|
{"finish_reason", finish_reason},
|
||||||
{"index", 0},
|
{"index", 0},
|
||||||
{"message", message},
|
{"message", msg.to_json_oaicompat<json>()},
|
||||||
};
|
};
|
||||||
|
|
||||||
if (!stream && probs_output.size() > 0) {
|
if (!stream && probs_output.size() > 0) {
|
||||||
@ -818,17 +806,35 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||||||
std::time_t t = std::time(0);
|
std::time_t t = std::time(0);
|
||||||
std::string finish_reason = "length";
|
std::string finish_reason = "length";
|
||||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||||
finish_reason = "stop";
|
finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls";
|
||||||
}
|
}
|
||||||
|
|
||||||
json choice = json {
|
json deltas = json::array();
|
||||||
{"finish_reason", finish_reason},
|
for (const auto & diff : oaicompat_msg_diffs) {
|
||||||
{"index", 0},
|
deltas.push_back({
|
||||||
{"delta", json::object()}
|
{"choices", json::array({
|
||||||
};
|
json {
|
||||||
|
{"finish_reason", nullptr},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", common_chat_msg_diff_to_json_oaicompat<json>(diff)},
|
||||||
|
},
|
||||||
|
})},
|
||||||
|
{"created", t},
|
||||||
|
{"id", oaicompat_cmpl_id},
|
||||||
|
{"model", oaicompat_model},
|
||||||
|
{"system_fingerprint", build_info},
|
||||||
|
{"object", "chat.completion.chunk"},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
json ret = json {
|
deltas.push_back({
|
||||||
{"choices", json::array({choice})},
|
{"choices", json::array({
|
||||||
|
json {
|
||||||
|
{"finish_reason", finish_reason},
|
||||||
|
{"index", 0},
|
||||||
|
{"delta", json::object()},
|
||||||
|
},
|
||||||
|
})},
|
||||||
{"created", t},
|
{"created", t},
|
||||||
{"id", oaicompat_cmpl_id},
|
{"id", oaicompat_cmpl_id},
|
||||||
{"model", oaicompat_model},
|
{"model", oaicompat_model},
|
||||||
@ -839,18 +845,18 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||||||
{"prompt_tokens", n_prompt_tokens},
|
{"prompt_tokens", n_prompt_tokens},
|
||||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||||
}},
|
}},
|
||||||
};
|
});
|
||||||
|
|
||||||
if (timings.prompt_n >= 0) {
|
if (timings.prompt_n >= 0) {
|
||||||
ret.push_back({"timings", timings.to_json()});
|
deltas.back().push_back({"timings", timings.to_json()});
|
||||||
}
|
}
|
||||||
|
|
||||||
// extra fields for debugging purposes
|
// extra fields for debugging purposes
|
||||||
if (verbose) {
|
if (verbose && !deltas.empty()) {
|
||||||
ret["__verbose"] = to_json_non_oaicompat();
|
deltas.front()["__verbose"] = to_json_non_oaicompat();
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret;
|
return deltas;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -868,10 +874,11 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||||||
result_timings timings;
|
result_timings timings;
|
||||||
|
|
||||||
// OAI-compat fields
|
// OAI-compat fields
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_cmpl_id;
|
std::string oaicompat_cmpl_id;
|
||||||
|
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
@ -955,84 +962,50 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||||||
std::time_t t = std::time(0);
|
std::time_t t = std::time(0);
|
||||||
json choices;
|
json choices;
|
||||||
|
|
||||||
if (first) {
|
std::vector<json> deltas;
|
||||||
if (content.empty()) {
|
auto add_delta = [&](const json & delta) {
|
||||||
choices = json::array({json{{"finish_reason", nullptr},
|
deltas.push_back({
|
||||||
{"index", 0},
|
{"choices", json::array({
|
||||||
{"delta", json{{"role", "assistant"}}}}});
|
json {
|
||||||
} else {
|
{"finish_reason", nullptr},
|
||||||
// We have to send this as two updates to conform to openai behavior
|
{"index", 0},
|
||||||
// initial_ret is the role message for stream=True
|
{"delta", delta},
|
||||||
json initial_ret = json{{"choices", json::array({json{
|
},
|
||||||
{"finish_reason", nullptr},
|
})},
|
||||||
{"index", 0},
|
{"created", t},
|
||||||
{"delta", json{
|
{"id", oaicompat_cmpl_id},
|
||||||
{"role", "assistant"},
|
{"model", oaicompat_model},
|
||||||
{"content", ""}
|
{"system_fingerprint", build_info},
|
||||||
}}}})},
|
{"object", "chat.completion.chunk"},
|
||||||
{"created", t},
|
});
|
||||||
{"id", oaicompat_cmpl_id},
|
|
||||||
{"model", oaicompat_model},
|
|
||||||
{"system_fingerprint", build_info},
|
|
||||||
{"object", "chat.completion.chunk"}};
|
|
||||||
|
|
||||||
json second_ret = json{
|
|
||||||
{"choices", json::array({json{{"finish_reason", nullptr},
|
|
||||||
{"index", 0},
|
|
||||||
{"delta", json {
|
|
||||||
{"content", content}}}
|
|
||||||
}})},
|
|
||||||
{"created", t},
|
|
||||||
{"id", oaicompat_cmpl_id},
|
|
||||||
{"model", oaicompat_model},
|
|
||||||
{"system_fingerprint", build_info},
|
|
||||||
{"object", "chat.completion.chunk"}};
|
|
||||||
|
|
||||||
if (prob_output.probs.size() > 0) {
|
|
||||||
second_ret["choices"][0]["logprobs"] = json{
|
|
||||||
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (timings.prompt_n >= 0) {
|
|
||||||
second_ret.push_back({"timings", timings.to_json()});
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::vector<json>({initial_ret, second_ret});
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
choices = json::array({json{
|
|
||||||
{"finish_reason", nullptr},
|
|
||||||
{"index", 0},
|
|
||||||
{"delta",
|
|
||||||
json {
|
|
||||||
{"content", content},
|
|
||||||
}},
|
|
||||||
}});
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_ASSERT(choices.size() >= 1);
|
|
||||||
|
|
||||||
if (prob_output.probs.size() > 0) {
|
|
||||||
choices[0]["logprobs"] = json{
|
|
||||||
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
json ret = json {
|
|
||||||
{"choices", choices},
|
|
||||||
{"created", t},
|
|
||||||
{"id", oaicompat_cmpl_id},
|
|
||||||
{"model", oaicompat_model},
|
|
||||||
{"system_fingerprint", build_info},
|
|
||||||
{"object", "chat.completion.chunk"}
|
|
||||||
};
|
};
|
||||||
|
// We have to send an initial update to conform to openai behavior
|
||||||
if (timings.prompt_n >= 0) {
|
if (first) {
|
||||||
ret.push_back({"timings", timings.to_json()});
|
add_delta({
|
||||||
|
{"role", "assistant"},
|
||||||
|
{"content", nullptr},
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::vector<json>({ret});
|
for (const auto & diff : oaicompat_msg_diffs) {
|
||||||
|
add_delta(common_chat_msg_diff_to_json_oaicompat<json>(diff));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!deltas.empty()) {
|
||||||
|
GGML_ASSERT(deltas[deltas.size() - 1].at("choices").size() >= 1);
|
||||||
|
|
||||||
|
if (prob_output.probs.size() > 0) {
|
||||||
|
deltas[deltas.size() - 1].at("choices").at(0)["logprobs"] = json {
|
||||||
|
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (timings.prompt_n >= 0) {
|
||||||
|
deltas[deltas.size() - 1].push_back({"timings", timings.to_json()});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return deltas;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1293,6 +1266,7 @@ struct server_slot {
|
|||||||
|
|
||||||
std::string generated_text;
|
std::string generated_text;
|
||||||
llama_tokens generated_tokens;
|
llama_tokens generated_tokens;
|
||||||
|
common_chat_msg chat_msg;
|
||||||
|
|
||||||
server_tokens cache_tokens;
|
server_tokens cache_tokens;
|
||||||
|
|
||||||
@ -1313,6 +1287,7 @@ struct server_slot {
|
|||||||
llama_token sampled;
|
llama_token sampled;
|
||||||
|
|
||||||
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
|
std::vector<std::string> generated_tool_call_ids;
|
||||||
|
|
||||||
// stats
|
// stats
|
||||||
size_t n_sent_text = 0; // number of sent text character
|
size_t n_sent_text = 0; // number of sent text character
|
||||||
@ -1342,9 +1317,13 @@ struct server_slot {
|
|||||||
n_past = 0;
|
n_past = 0;
|
||||||
n_sent_text = 0;
|
n_sent_text = 0;
|
||||||
task_type = SERVER_TASK_TYPE_COMPLETION;
|
task_type = SERVER_TASK_TYPE_COMPLETION;
|
||||||
|
chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
||||||
|
|
||||||
generated_tokens.clear();
|
generated_tokens.clear();
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
|
chat_msg = {};
|
||||||
|
json_schema = json();
|
||||||
|
generated_tool_call_ids.clear();
|
||||||
|
|
||||||
// clear speculative decoding stats
|
// clear speculative decoding stats
|
||||||
n_draft_total = 0;
|
n_draft_total = 0;
|
||||||
@ -1424,6 +1403,21 @@ struct server_slot {
|
|||||||
return timings;
|
return timings;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const common_chat_msg & update_chat_msg(std::vector<common_chat_msg_diff> & diffs) {
|
||||||
|
auto previous_msg = chat_msg;
|
||||||
|
SRV_DBG("Parsing chat message: %s\n", generated_text.c_str());
|
||||||
|
auto new_msg = common_chat_parse(
|
||||||
|
generated_text,
|
||||||
|
/* is_partial= */ stop != STOP_TYPE_EOS,
|
||||||
|
params.oaicompat_chat_syntax);
|
||||||
|
if (!new_msg.empty()) {
|
||||||
|
new_msg.ensure_tool_call_ids_set(generated_tool_call_ids, gen_tool_call_id);
|
||||||
|
chat_msg = new_msg;
|
||||||
|
diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg);
|
||||||
|
}
|
||||||
|
return chat_msg;
|
||||||
|
}
|
||||||
|
|
||||||
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
||||||
size_t stop_pos = std::string::npos;
|
size_t stop_pos = std::string::npos;
|
||||||
|
|
||||||
@ -2475,10 +2469,12 @@ struct server_context {
|
|||||||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||||
res->post_sampling_probs = slot.params.post_sampling_probs;
|
res->post_sampling_probs = slot.params.post_sampling_probs;
|
||||||
|
|
||||||
res->verbose = slot.params.verbose;
|
res->verbose = slot.params.verbose;
|
||||||
res->oaicompat = slot.params.oaicompat;
|
res->oaicompat = slot.params.oaicompat;
|
||||||
res->oaicompat_model = slot.params.oaicompat_model;
|
res->oaicompat_model = slot.params.oaicompat_model;
|
||||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||||
|
|
||||||
|
slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||||
|
|
||||||
// populate res.probs_output
|
// populate res.probs_output
|
||||||
if (slot.params.sampling.n_probs > 0) {
|
if (slot.params.sampling.n_probs > 0) {
|
||||||
@ -2499,7 +2495,7 @@ struct server_context {
|
|||||||
res->id_slot = slot.id;
|
res->id_slot = slot.id;
|
||||||
|
|
||||||
res->index = slot.index;
|
res->index = slot.index;
|
||||||
res->content = std::move(slot.generated_text);
|
res->content = slot.generated_text;
|
||||||
res->tokens = std::move(slot.generated_tokens);
|
res->tokens = std::move(slot.generated_tokens);
|
||||||
res->timings = slot.get_timings();
|
res->timings = slot.get_timings();
|
||||||
res->prompt = slot.prompt_tokens.detokenize(ctx, true);
|
res->prompt = slot.prompt_tokens.detokenize(ctx, true);
|
||||||
@ -2519,7 +2515,8 @@ struct server_context {
|
|||||||
res->oaicompat = slot.params.oaicompat;
|
res->oaicompat = slot.params.oaicompat;
|
||||||
res->oaicompat_model = slot.params.oaicompat_model;
|
res->oaicompat_model = slot.params.oaicompat_model;
|
||||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||||
res->oaicompat_chat_format = slot.params.oaicompat_chat_format;
|
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
|
||||||
|
|
||||||
// populate res.probs_output
|
// populate res.probs_output
|
||||||
if (slot.params.sampling.n_probs > 0) {
|
if (slot.params.sampling.n_probs > 0) {
|
||||||
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
|
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
|
||||||
|
@ -75,7 +75,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
|||||||
choice = data["choices"][0]
|
choice = data["choices"][0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# Check first role message for stream=True
|
# Check first role message for stream=True
|
||||||
assert choice["delta"]["content"] == ""
|
assert choice["delta"]["content"] is None
|
||||||
assert choice["delta"]["role"] == "assistant"
|
assert choice["delta"]["role"] == "assistant"
|
||||||
else:
|
else:
|
||||||
assert "role" not in choice["delta"]
|
assert "role" not in choice["delta"]
|
||||||
@ -92,7 +92,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
|||||||
assert choice["finish_reason"] == finish_reason
|
assert choice["finish_reason"] == finish_reason
|
||||||
else:
|
else:
|
||||||
assert choice["finish_reason"] is None
|
assert choice["finish_reason"] is None
|
||||||
content += choice["delta"]["content"]
|
content += choice["delta"]["content"] or ''
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_with_openai_library():
|
def test_chat_completion_with_openai_library():
|
||||||
@ -251,8 +251,9 @@ def test_chat_completion_with_timings_per_token():
|
|||||||
for i, data in enumerate(res):
|
for i, data in enumerate(res):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# Check first role message for stream=True
|
# Check first role message for stream=True
|
||||||
assert data["choices"][0]["delta"]["content"] == ""
|
assert data["choices"][0]["delta"]["content"] is None
|
||||||
assert data["choices"][0]["delta"]["role"] == "assistant"
|
assert data["choices"][0]["delta"]["role"] == "assistant"
|
||||||
|
assert "timings" not in data, f'First event should not have timings: {data}'
|
||||||
else:
|
else:
|
||||||
assert "role" not in data["choices"][0]["delta"]
|
assert "role" not in data["choices"][0]["delta"]
|
||||||
assert "timings" in data
|
assert "timings" in data
|
||||||
@ -311,7 +312,7 @@ def test_logprobs_stream():
|
|||||||
choice = data.choices[0]
|
choice = data.choices[0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# Check first role message for stream=True
|
# Check first role message for stream=True
|
||||||
assert choice.delta.content == ""
|
assert choice.delta.content is None
|
||||||
assert choice.delta.role == "assistant"
|
assert choice.delta.role == "assistant"
|
||||||
else:
|
else:
|
||||||
assert choice.delta.role is None
|
assert choice.delta.role is None
|
||||||
|
@ -8,6 +8,7 @@ path = Path(__file__).resolve().parents[1]
|
|||||||
sys.path.insert(0, str(path))
|
sys.path.insert(0, str(path))
|
||||||
|
|
||||||
from utils import *
|
from utils import *
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
server: ServerProcess
|
server: ServerProcess
|
||||||
|
|
||||||
@ -20,7 +21,11 @@ def create_server():
|
|||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
server.model_alias = "tinyllama-2-tool-call"
|
server.model_alias = "tinyllama-2-tool-call"
|
||||||
server.server_port = 8081
|
server.server_port = 8081
|
||||||
|
server.n_slots = 1
|
||||||
|
|
||||||
|
class CompletionMode(Enum):
|
||||||
|
NORMAL = "normal"
|
||||||
|
STREAMED = "streamed"
|
||||||
|
|
||||||
TEST_TOOL = {
|
TEST_TOOL = {
|
||||||
"type":"function",
|
"type":"function",
|
||||||
@ -73,9 +78,8 @@ WEATHER_TOOL = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
|
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
|
||||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||||
"max_tokens": n_predict,
|
"max_tokens": n_predict,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a coding assistant."},
|
{"role": "system", "content": "You are a coding assistant."},
|
||||||
@ -86,13 +90,13 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
|
|||||||
"parallel_tool_calls": False,
|
"parallel_tool_calls": False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
})
|
})
|
||||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
# assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||||
choice = res.body["choices"][0]
|
choice = body["choices"][0]
|
||||||
tool_calls = choice["message"].get("tool_calls")
|
tool_calls = choice["message"].get("tool_calls")
|
||||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||||
tool_call = tool_calls[0]
|
tool_call = tool_calls[0]
|
||||||
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||||
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
|
||||||
assert expected_function_name == tool_call["function"]["name"]
|
assert expected_function_name == tool_call["function"]["name"]
|
||||||
actual_arguments = tool_call["function"]["arguments"]
|
actual_arguments = tool_call["function"]["arguments"]
|
||||||
@ -102,12 +106,16 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict
|
|||||||
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||||
("google-gemma-2-2b-it", TEST_TOOL, "success"),
|
("google-gemma-2-2b-it", TEST_TOOL, "success"),
|
||||||
|
("google-gemma-2-2b-it", TEST_TOOL, "success"),
|
||||||
|
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||||
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
|
||||||
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||||
|
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
|
||||||
])
|
])
|
||||||
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
|
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
|
||||||
global server
|
global server
|
||||||
n_predict = 1024
|
n_predict = 1024
|
||||||
# server = ServerPreset.stories15m_moe()
|
# server = ServerPreset.stories15m_moe()
|
||||||
@ -115,31 +123,43 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
|
|||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0)
|
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||||
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
@pytest.mark.parametrize("template_name,tool,argument_key", [
|
||||||
("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
|
||||||
("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
|
||||||
|
|
||||||
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
|
||||||
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
|
||||||
|
|
||||||
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
|
||||||
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
# Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own.
|
||||||
|
# ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
|
||||||
|
|
||||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
|
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
|
||||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
|
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
|
||||||
|
|
||||||
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
|
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
|
||||||
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
|
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
|
||||||
|
|
||||||
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
|
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
|
||||||
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
|
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
|
||||||
|
|
||||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
|
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
|
||||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
|
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
|
||||||
|
|
||||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
|
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
|
||||||
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
|
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
|
||||||
|
|
||||||
("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
|
("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
|
||||||
|
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "codeFalse), True),
|
||||||
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
|
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
|
||||||
|
|
||||||
])
|
])
|
||||||
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
|
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode):
|
||||||
global server
|
global server
|
||||||
n_predict = 512
|
n_predict = 512
|
||||||
# server = ServerPreset.stories15m_moe()
|
# server = ServerPreset.stories15m_moe()
|
||||||
@ -147,10 +167,11 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
|||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||||
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict)
|
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||||
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
|
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
|
||||||
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||||
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||||
@ -184,9 +205,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
|||||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||||
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||||
|
|
||||||
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
# (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||||
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||||
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
# (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||||
|
|
||||||
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
||||||
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
|
||||||
@ -203,10 +224,9 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
|
|||||||
(TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
(TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
(PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
(PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
])
|
])
|
||||||
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||||
global server
|
global server
|
||||||
n_predict = 512
|
n_predict = 512
|
||||||
server.n_slots = 1
|
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.n_ctx = 8192
|
server.n_ctx = 8192
|
||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
@ -219,7 +239,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
|||||||
elif isinstance(template_override, str):
|
elif isinstance(template_override, str):
|
||||||
server.chat_template = template_override
|
server.chat_template = template_override
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||||
"max_tokens": n_predict,
|
"max_tokens": n_predict,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a coding assistant."},
|
{"role": "system", "content": "You are a coding assistant."},
|
||||||
@ -228,12 +248,12 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
|||||||
"tool_choice": "required",
|
"tool_choice": "required",
|
||||||
"tools": [tool],
|
"tools": [tool],
|
||||||
"parallel_tool_calls": False,
|
"parallel_tool_calls": False,
|
||||||
|
"stream": stream == CompletionMode.STREAMED,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"top_k": 1,
|
"top_k": 1,
|
||||||
"top_p": 1.0,
|
"top_p": 1.0,
|
||||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
choice = body["choices"][0]
|
||||||
choice = res.body["choices"][0]
|
|
||||||
tool_calls = choice["message"].get("tool_calls")
|
tool_calls = choice["message"].get("tool_calls")
|
||||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||||
tool_call = tool_calls[0]
|
tool_call = tool_calls[0]
|
||||||
@ -248,7 +268,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
|
|||||||
|
|
||||||
|
|
||||||
def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
|
def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
|
||||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||||
"max_tokens": n_predict,
|
"max_tokens": n_predict,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a coding assistant."},
|
{"role": "system", "content": "You are a coding assistant."},
|
||||||
@ -258,26 +278,27 @@ def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int,
|
|||||||
"tool_choice": tool_choice,
|
"tool_choice": tool_choice,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
choice = body["choices"][0]
|
||||||
choice = res.body["choices"][0]
|
|
||||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
|
("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
|
||||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
|
("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
|
||||||
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
|
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
|
||||||
])
|
])
|
||||||
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
|
||||||
global server
|
global server
|
||||||
server.jinja = True
|
|
||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
|
server.jinja = True
|
||||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
|
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||||
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
|
||||||
("meetkai-functionary-medium-v3.2", 256, [], None),
|
("meetkai-functionary-medium-v3.2", 256, [], None),
|
||||||
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
|
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
|
||||||
@ -289,16 +310,17 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
|
|||||||
("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
|
("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
|
||||||
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
|
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
|
||||||
])
|
])
|
||||||
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode):
|
||||||
global server
|
global server
|
||||||
server.jinja = True
|
|
||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
|
server.jinja = True
|
||||||
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||||
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
|
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||||
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||||
@ -321,11 +343,11 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
|||||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
|
||||||
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
|
||||||
|
|
||||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
|
||||||
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
# ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
|
||||||
|
|
||||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
|
||||||
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
# ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
|
||||||
|
|
||||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||||
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
|
||||||
@ -339,10 +361,9 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
|
|||||||
|
|
||||||
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
|
||||||
])
|
])
|
||||||
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||||
global server
|
global server
|
||||||
n_predict = 512
|
n_predict = 512
|
||||||
server.n_slots = 1
|
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.n_ctx = 8192
|
server.n_ctx = 8192
|
||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
@ -355,11 +376,11 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
|
|||||||
elif isinstance(template_override, str):
|
elif isinstance(template_override, str):
|
||||||
server.chat_template = template_override
|
server.chat_template = template_override
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||||
do_test_weather(server, max_tokens=n_predict)
|
do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
||||||
|
|
||||||
|
|
||||||
def do_test_weather(server: ServerProcess, **kwargs):
|
def do_test_weather(server: ServerProcess, **kwargs):
|
||||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
|
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
|
||||||
{"role": "user", "content": "What is the weather in Istanbul?"},
|
{"role": "user", "content": "What is the weather in Istanbul?"},
|
||||||
@ -367,14 +388,13 @@ def do_test_weather(server: ServerProcess, **kwargs):
|
|||||||
"tools": [WEATHER_TOOL],
|
"tools": [WEATHER_TOOL],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
choice = body["choices"][0]
|
||||||
choice = res.body["choices"][0]
|
|
||||||
tool_calls = choice["message"].get("tool_calls")
|
tool_calls = choice["message"].get("tool_calls")
|
||||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||||
tool_call = tool_calls[0]
|
tool_call = tool_calls[0]
|
||||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||||
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
|
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
|
||||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||||
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
|
||||||
location = actual_arguments["location"]
|
location = actual_arguments["location"]
|
||||||
@ -383,6 +403,7 @@ def do_test_weather(server: ServerProcess, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||||
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
|
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
|
||||||
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
|
||||||
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
|
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
|
||||||
@ -400,9 +421,8 @@ def do_test_weather(server: ServerProcess, **kwargs):
|
|||||||
# (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
# (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
|
||||||
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
])
|
])
|
||||||
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||||
global server
|
global server
|
||||||
server.n_slots = 1
|
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.n_ctx = 8192 * 2
|
server.n_ctx = 8192 * 2
|
||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
@ -415,11 +435,11 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
|
|||||||
elif isinstance(template_override, str):
|
elif isinstance(template_override, str):
|
||||||
server.chat_template = template_override
|
server.chat_template = template_override
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||||
do_test_calc_result(server, result_override, n_predict)
|
do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)
|
||||||
|
|
||||||
|
|
||||||
def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
|
def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
|
||||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||||
"max_tokens": n_predict,
|
"max_tokens": n_predict,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
|
{"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
|
||||||
@ -466,8 +486,7 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
|
|||||||
],
|
],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
choice = body["choices"][0]
|
||||||
choice = res.body["choices"][0]
|
|
||||||
tool_calls = choice["message"].get("tool_calls")
|
tool_calls = choice["message"].get("tool_calls")
|
||||||
assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
|
assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
|
||||||
content = choice["message"].get("content")
|
content = choice["message"].get("content")
|
||||||
@ -480,18 +499,18 @@ def do_test_calc_result(server: ServerProcess, result_override: str | None, n_pr
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
|
@pytest.mark.parametrize("n_predict,reasoning_format,stream,expect_reasoning_content,expect_content,hf_repo,template_override", [
|
||||||
(128, 'deepseek', "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
(128, 'deepseek', CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||||
(128, None, "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
(128, None, CompletionMode.NORMAL, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
|
||||||
|
(1024, 'deepseek', CompletionMode.NORMAL, "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "I need to calculate the sum of 102 and 7[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>I need to calculate [\\s\\S]*?</think>To find the sum of [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
(1024, 'none', "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
(1024, 'deepseek', CompletionMode.NORMAL, "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||||
|
(1024, 'deepseek', CompletionMode.STREAMED, None, "^<think>First, I [\\s\\S]*?</think>To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
||||||
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
|
# (1024, 'none', CompletionMode.NORMAL, None, "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
|
# (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None),
|
||||||
])
|
])
|
||||||
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||||
global server
|
global server
|
||||||
server.n_slots = 1
|
|
||||||
server.reasoning_format = reasoning_format
|
server.reasoning_format = reasoning_format
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.n_ctx = 8192 * 2
|
server.n_ctx = 8192 * 2
|
||||||
@ -505,14 +524,14 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
|||||||
elif isinstance(template_override, str):
|
elif isinstance(template_override, str):
|
||||||
server.chat_template = template_override
|
server.chat_template = template_override
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||||
"max_tokens": n_predict,
|
"max_tokens": n_predict,
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "What's the sum of 102 and 7?"},
|
{"role": "user", "content": "What's the sum of 102 and 7?"},
|
||||||
]
|
],
|
||||||
|
"stream": stream == CompletionMode.STREAMED,
|
||||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
choice = body["choices"][0]
|
||||||
choice = res.body["choices"][0]
|
|
||||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||||
|
|
||||||
content = choice["message"].get("content")
|
content = choice["message"].get("content")
|
||||||
@ -529,6 +548,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED])
|
||||||
@pytest.mark.parametrize("hf_repo,template_override", [
|
@pytest.mark.parametrize("hf_repo,template_override", [
|
||||||
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
|
||||||
|
|
||||||
@ -562,10 +582,9 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
|
|||||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
|
||||||
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
|
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
|
||||||
])
|
])
|
||||||
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
|
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode):
|
||||||
global server
|
global server
|
||||||
n_predict = 512 # High because of DeepSeek R1
|
n_predict = 512 # High because of DeepSeek R1
|
||||||
server.n_slots = 1
|
|
||||||
server.jinja = True
|
server.jinja = True
|
||||||
server.n_ctx = 8192
|
server.n_ctx = 8192
|
||||||
server.n_predict = n_predict
|
server.n_predict = n_predict
|
||||||
@ -579,11 +598,11 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
|
|||||||
server.chat_template = template_override
|
server.chat_template = template_override
|
||||||
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
server.start(timeout_seconds=TIMEOUT_SERVER_START)
|
||||||
|
|
||||||
do_test_hello_world(server, max_tokens=n_predict)
|
do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)
|
||||||
|
|
||||||
|
|
||||||
def do_test_hello_world(server: ServerProcess, **kwargs):
|
def do_test_hello_world(server: ServerProcess, **kwargs):
|
||||||
res = server.make_request("POST", "/v1/chat/completions", data={
|
body = server.make_any_request("POST", "/v1/chat/completions", data={
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are a tool-calling agent."},
|
{"role": "system", "content": "You are a tool-calling agent."},
|
||||||
{"role": "user", "content": "say hello world with python"},
|
{"role": "user", "content": "say hello world with python"},
|
||||||
@ -591,16 +610,15 @@ def do_test_hello_world(server: ServerProcess, **kwargs):
|
|||||||
"tools": [PYTHON_TOOL],
|
"tools": [PYTHON_TOOL],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}, timeout=TIMEOUT_HTTP_REQUEST)
|
}, timeout=TIMEOUT_HTTP_REQUEST)
|
||||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
choice = body["choices"][0]
|
||||||
choice = res.body["choices"][0]
|
|
||||||
tool_calls = choice["message"].get("tool_calls")
|
tool_calls = choice["message"].get("tool_calls")
|
||||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||||
tool_call = tool_calls[0]
|
tool_call = tool_calls[0]
|
||||||
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
|
||||||
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
|
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
|
||||||
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
# assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
|
||||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||||
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
|
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
|
||||||
code = actual_arguments["code"]
|
code = actual_arguments["code"]
|
||||||
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
|
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
|
||||||
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'
|
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}'
|
||||||
|
@ -294,6 +294,77 @@ class ServerProcess:
|
|||||||
print("Partial response from server", json.dumps(data, indent=2))
|
print("Partial response from server", json.dumps(data, indent=2))
|
||||||
yield data
|
yield data
|
||||||
|
|
||||||
|
def make_any_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
data: dict | None = None,
|
||||||
|
headers: dict | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> dict:
|
||||||
|
stream = data.get('stream', False)
|
||||||
|
if stream:
|
||||||
|
content: list[str] = []
|
||||||
|
tool_calls: list[dict] = []
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
content_parts = 0
|
||||||
|
tool_call_parts = 0
|
||||||
|
arguments_parts = 0
|
||||||
|
|
||||||
|
for chunk in self.make_stream_request(method, path, data, headers):
|
||||||
|
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
|
||||||
|
choice = chunk['choices'][0]
|
||||||
|
if choice['delta'].get('content') is not None:
|
||||||
|
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
|
||||||
|
content.append(choice['delta']['content'])
|
||||||
|
content_parts += 1
|
||||||
|
if choice['delta'].get('finish_reason') is not None:
|
||||||
|
finish_reason = choice['delta']['finish_reason']
|
||||||
|
for tc in choice['delta'].get('tool_calls', []):
|
||||||
|
if 'function' not in tc:
|
||||||
|
raise ValueError(f"Expected function type, got {tc['type']}")
|
||||||
|
if tc['index'] >= len(tool_calls):
|
||||||
|
tool_calls.append(dict(
|
||||||
|
id="",
|
||||||
|
type="function",
|
||||||
|
function=dict(
|
||||||
|
name="",
|
||||||
|
arguments="",
|
||||||
|
)
|
||||||
|
))
|
||||||
|
tool_call = tool_calls[tc['index']]
|
||||||
|
if tc.get('id') is not None:
|
||||||
|
tool_call['id'] = tc['id']
|
||||||
|
fct = tc['function']
|
||||||
|
if fct.get('name') is not None:
|
||||||
|
tool_call['function']['name'] = fct['name']
|
||||||
|
if fct.get('arguments') is not None:
|
||||||
|
assert len(fct['arguments']) > 0, f'Expected non empty arguments delta!'
|
||||||
|
tool_call['function']['arguments'] += fct['arguments']
|
||||||
|
|
||||||
|
print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
|
||||||
|
result = dict(
|
||||||
|
choices=[
|
||||||
|
dict(
|
||||||
|
index=0,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
message=dict(
|
||||||
|
role='assistant',
|
||||||
|
content=''.join(content) if content else None,
|
||||||
|
tool_calls=tool_calls if tool_calls else None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print("Final response from server", json.dumps(result, indent=2))
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
response = self.make_request(method, path, data, headers, timeout=timeout)
|
||||||
|
assert response.status_code == 200, f"Server returned error: {response.status_code}"
|
||||||
|
return response.body
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
server_instances: Set[ServerProcess] = set()
|
server_instances: Set[ServerProcess] = set()
|
||||||
|
|
||||||
|
@ -474,26 +474,6 @@ static std::string gen_tool_call_id() {
|
|||||||
// other common utils
|
// other common utils
|
||||||
//
|
//
|
||||||
|
|
||||||
static bool ends_with(const std::string & str, const std::string & suffix) {
|
|
||||||
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
|
|
||||||
if (!text.empty() && !stop.empty()) {
|
|
||||||
const char text_last_char = text.back();
|
|
||||||
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
|
|
||||||
if (stop[char_index] == text_last_char) {
|
|
||||||
const std::string current_partial = stop.substr(0, char_index + 1);
|
|
||||||
if (ends_with(text, current_partial)) {
|
|
||||||
return text.size() - char_index - 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return std::string::npos;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: reuse llama_detokenize
|
// TODO: reuse llama_detokenize
|
||||||
template <class Iter>
|
template <class Iter>
|
||||||
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
|
||||||
@ -599,19 +579,16 @@ static json oaicompat_chat_params_parse(
|
|||||||
json llama_params;
|
json llama_params;
|
||||||
|
|
||||||
auto tools = json_value(body, "tools", json());
|
auto tools = json_value(body, "tools", json());
|
||||||
|
auto has_tools = tools.is_array() && !tools.empty();
|
||||||
auto stream = json_value(body, "stream", false);
|
auto stream = json_value(body, "stream", false);
|
||||||
|
auto tool_choice = json_value(body, "tool_choice", std::string("auto"));
|
||||||
|
|
||||||
if (tools.is_array() && !tools.empty()) {
|
if (!opt.use_jinja) {
|
||||||
if (stream) {
|
if (has_tools) {
|
||||||
throw std::runtime_error("Cannot use tools with stream");
|
|
||||||
}
|
|
||||||
if (!opt.use_jinja) {
|
|
||||||
throw std::runtime_error("tools param requires --jinja flag");
|
throw std::runtime_error("tools param requires --jinja flag");
|
||||||
}
|
}
|
||||||
}
|
if (tool_choice != "auto") {
|
||||||
if (!opt.use_jinja) {
|
throw std::runtime_error("tool_choice param requires --jinja flag");
|
||||||
if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
|
|
||||||
throw std::runtime_error("Unsupported param: tool_choice");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -749,14 +726,12 @@ static json oaicompat_chat_params_parse(
|
|||||||
common_chat_templates_inputs inputs;
|
common_chat_templates_inputs inputs;
|
||||||
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
inputs.messages = common_chat_msgs_parse_oaicompat(messages);
|
||||||
inputs.tools = common_chat_tools_parse_oaicompat(tools);
|
inputs.tools = common_chat_tools_parse_oaicompat(tools);
|
||||||
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
|
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice);
|
||||||
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
|
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
|
||||||
inputs.grammar = grammar;
|
inputs.grammar = grammar;
|
||||||
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
|
|
||||||
inputs.use_jinja = opt.use_jinja;
|
inputs.use_jinja = opt.use_jinja;
|
||||||
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
|
||||||
inputs.extract_reasoning = opt.reasoning_format != COMMON_REASONING_FORMAT_NONE;
|
inputs.reasoning_format = opt.reasoning_format;
|
||||||
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
|
|
||||||
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
|
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
|
||||||
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
|
||||||
}
|
}
|
||||||
@ -774,7 +749,8 @@ static json oaicompat_chat_params_parse(
|
|||||||
throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
|
throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
|
||||||
}
|
}
|
||||||
|
|
||||||
inputs.extract_reasoning = false;
|
/* TODO: test this properly */
|
||||||
|
inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
||||||
inputs.add_generation_prompt = true;
|
inputs.add_generation_prompt = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -799,6 +775,7 @@ static json oaicompat_chat_params_parse(
|
|||||||
}
|
}
|
||||||
llama_params["grammar_triggers"] = grammar_triggers;
|
llama_params["grammar_triggers"] = grammar_triggers;
|
||||||
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
|
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
|
||||||
|
llama_params["thinking_forced_open"] = chat_params.thinking_forced_open;
|
||||||
for (const auto & stop : chat_params.additional_stops) {
|
for (const auto & stop : chat_params.additional_stops) {
|
||||||
llama_params["stop"].push_back(stop);
|
llama_params["stop"].push_back(stop);
|
||||||
}
|
}
|
||||||
@ -812,6 +789,9 @@ static json oaicompat_chat_params_parse(
|
|||||||
// Handle "logprobs" field
|
// Handle "logprobs" field
|
||||||
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
|
||||||
if (json_value(body, "logprobs", false)) {
|
if (json_value(body, "logprobs", false)) {
|
||||||
|
if (has_tools && stream) {
|
||||||
|
throw std::runtime_error("logprobs is not supported with tools + stream");
|
||||||
|
}
|
||||||
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
|
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
|
||||||
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
|
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
|
||||||
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
|
||||||
|
Reference in New Issue
Block a user