fix merge errors

This commit is contained in:
Xuan Son Nguyen
2025-03-13 17:32:36 +01:00
parent 17f954c8e2
commit 86973cb14a
3 changed files with 43 additions and 26 deletions

View File

@@ -994,6 +994,7 @@ extern "C" {
DEPRECATED(LLAMA_API int32_t llama_encode( DEPRECATED(LLAMA_API int32_t llama_encode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch), "use llama_batch_ext API instead"); struct llama_batch batch), "use llama_batch_ext API instead");
LLAMA_API int32_t llama_encode_ext( LLAMA_API int32_t llama_encode_ext(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch_ext * batch); struct llama_batch_ext * batch);
@@ -1005,6 +1006,7 @@ extern "C" {
DEPRECATED(LLAMA_API int32_t llama_decode( DEPRECATED(LLAMA_API int32_t llama_decode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch), "use llama_batch_ext API instead"); struct llama_batch batch), "use llama_batch_ext API instead");
LLAMA_API int32_t llama_decode_ext( LLAMA_API int32_t llama_decode_ext(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch_ext * batch); struct llama_batch_ext * batch);

View File

@@ -4,6 +4,7 @@
#include "llama-io.h" #include "llama-io.h"
#include "llama-mmap.h" #include "llama-mmap.h"
#include "llama-model.h" #include "llama-model.h"
#include "llama-batch.h"
#include "llama-kv-cache.h" #include "llama-kv-cache.h"
#include <cassert> #include <cassert>
@@ -980,16 +981,26 @@ bool llama_context::apply_adapter_cvec(
} }
int llama_context::encode(llama_batch & inp_batch) { int llama_context::encode(llama_batch & inp_batch) {
// temporary allocate memory and convert llama_batch to llama_batch_ext
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
return encode(*batch_allocr.batch);
}
int llama_context::decode(llama_batch & inp_batch) {
// temporary allocate memory and convert llama_batch to llama_batch_ext
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
return decode(*batch_allocr.batch);
}
int llama_context::encode(llama_batch_ext & inp_batch) {
if (inp_batch.n_tokens == 0) { if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1; return -1;
} }
// temporary allocate memory for the input batch if needed llama_batch_ext & batch = inp_batch;
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const int32_t n_tokens = batch.n_tokens; const int32_t n_tokens = batch.n_tokens;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
@@ -1132,17 +1143,13 @@ int llama_context::encode(llama_batch & inp_batch) {
return 0; return 0;
} }
int llama_context::decode(llama_batch & inp_batch) { int llama_context::decode(llama_batch_ext & inp_batch) {
if (inp_batch.n_tokens == 0) { if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1; return -1;
} }
// temporary allocate memory for the input batch if needed llama_batch_ext & batch = inp_batch;
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const auto & vocab = model.vocab; const auto & vocab = model.vocab;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
@@ -2714,26 +2721,30 @@ size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, lla
/// ///
// deprecated
int32_t llama_encode( int32_t llama_encode(
llama_context * ctx, struct llama_context * ctx,
llama_batch batch) { struct llama_batch inp_batch) {
const int ret = ctx->encode(batch); return ctx->encode(inp_batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
}
return ret;
} }
// deprecated
int32_t llama_decode( int32_t llama_decode(
llama_context * ctx, struct llama_context * ctx,
llama_batch batch) { struct llama_batch inp_batch) {
const int ret = ctx->decode(batch); return ctx->decode(inp_batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
} }
return ret; int32_t llama_encode_ext(
struct llama_context * ctx,
struct llama_batch_ext * inp_batch) {
return ctx->encode(*inp_batch);
}
int32_t llama_decode_ext(
struct llama_context * ctx,
struct llama_batch_ext * inp_batch) {
return ctx->decode(*inp_batch);
} }
// //

View File

@@ -81,9 +81,13 @@ struct llama_context {
int32_t il_start, int32_t il_start,
int32_t il_end); int32_t il_end);
// deprecated
int encode(llama_batch & inp_batch); int encode(llama_batch & inp_batch);
int decode(llama_batch & inp_batch); int decode(llama_batch & inp_batch);
int encode(llama_batch_ext & inp_batch);
int decode(llama_batch_ext & inp_batch);
// //
// state save/load // state save/load
// //