apply to the rest

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:36:27 +01:00
parent 4aabf4e8f4
commit 47086fa82d
18 changed files with 242 additions and 323 deletions

View File

@ -115,7 +115,7 @@ int main(int argc, char ** argv) {
// seq_id == 0 : the current input token
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
// seq_id [W + 1, W + G] : verification n-grams
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
llama_batch_ext * batch = llama_batch_ext_init(params.n_ctx, W + G + 1);
// target model sampling context
struct common_sampler * smpl = common_sampler_init(model, params.sampling);
@ -204,10 +204,10 @@ int main(int argc, char ** argv) {
// V V V V V V
// id
{
common_batch_clear(batch);
llama_batch_ext_clear(batch);
// current token - first token of the first level
common_batch_add(batch, id, n_past, seq_id_all, true);
llama_batch_ext_add_text(batch, id, n_past, seq_id_all.data(), seq_id_all.size(), true);
// verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
{
@ -230,9 +230,10 @@ int main(int argc, char ** argv) {
const llama_token t = ngrams_observed.tokens[idx + j];
ngrams_cur[g].tokens [j + 1] = t;
ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
ngrams_cur[g].i_batch[j + 1] = llama_batch_ext_get_n_tokens(batch);
common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
llama_seq_id seq_id = W + 1 + g;
llama_batch_ext_add_text(batch, t, n_past + j + 1, &seq_id, 1, true);
}
}
}
@ -244,18 +245,20 @@ int main(int argc, char ** argv) {
seq_id_look[j] = i + j + 1;
}
common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
llama_batch_ext_add_text(batch, tokens_j[0][i], n_past + i,
seq_id_look.data(), seq_id_look.size(), false);
}
// fill the rest of the levels
for (int j = 1; j < N - 1; j++) {
for (int i = 0; i < W; i++) {
common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
llama_seq_id seq_id = i + 1;
llama_batch_ext_add_text(batch, tokens_j[j][i], n_past + j + i, &seq_id, 1, j == N - 2);
}
}
}
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch) != 0) {
LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__);
return 1;
}
@ -475,7 +478,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_view_free(&kvc_view);
llama_batch_free(batch);
llama_batch_ext_free(batch);
llama_backend_free();