diff --git a/common/common.cpp b/common/common.cpp index ec4bf699a..c7cf66545 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1014,7 +1014,7 @@ struct common_init_result common_init_from_params(common_params & params) { } if (llama_model_has_encoder(model)) { - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0, true)); llama_encode_ext(lctx, batch.get()); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); if (decoder_start_token_id == LLAMA_TOKEN_NULL) { @@ -1024,7 +1024,7 @@ struct common_init_result common_init_from_params(common_params & params) { tmp.push_back(decoder_start_token_id); } if (llama_model_has_decoder(model)) { - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true)); llama_decode_ext(lctx, batch.get()); } llama_kv_self_clear(lctx); diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index f5ca61c31..13fa2c442 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { llama_kv_self_clear(ctx); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); + llama_batch_ext_set_output_last(batch.get()); if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 7e600440d..47dfd94d2 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -134,7 +134,7 @@ static bool run(llama_context * ctx, const common_params & params) { std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 631d0b07d..2c84ab8e7 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -353,7 +353,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true)); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 730e994b2..6a6ab4ab2 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1444,7 +1444,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0, true)); + llama_batch_ext_set_output_last(batch.get()); llama_decode_ext(ctx, batch.get()); n_processed += n_tokens; } @@ -1462,7 +1463,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) { llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0, true)); llama_decode_ext(ctx, batch.get()); llama_synchronize(ctx); token = std::rand() % n_vocab; diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index b4b6e63c7..233480354 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true)); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index adc3a615f..0740b4b4f 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -101,7 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true)); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 1e8de9673..88d0b1606 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -92,8 +92,8 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); llama_decode_ext(ctx, batch0.get()); llama_decode_ext(ctx, batch1.get()); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index a6bf80fdf..0e885fa41 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -91,8 +91,8 @@ int main(int argc, char ** argv){ const auto t_enc_start = ggml_time_us(); - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); llama_decode_ext(ctx, batch0.get()); llama_decode_ext(ctx, batch1.get()); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 2f735c420..8caf1ae3b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -548,7 +548,7 @@ int main(int argc, char ** argv) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0, true)); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; @@ -669,7 +669,8 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true)); + llama_batch_ext_set_output_last(batch.get()); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 02cafa9da..d7faa1472 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -946,7 +946,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str } // prepare a batch for the prompt - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); llama_token new_token_id; while (true) { check_context_size(llama_data.context, batch); @@ -969,7 +969,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str print_word_and_concatenate_to_response(piece, response); // prepare the next batch with the sampled token - batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0)); + batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0, true)); } printf(LOG_COL_DEFAULT); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index d1cf599b1..6ab35133b 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -48,7 +48,7 @@ int main(int argc, char ** argv) { auto tokens = common_tokenize(ctx, params.prompt, true); // prepare the batch - llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0); + llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true); // evaluate prompt llama_decode_ext(ctx, batch); diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index cee00ea82..0c2d34d56 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -108,8 +108,11 @@ int main(int argc, char ** argv) { } // prepare a batch for the prompt - llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); + llama_pos n_past = 0; + llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true); llama_batch_ext_set_output_last(batch); + n_past += llama_batch_ext_get_n_tokens(batch); + llama_token new_token_id; while (true) { // check if we have enough space in the context to evaluate this batch @@ -147,7 +150,8 @@ int main(int argc, char ** argv) { // prepare the next batch with the sampled token llama_batch_ext_clear(batch); llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); + llama_batch_ext_add_text(batch, new_token_id, n_past, &seq_id, 1, true); + n_past++; } llama_batch_ext_free(batch); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 7b3ba8d81..9101cc6bb 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -143,7 +143,7 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt - llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); + llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0, true); llama_batch_ext_set_output_last(batch); // main loop diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 2f4a85abd..61b9af2f0 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -113,7 +113,7 @@ int main(int argc, char ** argv) { struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); // eval the prompt - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0, true)); llama_decode_ext(ctx_tgt, batch.get()); // note: keep the last token separate! diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 2d44dc82c..2812846d1 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -166,9 +166,9 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0)); - llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); + llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true)); llama_decode_ext(ctx_tgt, batch0); llama_decode_ext(ctx_tgt, batch1); llama_decode_ext(ctx_dft, batch2); diff --git a/include/llama.h b/include/llama.h index 4521b3a41..5864519fd 100644 --- a/include/llama.h +++ b/include/llama.h @@ -928,12 +928,14 @@ extern "C" { // Same with llama_batch_init, but initializes the batch with the provided text tokens // First token will be at position pos0 // The sequence ID will be fixed to seq_id + // If output_last is true, the last token will have output set // The batch has to be freed with llama_batch_ext_free() LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text( llama_token * tokens, int32_t n_tokens, int32_t pos0, - int32_t seq_id); + int32_t seq_id, + bool output_last); // Same with llama_batch_init, but initializes the batch with the provided raw embeddings // First token will be at position pos0 diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index bae8b37b3..80f1592e9 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -341,11 +341,15 @@ struct llama_batch_ext * llama_batch_ext_init_from_text( llama_token * tokens, int32_t n_tokens, int32_t pos0, - int32_t seq_id) { + int32_t seq_id, + bool output_last) { llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1); for (int32_t i = 0; i < n_tokens; i++) { llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false); } + if (output_last) { + llama_batch_ext_set_output_last(batch); + } return batch; }