speculative : adapt to new llama API

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-03-18 16:10:26 +02:00
parent dc4bb64290
commit 7a3c178d78

View File

@ -45,7 +45,6 @@ int main(int argc, char ** argv) {
}
common_init();
#if 0
if (params.speculative.model.empty()) {
LOG_ERR("%s: --model-draft is required\n", __func__);
return 1;
@ -169,9 +168,9 @@ int main(int argc, char ** argv) {
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true));
llama_decode_ext(ctx_tgt, batch0);
llama_decode_ext(ctx_tgt, batch1);
llama_decode_ext(ctx_dft, batch2);
llama_decode_ext(ctx_tgt, batch0.get());
llama_decode_ext(ctx_tgt, batch1.get());
llama_decode_ext(ctx_dft, batch2.get());
const auto t_enc_end = ggml_time_us();
@ -338,7 +337,7 @@ int main(int argc, char ** argv) {
if (i == s) {
continue;
}
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
if (drafts[i].active && drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
// synchronize active status for sequences with the same drafted token
drafts[i].active = drafts[i].active && accept;
if (!drafts[i].active) {
@ -446,7 +445,7 @@ int main(int argc, char ** argv) {
llama_batch_ext_clear(batch_dft);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true);
llama_batch_ext_add_text(batch_dft, token_id, n_past_dft, &seq_id, 1, true);
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
@ -475,13 +474,19 @@ int main(int argc, char ** argv) {
drafts[0].drafting = true;
drafts[0].i_batch_dft = 0;
llama_batch_ext_clear(batch_tgt);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true);
struct batch_info {
llama_token id;
llama_pos pos;
std::vector<llama_seq_id> seq_id;
};
std::vector<batch_info> batch_tgt_data;
batch_tgt_data.push_back({ drafts[0].tokens[0], n_past_tgt, {0} });
// sample n_draft tokens from the draft model using tree-based sampling
for (int i = 0; i < n_draft; ++i) {
batch_dft.n_tokens = 0;
llama_batch_ext_clear(batch_dft);
for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].skip = false;
@ -512,11 +517,10 @@ int main(int argc, char ** argv) {
llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
// all previous tokens from this branch are now also part of the new branch
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
if (batch_tgt.seq_id[t][p] == s) {
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
batch_tgt.n_seq_id[t]++;
for (int t = 0; t < (int) batch_tgt_data.size(); ++t) {
for (int p = 0; p < (int) batch_tgt_data[t].seq_id.size(); ++p) {
if (batch_tgt_data[t].seq_id[p] == s) {
batch_tgt_data[t].seq_id.push_back(n_seq_cur);
break;
}
}
@ -558,32 +562,30 @@ int main(int argc, char ** argv) {
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
// add unique drafted tokens to the target batch
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
drafts[s].i_batch_tgt.push_back(batch_tgt_data.size());
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
batch_tgt_data.push_back({ id, n_past_tgt + i + 1, { s }});
// add the token to the batch for batched decoding with the draft model
drafts[s].i_batch_dft = batch_dft.n_tokens;
drafts[s].i_batch_dft = llama_batch_ext_add_text(batch_dft, id, n_past_cur, &s, 1, true);
common_batch_add(batch_dft, id, n_past_cur, { s }, true);
if (batch_tgt.n_tokens > n_draft) {
if (batch_tgt_data.size() > (size_t) n_draft) {
drafts[s].drafting = false;
}
}
}
// no sequence is drafting anymore
if (batch_dft.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch_dft) == 0) {
break;
}
// evaluate the drafted tokens on the draft model
llama_decode(ctx_dft, batch_dft);
llama_decode_ext(ctx_dft, batch_dft);
++n_past_cur;
++n_drafted;
if (batch_tgt.n_tokens > n_draft) {
if (batch_tgt_data.size() > (size_t) n_draft) {
break;
}
}
@ -595,8 +597,15 @@ int main(int argc, char ** argv) {
llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1);
}
llama_batch_ext_clear(batch_tgt);
for (int i = 0; i < (int) batch_tgt_data.size(); ++i) {
const auto & data = batch_tgt_data[i];
llama_batch_ext_add_text(batch_tgt, data.id, data.pos, data.seq_id.data(), data.seq_id.size(), true);
}
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
llama_decode_ext(ctx_tgt, batch_tgt);
++n_past_tgt;
}
@ -639,12 +648,12 @@ int main(int argc, char ** argv) {
common_sampler_free(drafts[s].smpl);
}
llama_batch_free(batch_dft);
llama_batch_ext_free(batch_dft);
llama_batch_ext_free(batch_tgt);
llama_backend_free();
LOG("\n\n");
#endif
return 0;
}