apply to the rest

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:36:27 +01:00
parent 4aabf4e8f4
commit 47086fa82d
18 changed files with 242 additions and 323 deletions

View File

@ -2,6 +2,7 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "llama-cpp.h"
#include <cmath>
#include <cstdio>
@ -122,7 +123,7 @@ int main(int argc, char ** argv) {
LOG_INF("prompt tokens: %d\n", n_tokens_all);
//LOG_INF("prompt: %s\n", params.prompt.c_str());
llama_batch batch = llama_batch_init(params.n_batch, 0, 1);
llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1));
int n_past = 0;
@ -140,17 +141,18 @@ int main(int argc, char ** argv) {
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
}
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
}
if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch.get());
}
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch.get()) != 0) {
LOG_INF("%s: llama_decode() failed\n", __func__);
return 1;
}
@ -174,17 +176,18 @@ int main(int argc, char ** argv) {
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false);
}
if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch.get());
}
if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch.get()) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
@ -223,7 +226,7 @@ int main(int argc, char ** argv) {
while (n_cur <= n_len) {
// sample the next token
{
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1);
// is it an end of generation?
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
@ -237,16 +240,17 @@ int main(int argc, char ** argv) {
n_decode += 1;
// prepare the next batch
common_batch_clear(batch);
llama_batch_ext_clear(batch.get());
// push this new token for next evaluation
common_batch_add(batch, new_token_id, n_past++, { 0 }, true);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), new_token_id, n_past++, &seq_id, 1, true);
}
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch.get())) {
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
@ -266,8 +270,6 @@ int main(int argc, char ** argv) {
llama_sampler_free(smpl);
llama_batch_free(batch);
llama_free(ctx);
llama_model_free(model);