mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 20:05:20 +00:00
Server: format error to json (#5961)
* server: format error to json * server: do not crash on grammar error * fix api key test case * revert limit max n_predict * small fix * correct coding style * update completion.js * launch_slot_with_task * update docs * update_slots * update webui * update readme
This commit is contained in:
@ -396,7 +396,7 @@ struct server_queue {
|
||||
// callback functions
|
||||
std::function<void(server_task &)> callback_new_task;
|
||||
std::function<void(server_task_multi &)> callback_finish_multitask;
|
||||
std::function<void(void)> callback_run_slots;
|
||||
std::function<void(void)> callback_update_slots;
|
||||
|
||||
// Add a new task to the end of the queue
|
||||
int post(server_task task) {
|
||||
@ -435,8 +435,8 @@ struct server_queue {
|
||||
}
|
||||
|
||||
// Register the function to be called when all slots data is ready to be processed
|
||||
void on_run_slots(std::function<void(void)> callback) {
|
||||
callback_run_slots = std::move(callback);
|
||||
void on_update_slots(std::function<void(void)> callback) {
|
||||
callback_update_slots = std::move(callback);
|
||||
}
|
||||
|
||||
// Call when the state of one slot is changed
|
||||
@ -461,7 +461,7 @@ struct server_queue {
|
||||
* - Wait until a new task arrives
|
||||
* - Process the task (i.e. maybe copy data into slot)
|
||||
* - Check if multitask is finished
|
||||
* - Run all slots
|
||||
* - Update all slots
|
||||
*/
|
||||
void start_loop() {
|
||||
running = true;
|
||||
@ -499,9 +499,9 @@ struct server_queue {
|
||||
}
|
||||
|
||||
// all tasks in the current loop is processed, slots data is now ready
|
||||
LOG_VERBOSE("callback_run_slots", {});
|
||||
LOG_VERBOSE("callback_update_slots", {});
|
||||
|
||||
callback_run_slots();
|
||||
callback_update_slots();
|
||||
|
||||
LOG_VERBOSE("wait for new task", {});
|
||||
{
|
||||
@ -805,9 +805,10 @@ struct server_context {
|
||||
return last_used;
|
||||
}
|
||||
|
||||
bool launch_slot_with_data(server_slot & slot, json data) const {
|
||||
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
||||
slot_params default_params;
|
||||
llama_sampling_params default_sparams;
|
||||
auto & data = task.data;
|
||||
|
||||
if (data.count("__oaicompat") != 0) {
|
||||
slot.oaicompat = true;
|
||||
@ -864,10 +865,15 @@ struct server_context {
|
||||
{
|
||||
const auto & prompt = data.find("prompt");
|
||||
if (prompt == data.end()) {
|
||||
slot.prompt = "";
|
||||
send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
} else {
|
||||
slot.prompt = *prompt;
|
||||
}
|
||||
if (slot.prompt.is_array() && slot.prompt.size() == 0) {
|
||||
send_error(task, "\"prompt\" cannot be an empty array", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// penalize user-provided tokens
|
||||
@ -926,6 +932,7 @@ struct server_context {
|
||||
if (logit_bias != data.end() && logit_bias->is_array()) {
|
||||
const int n_vocab = llama_n_vocab(model);
|
||||
for (const auto & el : *logit_bias) {
|
||||
// TODO: we may want to throw errors here, in case "el" is incorrect
|
||||
if (el.is_array() && el.size() == 2) {
|
||||
float bias;
|
||||
if (el[1].is_number()) {
|
||||
@ -985,6 +992,11 @@ struct server_context {
|
||||
llama_sampling_free(slot.ctx_sampling);
|
||||
}
|
||||
slot.ctx_sampling = llama_sampling_init(slot.sparams);
|
||||
if (slot.ctx_sampling == nullptr) {
|
||||
// for now, the only error that may happen here is invalid grammar
|
||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
llama_set_rng_seed(ctx, slot.params.seed);
|
||||
}
|
||||
|
||||
@ -1226,15 +1238,23 @@ struct server_context {
|
||||
};
|
||||
}
|
||||
|
||||
void send_error(const server_task & task, const std::string & error) {
|
||||
LOG_TEE("task %i - error: %s\n", task.id, error.c_str());
|
||||
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
||||
send_error(task.id, task.id_multi, error, type);
|
||||
}
|
||||
|
||||
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
||||
send_error(slot.id_task, slot.id_multi, error, type);
|
||||
}
|
||||
|
||||
void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
||||
LOG_TEE("task %i - error: %s\n", id_task, error.c_str());
|
||||
|
||||
server_task_result res;
|
||||
res.id = task.id;
|
||||
res.id_multi = task.id_multi;
|
||||
res.id = id_task;
|
||||
res.id_multi = id_multi;
|
||||
res.stop = false;
|
||||
res.error = true;
|
||||
res.data = { { "content", error } };
|
||||
res.data = format_error_response(error, type);
|
||||
|
||||
queue_results.send(res);
|
||||
}
|
||||
@ -1468,9 +1488,8 @@ struct server_context {
|
||||
slot->infill = task.infill;
|
||||
slot->embedding = task.embedding;
|
||||
|
||||
if (!launch_slot_with_data(*slot, task.data)) {
|
||||
// send error result
|
||||
send_error(task, "internal_error");
|
||||
if (!launch_slot_with_task(*slot, task)) {
|
||||
LOG_ERROR("error while launching slot", task.data);
|
||||
break;
|
||||
}
|
||||
} break;
|
||||
@ -1587,7 +1606,7 @@ struct server_context {
|
||||
queue_results.send(result);
|
||||
}
|
||||
|
||||
bool update_slots() {
|
||||
void update_slots() {
|
||||
if (system_need_update) {
|
||||
system_prompt_update();
|
||||
}
|
||||
@ -1630,7 +1649,7 @@ struct server_context {
|
||||
kv_cache_clear();
|
||||
}
|
||||
|
||||
return true;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1975,8 +1994,7 @@ struct server_context {
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
LOG_VERBOSE("no tokens to decode", {});
|
||||
|
||||
return true;
|
||||
return;
|
||||
}
|
||||
|
||||
LOG_VERBOSE("decoding batch", {
|
||||
@ -2033,7 +2051,13 @@ struct server_context {
|
||||
if (n_batch == 1 || ret < 0) {
|
||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
||||
return false;
|
||||
for (auto & slot : slots) {
|
||||
slot.state = SLOT_STATE_PROCESSING;
|
||||
slot.command = SLOT_COMMAND_NONE;
|
||||
slot.release();
|
||||
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
||||
}
|
||||
break; // break loop of n_batch
|
||||
}
|
||||
|
||||
LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2);
|
||||
@ -2042,12 +2066,12 @@ struct server_context {
|
||||
n_batch /= 2;
|
||||
i -= n_batch;
|
||||
|
||||
continue;
|
||||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
for (auto & slot : slots) {
|
||||
if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
|
||||
continue;
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
// prompt evaluated for embedding
|
||||
@ -2055,7 +2079,7 @@ struct server_context {
|
||||
send_embedding(slot, batch_view);
|
||||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
continue;
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
completion_token_output result;
|
||||
@ -2097,9 +2121,7 @@ struct server_context {
|
||||
}
|
||||
}
|
||||
|
||||
LOG_VERBOSE("slots updated", {});
|
||||
|
||||
return true;
|
||||
LOG_VERBOSE("run slots completed", {});
|
||||
}
|
||||
|
||||
json model_meta() const {
|
||||
@ -2745,32 +2767,32 @@ int main(int argc, char ** argv) {
|
||||
|
||||
svr->set_logger(log_server_request);
|
||||
|
||||
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||
const char fmt[] = "500 Internal Server Error\n%s";
|
||||
auto res_error = [](httplib::Response & res, json error_data) {
|
||||
json final_response {{"error", error_data}};
|
||||
res.set_content(final_response.dump(), "application/json; charset=utf-8");
|
||||
res.status = json_value(error_data, "code", 500);
|
||||
};
|
||||
|
||||
char buf[BUFSIZ];
|
||||
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
||||
std::string message;
|
||||
try {
|
||||
std::rethrow_exception(std::move(ep));
|
||||
} catch (std::exception &e) {
|
||||
snprintf(buf, sizeof(buf), fmt, e.what());
|
||||
} catch (std::exception & e) {
|
||||
message = e.what();
|
||||
} catch (...) {
|
||||
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
|
||||
message = "Unknown Exception";
|
||||
}
|
||||
|
||||
res.set_content(buf, "text/plain; charset=utf-8");
|
||||
res.status = 500;
|
||||
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
|
||||
LOG_VERBOSE("Got exception", formatted_error);
|
||||
res_error(res, formatted_error);
|
||||
});
|
||||
|
||||
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
|
||||
if (res.status == 401) {
|
||||
res.set_content("Unauthorized", "text/plain; charset=utf-8");
|
||||
}
|
||||
if (res.status == 400) {
|
||||
res.set_content("Invalid request", "text/plain; charset=utf-8");
|
||||
}
|
||||
svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
|
||||
if (res.status == 404) {
|
||||
res.set_content("File Not Found", "text/plain; charset=utf-8");
|
||||
res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
|
||||
}
|
||||
// for other error codes, we skip processing here because it's already done by res_error()
|
||||
});
|
||||
|
||||
// set timeouts and change hostname and port
|
||||
@ -2835,7 +2857,7 @@ int main(int argc, char ** argv) {
|
||||
// Middlewares
|
||||
//
|
||||
|
||||
auto middleware_validate_api_key = [&sparams](const httplib::Request & req, httplib::Response & res) {
|
||||
auto middleware_validate_api_key = [&sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
|
||||
static const std::set<std::string> protected_endpoints = {
|
||||
"/props",
|
||||
@ -2876,8 +2898,7 @@ int main(int argc, char ** argv) {
|
||||
// API key is invalid or not provided
|
||||
// TODO: make another middleware for CORS related logic
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
res.set_content("Unauthorized: Invalid API Key", "text/plain; charset=utf-8");
|
||||
res.status = 401; // Unauthorized
|
||||
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
|
||||
|
||||
LOG_WARNING("Unauthorized: Invalid API Key", {});
|
||||
|
||||
@ -2940,21 +2961,18 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
case SERVER_STATE_LOADING_MODEL:
|
||||
{
|
||||
res.set_content(R"({"status": "loading model"})", "application/json");
|
||||
res.status = 503; // HTTP Service Unavailable
|
||||
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
||||
} break;
|
||||
case SERVER_STATE_ERROR:
|
||||
{
|
||||
res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json");
|
||||
res.status = 500; // HTTP Internal Server Error
|
||||
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
|
||||
} break;
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
|
||||
if (!sparams.slots_endpoint) {
|
||||
res.status = 501;
|
||||
res.set_content("This server does not support slots endpoint.", "text/plain; charset=utf-8");
|
||||
res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -2978,8 +2996,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
||||
if (!sparams.metrics_endpoint) {
|
||||
res.status = 501;
|
||||
res.set_content("This server does not support metrics endpoint.", "text/plain; charset=utf-8");
|
||||
res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -3090,7 +3107,7 @@ int main(int argc, char ** argv) {
|
||||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
};
|
||||
|
||||
const auto handle_completions = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
json data = json::parse(req.body);
|
||||
@ -3105,8 +3122,7 @@ int main(int argc, char ** argv) {
|
||||
if (!result.error && result.stop) {
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
||||
} else {
|
||||
res.status = 500;
|
||||
res.set_content(result.data["content"], "text/plain; charset=utf-8");
|
||||
res_error(res, result.data);
|
||||
}
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
@ -3186,7 +3202,7 @@ int main(int argc, char ** argv) {
|
||||
res.set_content(models.dump(), "application/json; charset=utf-8");
|
||||
};
|
||||
|
||||
const auto handle_chat_completions = [&ctx_server, &sparams](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
|
||||
|
||||
@ -3204,8 +3220,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
||||
} else {
|
||||
res.status = 500;
|
||||
res.set_content(result.data["content"], "text/plain; charset=utf-8");
|
||||
res_error(res, result.data);
|
||||
}
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
} else {
|
||||
@ -3259,7 +3274,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
};
|
||||
|
||||
const auto handle_infill = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
|
||||
json data = json::parse(req.body);
|
||||
@ -3274,8 +3289,7 @@ int main(int argc, char ** argv) {
|
||||
if (!result.error && result.stop) {
|
||||
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
|
||||
} else {
|
||||
res.status = 404;
|
||||
res.set_content(result.data["content"], "text/plain; charset=utf-8");
|
||||
res_error(res, result.data);
|
||||
}
|
||||
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
@ -3346,7 +3360,7 @@ int main(int argc, char ** argv) {
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
};
|
||||
|
||||
const auto handle_embeddings = [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||
const auto handle_embeddings = [¶ms, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
if (!params.embedding) {
|
||||
res.status = 501;
|
||||
@ -3375,8 +3389,8 @@ int main(int argc, char ** argv) {
|
||||
std::string content = body["content"];
|
||||
prompts.push_back(content);
|
||||
} else {
|
||||
// TODO @ngxson : should return an error here
|
||||
prompts.push_back("");
|
||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
// process all prompts
|
||||
@ -3392,9 +3406,14 @@ int main(int argc, char ** argv) {
|
||||
// get the result
|
||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||
|
||||
// append to the responses
|
||||
responses.push_back(result.data);
|
||||
if (!result.error) {
|
||||
// append to the responses
|
||||
responses.push_back(result.data);
|
||||
} else {
|
||||
// error received, ignore everything else
|
||||
res_error(res, result.data);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// write JSON response
|
||||
@ -3488,7 +3507,7 @@ int main(int argc, char ** argv) {
|
||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_finish_multitask(std::bind(
|
||||
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
|
||||
ctx_server.queue_tasks.on_run_slots(std::bind(
|
||||
ctx_server.queue_tasks.on_update_slots(std::bind(
|
||||
&server_context::update_slots, &ctx_server));
|
||||
ctx_server.queue_results.on_multitask_update(std::bind(
|
||||
&server_queue::update_multitask,
|
||||
|
Reference in New Issue
Block a user