mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-26 19:55:04 +00:00
server
: fix format of streamed tool call deltas (diff name, fix id location) (#13800)
* fix deltas of tool_call.function.name * fix tool_call.id (was in tool_call.function.id!) + add function type * add tool_call.type * populate empty tool_call.function.arguments on first delta
This commit is contained in:
@ -106,9 +106,9 @@ std::vector<common_chat_msg_diff> common_chat_msg_diff::compute_diffs(const comm
|
|||||||
if (!args_diff.empty() || pref.id != newf.id) {
|
if (!args_diff.empty() || pref.id != newf.id) {
|
||||||
auto & diff = diffs.emplace_back();
|
auto & diff = diffs.emplace_back();
|
||||||
diff.tool_call_index = idx;
|
diff.tool_call_index = idx;
|
||||||
diff.tool_call_delta.name = newf.name;
|
|
||||||
if (pref.id != newf.id) {
|
if (pref.id != newf.id) {
|
||||||
diff.tool_call_delta.id = newf.id;
|
diff.tool_call_delta.id = newf.id;
|
||||||
|
diff.tool_call_delta.name = newf.name;
|
||||||
}
|
}
|
||||||
diff.tool_call_delta.arguments = args_diff;
|
diff.tool_call_delta.arguments = args_diff;
|
||||||
}
|
}
|
||||||
@ -392,22 +392,19 @@ template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_di
|
|||||||
delta["content"] = diff.content_delta;
|
delta["content"] = diff.content_delta;
|
||||||
}
|
}
|
||||||
if (diff.tool_call_index != std::string::npos) {
|
if (diff.tool_call_index != std::string::npos) {
|
||||||
|
json tool_call;
|
||||||
|
tool_call["index"] = diff.tool_call_index;
|
||||||
|
if (!diff.tool_call_delta.id.empty()) {
|
||||||
|
tool_call["id"] = diff.tool_call_delta.id;
|
||||||
|
tool_call["type"] = "function";
|
||||||
|
}
|
||||||
json function = json::object();
|
json function = json::object();
|
||||||
if (!diff.tool_call_delta.name.empty()) {
|
if (!diff.tool_call_delta.name.empty()) {
|
||||||
function["name"] = diff.tool_call_delta.name;
|
function["name"] = diff.tool_call_delta.name;
|
||||||
}
|
}
|
||||||
if (!diff.tool_call_delta.id.empty()) {
|
function["arguments"] = diff.tool_call_delta.arguments;
|
||||||
function["id"] = diff.tool_call_delta.id;
|
tool_call["function"] = function;
|
||||||
}
|
delta["tool_calls"] = json::array({tool_call});
|
||||||
if (!diff.tool_call_delta.arguments.empty()) {
|
|
||||||
function["arguments"] = diff.tool_call_delta.arguments;
|
|
||||||
}
|
|
||||||
delta["tool_calls"] = json::array({
|
|
||||||
json {
|
|
||||||
{"index", diff.tool_call_index},
|
|
||||||
{"function", function}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
return delta;
|
return delta;
|
||||||
}
|
}
|
||||||
|
@ -1356,8 +1356,7 @@ static void test_msg_diffs_compute() {
|
|||||||
|
|
||||||
common_chat_msg_diff diff12;
|
common_chat_msg_diff diff12;
|
||||||
diff12.tool_call_index = 0;
|
diff12.tool_call_index = 0;
|
||||||
diff12.tool_call_delta.name = "special_function";
|
// Note: neither id nor name change here.
|
||||||
// Note: id doesnt change here.
|
|
||||||
diff12.tool_call_delta.arguments = "g1\": 1}";
|
diff12.tool_call_delta.arguments = "g1\": 1}";
|
||||||
|
|
||||||
assert_equals(
|
assert_equals(
|
||||||
|
@ -328,6 +328,10 @@ class ServerProcess:
|
|||||||
if 'function' not in tc:
|
if 'function' not in tc:
|
||||||
raise ValueError(f"Expected function type, got {tc['type']}")
|
raise ValueError(f"Expected function type, got {tc['type']}")
|
||||||
if tc['index'] >= len(tool_calls):
|
if tc['index'] >= len(tool_calls):
|
||||||
|
assert 'id' in tc
|
||||||
|
assert tc.get('type') == 'function'
|
||||||
|
assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \
|
||||||
|
f"Expected function call with name, got {tc.get('function')}"
|
||||||
tool_calls.append(dict(
|
tool_calls.append(dict(
|
||||||
id="",
|
id="",
|
||||||
type="function",
|
type="function",
|
||||||
@ -340,10 +344,10 @@ class ServerProcess:
|
|||||||
if tc.get('id') is not None:
|
if tc.get('id') is not None:
|
||||||
tool_call['id'] = tc['id']
|
tool_call['id'] = tc['id']
|
||||||
fct = tc['function']
|
fct = tc['function']
|
||||||
|
assert 'id' not in fct, f"Function call should not have id: {fct}"
|
||||||
if fct.get('name') is not None:
|
if fct.get('name') is not None:
|
||||||
tool_call['function']['name'] = fct['name']
|
tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
|
||||||
if fct.get('arguments') is not None:
|
if fct.get('arguments') is not None:
|
||||||
assert len(fct['arguments']) > 0, f'Expected non empty arguments delta!'
|
|
||||||
tool_call['function']['arguments'] += fct['arguments']
|
tool_call['function']['arguments'] += fct['arguments']
|
||||||
|
|
||||||
print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
|
print(f'Streamed response had {content_parts} content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
|
||||||
|
Reference in New Issue
Block a user