server : add "/chat/completions" alias for "/v1/...` (#5722)

* Add "/chat/completions" as alias for "/v1/chat/completions"

* merge to upstream master

* minor : fix trailing whitespace

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Jorge A
2024-02-28 01:39:15 -07:00
committed by GitHub
parent 7c4263d426
commit efc72253f7
3 changed files with 115 additions and 68 deletions

View File

@ -3211,87 +3211,88 @@ int main(int argc, char **argv)
res.set_content(models.dump(), "application/json; charset=utf-8");
});
const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
// TODO: add mount point without "/v1" prefix -- how?
svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop) {
json oaicompat_result = format_final_response_oaicompat(data, result);
if (!result.error && result.stop) {
json oaicompat_result = format_final_response_oaicompat(data, result);
res.set_content(oaicompat_result.dump(-1, ' ', false,
json::error_handler_t::replace),
"application/json; charset=utf-8");
} else {
res.status = 500;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
}
llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
while (true) {
task_result llama_result = llama.queue_results.recv(task_id);
if (!llama_result.error) {
std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
res.set_content(oaicompat_result.dump(-1, ' ', false,
json::error_handler_t::replace),
"application/json; charset=utf-8");
} else {
res.status = 500;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
}
llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
while (true) {
task_result llama_result = llama.queue_results.recv(task_id);
if (!llama_result.error) {
std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
for (auto it = result_array.begin(); it != result_array.end(); ++it)
{
if (!it->empty()) {
const std::string str =
"data: " +
it->dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
}
}
if (llama_result.stop) {
break;
}
} else {
for (auto it = result_array.begin(); it != result_array.end(); ++it)
{
if (!it->empty()) {
const std::string str =
"error: " +
llama_result.result_json.dump(-1, ' ', false,
json::error_handler_t::replace) +
"data: " +
it->dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
break;
}
}
sink.done();
llama.queue_results.remove_waiting_task_id(task_id);
return true;
};
auto on_complete = [task_id, &llama](bool) {
// cancel request
llama.request_cancel(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
if (llama_result.stop) {
break;
}
} else {
const std::string str =
"error: " +
llama_result.result_json.dump(-1, ' ', false,
json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
break;
}
}
});
sink.done();
llama.queue_results.remove_waiting_task_id(task_id);
return true;
};
auto on_complete = [task_id, &llama](bool) {
// cancel request
llama.request_cancel(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
};
svr.Post("/chat/completions", chat_completions);
svr.Post("/v1/chat/completions", chat_completions);
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{