server : add openai-style logit_bias support (#14946)

Signed-off-by: Lukas Straub <lukasstraub2@web.de>
This commit is contained in:
Lukas Straub
2025-07-31 14:08:23 +02:00
committed by GitHub
parent 8a4a856277
commit a9f77a8be3
4 changed files with 90 additions and 1 deletions

View File

@@ -351,3 +351,32 @@ def test_logprobs_stream():
assert token.top_logprobs is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text
def test_logit_bias():
global server
server.start()
exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
res = server.make_request("POST", "/tokenize", data={
"content": " " + " ".join(exclude) + " ",
})
assert res.status_code == 200
tokens = res.body["tokens"]
logit_bias = {tok: -100 for tok in tokens}
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=64,
logit_bias=logit_bias
)
output_text = res.choices[0].message.content
assert output_text
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)

View File

@@ -444,6 +444,39 @@ def test_n_probs_post_sampling():
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
def test_logit_bias(tokenize, openai_style):
global server
server.start()
exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"]
logit_bias = []
if tokenize:
res = server.make_request("POST", "/tokenize", data={
"content": " " + " ".join(exclude) + " ",
})
assert res.status_code == 200
tokens = res.body["tokens"]
logit_bias = [[tok, -100] for tok in tokens]
else:
logit_bias = [[" " + tok + " ", -100] for tok in exclude]
if openai_style:
logit_bias = {el[0]: -100 for el in logit_bias}
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
"prompt": "What is the best book",
"logit_bias": logit_bias,
"temperature": 0.0
})
assert res.status_code == 200
output_text = res.body["content"]
assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude)
def test_cancel_request():
global server
server.n_ctx = 4096