diff --git a/common/common.cpp b/common/common.cpp index 92f2c57cc..f8498f01d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1576,35 +1576,6 @@ std::pair common_get_hf_file(const std::string &, cons #endif // LLAMA_USE_CURL -// -// Batch utils -// - -// DEPRECATED -void common_batch_clear(struct llama_batch & batch) { - batch.n_tokens = 0; -} - -// DEPRECATED -void common_batch_add( - struct llama_batch & batch, - llama_token id, - llama_pos pos, - const std::vector & seq_ids, - bool logits) { - GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); - - batch.token [batch.n_tokens] = id; - batch.pos [batch.n_tokens] = pos; - batch.n_seq_id[batch.n_tokens] = seq_ids.size(); - for (size_t i = 0; i < seq_ids.size(); ++i) { - batch.seq_id[batch.n_tokens][i] = seq_ids[i]; - } - batch.logits [batch.n_tokens] = logits; - - batch.n_tokens++; -} - // // Token utils // diff --git a/common/common.h b/common/common.h index c223685f2..5fe149ff8 100644 --- a/common/common.h +++ b/common/common.h @@ -569,17 +569,6 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector & seq_ids, - bool logits); - // convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions // this is meant to be temporary struct common_batch { diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 9654cd53c..9bf7db399 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -125,7 +125,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo ctx_params.n_threads = n_threads; ctx_params.n_threads_batch = n_threads; - llama_context * context = llama_new_context_with_model(model, ctx_params); + llama_context * context = llama_init_from_model(model, ctx_params); if (!context) { LOGe("llama_new_context_with_model() returned null)"); @@ -175,7 +175,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto context = reinterpret_cast(context_pointer); const auto model = reinterpret_cast(model_pointer); - const auto batch = reinterpret_cast(batch_pointer); + const auto batch = reinterpret_cast(batch_pointer); const int n_ctx = llama_n_ctx(context); @@ -186,19 +186,20 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( for (nri = 0; nri < nr; nri++) { LOGi("Benchmark prompt processing (pp)"); - common_batch_clear(*batch); + llama_batch_ext_clear(batch); const int n_tokens = pp; for (i = 0; i < n_tokens; i++) { - common_batch_add(*batch, 0, i, { 0 }, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false); } - batch->logits[batch->n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch); llama_kv_self_clear(context); const auto t_pp_start = ggml_time_us(); - if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during prompt processing"); + if (llama_decode_ext(context, batch) != 0) { + LOGi("llama_decode_ext() failed during prompt processing"); } const auto t_pp_end = ggml_time_us(); @@ -210,14 +211,15 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { - common_batch_clear(*batch); + llama_batch_ext_clear(batch); for (j = 0; j < pl; j++) { - common_batch_add(*batch, 0, i, { j }, true); + llama_seq_id seq_id = j; + llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, true); } - LOGi("llama_decode() text generation: %d", i); - if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during text generation"); + LOGi("llama_decode_ext() text generation: %d", i); + if (llama_decode_ext(context, batch) != 0) { + LOGi("llama_decode_ext() failed during text generation"); } } @@ -272,32 +274,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( extern "C" JNIEXPORT jlong JNICALL Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { - - // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. - - llama_batch *batch = new llama_batch { - 0, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - }; - - if (embd) { - batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); - } else { - batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); - } - - batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); - batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); - batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); - for (int i = 0; i < n_tokens; ++i) { - batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); - } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + llama_batch_ext * batch = llama_batch_ext_init(n_tokens, n_seq_max); return reinterpret_cast(batch); } @@ -305,9 +282,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { - //llama_batch_free(*reinterpret_cast(batch_pointer)); - const auto batch = reinterpret_cast(batch_pointer); - delete batch; + llama_batch_ext_free(reinterpret_cast(batch_pointer)); } extern "C" @@ -355,7 +330,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( const auto text = env->GetStringUTFChars(jtext, 0); const auto context = reinterpret_cast(context_pointer); - const auto batch = reinterpret_cast(batch_pointer); + const auto batch = reinterpret_cast(batch_pointer); bool parse_special = (format_chat == JNI_TRUE); const auto tokens_list = common_tokenize(context, text, true, parse_special); @@ -363,7 +338,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( auto n_ctx = llama_n_ctx(context); auto n_kv_req = tokens_list.size() + n_len; - LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); + LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", (int) n_len, (int) n_ctx, (int) n_kv_req); if (n_kv_req > n_ctx) { LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); @@ -373,23 +348,24 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id); } - common_batch_clear(*batch); + llama_batch_ext_clear(batch); // evaluate the initial prompt for (auto i = 0; i < tokens_list.size(); i++) { - common_batch_add(*batch, tokens_list[i], i, { 0 }, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, tokens_list[i], i, &seq_id, 1, false); } // llama_decode will output logits only for the last token of the prompt - batch->logits[batch->n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch); - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() failed"); + if (llama_decode_ext(context, batch) != 0) { + LOGe("llama_decode_ext() failed"); } env->ReleaseStringUTFChars(jtext, text); - return batch->n_tokens; + return llama_batch_ext_get_n_tokens(batch); } extern "C" @@ -404,7 +380,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( jobject intvar_ncur ) { const auto context = reinterpret_cast(context_pointer); - const auto batch = reinterpret_cast(batch_pointer); + const auto batch = reinterpret_cast(batch_pointer); const auto sampler = reinterpret_cast(sampler_pointer); const auto model = llama_get_model(context); const auto vocab = llama_model_get_vocab(model); @@ -433,13 +409,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( new_token = env->NewStringUTF(""); } - common_batch_clear(*batch); - common_batch_add(*batch, new_token_id, n_cur, { 0 }, true); + llama_batch_ext_clear(batch); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, new_token_id, n_cur, &seq_id, 1, true); env->CallVoidMethod(intvar_ncur, la_int_var_inc); - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() returned null"); + if (llama_decode_ext(context, batch) != 0) { + LOGe("llama_decode_ext() returned null"); } return new_token;