From a9f77a8be348deeb11fb3d54d412bf583003c90d Mon Sep 17 00:00:00 2001 From: Lukas Straub Date: Thu, 31 Jul 2025 14:08:23 +0200 Subject: [PATCH] server : add openai-style logit_bias support (#14946) Signed-off-by: Lukas Straub --- tools/server/README.md | 2 +- tools/server/server.cpp | 27 +++++++++++++++ .../server/tests/unit/test_chat_completion.py | 29 ++++++++++++++++ tools/server/tests/unit/test_completion.py | 33 +++++++++++++++++++ 4 files changed, 90 insertions(+), 1 deletion(-) diff --git a/tools/server/README.md b/tools/server/README.md index f3f4caed8..87cef7573 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -469,7 +469,7 @@ These words will not be included in the completion, so make sure to add them to `ignore_eos`: Ignore end of stream token and continue generating. Default: `false` -`logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]` +`logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. For compatibility with the OpenAI API, a JSON object {"": bias, ...} can also be passed. Default: `[]` `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0` diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 2e4c40af7..9a9b04447 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -473,6 +473,33 @@ struct server_task { } } } + } else if (logit_bias != data.end() && logit_bias->is_object()) { + const int n_vocab = llama_vocab_n_tokens(vocab); + for (const auto & el : logit_bias->items()) { + float bias; + const auto & key = el.key(); + const auto & value = el.value(); + if (value.is_number()) { + bias = value.get(); + } else if (value.is_boolean() && !value.get()) { + bias = -INFINITY; + } else { + continue; + } + + char *end; + llama_token tok = strtol(key.c_str(), &end, 10); + if (*end == 0) { + if (tok >= 0 && tok < n_vocab) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } else { + auto toks = common_tokenize(vocab, key, false); + for (auto tok : toks) { + params.sampling.logit_bias.push_back({tok, bias}); + } + } + } } params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 7ee9a1651..6c6f64f5e 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -351,3 +351,32 @@ def test_logprobs_stream(): assert token.top_logprobs is not None assert len(token.top_logprobs) > 0 assert aggregated_text == output_text + + +def test_logit_bias(): + global server + server.start() + + exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"] + + res = server.make_request("POST", "/tokenize", data={ + "content": " " + " ".join(exclude) + " ", + }) + assert res.status_code == 200 + tokens = res.body["tokens"] + logit_bias = {tok: -100 for tok in tokens} + + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + res = client.chat.completions.create( + model="gpt-3.5-turbo-instruct", + temperature=0.0, + messages=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_tokens=64, + logit_bias=logit_bias + ) + output_text = res.choices[0].message.content + assert output_text + assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude) diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index f6909e9ae..be3a0052c 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -444,6 +444,39 @@ def test_n_probs_post_sampling(): assert any(prob["prob"] == 1.0 for prob in tok["top_probs"]) +@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)]) +def test_logit_bias(tokenize, openai_style): + global server + server.start() + + exclude = ["i", "I", "the", "The", "to", "a", "an", "be", "is", "was", "but", "But", "and", "And", "so", "So", "you", "You", "he", "He", "she", "She", "we", "We", "they", "They", "it", "It", "his", "His", "her", "Her", "book", "Book"] + + logit_bias = [] + if tokenize: + res = server.make_request("POST", "/tokenize", data={ + "content": " " + " ".join(exclude) + " ", + }) + assert res.status_code == 200 + tokens = res.body["tokens"] + logit_bias = [[tok, -100] for tok in tokens] + + else: + logit_bias = [[" " + tok + " ", -100] for tok in exclude] + + if openai_style: + logit_bias = {el[0]: -100 for el in logit_bias} + + res = server.make_request("POST", "/completion", data={ + "n_predict": 64, + "prompt": "What is the best book", + "logit_bias": logit_bias, + "temperature": 0.0 + }) + assert res.status_code == 200 + output_text = res.body["content"] + assert all(output_text.find(" " + tok + " ") == -1 for tok in exclude) + + def test_cancel_request(): global server server.n_ctx = 4096