chat : support Granite model reasoning and tool call (#14864)

This commit is contained in:
Sachin Desai
2025-08-06 11:27:30 -07:00
committed by GitHub
parent 476aa3fd57
commit 3db4da56a5
6 changed files with 252 additions and 1 deletions

View File

@@ -606,6 +606,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
default:
throw std::runtime_error("Unknown chat format");
@@ -618,6 +619,7 @@ const char * common_reasoning_format_name(common_reasoning_format format) {
case COMMON_REASONING_FORMAT_AUTO: return "auto";
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
case COMMON_REASONING_FORMAT_GRANITE: return "granite";
default:
throw std::runtime_error("Unknown reasoning format");
}
@@ -1734,6 +1736,124 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.add_content(builder.consume_rest());
}
static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// Pass thinking context for Granite template
json additional_context = {
{"thinking", inputs.enable_thinking},
};
data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context);
data.format = COMMON_CHAT_FORMAT_GRANITE;
if (string_ends_with(data.prompt, "<think>\n") || string_ends_with(data.prompt, "<think>")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}
if (!inputs.tools.is_null()) {
// Granite uses <|tool_call|> followed by JSON list
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name +
"-args", {
{"type", "object"},
{"properties", {
{"name", {{"const", name}}},
{"arguments", parameters},
}},
{"required", json::array({"name", "arguments"})},
})));
});
auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | "));
auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\"");
if (data.thinking_forced_open) {
builder.add_rule("root", "\"</think>\" space \"<response>\" space [^<]* \"</response>\" space \"<|tool_call|>\" space " + tool_list);
} else {
builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list);
}
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
"<|tool_call|>"
});
data.preserved_tokens = {
"<think>",
"</think>",
"<response>",
"</response>",
"<|tool_call|>",
};
});
} else {
// Handle thinking tags for non-tool responses
if (data.thinking_forced_open && inputs.enable_thinking) {
data.grammar_lazy = false;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
builder.add_rule("root", "\"</think>\" space \"<response>\" space .* \"</response>\" space");
});
data.preserved_tokens = {
"<think>",
"</think>",
"<response>",
"</response>",
};
}
}
return data;
}
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
// Parse response tags using regex
static const common_regex response_regex("<response>([\\s\\S]*?)</response>");
if (auto res = builder.try_find_regex(response_regex)) {
// Extract the content between the tags (capture group 1)
auto content = builder.str(res->groups[1]);
builder.add_content(content);
builder.move_to(res->groups[0].end);
}
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.add_tool_calls(tool_calls_data.json)) {
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
}
} else {
builder.add_content("<|tool_call|>" + tool_calls_data.json.dump());
}
} else {
builder.add_content(builder.consume_rest());
}
}
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
@@ -1805,6 +1925,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_command_r7b(tmpl, params);
}
// Granite (IBM) - detects thinking / tools support
if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) {
return common_chat_params_init_granite(tmpl, params);
}
// Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools)
if (src.find("<tool_call>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_hermes_2_pro(tmpl, params);
@@ -1865,6 +1990,7 @@ static common_chat_params common_chat_templates_apply_legacy(
int alloc_size = 0;
std::vector<llama_chat_message> chat;
std::vector<std::string> contents;
for (const auto & msg : inputs.messages) {
auto content = msg.content;
for (const auto & part : msg.content_parts) {
@@ -1966,6 +2092,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_GRANITE:
common_chat_parse_granite(builder);
break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;