diff --git a/common/common.cpp b/common/common.cpp index e4e71ad13..262b67998 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1005,15 +1005,21 @@ struct common_init_result common_init_from_params(common_params & params) { params.sampling.ignore_eos = false; } - if (params.sampling.ignore_eos) { - for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { - if (llama_vocab_is_eog(vocab, i)) { - LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); - params.sampling.logit_bias.push_back({i, -INFINITY}); - } + // initialize once + for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { + if (llama_vocab_is_eog(vocab, i)) { + LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); + params.sampling.logit_bias_eog.push_back({i, -INFINITY}); } } + if (params.sampling.ignore_eos) { + // add EOG biases to the active set of logit biases + params.sampling.logit_bias.insert( + params.sampling.logit_bias.end(), + params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end()); + } + if (params.sampling.penalty_last_n == -1) { LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); params.sampling.penalty_last_n = llama_n_ctx(lctx); diff --git a/common/common.h b/common/common.h index a5abe3285..248e82d87 100644 --- a/common/common.h +++ b/common/common.h @@ -177,7 +177,8 @@ struct common_params_sampling { std::vector grammar_triggers; // optional triggers (for lazy grammars) std::set preserved_tokens; - std::vector logit_bias; // logit biases to apply + std::vector logit_bias; // logit biases to apply + std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens // print the parameters into a string std::string print() const; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 1e7d64a28..0afe213af 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -473,12 +473,9 @@ struct server_task { params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); if (params.sampling.ignore_eos) { - for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { - if (llama_vocab_is_eog(vocab, i)) { - //SRV_DBG("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(ctx, i).c_str(), -INFINITY); - params.sampling.logit_bias.push_back({i, -INFINITY}); - } - } + params.sampling.logit_bias.insert( + params.sampling.logit_bias.end(), + defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end()); } } @@ -1906,7 +1903,6 @@ struct server_context { bool clean_kv_cache = true; bool add_bos_token = true; - bool has_eos_token = false; int32_t n_ctx; // total context for all clients / slots @@ -1965,7 +1961,6 @@ struct server_context { n_ctx = llama_n_ctx(ctx); add_bos_token = llama_vocab_get_add_bos(vocab); - has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) { SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());