context : add decode/encode

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-10 16:11:17 +02:00
parent 879ba82777
commit ef358ee78f
3 changed files with 526 additions and 522 deletions

View File

@ -16,22 +16,7 @@
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
// TODO: this is very WIP - improve
struct llama_batch_manager_i {
virtual ~llama_batch_manager_i() = default;
//bool is_done() const;
virtual llama_ubatch next() = 0;
virtual bool prepare() = 0;
virtual void restore() = 0;
virtual void update() = 0;
virtual void finalize() = 0;
// TODO: might be temporary
int64_t n_outputs_all = 0;
};
struct llama_batch_manager_i;
// TODO: make implementation details private
// TODO: become abstract base class, split the current implementation into different child classes
@ -44,6 +29,8 @@ struct llama_context {
const llama_context_params & params,
build_graph_callback && cb_build_graph);
virtual ~llama_context() = default;
const struct llama_model & model;
llama_cparams cparams;
@ -104,8 +91,10 @@ struct llama_context {
ggml_abort_callback abort_callback = nullptr;
void * abort_callback_data = nullptr;
// TODO: do not pass logits_all explicitly
std::unique_ptr<llama_batch_manager_i> prepare_batch(const llama_batch & batch);
virtual std::unique_ptr<llama_batch_manager_i> prepare_batch(const llama_batch & batch);
virtual int decode(llama_batch & inp_batch);
virtual int encode(llama_batch & inp_batch);
// returns the result of ggml_backend_sched_graph_compute_async execution
enum ggml_status compute_graph(
@ -286,13 +275,6 @@ struct llama_context {
int n_pos_per_token = 1;
};
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
void llama_output_reorder(struct llama_context & ctx);
// For internal test use
// TODO: remove
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(struct llama_context * ctx);