diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 4d987332a..ff5eceb64 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -45,7 +45,6 @@ int main(int argc, char ** argv) { } common_init(); -#if 0 if (params.speculative.model.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; @@ -169,9 +168,9 @@ int main(int argc, char ** argv) { llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true)); llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true)); - llama_decode_ext(ctx_tgt, batch0); - llama_decode_ext(ctx_tgt, batch1); - llama_decode_ext(ctx_dft, batch2); + llama_decode_ext(ctx_tgt, batch0.get()); + llama_decode_ext(ctx_tgt, batch1.get()); + llama_decode_ext(ctx_dft, batch2.get()); const auto t_enc_end = ggml_time_us(); @@ -338,7 +337,7 @@ int main(int argc, char ** argv) { if (i == s) { continue; } - if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) { + if (drafts[i].active && drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) { // synchronize active status for sequences with the same drafted token drafts[i].active = drafts[i].active && accept; if (!drafts[i].active) { @@ -446,7 +445,7 @@ int main(int argc, char ** argv) { llama_batch_ext_clear(batch_dft); llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true); + llama_batch_ext_add_text(batch_dft, token_id, n_past_dft, &seq_id, 1, true); llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); @@ -475,13 +474,19 @@ int main(int argc, char ** argv) { drafts[0].drafting = true; drafts[0].i_batch_dft = 0; - llama_batch_ext_clear(batch_tgt); - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true); + struct batch_info { + llama_token id; + llama_pos pos; + std::vector seq_id; + }; + + std::vector batch_tgt_data; + + batch_tgt_data.push_back({ drafts[0].tokens[0], n_past_tgt, {0} }); // sample n_draft tokens from the draft model using tree-based sampling for (int i = 0; i < n_draft; ++i) { - batch_dft.n_tokens = 0; + llama_batch_ext_clear(batch_dft); for (int s = 0; s < n_seq_dft; ++s) { drafts[s].skip = false; @@ -512,11 +517,10 @@ int main(int argc, char ** argv) { llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); // all previous tokens from this branch are now also part of the new branch - for (int t = 0; t < batch_tgt.n_tokens; ++t) { - for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) { - if (batch_tgt.seq_id[t][p] == s) { - batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur; - batch_tgt.n_seq_id[t]++; + for (int t = 0; t < (int) batch_tgt_data.size(); ++t) { + for (int p = 0; p < (int) batch_tgt_data[t].seq_id.size(); ++p) { + if (batch_tgt_data[t].seq_id[p] == s) { + batch_tgt_data[t].seq_id.push_back(n_seq_cur); break; } } @@ -558,32 +562,30 @@ int main(int argc, char ** argv) { drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size}); // add unique drafted tokens to the target batch - drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); + drafts[s].i_batch_tgt.push_back(batch_tgt_data.size()); - common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true); + batch_tgt_data.push_back({ id, n_past_tgt + i + 1, { s }}); // add the token to the batch for batched decoding with the draft model - drafts[s].i_batch_dft = batch_dft.n_tokens; + drafts[s].i_batch_dft = llama_batch_ext_add_text(batch_dft, id, n_past_cur, &s, 1, true); - common_batch_add(batch_dft, id, n_past_cur, { s }, true); - - if (batch_tgt.n_tokens > n_draft) { + if (batch_tgt_data.size() > (size_t) n_draft) { drafts[s].drafting = false; } } } // no sequence is drafting anymore - if (batch_dft.n_tokens == 0) { + if (llama_batch_ext_get_n_tokens(batch_dft) == 0) { break; } // evaluate the drafted tokens on the draft model - llama_decode(ctx_dft, batch_dft); + llama_decode_ext(ctx_dft, batch_dft); ++n_past_cur; ++n_drafted; - if (batch_tgt.n_tokens > n_draft) { + if (batch_tgt_data.size() > (size_t) n_draft) { break; } } @@ -595,8 +597,15 @@ int main(int argc, char ** argv) { llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1); } + llama_batch_ext_clear(batch_tgt); + for (int i = 0; i < (int) batch_tgt_data.size(); ++i) { + const auto & data = batch_tgt_data[i]; + + llama_batch_ext_add_text(batch_tgt, data.id, data.pos, data.seq_id.data(), data.seq_id.size(), true); + } + // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); - llama_decode(ctx_tgt, batch_tgt); + llama_decode_ext(ctx_tgt, batch_tgt); ++n_past_tgt; } @@ -639,12 +648,12 @@ int main(int argc, char ** argv) { common_sampler_free(drafts[s].smpl); } - llama_batch_free(batch_dft); + llama_batch_ext_free(batch_dft); + llama_batch_ext_free(batch_tgt); llama_backend_free(); LOG("\n\n"); -#endif return 0; }