server : vision support via libmtmd (#12898)

* server : (experimental) vision support via libmtmd

* mtmd : add more api around mtmd_image_tokens

* mtmd : add more api around mtmd_image_tokens

* mtmd : ability to calc image hash

* shared_ptr for mtmd_image_tokens

* move hash to user-define ID (fixed)

* abstract out the batch management

* small fix

* refactor logic adding tokens to batch

* implement hashing image

* use FNV hash, now hash bitmap instead of file data

* allow decoding image embedding to be split into batches

* rm whitespace

* disable some features when mtmd is on

* fix --no-mmproj-offload

* mtmd_context_params no timings

* refactor server_inp to server_tokens

* fix the failing test case

* init

* wip

* working version

* add mtmd::bitmaps

* add test target

* rm redundant define

* test: mtmd_input_chunks_free

* rm outdated comment

* fix merging issue

* explicitly create mtmd::input_chunks

* mtmd_input_chunk_copy

* add clone()

* improve server_input struct

* clip :  fix confused naming ffn_up and ffn_down

* rm ffn_i/o/g naming

* rename n_embd, n_ff

* small fix

* no check n_ff

* fix detokenize

* add const to various places

* add warning about breaking changes

* add c api

* helper: use mtmd_image_tokens_get_n_pos

* fix ctx_shift

* fix name shadowing

* more strict condition

* support remote image_url

* remote image_url log

* add CI test

* do not log base64

* add "has_multimodal" to /props

* remove dangling image

* speculative: use slot.cache_tokens.insert

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* rm can_be_detokenized

* on prmpt processing done, assert cache_tokens.size

* handle_completions_impl returns void

* adapt the new web ui

* update docs and hot topics

* rm assert

* small fix (2)

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Xuan-Son Nguyen
2025-05-09 19:29:37 +02:00
committed by GitHub
parent 17512a94d6
commit 33eff40240
10 changed files with 774 additions and 101 deletions

View File

@ -7,6 +7,7 @@
#include "log.h"
#include "sampling.h"
#include "speculative.h"
#include "mtmd.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
@ -197,8 +198,8 @@ struct server_task {
int id_target = -1;
// used by SERVER_TASK_TYPE_INFERENCE
slot_params params;
llama_tokens prompt_tokens;
slot_params params;
server_tokens prompt_tokens;
int id_selected_slot = -1;
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
@ -1248,6 +1249,9 @@ struct server_slot {
llama_context * ctx = nullptr;
llama_context * ctx_dft = nullptr;
// multimodal
mtmd_context * mctx = nullptr;
common_speculative * spec = nullptr;
std::vector<common_adapter_lora_info> lora;
@ -1275,14 +1279,14 @@ struct server_slot {
int32_t n_prompt_tokens_processed = 0;
// input prompt tokens
llama_tokens prompt_tokens;
server_tokens prompt_tokens;
size_t last_nl_pos = 0;
std::string generated_text;
llama_tokens generated_tokens;
llama_tokens cache_tokens;
server_tokens cache_tokens;
std::vector<completion_token_output> generated_token_probs;
@ -1476,7 +1480,7 @@ struct server_slot {
{"is_processing", is_processing()},
{"non_causal", is_non_causal()},
{"params", params.to_json()},
{"prompt", common_detokenize(ctx, prompt_tokens)},
{"prompt", prompt_tokens.detokenize(ctx, true)},
{"next_token",
{
{"has_next_token", has_next_token},
@ -1849,13 +1853,16 @@ struct server_context {
llama_model * model = nullptr;
llama_context * ctx = nullptr;
// multimodal
mtmd_context * mctx = nullptr;
const llama_vocab * vocab = nullptr;
llama_model * model_dft = nullptr;
llama_context_params cparams_dft;
llama_batch batch = {};
llama_batch batch;
bool clean_kv_cache = true;
bool add_bos_token = true;
@ -1878,6 +1885,8 @@ struct server_context {
common_chat_templates_ptr chat_templates;
~server_context() {
mtmd_free(mctx);
// Clear any sampling context
for (server_slot & slot : slots) {
common_sampler_free(slot.smpl);
@ -1965,6 +1974,36 @@ struct server_context {
chat_templates = common_chat_templates_init(model, "chatml");
}
std::string & mmproj_path = params_base.mmproj.path;
if (!mmproj_path.empty()) {
mtmd_context_params mparams = mtmd_context_params_default();
mparams.use_gpu = params_base.mmproj_use_gpu;
mparams.print_timings = false;
mparams.n_threads = params_base.cpuparams.n_threads;
mparams.verbosity = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
if (mctx == nullptr) {
SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());
return false;
}
SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str());
if (params_base.ctx_shift) {
params_base.ctx_shift = false;
SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled");
}
if (params_base.n_cache_reuse) {
params_base.n_cache_reuse = 0;
SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled");
}
if (!params_base.speculative.model.path.empty()) {
SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal");
return false;
}
}
return true;
}
@ -1980,6 +2019,8 @@ struct server_context {
slot.ctx = ctx;
slot.n_ctx = n_ctx_slot;
slot.n_predict = params_base.n_predict;
slot.mctx = mctx;
slot.cache_tokens.has_mtmd = mctx != nullptr;
if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
@ -2016,8 +2057,6 @@ struct server_context {
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
{
const int32_t n_batch = llama_n_batch(ctx);
// only a single seq_id per token is needed
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
}
@ -2054,7 +2093,7 @@ struct server_context {
}
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens);
// fraction of the common subsequence length compared to the current slot's prompt length
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
@ -2096,18 +2135,6 @@ struct server_context {
return ret;
}
bool can_be_detokenized(const struct llama_context * ctx, const std::vector<llama_token> & tokens) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
for (const auto & token : tokens) {
if (token < 0 || token >= n_vocab) {
return false;
}
}
return true;
}
bool launch_slot_with_task(server_slot & slot, server_task && task) {
slot.reset();
slot.id_task = task.id;
@ -2122,8 +2149,7 @@ struct server_context {
slot.lora = slot.params.lora;
}
bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
if (!can_detokenize) {
if (!slot.prompt_tokens.validate(ctx)) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
}
@ -2385,6 +2411,15 @@ struct server_context {
queue_results.send(std::move(res));
}
// if multimodal is enabled, send an error and return false
bool ensure_no_mtmd(const int id_task) {
if (mctx) {
send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED);
return false;
}
return true;
}
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
auto res = std::make_unique<server_task_result_cmpl_partial>();
@ -2424,7 +2459,7 @@ struct server_context {
res->content = std::move(slot.generated_text);
res->tokens = std::move(slot.generated_tokens);
res->timings = slot.get_timings();
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
res->prompt = slot.prompt_tokens.detokenize(ctx, true);
res->response_fields = std::move(slot.params.response_fields);
res->truncated = slot.truncated;
@ -2734,6 +2769,10 @@ struct server_context {
} break;
case SERVER_TASK_TYPE_SLOT_SAVE:
{
if (!ensure_no_mtmd(task.id)) {
break;
}
int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
@ -2753,7 +2792,8 @@ struct server_context {
std::string filename = task.slot_action.filename;
std::string filepath = task.slot_action.filepath;
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
const llama_tokens & tokens = slot->cache_tokens.get_text_tokens();
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count);
const int64_t t_end = ggml_time_us();
const double t_save_ms = (t_end - t_start) / 1000.0;
@ -2770,6 +2810,7 @@ struct server_context {
} break;
case SERVER_TASK_TYPE_SLOT_RESTORE:
{
if (!ensure_no_mtmd(task.id)) break;
int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
@ -2788,15 +2829,18 @@ struct server_context {
std::string filename = task.slot_action.filename;
std::string filepath = task.slot_action.filepath;
slot->cache_tokens.resize(slot->n_ctx);
llama_tokens tokens;
tokens.resize(slot->n_ctx);
size_t token_count = 0;
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count);
if (nread == 0) {
slot->cache_tokens.resize(0);
slot->cache_tokens.clear(); // KV may already been invalidated?
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
break;
}
slot->cache_tokens.resize(token_count);
tokens.resize(token_count);
slot->cache_tokens.clear();
slot->cache_tokens.insert(tokens);
const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0;
@ -2813,6 +2857,7 @@ struct server_context {
} break;
case SERVER_TASK_TYPE_SLOT_ERASE:
{
if (!ensure_no_mtmd(task.id)) break;
int id_slot = task.slot_action.slot_id;
server_slot * slot = get_slot_by_id(id_slot);
if (slot == nullptr) {
@ -2844,6 +2889,7 @@ struct server_context {
res->id = task.id;
queue_results.send(std::move(res));
} break;
}
}
@ -2889,6 +2935,12 @@ struct server_context {
continue;
}
if (mctx) {
// we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded
// we don't support ctx_shift because an image chunk may contains multiple tokens
GGML_ABORT("not supported by multimodal");
}
// Shift context
const int n_keep = slot.params.n_keep + add_bos_token;
const int n_left = slot.n_past - n_keep;
@ -2900,11 +2952,14 @@ struct server_context {
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
new_tokens[i - n_discard] = new_tokens[i];
}
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
new_tokens.resize(slot.cache_tokens.size() - n_discard);
slot.cache_tokens.clear();
slot.cache_tokens.insert(new_tokens);
}
slot.n_past -= n_discard;
@ -2982,7 +3037,7 @@ struct server_context {
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
// print prompt tokens (for debugging)
if (1) {
/*if (1) {
// first 16 tokens (avoid flooding logs)
for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
@ -2992,7 +3047,7 @@ struct server_context {
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
}
}
}*/
// empty prompt passed -> release the slot and send empty response
if (prompt_tokens.empty()) {
@ -3034,21 +3089,27 @@ struct server_context {
// if input prompt is too big, truncate it
if (slot.n_prompt_tokens >= slot.n_ctx) {
if (mctx) {
// we should never reach this
GGML_ABORT("not supported by multimodal");
}
const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2;
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens();
llama_tokens new_tokens(
prompt_tokens.begin(),
prompt_tokens.begin() + slot.params.n_keep);
curr_tokens.begin(),
curr_tokens.begin() + slot.params.n_keep);
new_tokens.insert(
new_tokens.end(),
prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
prompt_tokens.end());
curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
curr_tokens.end());
prompt_tokens = std::move(new_tokens);
prompt_tokens.clear();
prompt_tokens.insert(new_tokens);
slot.truncated = true;
slot.n_prompt_tokens = prompt_tokens.size();
@ -3060,13 +3121,18 @@ struct server_context {
if (slot.params.cache_prompt) {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
// reuse chunks from the cached prompt by shifting their KV cache in the new position
if (params_base.n_cache_reuse > 0) {
size_t head_c = slot.n_past; // cache
size_t head_p = slot.n_past; // current prompt
if (mctx) {
// we should never reach this
GGML_ABORT("not supported by multimodal");
}
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
while (head_c < slot.cache_tokens.size() &&
@ -3092,7 +3158,7 @@ struct server_context {
llama_kv_self_seq_add(ctx, slot.id, head_c, head_c + n_match, kv_shift);
for (size_t i = 0; i < n_match; i++) {
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
slot.cache_tokens.set_token(head_p + i, slot.cache_tokens[head_c + i]);
slot.n_past++;
}
@ -3140,21 +3206,52 @@ struct server_context {
// remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past);
// check if we should process the image
if (slot.n_past < slot.n_prompt_tokens
&& slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
// process the image
int32_t new_n_past;
int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
int32_t n_pos = new_n_past - slot.n_past;
if (res != 0) {
SLT_ERR(slot, "failed to process image, res = %d\n", res);
slot.release();
send_error(slot, "failed to process image", ERROR_TYPE_SERVER);
continue;
}
if (slot.params.cache_prompt) {
const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
slot.cache_tokens.push_back(chunk.get()); // copy
}
slot.n_past += n_pos;
slot.n_prompt_tokens_processed += n_pos;
}
// add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
// get next token to process
llama_token cur_tok = slot.prompt_tokens[slot.n_past];
if (cur_tok == LLAMA_TOKEN_NULL) {
break; // end of text chunk
}
// without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
slot.cache_tokens.push_back(cur_tok);
}
slot.n_prompt_tokens_processed++;
slot.n_past++;
}
// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
// entire prompt has been processed
@ -3162,12 +3259,16 @@ struct server_context {
slot.state = SLOT_STATE_DONE_PROMPT;
GGML_ASSERT(batch.n_tokens > 0);
GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size());
common_sampler_reset(slot.smpl);
// Process all prompt tokens through sampler system
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
common_sampler_accept(slot.smpl, prompt_tokens[i], false);
llama_token id = slot.prompt_tokens[i];
if (id != LLAMA_TOKEN_NULL) {
common_sampler_accept(slot.smpl, id, false);
}
}
// extract the logits only for the last token
@ -3320,6 +3421,11 @@ struct server_context {
continue;
}
if (mctx) {
// we should never reach this, as speculative is automatically disabled if mmproj is loaded
GGML_ABORT("not supported by multimodal");
}
// determine the max draft that fits the current slot state
int n_draft_max = slot.params.speculative.n_max;
@ -3346,7 +3452,8 @@ struct server_context {
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
params_spec.p_min = slot.params.speculative.p_min;
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
const llama_tokens & cached_text_tokens = slot.cache_tokens.get_text_tokens();
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id);
// keep track of total number of tokens generated in the draft
slot.n_draft_total += draft.size();
@ -3380,7 +3487,7 @@ struct server_context {
slot.n_draft_accepted += ids.size() - 1;
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
slot.cache_tokens.insert({ids.begin(), ids.end() - 1});
llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
@ -3903,6 +4010,7 @@ int main(int argc, char ** argv) {
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
{ "total_slots", ctx_server.params_base.n_parallel },
{ "model_path", ctx_server.params_base.model.path },
{ "modalities", json{{"vision", ctx_server.mctx != nullptr}} }, // TODO: add more in the future
{ "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
{ "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
{ "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
@ -3950,9 +4058,10 @@ int main(int argc, char ** argv) {
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
server_task_type type,
json & data,
const std::vector<raw_buffer> & files,
const std::function<bool()> & is_connection_closed,
httplib::Response & res,
oaicompat_type oaicompat) {
oaicompat_type oaicompat) -> void {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
if (ctx_server.params_base.embedding) {
@ -3969,15 +4078,69 @@ int main(int argc, char ** argv) {
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
// process files
mtmd::bitmaps bitmaps;
const bool has_mtmd = ctx_server.mctx != nullptr;
{
if (!has_mtmd && !files.empty()) {
throw std::runtime_error("This server does not support multimodal");
}
for (auto & file : files) {
mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(file.data(), file.size()));
if (!bmp.ptr) {
throw std::runtime_error("Failed to load image");
}
// calculate bitmap hash (for KV caching)
std::string hash = fnv_hash(bmp.data(), bmp.nx()*bmp.ny()*3);
bmp.set_id(hash.c_str());
bitmaps.entries.push_back(std::move(bmp));
}
}
// process prompt
std::vector<server_tokens> inputs;
if (oaicompat && !prompt.is_string()) {
throw std::runtime_error("prompt must be a string");
}
if (oaicompat && has_mtmd) {
// multimodal
std::string prompt_str = prompt.get<std::string>();
mtmd_input_text inp_txt = {
prompt_str.c_str(),
/* add_special */ true,
/* parse_special */ true,
};
mtmd::input_chunks chunks(mtmd_input_chunks_init());
auto bitmaps_c_ptr = bitmaps.c_ptr();
int32_t tokenized = mtmd_tokenize(ctx_server.mctx,
chunks.ptr.get(),
&inp_txt,
bitmaps_c_ptr.data(),
bitmaps_c_ptr.size());
if (tokenized != 0) {
throw std::runtime_error("Failed to tokenize prompt");
}
server_tokens tmp(chunks, true);
inputs.push_back(std::move(tmp));
} else {
// non-multimodal version
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
for (auto & p : tokenized_prompts) {
auto tmp = server_tokens(p, ctx_server.mctx != nullptr);
inputs.push_back(std::move(tmp));
}
}
tasks.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);
task.prompt_tokens = std::move(inputs[i]);
task.params = server_task::params_from_json_cmpl(
ctx_server.ctx,
ctx_server.params_base,
@ -4059,9 +4222,11 @@ int main(int argc, char ** argv) {
const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
json data = json::parse(req.body);
return handle_completions_impl(
std::vector<raw_buffer> files; // dummy
handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
files,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_NONE);
@ -4069,9 +4234,11 @@ int main(int argc, char ** argv) {
const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
json data = oaicompat_completion_params_parse(json::parse(req.body));
return handle_completions_impl(
std::vector<raw_buffer> files; // dummy
handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
files,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_COMPLETION);
@ -4146,9 +4313,11 @@ int main(int argc, char ** argv) {
tokenized_prompts[0]
);
return handle_completions_impl(
std::vector<raw_buffer> files; // dummy
handle_completions_impl(
SERVER_TASK_TYPE_INFILL,
data,
files,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
@ -4162,11 +4331,19 @@ int main(int argc, char ** argv) {
}
auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
std::vector<raw_buffer> files;
json data = oaicompat_completion_params_parse(
body,
params.use_jinja,
params.reasoning_format,
ctx_server.chat_templates.get(),
ctx_server.mctx,
files);
return handle_completions_impl(
handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
data,
files,
req.is_connection_closed,
res,
OAICOMPAT_TYPE_CHAT);
@ -4175,7 +4352,14 @@ int main(int argc, char ** argv) {
// same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &params, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
json data = oaicompat_completion_params_parse(body, params.use_jinja, params.reasoning_format, ctx_server.chat_templates.get());
std::vector<raw_buffer> files; // dummy, unused
json data = oaicompat_completion_params_parse(
body,
params.use_jinja,
params.reasoning_format,
ctx_server.chat_templates.get(),
ctx_server.mctx,
files);
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
};
@ -4280,7 +4464,7 @@ int main(int argc, char ** argv) {
}
}
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
for (const auto & tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input
if (tokens.empty()) {
@ -4300,7 +4484,7 @@ int main(int argc, char ** argv) {
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = std::move(tokenized_prompts[i]);
task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr);
// OAI-compat
task.params.oaicompat = oaicompat;
@ -4394,13 +4578,14 @@ int main(int argc, char ** argv) {
std::unordered_set<int> task_ids;
{
std::vector<server_task> tasks;
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
tasks.reserve(tokenized_docs.size());
for (size_t i = 0; i < tokenized_docs.size(); i++) {
auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server.queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
tasks.push_back(std::move(task));
}