2025-02-18 18:03:23 +00:00
|
|
|
// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include "common.h"
|
2025-05-25 01:48:08 +01:00
|
|
|
#include <functional>
|
2025-05-15 02:39:51 +01:00
|
|
|
#include <chrono>
|
2025-02-18 18:03:23 +00:00
|
|
|
#include <string>
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
struct common_chat_templates;
|
|
|
|
|
|
|
|
struct common_chat_tool_call {
|
|
|
|
std::string name;
|
|
|
|
std::string arguments;
|
|
|
|
std::string id;
|
2025-05-25 01:48:08 +01:00
|
|
|
|
|
|
|
bool operator==(const common_chat_tool_call & other) const {
|
|
|
|
return name == other.name && arguments == other.arguments && id == other.id;
|
|
|
|
}
|
2025-02-18 18:03:23 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
struct common_chat_msg_content_part {
|
|
|
|
std::string type;
|
|
|
|
std::string text;
|
2025-05-25 01:48:08 +01:00
|
|
|
|
|
|
|
bool operator==(const common_chat_msg_content_part & other) const {
|
|
|
|
return type == other.type && text == other.text;
|
|
|
|
}
|
2025-02-18 18:03:23 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
struct common_chat_msg {
|
|
|
|
std::string role;
|
|
|
|
std::string content;
|
|
|
|
std::vector<common_chat_msg_content_part> content_parts = {};
|
|
|
|
std::vector<common_chat_tool_call> tool_calls = {};
|
|
|
|
std::string reasoning_content;
|
|
|
|
std::string tool_name;
|
|
|
|
std::string tool_call_id;
|
2025-05-25 01:48:08 +01:00
|
|
|
|
|
|
|
template <class T> T to_json_oaicompat() const;
|
|
|
|
|
|
|
|
bool empty() const {
|
|
|
|
return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
|
|
|
|
}
|
|
|
|
void ensure_tool_call_ids_set(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
|
|
|
|
for (auto i = 0u; i < tool_calls.size(); i++) {
|
|
|
|
if (ids_cache.size() <= i) {
|
|
|
|
auto id = tool_calls[i].id;
|
|
|
|
if (id.empty()) {
|
|
|
|
id = gen_tool_call_id();
|
|
|
|
}
|
|
|
|
ids_cache.push_back(id);
|
|
|
|
}
|
|
|
|
tool_calls[i].id = ids_cache[i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
bool operator==(const common_chat_msg & other) const {
|
|
|
|
return role == other.role
|
|
|
|
&& content == other.content
|
|
|
|
&& content_parts == other.content_parts
|
|
|
|
&& tool_calls == other.tool_calls
|
|
|
|
&& reasoning_content == other.reasoning_content
|
|
|
|
&& tool_name == other.tool_name
|
|
|
|
&& tool_call_id == other.tool_call_id;
|
|
|
|
}
|
|
|
|
bool operator!=(const common_chat_msg & other) const {
|
|
|
|
return !(*this == other);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
struct common_chat_msg_diff {
|
2025-06-02 10:15:44 -07:00
|
|
|
std::string reasoning_content_delta;
|
2025-05-25 01:48:08 +01:00
|
|
|
std::string content_delta;
|
|
|
|
size_t tool_call_index = std::string::npos;
|
|
|
|
common_chat_tool_call tool_call_delta;
|
|
|
|
|
|
|
|
static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
|
|
|
|
|
|
|
|
bool operator==(const common_chat_msg_diff & other) const {
|
|
|
|
return content_delta == other.content_delta
|
|
|
|
&& tool_call_index == other.tool_call_index
|
|
|
|
&& tool_call_delta == other.tool_call_delta;
|
|
|
|
}
|
2025-02-18 18:03:23 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
struct common_chat_tool {
|
|
|
|
std::string name;
|
|
|
|
std::string description;
|
|
|
|
std::string parameters;
|
|
|
|
};
|
|
|
|
|
|
|
|
enum common_chat_tool_choice {
|
|
|
|
COMMON_CHAT_TOOL_CHOICE_AUTO,
|
|
|
|
COMMON_CHAT_TOOL_CHOICE_REQUIRED,
|
|
|
|
COMMON_CHAT_TOOL_CHOICE_NONE,
|
|
|
|
};
|
|
|
|
|
|
|
|
enum common_chat_format {
|
|
|
|
COMMON_CHAT_FORMAT_CONTENT_ONLY,
|
|
|
|
COMMON_CHAT_FORMAT_GENERIC,
|
|
|
|
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
|
|
|
|
COMMON_CHAT_FORMAT_LLAMA_3_X,
|
|
|
|
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
|
|
|
|
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
|
|
|
|
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
|
|
|
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
|
|
|
|
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
|
|
|
|
COMMON_CHAT_FORMAT_HERMES_2_PRO,
|
|
|
|
COMMON_CHAT_FORMAT_COMMAND_R7B,
|
|
|
|
|
|
|
|
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
|
|
|
|
};
|
|
|
|
|
|
|
|
struct common_chat_templates_inputs {
|
|
|
|
std::vector<common_chat_msg> messages;
|
|
|
|
std::string grammar;
|
|
|
|
std::string json_schema;
|
|
|
|
bool add_generation_prompt = true;
|
|
|
|
bool use_jinja = true;
|
|
|
|
// Parameters below only supported when use_jinja is true
|
|
|
|
std::vector<common_chat_tool> tools;
|
|
|
|
common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
|
|
|
|
bool parallel_tool_calls = false;
|
2025-05-25 01:48:08 +01:00
|
|
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
2025-05-26 00:30:51 +01:00
|
|
|
bool enable_thinking = true;
|
2025-05-15 02:39:51 +01:00
|
|
|
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
|
2025-02-18 18:03:23 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
struct common_chat_params {
|
|
|
|
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
|
|
std::string prompt;
|
|
|
|
std::string grammar;
|
|
|
|
bool grammar_lazy = false;
|
2025-05-25 01:48:08 +01:00
|
|
|
bool thinking_forced_open = false;
|
2025-02-18 18:03:23 +00:00
|
|
|
std::vector<common_grammar_trigger> grammar_triggers;
|
|
|
|
std::vector<std::string> preserved_tokens;
|
|
|
|
std::vector<std::string> additional_stops;
|
|
|
|
};
|
|
|
|
|
2025-05-25 01:48:08 +01:00
|
|
|
struct common_chat_syntax {
|
|
|
|
common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
|
|
|
|
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
|
|
|
|
// Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
|
|
|
|
bool reasoning_in_content = false;
|
|
|
|
bool thinking_forced_open = false;
|
2025-05-26 08:03:57 -07:00
|
|
|
bool parse_tool_calls = true;
|
2025-05-25 01:48:08 +01:00
|
|
|
};
|
|
|
|
|
2025-02-18 18:03:23 +00:00
|
|
|
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
|
|
|
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
|
|
|
|
|
|
|
|
void common_chat_templates_free(struct common_chat_templates * tmpls);
|
|
|
|
|
|
|
|
struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
|
|
|
|
|
|
|
|
typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
|
|
|
|
|
|
|
|
common_chat_templates_ptr common_chat_templates_init(
|
|
|
|
const struct llama_model * model,
|
|
|
|
const std::string & chat_template_override,
|
|
|
|
const std::string & bos_token_override = "",
|
|
|
|
const std::string & eos_token_override = "");
|
|
|
|
|
|
|
|
bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
|
|
|
|
const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
|
|
|
|
|
|
|
|
|
|
|
|
struct common_chat_params common_chat_templates_apply(
|
|
|
|
const struct common_chat_templates * tmpls,
|
|
|
|
const struct common_chat_templates_inputs & inputs);
|
|
|
|
|
|
|
|
// Format single message, while taking into account the position of that message in chat history
|
|
|
|
std::string common_chat_format_single(
|
|
|
|
const struct common_chat_templates * tmpls,
|
|
|
|
const std::vector<common_chat_msg> & past_msg,
|
|
|
|
const common_chat_msg & new_msg,
|
|
|
|
bool add_ass,
|
|
|
|
bool use_jinja);
|
|
|
|
|
|
|
|
// Returns an example of formatted chat
|
|
|
|
std::string common_chat_format_example(
|
|
|
|
const struct common_chat_templates * tmpls,
|
|
|
|
bool use_jinja);
|
|
|
|
|
2025-05-26 00:30:51 +01:00
|
|
|
const char* common_chat_format_name(common_chat_format format);
|
|
|
|
const char* common_reasoning_format_name(common_reasoning_format format);
|
2025-05-25 01:48:08 +01:00
|
|
|
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
|
2025-02-18 18:03:23 +00:00
|
|
|
|
|
|
|
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
|
|
|
|
|
|
|
|
// Parses a JSON array of messages in OpenAI's chat completion API format.
|
|
|
|
// T can be std::string containing JSON or nlohmann::ordered_json
|
|
|
|
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
|
|
|
|
template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
|
|
|
|
|
|
|
|
// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
|
|
|
|
// T can be std::string containing JSON or nlohmann::ordered_json
|
|
|
|
template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
|
|
|
|
template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
|
2025-05-25 01:48:08 +01:00
|
|
|
|
|
|
|
template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
|