mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 20:05:20 +00:00
server : add TEI API format for /rerank endpoint (#11942)
* server : add TEI API format for /rerank endpoint * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * fix * also gitignore examples/server/*.gz.hpp --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -98,6 +98,7 @@ examples/server/*.css.hpp
|
|||||||
examples/server/*.html.hpp
|
examples/server/*.html.hpp
|
||||||
examples/server/*.js.hpp
|
examples/server/*.js.hpp
|
||||||
examples/server/*.mjs.hpp
|
examples/server/*.mjs.hpp
|
||||||
|
examples/server/*.gz.hpp
|
||||||
!build_64.sh
|
!build_64.sh
|
||||||
!examples/*.bat
|
!examples/*.bat
|
||||||
!examples/*/*.kts
|
!examples/*/*.kts
|
||||||
|
@ -4263,6 +4263,11 @@ int main(int argc, char ** argv) {
|
|||||||
// return;
|
// return;
|
||||||
//}
|
//}
|
||||||
|
|
||||||
|
// if true, use TEI API format, otherwise use Jina API format
|
||||||
|
// Jina: https://jina.ai/reranker/
|
||||||
|
// TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank
|
||||||
|
bool is_tei_format = body.contains("texts");
|
||||||
|
|
||||||
json query;
|
json query;
|
||||||
if (body.count("query") == 1) {
|
if (body.count("query") == 1) {
|
||||||
query = body.at("query");
|
query = body.at("query");
|
||||||
@ -4275,7 +4280,8 @@ int main(int argc, char ** argv) {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
|
std::vector<std::string> documents = json_value(body, "documents",
|
||||||
|
json_value(body, "texts", std::vector<std::string>()));
|
||||||
if (documents.empty()) {
|
if (documents.empty()) {
|
||||||
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
||||||
return;
|
return;
|
||||||
@ -4320,7 +4326,12 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// write JSON response
|
// write JSON response
|
||||||
json root = format_response_rerank(body, responses);
|
json root = format_response_rerank(
|
||||||
|
body,
|
||||||
|
responses,
|
||||||
|
is_tei_format,
|
||||||
|
documents);
|
||||||
|
|
||||||
res_ok(res, root);
|
res_ok(res, root);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -10,17 +10,20 @@ def create_server():
|
|||||||
server = ServerPreset.jina_reranker_tiny()
|
server = ServerPreset.jina_reranker_tiny()
|
||||||
|
|
||||||
|
|
||||||
|
TEST_DOCUMENTS = [
|
||||||
|
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
|
||||||
|
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
|
||||||
|
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
|
||||||
|
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_rerank():
|
def test_rerank():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/rerank", data={
|
res = server.make_request("POST", "/rerank", data={
|
||||||
"query": "Machine learning is",
|
"query": "Machine learning is",
|
||||||
"documents": [
|
"documents": TEST_DOCUMENTS,
|
||||||
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
|
|
||||||
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
|
|
||||||
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
|
|
||||||
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
|
|
||||||
]
|
|
||||||
})
|
})
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert len(res.body["results"]) == 4
|
assert len(res.body["results"]) == 4
|
||||||
@ -38,6 +41,29 @@ def test_rerank():
|
|||||||
assert least_relevant["index"] == 3
|
assert least_relevant["index"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_rerank_tei_format():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/rerank", data={
|
||||||
|
"query": "Machine learning is",
|
||||||
|
"texts": TEST_DOCUMENTS,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert len(res.body) == 4
|
||||||
|
|
||||||
|
most_relevant = res.body[0]
|
||||||
|
least_relevant = res.body[0]
|
||||||
|
for doc in res.body:
|
||||||
|
if doc["score"] > most_relevant["score"]:
|
||||||
|
most_relevant = doc
|
||||||
|
if doc["score"] < least_relevant["score"]:
|
||||||
|
least_relevant = doc
|
||||||
|
|
||||||
|
assert most_relevant["score"] > least_relevant["score"]
|
||||||
|
assert most_relevant["index"] == 2
|
||||||
|
assert least_relevant["index"] == 3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("documents", [
|
@pytest.mark.parametrize("documents", [
|
||||||
[],
|
[],
|
||||||
None,
|
None,
|
||||||
|
@ -737,29 +737,51 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
static json format_response_rerank(const json & request, const json & ranks) {
|
static json format_response_rerank(
|
||||||
json data = json::array();
|
const json & request,
|
||||||
int32_t n_tokens = 0;
|
const json & ranks,
|
||||||
int i = 0;
|
bool is_tei_format,
|
||||||
for (const auto & rank : ranks) {
|
std::vector<std::string> & texts) {
|
||||||
data.push_back(json{
|
json res;
|
||||||
{"index", i++},
|
if (is_tei_format) {
|
||||||
{"relevance_score", json_value(rank, "score", 0.0)},
|
// TEI response format
|
||||||
});
|
res = json::array();
|
||||||
|
bool return_text = json_value(request, "return_text", false);
|
||||||
|
for (const auto & rank : ranks) {
|
||||||
|
int index = json_value(rank, "index", 0);
|
||||||
|
json elem = json{
|
||||||
|
{"index", index},
|
||||||
|
{"score", json_value(rank, "score", 0.0)},
|
||||||
|
};
|
||||||
|
if (return_text) {
|
||||||
|
elem["text"] = std::move(texts[index]);
|
||||||
|
}
|
||||||
|
res.push_back(elem);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Jina response format
|
||||||
|
json results = json::array();
|
||||||
|
int32_t n_tokens = 0;
|
||||||
|
for (const auto & rank : ranks) {
|
||||||
|
results.push_back(json{
|
||||||
|
{"index", json_value(rank, "index", 0)},
|
||||||
|
{"relevance_score", json_value(rank, "score", 0.0)},
|
||||||
|
});
|
||||||
|
|
||||||
n_tokens += json_value(rank, "tokens_evaluated", 0);
|
n_tokens += json_value(rank, "tokens_evaluated", 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
res = json{
|
||||||
|
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||||
|
{"object", "list"},
|
||||||
|
{"usage", json{
|
||||||
|
{"prompt_tokens", n_tokens},
|
||||||
|
{"total_tokens", n_tokens}
|
||||||
|
}},
|
||||||
|
{"results", results}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
json res = json {
|
|
||||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
|
||||||
{"object", "list"},
|
|
||||||
{"usage", json {
|
|
||||||
{"prompt_tokens", n_tokens},
|
|
||||||
{"total_tokens", n_tokens}
|
|
||||||
}},
|
|
||||||
{"results", data}
|
|
||||||
};
|
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user