context : move encode/decode to llama-context.cpp

This commit is contained in:
Georgi Gerganov
2025-02-12 11:23:38 +02:00
parent 02ef4be975
commit b52b79b048
3 changed files with 48 additions and 57 deletions

View File

@ -3980,6 +3980,31 @@ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepa
}
}
///
int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch) {
const int ret = ctx->encode(batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
}
return ret;
}
int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {
const int ret = ctx->decode(batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
return ret;
}
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
struct llama_context * ctx
) {

View File

@ -45,7 +45,30 @@ struct llama_context {
virtual ggml_context_ptr init();
// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - inp_batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int decode(llama_batch & inp_batch) = 0;
// encode a batch of tokens by evaluating the encoder part of the transformer
//
// - lctx: llama context
// - batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int encode(llama_batch & inp_batch) = 0;
// graph build API (generic)

View File

@ -7401,39 +7401,6 @@ static struct ggml_cgraph * llama_build_graph(
return result;
}
// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - inp_batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
static int llama_decode_impl(
llama_context & lctx,
llama_batch inp_batch) {
return lctx.decode(inp_batch);
}
// encode a batch of tokens by evaluating the encoder part of the transformer
//
// - lctx: llama context
// - batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
static int llama_encode_impl(
llama_context & lctx,
llama_batch inp_batch) {
return lctx.encode(inp_batch);
}
//
// interface implementation
//
@ -7759,30 +7726,6 @@ struct llama_context * llama_new_context_with_model(
return llama_init_from_model(model, params);
}
///
int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch) {
const int ret = llama_encode_impl(*ctx, batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
}
return ret;
}
int32_t llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {
const int ret = llama_decode_impl(*ctx, batch);
if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
return ret;
}
//
// chat templates
//