fix missing n_past in various places

this is actually a revert of cda0e4b648
This commit is contained in:
Xuan Son Nguyen
2025-03-14 10:47:08 +01:00
parent 32940369d3
commit 07d84fa3c2
6 changed files with 18 additions and 18 deletions

View File

@ -1427,7 +1427,7 @@ struct sql_printer : public printer {
} }
}; };
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) { static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads); llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
@ -1444,7 +1444,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
for (int i = 1; i < n_tokens; i++) { for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab; tokens[i] = std::rand() % n_vocab;
} }
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0, true)); llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true));
llama_decode_ext(ctx, batch.get()); llama_decode_ext(ctx, batch.get());
n_processed += n_tokens; n_processed += n_tokens;
} }
@ -1452,7 +1452,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
llama_synchronize(ctx); llama_synchronize(ctx);
} }
static void test_gen(llama_context * ctx, int n_gen, int n_threads) { static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads); llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
for (int i = 0; i < n_gen; i++) { for (int i = 0; i < n_gen; i++) {
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0, true)); llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, n_past + i, 0, true));
llama_decode_ext(ctx, batch.get()); llama_decode_ext(ctx, batch.get());
llama_synchronize(ctx); llama_synchronize(ctx);
token = std::rand() % n_vocab; token = std::rand() % n_vocab;
@ -1610,13 +1610,13 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count); fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
} }
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
} }
if (t.n_gen > 0) { if (t.n_gen > 0) {
if (params.progress) { if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count); fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
} }
test_gen(ctx, 1, t.n_threads); test_gen(ctx, 1, 0, t.n_threads);
} }
for (int i = 0; i < params.reps; i++) { for (int i = 0; i < params.reps; i++) {
@ -1629,14 +1629,14 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count, fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
i + 1, params.reps); i + 1, params.reps);
} }
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
} }
if (t.n_gen > 0) { if (t.n_gen > 0) {
if (params.progress) { if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count, fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
i + 1, params.reps); i + 1, params.reps);
} }
test_gen(ctx, t.n_gen, t.n_threads); test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
} }
uint64_t t_ns = get_time_ns() - t_start; uint64_t t_ns = get_time_ns() - t_start;

View File

@ -92,8 +92,8 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us(); const auto t_enc_start = ggml_time_us();
// eval the prompt // eval the prompt
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
llama_decode_ext(ctx, batch0.get()); llama_decode_ext(ctx, batch0.get());
llama_decode_ext(ctx, batch1.get()); llama_decode_ext(ctx, batch1.get());

View File

@ -91,8 +91,8 @@ int main(int argc, char ** argv){
const auto t_enc_start = ggml_time_us(); const auto t_enc_start = ggml_time_us();
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
llama_decode_ext(ctx, batch0.get()); llama_decode_ext(ctx, batch0.get());
llama_decode_ext(ctx, batch1.get()); llama_decode_ext(ctx, batch1.get());

View File

@ -133,7 +133,7 @@ int main(int argc, char ** argv) {
result1 += next_token_str; result1 += next_token_str;
llama_batch_ext_clear(batch); llama_batch_ext_clear(batch);
llama_seq_id seq_id = 1; llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode_ext(ctx2, batch)) { if (llama_decode_ext(ctx2, batch)) {
@ -215,7 +215,7 @@ int main(int argc, char ** argv) {
result2 += next_token_str; result2 += next_token_str;
llama_batch_ext_clear(batch); llama_batch_ext_clear(batch);
llama_seq_id seq_id = 1; llama_seq_id seq_id = 1; // seq 1 instead of 0
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode_ext(ctx3, batch)) { if (llama_decode_ext(ctx3, batch)) {

View File

@ -182,7 +182,7 @@ int main(int argc, char ** argv) {
// prepare the next batch with the sampled token // prepare the next batch with the sampled token
llama_batch_ext_clear(batch); llama_batch_ext_clear(batch);
llama_seq_id seq_id = 0; llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); llama_batch_ext_add_text(batch, new_token_id, n_pos, &seq_id, 1, true);
n_decode += 1; n_decode += 1;
} }

View File

@ -166,9 +166,9 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us(); const auto t_enc_start = ggml_time_us();
// eval the prompt with both models // eval the prompt with both models
llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true));
llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true));
llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true)); llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true));
llama_decode_ext(ctx_tgt, batch0); llama_decode_ext(ctx_tgt, batch0);
llama_decode_ext(ctx_tgt, batch1); llama_decode_ext(ctx_tgt, batch1);
llama_decode_ext(ctx_dft, batch2); llama_decode_ext(ctx_dft, batch2);