mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-19 17:17:40 +00:00
context : move output functionality to base class
ggml-ci
This commit is contained in:
@ -43,12 +43,12 @@ struct llama_context : public llama_graph_i {
|
||||
|
||||
virtual enum llama_pooling_type pooling_type() const;
|
||||
|
||||
virtual float * get_logits() = 0;
|
||||
virtual float * get_logits_ith(int32_t i) = 0;
|
||||
virtual float * get_logits();
|
||||
virtual float * get_logits_ith(int32_t i);
|
||||
|
||||
virtual float * get_embeddings() = 0;
|
||||
virtual float * get_embeddings_ith(int32_t i) = 0;
|
||||
virtual float * get_embeddings_seq(llama_seq_id seq_id) = 0;
|
||||
virtual float * get_embeddings();
|
||||
virtual float * get_embeddings_ith(int32_t i);
|
||||
virtual float * get_embeddings_seq(llama_seq_id seq_id);
|
||||
|
||||
virtual int64_t n_pos_per_token() const; // vision
|
||||
|
||||
@ -85,6 +85,19 @@ struct llama_context : public llama_graph_i {
|
||||
int32_t il_start,
|
||||
int32_t il_end);
|
||||
|
||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||
virtual enum ggml_status compute_graph(
|
||||
ggml_cgraph * graph,
|
||||
bool batched);
|
||||
|
||||
// Make sure enough space is available for outputs.
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
virtual size_t output_reserve(size_t n_outputs);
|
||||
|
||||
// make the outputs have the same order they had in the user-provided batch
|
||||
// TODO: maybe remove this
|
||||
virtual void output_reorder();
|
||||
|
||||
// graph build API (generic)
|
||||
|
||||
virtual void build_cb(
|
||||
@ -198,6 +211,7 @@ protected:
|
||||
llama_cparams cparams;
|
||||
llama_adapter_cvec cvec;
|
||||
llama_loras loras;
|
||||
llama_sbatch sbatch;
|
||||
|
||||
ggml_threadpool_t threadpool = nullptr;
|
||||
ggml_threadpool_t threadpool_batch = nullptr;
|
||||
@ -215,6 +229,31 @@ protected:
|
||||
// memory buffers used to evaluate the model
|
||||
std::vector<uint8_t> buf_compute_meta;
|
||||
|
||||
// host buffer for the model output (logits and embeddings)
|
||||
ggml_backend_buffer_ptr buf_output;
|
||||
|
||||
// TODO: remove
|
||||
bool logits_all = false;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
float * logits = nullptr;
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
float * embd = nullptr;
|
||||
|
||||
// sequence embeddings output (map of [n_embd] vectors)
|
||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||
|
||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||
|
||||
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
||||
|
||||
bool need_reserve = false;
|
||||
bool has_evaluated_once = false;
|
||||
|
||||
mutable int64_t t_start_us = 0;
|
||||
@ -247,69 +286,21 @@ public:
|
||||
|
||||
virtual void kv_self_update() override;
|
||||
|
||||
virtual float * get_logits() override;
|
||||
virtual float * get_logits_ith(int32_t i) override;
|
||||
|
||||
virtual float * get_embeddings() override;
|
||||
virtual float * get_embeddings_ith(int32_t i) override;
|
||||
virtual float * get_embeddings_seq(llama_seq_id seq_id) override;
|
||||
|
||||
virtual ggml_context_ptr init() override;
|
||||
|
||||
virtual int decode(llama_batch & inp_batch) override;
|
||||
virtual int encode(llama_batch & inp_batch) override;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
// host buffer for the model output (logits and embeddings)
|
||||
ggml_backend_buffer_ptr buf_output;
|
||||
|
||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
float * logits = nullptr;
|
||||
|
||||
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||
|
||||
bool logits_all = false;
|
||||
bool need_reserve = false;
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
float * embd = nullptr;
|
||||
|
||||
// sequence embeddings output (map of [n_embd] vectors)
|
||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||
|
||||
virtual std::unique_ptr<batch_manager> prepare_batch(const llama_batch & batch);
|
||||
|
||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||
enum ggml_status compute_graph(
|
||||
ggml_cgraph * graph,
|
||||
bool batched);
|
||||
|
||||
// max token position across all sequences in the current context
|
||||
llama_pos pos_max() const;
|
||||
|
||||
// certain implementations could require a padding for the context size
|
||||
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
|
||||
|
||||
void prepare_k_shift();
|
||||
void prepare_defrag();
|
||||
|
||||
void set_inputs(const llama_ubatch & ubatch);
|
||||
|
||||
// make the outputs have the same order they had in the user-provided batch
|
||||
// TODO: maybe remove this
|
||||
void reorder_outputs();
|
||||
|
||||
// Make sure enough space is available for outputs.
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
size_t reserve_outputs(size_t n_outputs);
|
||||
|
||||
// input tensors
|
||||
struct ggml_tensor * inp_tokens; // I32 [n_batch]
|
||||
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
|
||||
|
Reference in New Issue
Block a user