diff --git a/common/speculative.cpp b/common/speculative.cpp index 318e96ea3..b1fff27a5 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -252,11 +252,6 @@ llama_tokens common_speculative_gen_draft( // add drafted token for each sequence const llama_token id = cur_p->data[0].id; - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < params.p_min) { - break; - } - common_sampler_accept(smpl, id, true); result.push_back(id); @@ -265,6 +260,11 @@ llama_tokens common_speculative_gen_draft( break; } + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + break; + } + common_batch_add(batch, id, n_past + i + 1, { 0 }, true); // evaluate the drafted tokens on the draft model diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 809bfe0e3..2306dc26f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -274,7 +274,7 @@ struct server_task { params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); - params.speculative.n_min = std::max(params.speculative.n_min, 2); + params.speculative.n_min = std::max(params.speculative.n_min, 0); params.speculative.n_max = std::max(params.speculative.n_max, 0); // Use OpenAI API logprobs only if n_probs wasn't provided