From 86973cb14a51bbd5268871c23fe7ab1ddfa75830 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 13 Mar 2025 17:32:36 +0100 Subject: [PATCH] fix merge errors --- include/llama.h | 2 ++ src/llama-context.cpp | 63 +++++++++++++++++++++++++------------------ src/llama-context.h | 4 +++ 3 files changed, 43 insertions(+), 26 deletions(-) diff --git a/include/llama.h b/include/llama.h index ac0481393..564ffe1aa 100644 --- a/include/llama.h +++ b/include/llama.h @@ -994,6 +994,7 @@ extern "C" { DEPRECATED(LLAMA_API int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch), "use llama_batch_ext API instead"); + LLAMA_API int32_t llama_encode_ext( struct llama_context * ctx, struct llama_batch_ext * batch); @@ -1005,6 +1006,7 @@ extern "C" { DEPRECATED(LLAMA_API int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch), "use llama_batch_ext API instead"); + LLAMA_API int32_t llama_decode_ext( struct llama_context * ctx, struct llama_batch_ext * batch); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 0a43a3af8..d89e1ac2c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -4,6 +4,7 @@ #include "llama-io.h" #include "llama-mmap.h" #include "llama-model.h" +#include "llama-batch.h" #include "llama-kv-cache.h" #include @@ -980,16 +981,26 @@ bool llama_context::apply_adapter_cvec( } int llama_context::encode(llama_batch & inp_batch) { + // temporary allocate memory and convert llama_batch to llama_batch_ext + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + return encode(*batch_allocr.batch); +} + +int llama_context::decode(llama_batch & inp_batch) { + // temporary allocate memory and convert llama_batch to llama_batch_ext + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + return decode(*batch_allocr.batch); +} + +int llama_context::encode(llama_batch_ext & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } - // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); - - const llama_batch & batch = batch_allocr.batch; + llama_batch_ext & batch = inp_batch; const int32_t n_tokens = batch.n_tokens; const auto & hparams = model.hparams; @@ -1132,17 +1143,13 @@ int llama_context::encode(llama_batch & inp_batch) { return 0; } -int llama_context::decode(llama_batch & inp_batch) { +int llama_context::decode(llama_batch_ext & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } - // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); - - const llama_batch & batch = batch_allocr.batch; + llama_batch_ext & batch = inp_batch; const auto & vocab = model.vocab; const auto & hparams = model.hparams; @@ -2714,26 +2721,30 @@ size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, lla /// +// deprecated int32_t llama_encode( - llama_context * ctx, - llama_batch batch) { - const int ret = ctx->encode(batch); - if (ret != 0) { - LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); - } - - return ret; + struct llama_context * ctx, + struct llama_batch inp_batch) { + return ctx->encode(inp_batch); } +// deprecated int32_t llama_decode( - llama_context * ctx, - llama_batch batch) { - const int ret = ctx->decode(batch); - if (ret != 0) { - LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); - } + struct llama_context * ctx, + struct llama_batch inp_batch) { + return ctx->decode(inp_batch); +} - return ret; +int32_t llama_encode_ext( + struct llama_context * ctx, + struct llama_batch_ext * inp_batch) { + return ctx->encode(*inp_batch); +} + +int32_t llama_decode_ext( + struct llama_context * ctx, + struct llama_batch_ext * inp_batch) { + return ctx->decode(*inp_batch); } // diff --git a/src/llama-context.h b/src/llama-context.h index 71d702e8b..29bb230f1 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -81,9 +81,13 @@ struct llama_context { int32_t il_start, int32_t il_end); + // deprecated int encode(llama_batch & inp_batch); int decode(llama_batch & inp_batch); + int encode(llama_batch_ext & inp_batch); + int decode(llama_batch_ext & inp_batch); + // // state save/load //