llama : rework embeddings logic (#14208)

* llama : rework embeddings logic

ggml-ci

* cont : fix rerank

ggml-ci

* cont : engrish [no ci]

* cont : fix rerank

ggml-ci

* server : support both embeddings and completions with single model

ggml-ci

* cont : avoid embeddings_org

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-16 14:14:00 +03:00
committed by GitHub
parent 3ba0d843c6
commit d3e64b9f49
16 changed files with 159 additions and 114 deletions

View File

@ -988,10 +988,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
params.tensor_buft_overrides.push_back({nullptr, nullptr});
}
if (params.reranking && params.embedding) {
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
}
if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s%s\n",
@ -2747,9 +2743,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
add_opt(common_arg(
{"--reranking", "--rerank"},
string_format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"),
string_format("enable reranking endpoint on server (default: %s)", "disabled"),
[](common_params & params) {
params.reranking = true;
params.embedding = true;
params.pooling_type = LLAMA_POOLING_TYPE_RANK;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
add_opt(common_arg(

View File

@ -897,34 +897,6 @@ struct common_init_result common_init_from_params(common_params & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
if (params.reranking) {
bool ok = true;
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
ok = false;
}
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
if (!has_eos && !has_sep) {
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
ok = false;
} else if (!has_eos) {
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
} else if (!has_sep) {
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
ok = false;
}
if (!ok) {
llama_model_free(model);
return iparams;
}
}
auto cparams = common_context_params_to_llama(params);
llama_context * lctx = llama_init_from_model(model, cparams);
@ -966,6 +938,35 @@ struct common_init_result common_init_from_params(common_params & params) {
}
}
if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) {
bool ok = true;
if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
ok = false;
}
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
if (!has_eos && !has_sep) {
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
ok = false;
} else if (!has_eos) {
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
} else if (!has_sep) {
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
ok = false;
}
if (!ok) {
llama_free(lctx);
llama_model_free(model);
return iparams;
}
}
// load and optionally apply lora adapters
for (auto & la : params.lora_adapters) {
llama_adapter_lora_ptr lora;
@ -1143,11 +1144,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full;
if (params.reranking) {
cparams.embeddings = true;
cparams.pooling_type = LLAMA_POOLING_TYPE_RANK;
}
cparams.type_k = params.cache_type_k;
cparams.type_v = params.cache_type_v;

View File

@ -355,7 +355,6 @@ struct common_params {
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embeddings
bool reranking = false; // enable reranking support on server
// server params
int32_t port = 8080; // server listens on this network port

View File

@ -41,12 +41,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
// add input to batch (this increments n_tokens)
for (int32_t j = 0; j < n_toks; j++) {
common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
common_batch_add(batch, inputs[j], j, { 0 }, true);
}
// clear previous kv_cache values (irrelevant for embeddings)
llama_memory_clear(llama_get_memory(ctx), true);
llama_set_embeddings(ctx, true);
llama_set_causal_attn(ctx, false);
// run model
@ -103,7 +102,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_token eos_token = llama_vocab_eos(vocab);
llama_memory_clear(llama_get_memory(ctx), true);
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
@ -166,6 +164,8 @@ int main(int argc, char * argv[]) {
llama_model_params mparams = common_model_params_to_llama(params);
llama_context_params cparams = common_context_params_to_llama(params);
cparams.embeddings = true;
llama_backend_init();
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
@ -213,6 +213,8 @@ int main(int argc, char * argv[]) {
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
}
llama_set_embeddings(ctx, false);
// ### Generation ###
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{

View File

@ -254,7 +254,10 @@ extern "C" {
// - seq_id : the sequence to which the respective token belongs
// (if set to NULL, the sequence ID will be assumed to be 0)
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
// (if set to NULL, only the logits for last token will be returned)
// (if set to NULL:
// - if embeddings: all tokens are output
// - if not: only the last token is output
// )
//
typedef struct llama_batch {
int32_t n_tokens;
@ -262,8 +265,8 @@ extern "C" {
llama_token * token;
float * embd;
llama_pos * pos;
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
} llama_batch;
@ -961,8 +964,7 @@ extern "C" {
// Get the number of threads used for prompt and batch processing (multiple token).
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
// Set whether the model is in embeddings mode or not
// If true, embeddings will be returned but logits will not
// Set whether the context outputs embeddings or not
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
// Set whether to use causal attention or not

View File

@ -299,7 +299,8 @@ llama_batch_allocr::llama_batch_allocr() {
bool llama_batch_allocr::init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory) {
const llama_memory_i * memory,
bool embd_all) {
clear();
batch = batch_inp;
@ -378,10 +379,31 @@ bool llama_batch_allocr::init(
}
if (!batch.logits) {
// by default return the output only for the last token
output.resize(batch.n_tokens);
output[output.size() - 1] = true;
if (embd_all) {
// return the output for all tokens
output.resize(batch.n_tokens, true);
} else {
// return the output only for the last token
output.resize(batch.n_tokens, false);
output[output.size() - 1] = true;
}
batch.logits = output.data();
} else if (embd_all) {
bool warn = false;
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.logits[i] == 0) {
warn = true;
}
}
if (warn) {
LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
output.resize(batch.n_tokens, true);
batch.logits = output.data();
}
}
//

View File

@ -88,7 +88,8 @@ public:
bool init(
const llama_batch & batch_inp,
const llama_vocab & vocab,
const llama_memory_i * memory);
const llama_memory_i * memory,
bool embd_all);
const llama_batch & get_batch() const;

View File

@ -728,7 +728,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
// note: during encode, we always pass the full sequence starting from pos = 0
if (!batch_allocr->init(batch_inp, model.vocab, nullptr)) {
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
@ -894,7 +894,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
return -1;
}
if (!batch_allocr->init(batch_inp, model.vocab, memory.get())) {
// when computing embeddings, all tokens are output
const bool embd_all = cparams.embeddings;
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
@ -911,12 +914,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
// this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
if (embd_pooled) {
if (embd_all) {
// require that all tokens are output
if (n_outputs_all != n_tokens_all) {
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@ -945,7 +945,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
llama_memory_state_ptr mstate;
while (true) {
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
if (!mstate) {
return -2;
}
@ -1058,7 +1058,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
//}
auto * t_logits = cparams.embeddings ? nullptr : res->get_logits();
auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
if (t_embd && res->get_embd_pooled()) {
@ -1222,9 +1222,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd;
// TODO: use a per-batch flag for logits presence instead
bool has_logits = !cparams.embeddings;
bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
bool has_logits = true;
bool has_embd = cparams.embeddings;
// TODO: hacky enc-dec support
if (model.arch == LLM_ARCH_T5) {
@ -2044,14 +2043,11 @@ void llama_context::opt_epoch_iter(
n_queued_tokens += n_tokens_all;
// this indicates we are doing pooled embedding
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
embd_seq.clear();
uint32_t n_outputs_all = n_tokens_all;
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled);
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break;

View File

@ -359,9 +359,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result;
}
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
std::vector<llama_ubatch> ubatches;
@ -369,8 +367,8 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch &
while (sbatch.n_tokens > 0) {
llama_ubatch ubatch;
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
if (embd_all) {
// if all tokens are output, split by sequence
ubatch = sbatch.split_seq(n_ubatch);
} else {
ubatch = sbatch.split_equal(n_ubatch);

View File

@ -32,7 +32,7 @@ public:
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) override;
bool embd_all) override;
llama_memory_state_ptr init_full() override;

View File

@ -95,8 +95,8 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id);
}
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
GGML_UNUSED(embd_pooled);
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
GGML_UNUSED(embd_all);
// first try simple split
do {

View File

@ -34,7 +34,7 @@ public:
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) override;
bool embd_all) override;
llama_memory_state_ptr init_full() override;

View File

@ -310,8 +310,8 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) {
GGML_UNUSED(embd_pooled);
bool embd_all) {
GGML_UNUSED(embd_all);
do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true);

View File

@ -59,7 +59,7 @@ public:
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) override;
bool embd_all) override;
llama_memory_state_ptr init_full() override;

View File

@ -73,7 +73,7 @@ struct llama_memory_i {
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled) = 0;
bool embd_all) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;

View File

@ -88,6 +88,26 @@ enum error_type {
ERROR_TYPE_NOT_SUPPORTED, // custom error
};
static bool server_task_type_need_embd(server_task_type task_type) {
switch (task_type) {
case SERVER_TASK_TYPE_EMBEDDING:
case SERVER_TASK_TYPE_RERANK:
return true;
default:
return false;
}
}
static bool server_task_type_need_logits(server_task_type task_type) {
switch (task_type) {
case SERVER_TASK_TYPE_COMPLETION:
case SERVER_TASK_TYPE_INFILL:
return true;
default:
return false;
}
}
struct slot_params {
bool stream = true;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
@ -1330,13 +1350,16 @@ struct server_slot {
n_draft_accepted = 0;
}
bool is_non_causal() const {
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
bool need_embd() const {
return server_task_type_need_embd(task_type);
}
bool need_logits() const {
return server_task_type_need_logits(task_type);
}
bool can_batch_with(server_slot & other_slot) const {
return is_non_causal() == other_slot.is_non_causal()
&& are_lora_equal(lora, other_slot.lora);
return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora);
}
bool has_budget(const common_params & global_params) {
@ -1480,7 +1503,6 @@ struct server_slot {
{"n_ctx", n_ctx},
{"speculative", can_speculate()},
{"is_processing", is_processing()},
{"non_causal", is_non_causal()},
{"params", params.to_json()},
{"prompt", prompt_tokens.detokenize(ctx, true)},
{"next_token",
@ -1907,6 +1929,14 @@ struct server_context {
llama_batch_free(batch);
}
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
// also we cannot split if the pooling would require any past tokens
bool can_split() const {
return
!llama_get_embeddings(ctx) ||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
}
bool load_model(const common_params & params) {
SRV_INF("loading model '%s'\n", params.model.path.c_str());
@ -2730,6 +2760,7 @@ struct server_context {
queue_tasks.defer(std::move(task));
break;
}
if (slot->is_processing()) {
// if requested slot is unavailable, we defer this task for processing later
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
@ -3092,7 +3123,14 @@ struct server_context {
continue;
}
if (slot.is_non_causal()) {
// TODO: support memory-less logits computation
if (slot.need_logits() && !llama_get_memory(ctx)) {
slot.release();
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
continue;
}
if (!can_split()) {
if (slot.n_prompt_tokens > n_ubatch) {
slot.release();
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
@ -3227,8 +3265,7 @@ struct server_context {
}
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
// we have to evaluate at least 1 token to generate logits.
SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
SLT_WRN(slot, "need to evaluate at least 1 token for each active slot, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
slot.n_past--;
}
@ -3236,8 +3273,7 @@ struct server_context {
slot.n_prompt_tokens_processed = 0;
}
// non-causal tasks require to fit the entire prompt in the physical batch
if (slot.is_non_causal()) {
if (!can_split()) {
// cannot fit the prompt in the current batch - will try next iter
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
continue;
@ -3259,8 +3295,7 @@ struct server_context {
slot.cache_tokens.keep_first(slot.n_past);
// check if we should process the image
if (slot.n_past < slot.n_prompt_tokens
&& slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
if (slot.n_past < slot.n_prompt_tokens && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) {
// process the image
int32_t new_n_past;
int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past);
@ -3291,8 +3326,8 @@ struct server_context {
break; // end of text chunk
}
// without pooling, we want to output the embeddings for all the tokens in the batch
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
// embedding requires all tokens in the batch to be output
const bool need_embd = server_task_type_need_embd(slot.task_type);
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
slot.cache_tokens.push_back(cur_tok);
@ -3346,17 +3381,15 @@ struct server_context {
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
if (slot_batched) {
// make sure we're in the right embedding mode
llama_set_embeddings(ctx, slot_batched->is_non_causal());
// apply lora, only need to do it once per batch
common_set_adapter_lora(ctx, slot_batched->lora);
}
const bool do_encode = (params_base.embedding || params_base.reranking);
llama_set_embeddings(ctx, slot_batched->need_embd());
}
// pad the batch so that batch.n_tokens >= n_slots
// TODO: temporary workaround for https://github.com/ggml-org/llama.cpp/issues/13689
if (do_encode) {
if (slot_batched->need_embd()) {
const int n_slots = slots.size();
if (batch.n_tokens < n_slots) {
@ -3378,8 +3411,11 @@ struct server_context {
SRV_WRN("adding %d dummy tokens to the batch, seq_id = %d\n", n_add, seq_id);
for (int j = 0; j < n_add; ++j) {
common_batch_add(batch, 0, j, { seq_id }, false);
common_batch_add(batch, 0, j, { seq_id }, true);
}
slots[seq_id].cache_tokens.clear();
llama_memory_seq_rm(llama_get_memory(ctx), seq_id, -1, -1);
}
}
@ -4174,11 +4210,6 @@ int main(int argc, char ** argv) {
oaicompat_type oaicompat) -> void {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
auto completion_id = gen_chatcmplid();
std::unordered_set<int> task_ids;
try {
@ -4433,12 +4464,8 @@ int main(int argc, char ** argv) {
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
};
const auto handle_chat_completions = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
const auto handle_chat_completions = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
LOG_DBG("request: %s\n", req.body.c_str());
if (ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
auto body = json::parse(req.body);
std::vector<raw_buffer> files;
@ -4566,13 +4593,18 @@ int main(int argc, char ** argv) {
};
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
const json body = json::parse(req.body);
if (!ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
return;
}
const json body = json::parse(req.body);
// for the shape of input/content, see tokenize_input_prompts()
json prompt;
if (body.count("input") != 0) {
@ -4662,8 +4694,8 @@ int main(int argc, char ** argv) {
};
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
}