mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
apply to the rest
This commit is contained in:
@ -45,7 +45,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
#ifdef 0
|
||||
if (params.speculative.model.empty()) {
|
||||
LOG_ERR("%s: --model-draft is required\n", __func__);
|
||||
return 1;
|
||||
@ -199,8 +199,8 @@ int main(int argc, char ** argv) {
|
||||
drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
|
||||
}
|
||||
|
||||
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
||||
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
|
||||
llama_batch_ext * batch_dft = llama_batch_ext_init(llama_n_batch(ctx_dft), 1);
|
||||
llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft);
|
||||
|
||||
const auto t_dec_start = ggml_time_us();
|
||||
|
||||
@ -441,12 +441,13 @@ int main(int argc, char ** argv) {
|
||||
drafts[0].dists.push_back(std::vector<llama_token_data>());
|
||||
drafts[0].i_batch_tgt.push_back(0);
|
||||
|
||||
common_batch_clear(batch_dft);
|
||||
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
||||
llama_batch_ext_clear(batch_dft);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true);
|
||||
|
||||
llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
||||
llama_decode(ctx_dft, batch_dft);
|
||||
llama_decode_ext(ctx_dft, batch_dft);
|
||||
|
||||
++n_past_dft;
|
||||
}
|
||||
@ -471,8 +472,9 @@ int main(int argc, char ** argv) {
|
||||
drafts[0].drafting = true;
|
||||
drafts[0].i_batch_dft = 0;
|
||||
|
||||
common_batch_clear(batch_tgt);
|
||||
common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
|
||||
llama_batch_ext_clear(batch_tgt);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true);
|
||||
|
||||
// sample n_draft tokens from the draft model using tree-based sampling
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
@ -640,5 +642,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
LOG("\n\n");
|
||||
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
Reference in New Issue
Block a user