mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-30 20:58:45 +00:00
android : adapt to new API
This commit is contained in:
@ -1576,35 +1576,6 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons
|
|||||||
|
|
||||||
#endif // LLAMA_USE_CURL
|
#endif // LLAMA_USE_CURL
|
||||||
|
|
||||||
//
|
|
||||||
// Batch utils
|
|
||||||
//
|
|
||||||
|
|
||||||
// DEPRECATED
|
|
||||||
void common_batch_clear(struct llama_batch & batch) {
|
|
||||||
batch.n_tokens = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// DEPRECATED
|
|
||||||
void common_batch_add(
|
|
||||||
struct llama_batch & batch,
|
|
||||||
llama_token id,
|
|
||||||
llama_pos pos,
|
|
||||||
const std::vector<llama_seq_id> & seq_ids,
|
|
||||||
bool logits) {
|
|
||||||
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");
|
|
||||||
|
|
||||||
batch.token [batch.n_tokens] = id;
|
|
||||||
batch.pos [batch.n_tokens] = pos;
|
|
||||||
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
|
|
||||||
for (size_t i = 0; i < seq_ids.size(); ++i) {
|
|
||||||
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
|
|
||||||
}
|
|
||||||
batch.logits [batch.n_tokens] = logits;
|
|
||||||
|
|
||||||
batch.n_tokens++;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Token utils
|
// Token utils
|
||||||
//
|
//
|
||||||
|
@ -569,17 +569,6 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adap
|
|||||||
// Batch utils
|
// Batch utils
|
||||||
//
|
//
|
||||||
|
|
||||||
// DEPRECATED
|
|
||||||
void common_batch_clear(struct llama_batch & batch);
|
|
||||||
|
|
||||||
// DEPRECATED
|
|
||||||
void common_batch_add(
|
|
||||||
struct llama_batch & batch,
|
|
||||||
llama_token id,
|
|
||||||
llama_pos pos,
|
|
||||||
const std::vector<llama_seq_id> & seq_ids,
|
|
||||||
bool logits);
|
|
||||||
|
|
||||||
// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
|
// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
|
||||||
// this is meant to be temporary
|
// this is meant to be temporary
|
||||||
struct common_batch {
|
struct common_batch {
|
||||||
|
@ -125,7 +125,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo
|
|||||||
ctx_params.n_threads = n_threads;
|
ctx_params.n_threads = n_threads;
|
||||||
ctx_params.n_threads_batch = n_threads;
|
ctx_params.n_threads_batch = n_threads;
|
||||||
|
|
||||||
llama_context * context = llama_new_context_with_model(model, ctx_params);
|
llama_context * context = llama_init_from_model(model, ctx_params);
|
||||||
|
|
||||||
if (!context) {
|
if (!context) {
|
||||||
LOGe("llama_new_context_with_model() returned null)");
|
LOGe("llama_new_context_with_model() returned null)");
|
||||||
@ -175,7 +175,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
|||||||
|
|
||||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
||||||
const auto model = reinterpret_cast<llama_model *>(model_pointer);
|
const auto model = reinterpret_cast<llama_model *>(model_pointer);
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
const auto batch = reinterpret_cast<llama_batch_ext *>(batch_pointer);
|
||||||
|
|
||||||
const int n_ctx = llama_n_ctx(context);
|
const int n_ctx = llama_n_ctx(context);
|
||||||
|
|
||||||
@ -186,19 +186,20 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
|||||||
for (nri = 0; nri < nr; nri++) {
|
for (nri = 0; nri < nr; nri++) {
|
||||||
LOGi("Benchmark prompt processing (pp)");
|
LOGi("Benchmark prompt processing (pp)");
|
||||||
|
|
||||||
common_batch_clear(*batch);
|
llama_batch_ext_clear(batch);
|
||||||
|
|
||||||
const int n_tokens = pp;
|
const int n_tokens = pp;
|
||||||
for (i = 0; i < n_tokens; i++) {
|
for (i = 0; i < n_tokens; i++) {
|
||||||
common_batch_add(*batch, 0, i, { 0 }, false);
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
batch->logits[batch->n_tokens - 1] = true;
|
llama_batch_ext_set_output_last(batch);
|
||||||
llama_kv_self_clear(context);
|
llama_kv_self_clear(context);
|
||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
const auto t_pp_start = ggml_time_us();
|
||||||
if (llama_decode(context, *batch) != 0) {
|
if (llama_decode_ext(context, batch) != 0) {
|
||||||
LOGi("llama_decode() failed during prompt processing");
|
LOGi("llama_decode_ext() failed during prompt processing");
|
||||||
}
|
}
|
||||||
const auto t_pp_end = ggml_time_us();
|
const auto t_pp_end = ggml_time_us();
|
||||||
|
|
||||||
@ -210,14 +211,15 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
|||||||
const auto t_tg_start = ggml_time_us();
|
const auto t_tg_start = ggml_time_us();
|
||||||
for (i = 0; i < tg; i++) {
|
for (i = 0; i < tg; i++) {
|
||||||
|
|
||||||
common_batch_clear(*batch);
|
llama_batch_ext_clear(batch);
|
||||||
for (j = 0; j < pl; j++) {
|
for (j = 0; j < pl; j++) {
|
||||||
common_batch_add(*batch, 0, i, { j }, true);
|
llama_seq_id seq_id = j;
|
||||||
|
llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGi("llama_decode() text generation: %d", i);
|
LOGi("llama_decode_ext() text generation: %d", i);
|
||||||
if (llama_decode(context, *batch) != 0) {
|
if (llama_decode_ext(context, batch) != 0) {
|
||||||
LOGi("llama_decode() failed during text generation");
|
LOGi("llama_decode_ext() failed during text generation");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -272,32 +274,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
|||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT jlong JNICALL
|
JNIEXPORT jlong JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
|
Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) {
|
||||||
|
llama_batch_ext * batch = llama_batch_ext_init(n_tokens, n_seq_max);
|
||||||
// Source: Copy of llama.cpp:llama_batch_init but heap-allocated.
|
|
||||||
|
|
||||||
llama_batch *batch = new llama_batch {
|
|
||||||
0,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
nullptr,
|
|
||||||
};
|
|
||||||
|
|
||||||
if (embd) {
|
|
||||||
batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd);
|
|
||||||
} else {
|
|
||||||
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
|
|
||||||
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens);
|
|
||||||
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens);
|
|
||||||
for (int i = 0; i < n_tokens; ++i) {
|
|
||||||
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
|
|
||||||
}
|
|
||||||
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
|
||||||
|
|
||||||
return reinterpret_cast<jlong>(batch);
|
return reinterpret_cast<jlong>(batch);
|
||||||
}
|
}
|
||||||
@ -305,9 +282,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
|
|||||||
extern "C"
|
extern "C"
|
||||||
JNIEXPORT void JNICALL
|
JNIEXPORT void JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
|
Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) {
|
||||||
//llama_batch_free(*reinterpret_cast<llama_batch *>(batch_pointer));
|
llama_batch_ext_free(reinterpret_cast<llama_batch_ext *>(batch_pointer));
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
|
||||||
delete batch;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
@ -355,7 +330,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
|||||||
|
|
||||||
const auto text = env->GetStringUTFChars(jtext, 0);
|
const auto text = env->GetStringUTFChars(jtext, 0);
|
||||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
const auto batch = reinterpret_cast<llama_batch_ext *>(batch_pointer);
|
||||||
|
|
||||||
bool parse_special = (format_chat == JNI_TRUE);
|
bool parse_special = (format_chat == JNI_TRUE);
|
||||||
const auto tokens_list = common_tokenize(context, text, true, parse_special);
|
const auto tokens_list = common_tokenize(context, text, true, parse_special);
|
||||||
@ -363,7 +338,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
|||||||
auto n_ctx = llama_n_ctx(context);
|
auto n_ctx = llama_n_ctx(context);
|
||||||
auto n_kv_req = tokens_list.size() + n_len;
|
auto n_kv_req = tokens_list.size() + n_len;
|
||||||
|
|
||||||
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req);
|
LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", (int) n_len, (int) n_ctx, (int) n_kv_req);
|
||||||
|
|
||||||
if (n_kv_req > n_ctx) {
|
if (n_kv_req > n_ctx) {
|
||||||
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
|
LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough");
|
||||||
@ -373,23 +348,24 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
|
|||||||
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
|
LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id);
|
||||||
}
|
}
|
||||||
|
|
||||||
common_batch_clear(*batch);
|
llama_batch_ext_clear(batch);
|
||||||
|
|
||||||
// evaluate the initial prompt
|
// evaluate the initial prompt
|
||||||
for (auto i = 0; i < tokens_list.size(); i++) {
|
for (auto i = 0; i < tokens_list.size(); i++) {
|
||||||
common_batch_add(*batch, tokens_list[i], i, { 0 }, false);
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch, tokens_list[i], i, &seq_id, 1, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// llama_decode will output logits only for the last token of the prompt
|
// llama_decode will output logits only for the last token of the prompt
|
||||||
batch->logits[batch->n_tokens - 1] = true;
|
llama_batch_ext_set_output_last(batch);
|
||||||
|
|
||||||
if (llama_decode(context, *batch) != 0) {
|
if (llama_decode_ext(context, batch) != 0) {
|
||||||
LOGe("llama_decode() failed");
|
LOGe("llama_decode_ext() failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
env->ReleaseStringUTFChars(jtext, text);
|
env->ReleaseStringUTFChars(jtext, text);
|
||||||
|
|
||||||
return batch->n_tokens;
|
return llama_batch_ext_get_n_tokens(batch);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C"
|
extern "C"
|
||||||
@ -404,7 +380,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||||||
jobject intvar_ncur
|
jobject intvar_ncur
|
||||||
) {
|
) {
|
||||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
||||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
const auto batch = reinterpret_cast<llama_batch_ext *>(batch_pointer);
|
||||||
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
|
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
|
||||||
const auto model = llama_get_model(context);
|
const auto model = llama_get_model(context);
|
||||||
const auto vocab = llama_model_get_vocab(model);
|
const auto vocab = llama_model_get_vocab(model);
|
||||||
@ -433,13 +409,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||||||
new_token = env->NewStringUTF("");
|
new_token = env->NewStringUTF("");
|
||||||
}
|
}
|
||||||
|
|
||||||
common_batch_clear(*batch);
|
llama_batch_ext_clear(batch);
|
||||||
common_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
|
llama_seq_id seq_id = 0;
|
||||||
|
llama_batch_ext_add_text(batch, new_token_id, n_cur, &seq_id, 1, true);
|
||||||
|
|
||||||
env->CallVoidMethod(intvar_ncur, la_int_var_inc);
|
env->CallVoidMethod(intvar_ncur, la_int_var_inc);
|
||||||
|
|
||||||
if (llama_decode(context, *batch) != 0) {
|
if (llama_decode_ext(context, batch) != 0) {
|
||||||
LOGe("llama_decode() returned null");
|
LOGe("llama_decode_ext() returned null");
|
||||||
}
|
}
|
||||||
|
|
||||||
return new_token;
|
return new_token;
|
||||||
|
Reference in New Issue
Block a user