mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 04:35:05 +00:00
speculative : adapt to new llama API
ggml-ci
This commit is contained in:
@ -45,7 +45,6 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
common_init();
|
||||
#if 0
|
||||
if (params.speculative.model.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
@ -169,9 +168,9 @@ int main(int argc, char ** argv) {
|
||||
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
|
||||
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
|
||||
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true));
|
||||
llama_decode_ext(ctx_tgt, batch0);
|
||||
llama_decode_ext(ctx_tgt, batch1);
|
||||
llama_decode_ext(ctx_dft, batch2);
|
||||
llama_decode_ext(ctx_tgt, batch0.get());
|
||||
llama_decode_ext(ctx_tgt, batch1.get());
|
||||
llama_decode_ext(ctx_dft, batch2.get());
|
||||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
|
||||
@ -338,7 +337,7 @@ int main(int argc, char ** argv) {
|
||||
if (i == s) {
|
||||
continue;
|
||||
}
|
||||
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
||||
if (drafts[i].active && drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
||||
// synchronize active status for sequences with the same drafted token
|
||||
drafts[i].active = drafts[i].active && accept;
|
||||
if (!drafts[i].active) {
|
||||
@ -446,7 +445,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_batch_ext_clear(batch_dft);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true);
|
||||
llama_batch_ext_add_text(batch_dft, token_id, n_past_dft, &seq_id, 1, true);
|
||||
|
||||
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||
@ -475,13 +474,19 @@ int main(int argc, char ** argv) {
|
||||
drafts[0].drafting = true;
|
||||
drafts[0].i_batch_dft = 0;
|
||||
|
||||
llama_batch_ext_clear(batch_tgt);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true);
|
||||
struct batch_info {
|
||||
llama_token id;
|
||||
llama_pos pos;
|
||||
std::vector<llama_seq_id> seq_id;
|
||||
};
|
||||
|
||||
std::vector<batch_info> batch_tgt_data;
|
||||
|
||||
batch_tgt_data.push_back({ drafts[0].tokens[0], n_past_tgt, {0} });
|
||||
|
||||
// sample n_draft tokens from the draft model using tree-based sampling
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
batch_dft.n_tokens = 0;
|
||||
llama_batch_ext_clear(batch_dft);
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].skip = false;
|
||||
@ -512,11 +517,10 @@ int main(int argc, char ** argv) {
|
||||
llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
||||
|
||||
// all previous tokens from this branch are now also part of the new branch
|
||||
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
||||
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
|
||||
if (batch_tgt.seq_id[t][p] == s) {
|
||||
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
|
||||
batch_tgt.n_seq_id[t]++;
|
||||
for (int t = 0; t < (int) batch_tgt_data.size(); ++t) {
|
||||
for (int p = 0; p < (int) batch_tgt_data[t].seq_id.size(); ++p) {
|
||||
if (batch_tgt_data[t].seq_id[p] == s) {
|
||||
batch_tgt_data[t].seq_id.push_back(n_seq_cur);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -558,32 +562,30 @@ int main(int argc, char ** argv) {
|
||||
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
|
||||
|
||||
// add unique drafted tokens to the target batch
|
||||
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||
drafts[s].i_batch_tgt.push_back(batch_tgt_data.size());
|
||||
|
||||
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
||||
batch_tgt_data.push_back({ id, n_past_tgt + i + 1, { s }});
|
||||
|
||||
// add the token to the batch for batched decoding with the draft model
|
||||
drafts[s].i_batch_dft = batch_dft.n_tokens;
|
||||
drafts[s].i_batch_dft = llama_batch_ext_add_text(batch_dft, id, n_past_cur, &s, 1, true);
|
||||
|
||||
common_batch_add(batch_dft, id, n_past_cur, { s }, true);
|
||||
|
||||
if (batch_tgt.n_tokens > n_draft) {
|
||||
if (batch_tgt_data.size() > (size_t) n_draft) {
|
||||
drafts[s].drafting = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// no sequence is drafting anymore
|
||||
if (batch_dft.n_tokens == 0) {
|
||||
if (llama_batch_ext_get_n_tokens(batch_dft) == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// evaluate the drafted tokens on the draft model
|
||||
llama_decode(ctx_dft, batch_dft);
|
||||
llama_decode_ext(ctx_dft, batch_dft);
|
||||
++n_past_cur;
|
||||
++n_drafted;
|
||||
|
||||
if (batch_tgt.n_tokens > n_draft) {
|
||||
if (batch_tgt_data.size() > (size_t) n_draft) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -595,8 +597,15 @@ int main(int argc, char ** argv) {
|
||||
llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||
}
|
||||
|
||||
llama_batch_ext_clear(batch_tgt);
|
||||
for (int i = 0; i < (int) batch_tgt_data.size(); ++i) {
|
||||
const auto & data = batch_tgt_data[i];
|
||||
|
||||
llama_batch_ext_add_text(batch_tgt, data.id, data.pos, data.seq_id.data(), data.seq_id.size(), true);
|
||||
}
|
||||
|
||||
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
||||
llama_decode(ctx_tgt, batch_tgt);
|
||||
llama_decode_ext(ctx_tgt, batch_tgt);
|
||||
++n_past_tgt;
|
||||
}
|
||||
|
||||
@ -639,12 +648,12 @@ int main(int argc, char ** argv) {
|
||||
common_sampler_free(drafts[s].smpl);
|
||||
}
|
||||
|
||||
llama_batch_free(batch_dft);
|
||||
llama_batch_ext_free(batch_dft);
|
||||
llama_batch_ext_free(batch_tgt);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
LOG("\n\n");
|
||||
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
Reference in New Issue
Block a user