fix llama_batch_ext_init_from_text

This commit is contained in:
Xuan Son Nguyen
2025-03-13 23:09:27 +01:00
parent 65f0184517
commit c3dd79007b
18 changed files with 40 additions and 27 deletions

View File

@ -1444,7 +1444,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab;
}
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0));
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0, true));
llama_batch_ext_set_output_last(batch.get());
llama_decode_ext(ctx, batch.get());
n_processed += n_tokens;
}
@ -1462,7 +1463,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
for (int i = 0; i < n_gen; i++) {
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0));
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0, true));
llama_decode_ext(ctx, batch.get());
llama_synchronize(ctx);
token = std::rand() % n_vocab;