adapt common

This commit is contained in:
Xuan Son Nguyen
2025-03-01 12:12:52 +01:00
parent a1b1dea33b
commit f0ffd81130
2 changed files with 18 additions and 16 deletions

View File

@ -13,7 +13,7 @@ struct common_speculative {
struct llama_context * ctx;
struct common_sampler * smpl;
llama_batch batch;
llama_batch_ext_ptr batch;
llama_tokens prompt;
};
@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init(
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .batch = */ llama_batch_ext_ptr(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)),
/* .prompt = */ {},
};
@ -68,8 +68,6 @@ void common_speculative_free(struct common_speculative * spec) {
common_sampler_free(spec->smpl);
llama_batch_free(spec->batch);
delete spec;
}
@ -150,6 +148,8 @@ llama_tokens common_speculative_gen_draft(
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
const llama_seq_id seq_id = 0;
// reuse as much as possible from the old draft context
// ideally, the draft context should be as big as the target context and we will always reuse the entire prompt
for (int i = 0; i < (int) prompt.size(); ++i) {
@ -205,40 +205,40 @@ llama_tokens common_speculative_gen_draft(
}
// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
llama_batch_ext_add_text_token(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false);
prompt.push_back(prompt_tgt[i]);
}
// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
if (llama_batch_ext_get_n_tokens(batch.get()) > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch.get());
}
const llama_pos n_past = prompt.size();
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
common_batch_clear(batch);
common_batch_add (batch, id_last, n_past, { 0 }, true);
llama_batch_ext_clear(batch.get());
llama_batch_ext_add_text_token(batch.get(), id_last, n_past, &seq_id, 1, true);
prompt.push_back(id_last);
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch.get());
common_sampler_reset(smpl);
// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
common_sampler_sample(smpl, ctx, 0, true);
@ -265,10 +265,10 @@ llama_tokens common_speculative_gen_draft(
break;
}
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
llama_batch_ext_add_text_token(batch.get(), id, n_past + i + 1, &seq_id, 1, true);
// evaluate the drafted tokens on the draft model
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch.get());
prompt.push_back(id);
}