apply to the rest

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:36:27 +01:00
parent 4aabf4e8f4
commit 47086fa82d
18 changed files with 242 additions and 323 deletions

View File

@@ -132,7 +132,7 @@ int main(int argc, char ** argv) {
struct common_speculative * spec = common_speculative_init(ctx_dft);
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), 1);
const auto t_enc_end = ggml_time_us();
@@ -151,8 +151,9 @@ int main(int argc, char ** argv) {
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
// always have a token to evaluate from before - id_last
common_batch_clear(batch_tgt);
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
llama_batch_ext_clear(batch_tgt);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch_tgt, id_last, n_past++, &seq_id, 1, true);
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
{
@@ -162,12 +163,12 @@ int main(int argc, char ** argv) {
}
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true);
}
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
llama_decode(ctx_tgt, batch_tgt);
llama_decode_ext(ctx_tgt, batch_tgt);
}
// sample from the full target batch and return the accepted tokens based on the target sampler
@@ -253,6 +254,7 @@ int main(int argc, char ** argv) {
common_sampler_free(smpl);
common_speculative_free(spec);
llama_batch_ext_free(batch_tgt);
llama_backend_free();
LOG("\n\n");