server : fix first message identification (#13634)

* server : fix first message identification

When using the OpenAI SDK (https://github.com/openai/openai-node/blob/master/src/lib/ChatCompletionStream.ts#L623-L626) we noticed that the expected assistant role is missing in the first streaming message. Fix this by correctly checking for the first message.

Co-authored-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
Signed-off-by: Dorin Geman <dorin.geman@docker.com>

* server : Fix checks for first role message for stream=True

Co-authored-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
Signed-off-by: Dorin Geman <dorin.geman@docker.com>

---------

Signed-off-by: Dorin Geman <dorin.geman@docker.com>
Co-authored-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
This commit is contained in:
Dorin-Andrei Geman
2025-05-21 16:07:57 +03:00
committed by GitHub
parent 797f2ac062
commit 42158ae2e8
2 changed files with 53 additions and 21 deletions

View File

@ -951,7 +951,7 @@ struct server_task_result_cmpl_partial : server_task_result {
}
json to_json_oaicompat_chat() {
bool first = n_decoded == 0;
bool first = n_decoded == 1;
std::time_t t = std::time(0);
json choices;
@ -962,15 +962,18 @@ struct server_task_result_cmpl_partial : server_task_result {
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
// initial_ret is the role message for stream=True
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
{"role", "assistant"},
{"content", ""}
}}}})},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}};
json second_ret = json{
@ -982,8 +985,19 @@ struct server_task_result_cmpl_partial : server_task_result {
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"}};
if (prob_output.probs.size() > 0) {
second_ret["choices"][0]["logprobs"] = json{
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
};
}
if (timings.prompt_n >= 0) {
second_ret.push_back({"timings", timings.to_json()});
}
return std::vector<json>({initial_ret, second_ret});
}
} else {

View File

@ -71,8 +71,14 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
})
content = ""
last_cmpl_id = None
for data in res:
for i, data in enumerate(res):
choice = data["choices"][0]
if i == 0:
# Check first role message for stream=True
assert choice["delta"]["content"] == ""
assert choice["delta"]["role"] == "assistant"
else:
assert "role" not in choice["delta"]
assert data["system_fingerprint"].startswith("b")
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
if last_cmpl_id is None:
@ -242,12 +248,18 @@ def test_chat_completion_with_timings_per_token():
"stream": True,
"timings_per_token": True,
})
for data in res:
assert "timings" in data
assert "prompt_per_second" in data["timings"]
assert "predicted_per_second" in data["timings"]
assert "predicted_n" in data["timings"]
assert data["timings"]["predicted_n"] <= 10
for i, data in enumerate(res):
if i == 0:
# Check first role message for stream=True
assert data["choices"][0]["delta"]["content"] == ""
assert data["choices"][0]["delta"]["role"] == "assistant"
else:
assert "role" not in data["choices"][0]["delta"]
assert "timings" in data
assert "prompt_per_second" in data["timings"]
assert "predicted_per_second" in data["timings"]
assert "predicted_n" in data["timings"]
assert data["timings"]["predicted_n"] <= 10
def test_logprobs():
@ -295,17 +307,23 @@ def test_logprobs_stream():
)
output_text = ''
aggregated_text = ''
for data in res:
for i, data in enumerate(res):
choice = data.choices[0]
if choice.finish_reason is None:
if choice.delta.content:
output_text += choice.delta.content
assert choice.logprobs is not None
assert choice.logprobs.content is not None
for token in choice.logprobs.content:
aggregated_text += token.token
assert token.logprob <= 0.0
assert token.bytes is not None
assert token.top_logprobs is not None
assert len(token.top_logprobs) > 0
if i == 0:
# Check first role message for stream=True
assert choice.delta.content == ""
assert choice.delta.role == "assistant"
else:
assert choice.delta.role is None
if choice.finish_reason is None:
if choice.delta.content:
output_text += choice.delta.content
assert choice.logprobs is not None
assert choice.logprobs.content is not None
for token in choice.logprobs.content:
aggregated_text += token.token
assert token.logprob <= 0.0
assert token.bytes is not None
assert token.top_logprobs is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text