diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 13beb097c..4e02f155b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -3980,6 +3980,31 @@ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepa } } +/// + +int32_t llama_encode( + struct llama_context * ctx, + struct llama_batch batch) { + const int ret = ctx->encode(batch); + if (ret != 0) { + LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); + } + + return ret; +} + +int32_t llama_decode( + struct llama_context * ctx, + struct llama_batch batch) { + const int ret = ctx->decode(batch); + if (ret != 0) { + LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); + } + + return ret; +} + + const std::vector> & llama_internal_get_tensor_map( struct llama_context * ctx ) { diff --git a/src/llama-context.h b/src/llama-context.h index f7e007f32..ac842dc8b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -45,7 +45,30 @@ struct llama_context { virtual ggml_context_ptr init(); + // decode a batch of tokens by evaluating the transformer + // in case of unsuccessful decoding (error or warning), + // the kv_cache state will be returned to its original state + // (for non-recurrent models) or cleaned (for recurrent models) + // + // - lctx: llama context + // - inp_batch: batch to evaluate + // + // return 0 on success + // return positive int on warning + // return negative int on error + // virtual int decode(llama_batch & inp_batch) = 0; + + + // encode a batch of tokens by evaluating the encoder part of the transformer + // + // - lctx: llama context + // - batch: batch to evaluate + // + // return 0 on success + // return positive int on warning + // return negative int on error + // virtual int encode(llama_batch & inp_batch) = 0; // graph build API (generic) diff --git a/src/llama.cpp b/src/llama.cpp index 7c002f9bf..f623dd385 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7401,39 +7401,6 @@ static struct ggml_cgraph * llama_build_graph( return result; } -// decode a batch of tokens by evaluating the transformer -// in case of unsuccessful decoding (error or warning), -// the kv_cache state will be returned to its original state -// (for non-recurrent models) or cleaned (for recurrent models) -// -// - lctx: llama context -// - inp_batch: batch to evaluate -// -// return 0 on success -// return positive int on warning -// return negative int on error -// -static int llama_decode_impl( - llama_context & lctx, - llama_batch inp_batch) { - return lctx.decode(inp_batch); -} - -// encode a batch of tokens by evaluating the encoder part of the transformer -// -// - lctx: llama context -// - batch: batch to evaluate -// -// return 0 on success -// return positive int on warning -// return negative int on error -// -static int llama_encode_impl( - llama_context & lctx, - llama_batch inp_batch) { - return lctx.encode(inp_batch); -} - // // interface implementation // @@ -7759,30 +7726,6 @@ struct llama_context * llama_new_context_with_model( return llama_init_from_model(model, params); } -/// - -int32_t llama_encode( - struct llama_context * ctx, - struct llama_batch batch) { - const int ret = llama_encode_impl(*ctx, batch); - if (ret != 0) { - LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); - } - - return ret; -} - -int32_t llama_decode( - struct llama_context * ctx, - struct llama_batch batch) { - const int ret = llama_decode_impl(*ctx, batch); - if (ret != 0) { - LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); - } - - return ret; -} - // // chat templates //