apply to the rest

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:36:27 +01:00
parent 4aabf4e8f4
commit 47086fa82d
18 changed files with 242 additions and 323 deletions

View File

@ -143,7 +143,8 @@ int main(int argc, char ** argv) {
// prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0);
llama_batch_ext_set_output_last(batch);
// main loop
@ -151,14 +152,14 @@ int main(int argc, char ** argv) {
int n_decode = 0;
llama_token new_token_id;
for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) {
for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch) < n_prompt + n_predict; ) {
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
n_pos += batch.n_tokens;
n_pos += llama_batch_ext_get_n_tokens(batch);
// sample the next token
{
@ -180,7 +181,9 @@ int main(int argc, char ** argv) {
fflush(stdout);
// prepare the next batch with the sampled token
batch = llama_batch_get_one(&new_token_id, 1);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true);
n_decode += 1;
}
@ -198,6 +201,7 @@ int main(int argc, char ** argv) {
llama_perf_context_print(ctx);
fprintf(stderr, "\n");
llama_batch_ext_free(batch);
llama_sampler_free(smpl);
llama_free(ctx);
llama_model_free(model);