mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-15 12:42:40 -04:00
chat : support Granite model reasoning and tool call (#14864)
This commit is contained in:
129
common/chat.cpp
129
common/chat.cpp
@@ -606,6 +606,7 @@ const char * common_chat_format_name(common_chat_format format) {
|
||||
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
|
||||
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
|
||||
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
|
||||
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
|
||||
default:
|
||||
throw std::runtime_error("Unknown chat format");
|
||||
@@ -618,6 +619,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
|
||||
case COMMON_REASONING_FORMAT_AUTO: return "auto";
|
||||
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
|
||||
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
|
||||
case COMMON_REASONING_FORMAT_GRANITE: return "granite";
|
||||
default:
|
||||
throw std::runtime_error("Unknown reasoning format");
|
||||
}
|
||||
@@ -1734,6 +1736,124 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
|
||||
// Pass thinking context for Granite template
|
||||
json additional_context = {
|
||||
{"thinking", inputs.enable_thinking},
|
||||
};
|
||||
|
||||
data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context);
|
||||
data.format = COMMON_CHAT_FORMAT_GRANITE;
|
||||
|
||||
if (string_ends_with(data.prompt, "<think>\n") || string_ends_with(data.prompt, "<think>")) {
|
||||
if (!inputs.enable_thinking) {
|
||||
data.prompt += "</think>";
|
||||
} else {
|
||||
data.thinking_forced_open = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!inputs.tools.is_null()) {
|
||||
// Granite uses <|tool_call|> followed by JSON list
|
||||
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
std::vector<std::string> tool_rules;
|
||||
foreach_function(inputs.tools, [&](const json & tool) {
|
||||
const auto & function = tool.at("function");
|
||||
std::string name = function.at("name");
|
||||
auto parameters = function.at("parameters");
|
||||
builder.resolve_refs(parameters);
|
||||
tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name +
|
||||
"-args", {
|
||||
{"type", "object"},
|
||||
{"properties", {
|
||||
{"name", {{"const", name}}},
|
||||
{"arguments", parameters},
|
||||
}},
|
||||
{"required", json::array({"name", "arguments"})},
|
||||
})));
|
||||
});
|
||||
|
||||
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
|
||||
auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\"");
|
||||
|
||||
if (data.thinking_forced_open) {
|
||||
builder.add_rule("root", "\"</think>\" space \"<response>\" space [^<]* \"</response>\" space \"<|tool_call|>\" space " + tool_list);
|
||||
} else {
|
||||
builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list);
|
||||
}
|
||||
|
||||
data.grammar_triggers.push_back({
|
||||
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
|
||||
"<|tool_call|>"
|
||||
});
|
||||
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<response>",
|
||||
"</response>",
|
||||
"<|tool_call|>",
|
||||
};
|
||||
});
|
||||
} else {
|
||||
// Handle thinking tags for non-tool responses
|
||||
if (data.thinking_forced_open && inputs.enable_thinking) {
|
||||
data.grammar_lazy = false;
|
||||
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
|
||||
builder.add_rule("root", "\"</think>\" space \"<response>\" space .* \"</response>\" space");
|
||||
});
|
||||
data.preserved_tokens = {
|
||||
"<think>",
|
||||
"</think>",
|
||||
"<response>",
|
||||
"</response>",
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
|
||||
// Parse thinking tags
|
||||
builder.try_parse_reasoning("<think>", "</think>");
|
||||
|
||||
// Parse response tags using regex
|
||||
static const common_regex response_regex("<response>([\\s\\S]*?)</response>");
|
||||
if (auto res = builder.try_find_regex(response_regex)) {
|
||||
// Extract the content between the tags (capture group 1)
|
||||
auto content = builder.str(res->groups[1]);
|
||||
builder.add_content(content);
|
||||
builder.move_to(res->groups[0].end);
|
||||
}
|
||||
|
||||
if (!builder.syntax().parse_tool_calls) {
|
||||
builder.add_content(builder.consume_rest());
|
||||
return;
|
||||
}
|
||||
|
||||
// Look for tool calls
|
||||
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
|
||||
if (auto res = builder.try_find_regex(tool_call_regex)) {
|
||||
builder.move_to(res->groups[0].end);
|
||||
|
||||
// Expect JSON array of tool calls
|
||||
auto tool_calls_data = builder.consume_json();
|
||||
if (tool_calls_data.json.is_array()) {
|
||||
if (!builder.add_tool_calls(tool_calls_data.json)) {
|
||||
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
|
||||
}
|
||||
} else {
|
||||
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
|
||||
}
|
||||
} else {
|
||||
builder.add_content(builder.consume_rest());
|
||||
}
|
||||
}
|
||||
|
||||
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
|
||||
common_chat_params data;
|
||||
data.prompt = apply(tmpl, inputs);
|
||||
@@ -1805,6 +1925,11 @@ static common_chat_params common_chat_templates_apply_jinja(
|
||||
return common_chat_params_init_command_r7b(tmpl, params);
|
||||
}
|
||||
|
||||
// Granite (IBM) - detects thinking / tools support
|
||||
if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) {
|
||||
return common_chat_params_init_granite(tmpl, params);
|
||||
}
|
||||
|
||||
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
|
||||
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
|
||||
return common_chat_params_init_hermes_2_pro(tmpl, params);
|
||||
@@ -1865,6 +1990,7 @@ static common_chat_params common_chat_templates_apply_legacy(
|
||||
int alloc_size = 0;
|
||||
std::vector<llama_chat_message> chat;
|
||||
std::vector<std::string> contents;
|
||||
|
||||
for (const auto & msg : inputs.messages) {
|
||||
auto content = msg.content;
|
||||
for (const auto & part : msg.content_parts) {
|
||||
@@ -1966,6 +2092,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
|
||||
case COMMON_CHAT_FORMAT_COMMAND_R7B:
|
||||
common_chat_parse_command_r7b(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GRANITE:
|
||||
common_chat_parse_granite(builder);
|
||||
break;
|
||||
case COMMON_CHAT_FORMAT_GPT_OSS:
|
||||
common_chat_parse_gpt_oss(builder);
|
||||
break;
|
||||
|
Reference in New Issue
Block a user