mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
apply to the rest
This commit is contained in:
@ -48,15 +48,11 @@ int main(int argc, char ** argv) {
|
||||
auto tokens = common_tokenize(ctx, params.prompt, true);
|
||||
|
||||
// prepare the batch
|
||||
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
common_batch_add(batch, tokens[i], i, {0}, false);
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true; // generate next token
|
||||
llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0);
|
||||
|
||||
// evaluate prompt
|
||||
llama_decode(ctx, batch);
|
||||
n_past += batch.n_tokens;
|
||||
llama_decode_ext(ctx, batch);
|
||||
n_past += llama_batch_ext_get_n_tokens(batch);
|
||||
|
||||
// save state (rng, logits, embedding and kv_cache) to file
|
||||
{
|
||||
@ -83,12 +79,13 @@ int main(int argc, char ** argv) {
|
||||
printf("%s", next_token_str.c_str());
|
||||
result0 += next_token_str;
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, n_past, {0}, true);
|
||||
llama_batch_ext_clear(batch);
|
||||
llama_seq_id seq_id = 0;
|
||||
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
if (llama_decode_ext(ctx, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
return 1;
|
||||
}
|
||||
n_past += 1;
|
||||
@ -135,12 +132,13 @@ int main(int argc, char ** argv) {
|
||||
printf("%s", next_token_str.c_str());
|
||||
result1 += next_token_str;
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, n_past, {0}, true);
|
||||
llama_batch_ext_clear(batch);
|
||||
llama_seq_id seq_id = 1;
|
||||
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
|
||||
|
||||
if (llama_decode(ctx2, batch)) {
|
||||
if (llama_decode_ext(ctx2, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
return 1;
|
||||
}
|
||||
n_past += 1;
|
||||
@ -216,12 +214,13 @@ int main(int argc, char ** argv) {
|
||||
printf("%s", next_token_str.c_str());
|
||||
result2 += next_token_str;
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, next_token, n_past, {1}, true);
|
||||
llama_batch_ext_clear(batch);
|
||||
llama_seq_id seq_id = 1;
|
||||
llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true);
|
||||
|
||||
if (llama_decode(ctx3, batch)) {
|
||||
if (llama_decode_ext(ctx3, batch)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
return 1;
|
||||
}
|
||||
n_past += 1;
|
||||
@ -233,7 +232,7 @@ int main(int argc, char ** argv) {
|
||||
llama_sampler_free(smpl2);
|
||||
llama_sampler_free(smpl3);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_batch_ext_free(batch);
|
||||
|
||||
if (result0 != result2) {
|
||||
fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
|
||||
|
Reference in New Issue
Block a user