diff --git a/examples/run/run.cpp b/examples/run/run.cpp index d7faa1472..39026813b 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -595,6 +595,7 @@ class LlamaData { std::vector messages; // TODO: switch to common_chat_msg std::list msg_strs; std::vector fmtted; + llama_pos n_past = 0; int init(Opt & opt) { model = initialize_model(opt); @@ -946,7 +947,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str } // prepare a batch for the prompt - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true)); llama_token new_token_id; while (true) { check_context_size(llama_data.context, batch); @@ -955,6 +956,8 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str return 1; } + llama_data.n_past += llama_batch_ext_get_n_tokens(batch.get()); + // sample the next token, check is it an end of generation? new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); if (llama_vocab_is_eog(vocab, new_token_id)) { @@ -969,7 +972,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str print_word_and_concatenate_to_response(piece, response); // prepare the next batch with the sampled token - batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0, true)); + batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, llama_data.n_past, 0, true)); } printf(LOG_COL_DEFAULT);