diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index bf39134d0..992df2b51 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1427,7 +1427,7 @@ struct sql_printer : public printer { } }; -static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) { +static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); @@ -1444,7 +1444,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0, true)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true)); llama_decode_ext(ctx, batch.get()); n_processed += n_tokens; } @@ -1452,7 +1452,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th llama_synchronize(ctx); } -static void test_gen(llama_context * ctx, int n_gen, int n_threads) { +static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); @@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) { llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0, true)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, n_past + i, 0, true)); llama_decode_ext(ctx, batch.get()); llama_synchronize(ctx); token = std::rand() % n_vocab; @@ -1610,13 +1610,13 @@ int main(int argc, char ** argv) { fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count); } //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); - test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); } if (t.n_gen > 0) { if (params.progress) { fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count); } - test_gen(ctx, 1, t.n_threads); + test_gen(ctx, 1, 0, t.n_threads); } for (int i = 0; i < params.reps; i++) { @@ -1629,14 +1629,14 @@ int main(int argc, char ** argv) { fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps); } - test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); } if (t.n_gen > 0) { if (params.progress) { fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps); } - test_gen(ctx, t.n_gen, t.n_threads); + test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads); } uint64_t t_ns = get_time_ns() - t_start; diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 88d0b1606..827755968 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -92,8 +92,8 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true)); llama_decode_ext(ctx, batch0.get()); llama_decode_ext(ctx, batch1.get()); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 0e885fa41..07e57afcb 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -91,8 +91,8 @@ int main(int argc, char ** argv){ const auto t_enc_start = ggml_time_us(); - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true)); llama_decode_ext(ctx, batch0.get()); llama_decode_ext(ctx, batch1.get()); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 6ab35133b..2ff4e24c1 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -133,7 +133,7 @@ int main(int argc, char ** argv) { result1 += next_token_str; llama_batch_ext_clear(batch); - llama_seq_id seq_id = 1; + llama_seq_id seq_id = 0; llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); if (llama_decode_ext(ctx2, batch)) { @@ -215,7 +215,7 @@ int main(int argc, char ** argv) { result2 += next_token_str; llama_batch_ext_clear(batch); - llama_seq_id seq_id = 1; + llama_seq_id seq_id = 1; // seq 1 instead of 0 llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); if (llama_decode_ext(ctx3, batch)) { diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 4aea9dbdc..26009a5ae 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -182,7 +182,7 @@ int main(int argc, char ** argv) { // prepare the next batch with the sampled token llama_batch_ext_clear(batch); llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); + llama_batch_ext_add_text(batch, new_token_id, n_pos, &seq_id, 1, true); n_decode += 1; } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 2812846d1..4d987332a 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -166,9 +166,9 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); - llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true)); + llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true)); llama_decode_ext(ctx_tgt, batch0); llama_decode_ext(ctx_tgt, batch1); llama_decode_ext(ctx_dft, batch2);