context : abstract input

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-13 15:53:15 +02:00
parent 107d1e2c32
commit ed3cb55abe
2 changed files with 323 additions and 316 deletions

View File

@ -90,6 +90,8 @@ struct llama_context : public llama_graph_i {
ggml_cgraph * graph,
bool batched);
virtual void input_set(const llama_ubatch & ubatch);
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
virtual size_t output_reserve(size_t n_outputs);
@ -204,6 +206,15 @@ protected:
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id);
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id);
// input tensors
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
// members
const llama_model & model;
@ -288,6 +299,8 @@ public:
virtual ggml_context_ptr init() override;
virtual void input_set(const llama_ubatch & ubatch) override;
virtual int decode(llama_batch & inp_batch) override;
virtual int encode(llama_batch & inp_batch) override;
@ -299,16 +312,6 @@ public:
// certain implementations could require a padding for the context size
uint32_t get_ctx_padding(const llama_cparams & cparams) const;
void set_inputs(const llama_ubatch & ubatch);
// input tensors
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
// === unified KV cache ===
llama_kv_cache kv_self;