diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index b3236ea85..e0b647632 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -343,7 +343,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { llama_kv_self_clear(ctx); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(tokens.data(), tokens.size(), 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 47dfd94d2..86a36223d 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -134,7 +134,7 @@ static bool run(llama_context * ctx, const common_params & params) { std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(tokens.data(), tokens.size(), 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 2c84ab8e7..574ef644f 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -353,7 +353,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(&embd[i], n_eval, n_past, 0, true); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 992df2b51..0ed841f09 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1444,7 +1444,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true); llama_decode_ext(ctx, batch.get()); n_processed += n_tokens; } @@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, n_past + i, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(&token, 1, n_past + i, 0, true); llama_decode_ext(ctx, batch.get()); llama_synchronize(ctx); token = std::rand() % n_vocab; diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 233480354..ed4326f87 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(&tokens[i], n_eval, *n_past, 0, true); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index de967e069..901061ca3 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -448,7 +448,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ n_eval = n_batch; } float * embd = image_embed->embed+i*n_embd; - llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, n_embd, 0, 0)); + auto batch = llama_batch_ext_ptr::from_embd(embd, n_eval, n_embd, 0, 0); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 0740b4b4f..2a725d384 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -101,7 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(&tokens[i], n_eval, *n_past, 0, true); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index a702ab46a..c655fd7a2 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -67,7 +67,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); float * batch_embd = image_embed->embed+i*n_embd; - llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(batch_embd, n_eval, n_embd, 0, 0)); + auto batch = llama_batch_ext_ptr::from_embd(batch_embd, n_eval, n_embd, 0, 0); llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval); if (llama_decode_ext(ctx_llama, batch.get())) { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 8caf1ae3b..1ec5a51aa 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -548,7 +548,7 @@ int main(int argc, char ** argv) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(enc_input_buf, enc_input_size, 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; @@ -669,7 +669,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(&embd[i], n_eval, n_past, 0, true); llama_batch_ext_set_output_last(batch.get()); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 39026813b..aac2f3900 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -947,7 +947,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str } // prepare a batch for the prompt - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true); llama_token new_token_id; while (true) { check_context_size(llama_data.context, batch); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 61b9af2f0..b15593b15 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -113,7 +113,7 @@ int main(int argc, char ** argv) { struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); // eval the prompt - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(inp.data(), inp.size() - 1, 0, 0, true); llama_decode_ext(ctx_tgt, batch.get()); // note: keep the last token separate! diff --git a/include/llama-cpp.h b/include/llama-cpp.h index 880a6a5fa..dfced7ef9 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -32,4 +32,24 @@ typedef std::unique_ptr llama_model_ptr; typedef std::unique_ptr llama_context_ptr; typedef std::unique_ptr llama_sampler_ptr; typedef std::unique_ptr llama_adapter_lora_ptr; -typedef std::unique_ptr llama_batch_ext_ptr; + +struct llama_batch_ext_ptr : std::unique_ptr { + llama_batch_ext_ptr(llama_batch_ext * batch) : std::unique_ptr(batch) {} + + // convience function to create a batch from text tokens, without worrying about manually freeing it + static llama_batch_ext_ptr from_text(llama_token * tokens, + int32_t n_tokens, + int32_t pos0, + int32_t seq_id, + bool output_last) { + return llama_batch_ext_ptr(llama_batch_ext_init_from_text(tokens, n_tokens, pos0, seq_id, output_last)); + } + + static llama_batch_ext_ptr from_embd(float * embd, + size_t n_tokens, + size_t n_embd, + int32_t pos0, + int32_t seq_id) { + return llama_batch_ext_ptr(llama_batch_ext_init_from_embd(embd, n_tokens, n_embd, pos0, seq_id)); + } +};