fix merge errors

This commit is contained in:
Xuan Son Nguyen
2025-03-13 17:32:36 +01:00
parent 17f954c8e2
commit 86973cb14a
3 changed files with 43 additions and 26 deletions

View File

@@ -4,6 +4,7 @@
#include "llama-io.h"
#include "llama-mmap.h"
#include "llama-model.h"
#include "llama-batch.h"
#include "llama-kv-cache.h"
#include <cassert>
@@ -980,16 +981,26 @@ bool llama_context::apply_adapter_cvec(
}
int llama_context::encode(llama_batch & inp_batch) {
// temporary allocate memory and convert llama_batch to llama_batch_ext
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
return encode(*batch_allocr.batch);
}
int llama_context::decode(llama_batch & inp_batch) {
// temporary allocate memory and convert llama_batch to llama_batch_ext
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
return decode(*batch_allocr.batch);
}
int llama_context::encode(llama_batch_ext & inp_batch) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
llama_batch_ext & batch = inp_batch;
const int32_t n_tokens = batch.n_tokens;
const auto & hparams = model.hparams;
@@ -1132,17 +1143,13 @@ int llama_context::encode(llama_batch & inp_batch) {
return 0;
}
int llama_context::decode(llama_batch & inp_batch) {
int llama_context::decode(llama_batch_ext & inp_batch) {
if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
llama_batch_ext & batch = inp_batch;
const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
@@ -2714,26 +2721,30 @@ size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, lla
///
// deprecated
int32_t llama_encode(
llama_context * ctx,
llama_batch batch) {
const int ret = ctx->encode(batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
}
return ret;
struct llama_context * ctx,
struct llama_batch inp_batch) {
return ctx->encode(inp_batch);
}
// deprecated
int32_t llama_decode(
llama_context * ctx,
llama_batch batch) {
const int ret = ctx->decode(batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
struct llama_context * ctx,
struct llama_batch inp_batch) {
return ctx->decode(inp_batch);
}
return ret;
int32_t llama_encode_ext(
struct llama_context * ctx,
struct llama_batch_ext * inp_batch) {
return ctx->encode(*inp_batch);
}
int32_t llama_decode_ext(
struct llama_context * ctx,
struct llama_batch_ext * inp_batch) {
return ctx->decode(*inp_batch);
}
//

View File

@@ -81,9 +81,13 @@ struct llama_context {
int32_t il_start,
int32_t il_end);
// deprecated
int encode(llama_batch & inp_batch);
int decode(llama_batch & inp_batch);
int encode(llama_batch_ext & inp_batch);
int decode(llama_batch_ext & inp_batch);
//
// state save/load
//