mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-14 20:29:41 -04:00
server : add openai-style logit_bias support (#14946)
Signed-off-by: Lukas Straub <lukasstraub2@web.de>
This commit is contained in:
@@ -444,6 +444,39 @@ def test_n_probs_post_sampling():
|
||||
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
|
||||
def test_logit_bias(tokenize, openai_style):
|
||||
global server
|
||||
server.start()
|
||||
|
||||
exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
|
||||
|
||||
logit_bias = []
|
||||
if tokenize:
|
||||
res = server.make_request("POST", "/tokenize", data={
|
||||
"content": " " + " ".join(exclude) + " ",
|
||||
})
|
||||
assert res.status_code == 200
|
||||
tokens = res.body["tokens"]
|
||||
logit_bias = [[tok, -100] for tok in tokens]
|
||||
|
||||
else:
|
||||
logit_bias = [[" " + tok + " ", -100] for tok in exclude]
|
||||
|
||||
if openai_style:
|
||||
logit_bias = {el[0]: -100 for el in logit_bias}
|
||||
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 64,
|
||||
"prompt": "What is the best book",
|
||||
"logit_bias": logit_bias,
|
||||
"temperature": 0.0
|
||||
})
|
||||
assert res.status_code == 200
|
||||
output_text = res.body["content"]
|
||||
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
|
||||
|
||||
|
||||
def test_cancel_request():
|
||||
global server
|
||||
server.n_ctx = 4096
|
||||
|
Reference in New Issue
Block a user