mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 12:25:03 +00:00
server : use std::move whenever possible (#12936)
* server : use std::move whenever possible * use r-value ref * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * make task creation scoped * restore std::move * fix task_id not set correctly * apply changes from suggestion Co-authored-by: ggerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
@ -1552,29 +1552,30 @@ struct server_queue {
|
|||||||
std::condition_variable condition_tasks;
|
std::condition_variable condition_tasks;
|
||||||
|
|
||||||
// callback functions
|
// callback functions
|
||||||
std::function<void(server_task)> callback_new_task;
|
std::function<void(server_task &&)> callback_new_task;
|
||||||
std::function<void(void)> callback_update_slots;
|
std::function<void(void)> callback_update_slots;
|
||||||
|
|
||||||
// Add a new task to the end of the queue
|
// Add a new task to the end of the queue
|
||||||
int post(server_task task, bool front = false) {
|
int post(server_task && task, bool front = false) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||||
GGML_ASSERT(task.id != -1);
|
GGML_ASSERT(task.id != -1);
|
||||||
// if this is cancel task make sure to clean up pending tasks
|
// if this is cancel task make sure to clean up pending tasks
|
||||||
if (task.type == SERVER_TASK_TYPE_CANCEL) {
|
if (task.type == SERVER_TASK_TYPE_CANCEL) {
|
||||||
cleanup_pending_task(task.id_target);
|
cleanup_pending_task(task.id_target);
|
||||||
}
|
}
|
||||||
QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
|
const int task_id = task.id;
|
||||||
|
QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
|
||||||
if (front) {
|
if (front) {
|
||||||
queue_tasks.push_front(std::move(task));
|
queue_tasks.push_front(std::move(task));
|
||||||
} else {
|
} else {
|
||||||
queue_tasks.push_back(std::move(task));
|
queue_tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
condition_tasks.notify_one();
|
condition_tasks.notify_one();
|
||||||
return task.id;
|
return task_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
// multi-task version of post()
|
// multi-task version of post()
|
||||||
int post(std::vector<server_task> & tasks, bool front = false) {
|
int post(std::vector<server_task> && tasks, bool front = false) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||||
for (auto & task : tasks) {
|
for (auto & task : tasks) {
|
||||||
if (task.id == -1) {
|
if (task.id == -1) {
|
||||||
@ -1596,7 +1597,7 @@ struct server_queue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add a new task, but defer until one slot is available
|
// Add a new task, but defer until one slot is available
|
||||||
void defer(server_task task) {
|
void defer(server_task && task) {
|
||||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||||
QUE_DBG("defer task, id = %d\n", task.id);
|
QUE_DBG("defer task, id = %d\n", task.id);
|
||||||
queue_tasks_deferred.push_back(std::move(task));
|
queue_tasks_deferred.push_back(std::move(task));
|
||||||
@ -1611,7 +1612,7 @@ struct server_queue {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Register function to process a new task
|
// Register function to process a new task
|
||||||
void on_new_task(std::function<void(server_task)> callback) {
|
void on_new_task(std::function<void(server_task &&)> callback) {
|
||||||
callback_new_task = std::move(callback);
|
callback_new_task = std::move(callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1660,7 +1661,7 @@ struct server_queue {
|
|||||||
lock.unlock();
|
lock.unlock();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
server_task task = queue_tasks.front();
|
server_task task = std::move(queue_tasks.front());
|
||||||
queue_tasks.pop_front();
|
queue_tasks.pop_front();
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
|
|
||||||
@ -2004,7 +2005,7 @@ struct server_context {
|
|||||||
|
|
||||||
slot.reset();
|
slot.reset();
|
||||||
|
|
||||||
slots.push_back(slot);
|
slots.push_back(std::move(slot));
|
||||||
}
|
}
|
||||||
|
|
||||||
default_generation_settings_for_props = slots[0].to_json();
|
default_generation_settings_for_props = slots[0].to_json();
|
||||||
@ -2105,7 +2106,7 @@ struct server_context {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
bool launch_slot_with_task(server_slot & slot, server_task && task) {
|
||||||
slot.reset();
|
slot.reset();
|
||||||
slot.id_task = task.id;
|
slot.id_task = task.id;
|
||||||
slot.index = task.index;
|
slot.index = task.index;
|
||||||
@ -2113,10 +2114,10 @@ struct server_context {
|
|||||||
slot.params = std::move(task.params);
|
slot.params = std::move(task.params);
|
||||||
slot.prompt_tokens = std::move(task.prompt_tokens);
|
slot.prompt_tokens = std::move(task.prompt_tokens);
|
||||||
|
|
||||||
if (!are_lora_equal(task.params.lora, slot.lora)) {
|
if (!are_lora_equal(slot.params.lora, slot.lora)) {
|
||||||
// if lora is changed, we cannot reuse cached tokens
|
// if lora is changed, we cannot reuse cached tokens
|
||||||
slot.cache_tokens.clear();
|
slot.cache_tokens.clear();
|
||||||
slot.lora = task.params.lora;
|
slot.lora = slot.params.lora;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
|
bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
|
||||||
@ -2547,10 +2548,10 @@ struct server_context {
|
|||||||
server_task task(SERVER_TASK_TYPE_CANCEL);
|
server_task task(SERVER_TASK_TYPE_CANCEL);
|
||||||
task.id_target = id_task;
|
task.id_target = id_task;
|
||||||
queue_results.remove_waiting_task_id(id_task);
|
queue_results.remove_waiting_task_id(id_task);
|
||||||
cancel_tasks.push_back(task);
|
cancel_tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
// push to beginning of the queue, so it has highest priority
|
// push to beginning of the queue, so it has highest priority
|
||||||
queue_tasks.post(cancel_tasks, true);
|
queue_tasks.post(std::move(cancel_tasks), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// receive the results from task(s)
|
// receive the results from task(s)
|
||||||
@ -2637,7 +2638,7 @@ struct server_context {
|
|||||||
// Functions to process the task
|
// Functions to process the task
|
||||||
//
|
//
|
||||||
|
|
||||||
void process_single_task(server_task task) {
|
void process_single_task(server_task && task) {
|
||||||
switch (task.type) {
|
switch (task.type) {
|
||||||
case SERVER_TASK_TYPE_COMPLETION:
|
case SERVER_TASK_TYPE_COMPLETION:
|
||||||
case SERVER_TASK_TYPE_INFILL:
|
case SERVER_TASK_TYPE_INFILL:
|
||||||
@ -2651,17 +2652,17 @@ struct server_context {
|
|||||||
if (slot == nullptr) {
|
if (slot == nullptr) {
|
||||||
// if no slot is available, we defer this task for processing later
|
// if no slot is available, we defer this task for processing later
|
||||||
SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
|
SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
|
||||||
queue_tasks.defer(task);
|
queue_tasks.defer(std::move(task));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if (slot->is_processing()) {
|
if (slot->is_processing()) {
|
||||||
// if requested slot is unavailable, we defer this task for processing later
|
// if requested slot is unavailable, we defer this task for processing later
|
||||||
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
||||||
queue_tasks.defer(task);
|
queue_tasks.defer(std::move(task));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!launch_slot_with_task(*slot, task)) {
|
if (!launch_slot_with_task(*slot, std::move(task))) {
|
||||||
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -2740,7 +2741,7 @@ struct server_context {
|
|||||||
if (slot->is_processing()) {
|
if (slot->is_processing()) {
|
||||||
// if requested slot is unavailable, we defer this task for processing later
|
// if requested slot is unavailable, we defer this task for processing later
|
||||||
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
||||||
queue_tasks.defer(task);
|
queue_tasks.defer(std::move(task));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2776,7 +2777,7 @@ struct server_context {
|
|||||||
if (slot->is_processing()) {
|
if (slot->is_processing()) {
|
||||||
// if requested slot is unavailable, we defer this task for processing later
|
// if requested slot is unavailable, we defer this task for processing later
|
||||||
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
||||||
queue_tasks.defer(task);
|
queue_tasks.defer(std::move(task));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2819,7 +2820,7 @@ struct server_context {
|
|||||||
if (slot->is_processing()) {
|
if (slot->is_processing()) {
|
||||||
// if requested slot is unavailable, we defer this task for processing later
|
// if requested slot is unavailable, we defer this task for processing later
|
||||||
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
||||||
queue_tasks.defer(task);
|
queue_tasks.defer(std::move(task));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2871,7 +2872,7 @@ struct server_context {
|
|||||||
|
|
||||||
server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
|
server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
|
||||||
task.id = queue_tasks.get_new_id();
|
task.id = queue_tasks.get_new_id();
|
||||||
queue_tasks.post(task);
|
queue_tasks.post(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply context-shift if needed
|
// apply context-shift if needed
|
||||||
@ -3633,14 +3634,17 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// request slots data using task queue
|
// request slots data using task queue
|
||||||
|
int task_id = ctx_server.queue_tasks.get_new_id();
|
||||||
|
{
|
||||||
server_task task(SERVER_TASK_TYPE_METRICS);
|
server_task task(SERVER_TASK_TYPE_METRICS);
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = task_id;
|
||||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
ctx_server.queue_results.add_waiting_task_id(task_id);
|
||||||
ctx_server.queue_tasks.post(task, true); // high-priority task
|
ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
|
||||||
|
}
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
res_error(res, result->to_json());
|
res_error(res, result->to_json());
|
||||||
@ -3669,16 +3673,17 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// request slots data using task queue
|
// request slots data using task queue
|
||||||
|
int task_id = ctx_server.queue_tasks.get_new_id();
|
||||||
|
{
|
||||||
server_task task(SERVER_TASK_TYPE_METRICS);
|
server_task task(SERVER_TASK_TYPE_METRICS);
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = task_id;
|
||||||
task.metrics_reset_bucket = true;
|
ctx_server.queue_results.add_waiting_task_id(task_id);
|
||||||
|
ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
|
||||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
}
|
||||||
ctx_server.queue_tasks.post(task, true); // high-priority task
|
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
res_error(res, result->to_json());
|
res_error(res, result->to_json());
|
||||||
@ -3775,17 +3780,20 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
std::string filepath = params.slot_save_path + filename;
|
std::string filepath = params.slot_save_path + filename;
|
||||||
|
|
||||||
|
int task_id = ctx_server.queue_tasks.get_new_id();
|
||||||
|
{
|
||||||
server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
|
server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = task_id;
|
||||||
task.slot_action.slot_id = id_slot;
|
task.slot_action.slot_id = id_slot;
|
||||||
task.slot_action.filename = filename;
|
task.slot_action.filename = filename;
|
||||||
task.slot_action.filepath = filepath;
|
task.slot_action.filepath = filepath;
|
||||||
|
|
||||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
ctx_server.queue_results.add_waiting_task_id(task_id);
|
||||||
ctx_server.queue_tasks.post(task);
|
ctx_server.queue_tasks.post(std::move(task));
|
||||||
|
}
|
||||||
|
|
||||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
res_error(res, result->to_json());
|
res_error(res, result->to_json());
|
||||||
@ -3804,17 +3812,20 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
std::string filepath = params.slot_save_path + filename;
|
std::string filepath = params.slot_save_path + filename;
|
||||||
|
|
||||||
|
int task_id = ctx_server.queue_tasks.get_new_id();
|
||||||
|
{
|
||||||
server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
|
server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = task_id;
|
||||||
task.slot_action.slot_id = id_slot;
|
task.slot_action.slot_id = id_slot;
|
||||||
task.slot_action.filename = filename;
|
task.slot_action.filename = filename;
|
||||||
task.slot_action.filepath = filepath;
|
task.slot_action.filepath = filepath;
|
||||||
|
|
||||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
ctx_server.queue_results.add_waiting_task_id(task_id);
|
||||||
ctx_server.queue_tasks.post(task);
|
ctx_server.queue_tasks.post(std::move(task));
|
||||||
|
}
|
||||||
|
|
||||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
res_error(res, result->to_json());
|
res_error(res, result->to_json());
|
||||||
@ -3826,15 +3837,18 @@ int main(int argc, char ** argv) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
||||||
|
int task_id = ctx_server.queue_tasks.get_new_id();
|
||||||
|
{
|
||||||
server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
|
server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = task_id;
|
||||||
task.slot_action.slot_id = id_slot;
|
task.slot_action.slot_id = id_slot;
|
||||||
|
|
||||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
ctx_server.queue_results.add_waiting_task_id(task_id);
|
||||||
ctx_server.queue_tasks.post(task);
|
ctx_server.queue_tasks.post(std::move(task));
|
||||||
|
}
|
||||||
|
|
||||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
res_error(res, result->to_json());
|
res_error(res, result->to_json());
|
||||||
@ -3938,9 +3952,10 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto completion_id = gen_chatcmplid();
|
auto completion_id = gen_chatcmplid();
|
||||||
|
std::unordered_set<int> task_ids;
|
||||||
|
try {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
|
|
||||||
try {
|
|
||||||
const auto & prompt = data.at("prompt");
|
const auto & prompt = data.at("prompt");
|
||||||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||||
@ -3965,18 +3980,18 @@ int main(int argc, char ** argv) {
|
|||||||
task.params.oaicompat_cmpl_id = completion_id;
|
task.params.oaicompat_cmpl_id = completion_id;
|
||||||
// oaicompat_model is already populated by params_from_json_cmpl
|
// oaicompat_model is already populated by params_from_json_cmpl
|
||||||
|
|
||||||
tasks.push_back(task);
|
tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
task_ids = server_task::get_list_id(tasks);
|
||||||
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
|
ctx_server.queue_tasks.post(std::move(tasks));
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
||||||
ctx_server.queue_tasks.post(tasks);
|
|
||||||
|
|
||||||
bool stream = json_value(data, "stream", false);
|
bool stream = json_value(data, "stream", false);
|
||||||
const auto task_ids = server_task::get_list_id(tasks);
|
|
||||||
|
|
||||||
if (!stream) {
|
if (!stream) {
|
||||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||||
@ -4268,6 +4283,7 @@ int main(int argc, char ** argv) {
|
|||||||
// create and queue the task
|
// create and queue the task
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
|
std::unordered_set<int> task_ids;
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
@ -4280,15 +4296,15 @@ int main(int argc, char ** argv) {
|
|||||||
// OAI-compat
|
// OAI-compat
|
||||||
task.params.oaicompat = oaicompat;
|
task.params.oaicompat = oaicompat;
|
||||||
|
|
||||||
tasks.push_back(task);
|
tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
task_ids = server_task::get_list_id(tasks);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(std::move(tasks));
|
||||||
|
}
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
|
||||||
|
|
||||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||||
for (auto & res : results) {
|
for (auto & res : results) {
|
||||||
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
||||||
@ -4300,7 +4316,6 @@ int main(int argc, char ** argv) {
|
|||||||
}, req.is_connection_closed);
|
}, req.is_connection_closed);
|
||||||
|
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||||
}
|
|
||||||
|
|
||||||
if (error) {
|
if (error) {
|
||||||
return;
|
return;
|
||||||
@ -4367,6 +4382,7 @@ int main(int argc, char ** argv) {
|
|||||||
// create and queue the task
|
// create and queue the task
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
|
std::unordered_set<int> task_ids;
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
|
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
|
||||||
@ -4376,14 +4392,13 @@ int main(int argc, char ** argv) {
|
|||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
task.index = i;
|
task.index = i;
|
||||||
task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
|
task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
|
||||||
tasks.push_back(task);
|
tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
task_ids = server_task::get_list_id(tasks);
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
ctx_server.queue_tasks.post(tasks);
|
ctx_server.queue_tasks.post(std::move(tasks));
|
||||||
|
}
|
||||||
// get the result
|
|
||||||
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
|
||||||
|
|
||||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
||||||
for (auto & res : results) {
|
for (auto & res : results) {
|
||||||
@ -4394,7 +4409,6 @@ int main(int argc, char ** argv) {
|
|||||||
res_error(res, error_data);
|
res_error(res, error_data);
|
||||||
error = true;
|
error = true;
|
||||||
}, req.is_connection_closed);
|
}, req.is_connection_closed);
|
||||||
}
|
|
||||||
|
|
||||||
if (error) {
|
if (error) {
|
||||||
return;
|
return;
|
||||||
@ -4431,14 +4445,19 @@ int main(int argc, char ** argv) {
|
|||||||
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
|
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
server_task task(SERVER_TASK_TYPE_SET_LORA);
|
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
|
||||||
task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
|
|
||||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
||||||
ctx_server.queue_tasks.post(task);
|
|
||||||
|
|
||||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
int task_id = ctx_server.queue_tasks.get_new_id();
|
||||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
{
|
||||||
|
server_task task(SERVER_TASK_TYPE_SET_LORA);
|
||||||
|
task.id = task_id;
|
||||||
|
task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
|
||||||
|
ctx_server.queue_results.add_waiting_task_id(task_id);
|
||||||
|
ctx_server.queue_tasks.post(std::move(task));
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the result
|
||||||
|
server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
|
||||||
|
ctx_server.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
res_error(res, result->to_json());
|
res_error(res, result->to_json());
|
||||||
@ -4582,8 +4601,8 @@ int main(int argc, char ** argv) {
|
|||||||
common_chat_templates_source(ctx_server.chat_templates.get()),
|
common_chat_templates_source(ctx_server.chat_templates.get()),
|
||||||
common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
|
common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
|
||||||
|
|
||||||
ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
|
ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) {
|
||||||
ctx_server.process_single_task(task);
|
ctx_server.process_single_task(std::move(task));
|
||||||
});
|
});
|
||||||
|
|
||||||
ctx_server.queue_tasks.on_update_slots([&ctx_server]() {
|
ctx_server.queue_tasks.on_update_slots([&ctx_server]() {
|
||||||
|
Reference in New Issue
Block a user