apply to the rest

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:36:27 +01:00
parent 4aabf4e8f4
commit 47086fa82d
18 changed files with 242 additions and 323 deletions

View File

@ -45,7 +45,7 @@ int main(int argc, char ** argv) {
}
common_init();
#ifdef 0
if (params.speculative.model.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
@ -199,8 +199,8 @@ int main(int argc, char ** argv) {
drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
}
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
llama_batch_ext * batch_dft = llama_batch_ext_init(llama_n_batch(ctx_dft), 1);
llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft);
const auto t_dec_start = ggml_time_us();
@ -441,12 +441,13 @@ int main(int argc, char ** argv) {
drafts[0].dists.push_back(std::vector<llama_token_data>());
drafts[0].i_batch_tgt.push_back(0);
common_batch_clear(batch_dft);
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
llama_batch_ext_clear(batch_dft);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true);
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
llama_decode(ctx_dft, batch_dft);
llama_decode_ext(ctx_dft, batch_dft);
++n_past_dft;
}
@ -471,8 +472,9 @@ int main(int argc, char ** argv) {
drafts[0].drafting = true;
drafts[0].i_batch_dft = 0;
common_batch_clear(batch_tgt);
common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
llama_batch_ext_clear(batch_tgt);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true);
// sample n_draft tokens from the draft model using tree-based sampling
for (int i = 0; i < n_draft; ++i) {
@ -640,5 +642,6 @@ int main(int argc, char ** argv) {
LOG("\n\n");
#endif
return 0;
}