apply to the rest

This commit is contained in:
Xuan Son Nguyen
2025-03-13 22:36:27 +01:00
parent 4aabf4e8f4
commit 47086fa82d
18 changed files with 242 additions and 323 deletions

View File

@ -48,15 +48,11 @@ int main(int argc, char ** argv) {
auto tokens = common_tokenize(ctx, params.prompt, true);
// prepare the batch
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
for (size_t i = 0; i < tokens.size(); i++) {
common_batch_add(batch, tokens[i], i, {0}, false);
}
batch.logits[batch.n_tokens - 1] = true; // generate next token
llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0);
// evaluate prompt
llama_decode(ctx, batch);
n_past += batch.n_tokens;
llama_decode_ext(ctx, batch);
n_past += llama_batch_ext_get_n_tokens(batch);
// save state (rng, logits, embedding and kv_cache) to file
{
@ -83,12 +79,13 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result0 += next_token_str;
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {0}, true);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
llama_batch_ext_free(batch);
return 1;
}
n_past += 1;
@ -135,12 +132,13 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result1 += next_token_str;
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {0}, true);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 1;
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode(ctx2, batch)) {
if (llama_decode_ext(ctx2, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
llama_batch_ext_free(batch);
return 1;
}
n_past += 1;
@ -216,12 +214,13 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result2 += next_token_str;
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {1}, true);
llama_batch_ext_clear(batch);
llama_seq_id seq_id = 1;
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
if (llama_decode(ctx3, batch)) {
if (llama_decode_ext(ctx3, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
llama_batch_ext_free(batch);
return 1;
}
n_past += 1;
@ -233,7 +232,7 @@ int main(int argc, char ** argv) {
llama_sampler_free(smpl2);
llama_sampler_free(smpl3);
llama_batch_free(batch);
llama_batch_ext_free(batch);
if (result0 != result2) {
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);