android : adapt to new API

This commit is contained in:
Georgi Gerganov
2025-03-19 10:16:55 +02:00
parent 23d7407314
commit b0db7fc2c6
3 changed files with 31 additions and 94 deletions

View File

@ -1576,35 +1576,6 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons
#endif // LLAMA_USE_CURL #endif // LLAMA_USE_CURL
//
// Batch utils
//
// DEPRECATED
void common_batch_clear(struct llama_batch & batch) {
batch.n_tokens = 0;
}
// DEPRECATED
void common_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos;
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;
batch.n_tokens++;
}
// //
// Token utils // Token utils
// //

View File

@ -569,17 +569,6 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap
// Batch utils // Batch utils
// //
// DEPRECATED
void common_batch_clear(struct llama_batch & batch);
// DEPRECATED
void common_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits);
// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions // convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
// this is meant to be temporary // this is meant to be temporary
struct common_batch { struct common_batch {

View File

@ -125,7 +125,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo
ctx_params.n_threads = n_threads; ctx_params.n_threads = n_threads;
ctx_params.n_threads_batch = n_threads; ctx_params.n_threads_batch = n_threads;
llama_context * context = llama_new_context_with_model(model, ctx_params); llama_context * context = llama_init_from_model(model, ctx_params);
if (!context) { if (!context) {
LOGe("llama_new_context_with_model() returned null)"); LOGe("llama_new_context_with_model() returned null)");
@ -175,7 +175,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
const auto context = reinterpret_cast<llama_context *>(context_pointer); const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto model = reinterpret_cast<llama_model *>(model_pointer); const auto model = reinterpret_cast<llama_model *>(model_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); const auto batch = reinterpret_cast<llama_batch_ext *>(batch_pointer);
const int n_ctx = llama_n_ctx(context); const int n_ctx = llama_n_ctx(context);
@ -186,19 +186,20 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
for (nri = 0; nri < nr; nri++) { for (nri = 0; nri < nr; nri++) {
LOGi("Benchmark prompt processing (pp)"); LOGi("Benchmark prompt processing (pp)");
common_batch_clear(*batch); llama_batch_ext_clear(batch);
const int n_tokens = pp; const int n_tokens = pp;
for (i = 0; i < n_tokens; i++) { for (i = 0; i < n_tokens; i++) {
common_batch_add(*batch, 0, i, { 0 }, false); llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false);
} }
batch->logits[batch->n_tokens - 1] = true; llama_batch_ext_set_output_last(batch);
llama_kv_self_clear(context); llama_kv_self_clear(context);
const auto t_pp_start = ggml_time_us(); const auto t_pp_start = ggml_time_us();
if (llama_decode(context, *batch) != 0) { if (llama_decode_ext(context, batch) != 0) {
LOGi("llama_decode() failed during prompt processing"); LOGi("llama_decode_ext() failed during prompt processing");
} }
const auto t_pp_end = ggml_time_us(); const auto t_pp_end = ggml_time_us();
@ -210,14 +211,15 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
const auto t_tg_start = ggml_time_us(); const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) { for (i = 0; i < tg; i++) {
common_batch_clear(*batch); llama_batch_ext_clear(batch);
for (j = 0; j < pl; j++) { for (j = 0; j < pl; j++) {
common_batch_add(*batch, 0, i, { j }, true); llama_seq_id seq_id = j;
llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, true);
} }
LOGi("llama_decode() text generation: %d", i); LOGi("llama_decode_ext() text generation: %d", i);
if (llama_decode(context, *batch) != 0) { if (llama_decode_ext(context, batch) != 0) {
LOGi("llama_decode() failed during text generation"); LOGi("llama_decode_ext() failed during text generation");
} }
} }
@ -272,32 +274,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
extern "C" extern "C"
JNIEXPORT jlong JNICALL JNIEXPORT jlong JNICALL
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
llama_batch_ext * batch = llama_batch_ext_init(n_tokens, n_seq_max);
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
llama_batch *batch = new llama_batch {
0,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
};
if (embd) {
batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
} else {
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
}
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
for (int i = 0; i < n_tokens; ++i) {
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
return reinterpret_cast<jlong>(batch); return reinterpret_cast<jlong>(batch);
} }
@ -305,9 +282,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
//llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer)); llama_batch_ext_free(reinterpret_cast<llama_batch_ext *>(batch_pointer));
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
delete batch;
} }
extern "C" extern "C"
@ -355,7 +330,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
const auto text = env->GetStringUTFChars(jtext, 0); const auto text = env->GetStringUTFChars(jtext, 0);
const auto context = reinterpret_cast<llama_context *>(context_pointer); const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); const auto batch = reinterpret_cast<llama_batch_ext *>(batch_pointer);
bool parse_special = (format_chat == JNI_TRUE); bool parse_special = (format_chat == JNI_TRUE);
const auto tokens_list = common_tokenize(context, text, true, parse_special); const auto tokens_list = common_tokenize(context, text, true, parse_special);
@ -363,7 +338,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
auto n_ctx = llama_n_ctx(context); auto n_ctx = llama_n_ctx(context);
auto n_kv_req = tokens_list.size() + n_len; auto n_kv_req = tokens_list.size() + n_len;
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", (int) n_len, (int) n_ctx, (int) n_kv_req);
if (n_kv_req > n_ctx) { if (n_kv_req > n_ctx) {
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
@ -373,23 +348,24 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id); LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
} }
common_batch_clear(*batch); llama_batch_ext_clear(batch);
// evaluate the initial prompt // evaluate the initial prompt
for (auto i = 0; i < tokens_list.size(); i++) { for (auto i = 0; i < tokens_list.size(); i++) {
common_batch_add(*batch, tokens_list[i], i, { 0 }, false); llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, tokens_list[i], i, &seq_id, 1, false);
} }
// llama_decode will output logits only for the last token of the prompt // llama_decode will output logits only for the last token of the prompt
batch->logits[batch->n_tokens - 1] = true; llama_batch_ext_set_output_last(batch);
if (llama_decode(context, *batch) != 0) { if (llama_decode_ext(context, batch) != 0) {
LOGe("llama_decode() failed"); LOGe("llama_decode_ext() failed");
} }
env->ReleaseStringUTFChars(jtext, text); env->ReleaseStringUTFChars(jtext, text);
return batch->n_tokens; return llama_batch_ext_get_n_tokens(batch);
} }
extern "C" extern "C"
@ -404,7 +380,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
jobject intvar_ncur jobject intvar_ncur
) { ) {
const auto context = reinterpret_cast<llama_context *>(context_pointer); const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); const auto batch = reinterpret_cast<llama_batch_ext *>(batch_pointer);
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer); const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
const auto model = llama_get_model(context); const auto model = llama_get_model(context);
const auto vocab = llama_model_get_vocab(model); const auto vocab = llama_model_get_vocab(model);
@ -433,13 +409,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
new_token = env->NewStringUTF(""); new_token = env->NewStringUTF("");
} }
common_batch_clear(*batch); llama_batch_ext_clear(batch);
common_batch_add(*batch, new_token_id, n_cur, { 0 }, true); llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, new_token_id, n_cur, &seq_id, 1, true);
env->CallVoidMethod(intvar_ncur, la_int_var_inc); env->CallVoidMethod(intvar_ncur, la_int_var_inc);
if (llama_decode(context, *batch) != 0) { if (llama_decode_ext(context, batch) != 0) {
LOGe("llama_decode() returned null"); LOGe("llama_decode_ext() returned null");
} }
return new_token; return new_token;