graph : add comments

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-28 21:13:08 +02:00
parent 0f7daa9d1b
commit 624f7bd03b
3 changed files with 52 additions and 16 deletions

View File

@ -101,6 +101,7 @@ void llama_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
} }
} }
// note: this does not depend on the context and can technically be moved to llama-model.cpp
class llama_graph_input_attn_base : public llama_graph_input_attn_i { class llama_graph_input_attn_base : public llama_graph_input_attn_i {
public: public:
llama_graph_input_attn_base(const llama_hparams & hparams, const llama_cparams & cparams) : llama_graph_input_attn_base(const llama_hparams & hparams, const llama_cparams & cparams) :

View File

@ -19,6 +19,14 @@ ggml_tensor * llama_graph_input_attn_i::get_kq_mask_cross() {
llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {} llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {}
llama_graph_input_ptr llama_graph_i::build_inp_cross_embd(
ggml_context * ctx0) const {
GGML_UNUSED(ctx0);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr;
}
ggml_tensor * llama_graph_i::build_attn( ggml_tensor * llama_graph_i::build_attn(
llama_graph_input_attn_i * inp, llama_graph_input_attn_i * inp,
ggml_context * ctx0, ggml_context * ctx0,
@ -67,14 +75,6 @@ ggml_tensor * llama_graph_i::build_attn_cross(
return nullptr; return nullptr;
} }
llama_graph_input_ptr llama_graph_i::build_inp_cross_embd(
ggml_context * ctx0) const {
GGML_UNUSED(ctx0);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr;
}
llama_graph_input_ptr llama_graph_i::build_inp_s_copy ( llama_graph_input_ptr llama_graph_i::build_inp_s_copy (
ggml_context * ctx0) const { ggml_context * ctx0) const {
GGML_UNUSED(ctx0); GGML_UNUSED(ctx0);

View File

@ -10,32 +10,49 @@
struct ggml_cgraph; struct ggml_cgraph;
struct ggml_context; struct ggml_context;
struct ggml_tensor; struct ggml_tensor;
struct ggml_backend_buffer;
struct llama_ubatch; struct llama_ubatch;
// certain models (typically multi-modal) can produce different types of graphs
// the llama_context specifies which type of graph it needs through the llama_graph_i::type member
enum llama_graph_type { enum llama_graph_type {
LLAMA_GRAPH_TYPE_DEFAULT, LLAMA_GRAPH_TYPE_DEFAULT,
LLAMA_GRAPH_TYPE_ENCODER, LLAMA_GRAPH_TYPE_ENCODER,
LLAMA_GRAPH_TYPE_DECODER, LLAMA_GRAPH_TYPE_DECODER,
}; };
// //
// llama_graph_input // llama_graph_input
// //
// denotes an input to the graph
// typically, the data of these objects is populated based on the contents of the current llama_ubatch:
//
// - llama_graph_input_pos
// - llama_graph_input_out_ids
// - etc.
//
// some inputs require context-specific data (e.g. KV cache) - such inputs are defined for the specific llama_context:
//
// - llama_graph_input_embd (can apply lora)
// - llama_graph_input_attn_kv_self (requires KV cache instance)
// - etc.
//
class llama_graph_input_i { class llama_graph_input_i {
public: public:
virtual ~llama_graph_input_i() = default; virtual ~llama_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0; virtual void set_input(const llama_ubatch * ubatch) = 0;
// by default, we produce a single input tensor, but some children could produce more // by default, we produce a single input tensor, but some implementations could produce more
ggml_tensor * cur = nullptr; ggml_tensor * cur = nullptr;
}; };
using llama_graph_input_ptr = std::shared_ptr<llama_graph_input_i>; using llama_graph_input_ptr = std::shared_ptr<llama_graph_input_i>;
class llama_graph_input_attn_i : public llama_graph_input_i { class llama_graph_input_attn_i : public llama_graph_input_i {
public: public:
virtual ~llama_graph_input_attn_i() = default; virtual ~llama_graph_input_attn_i() = default;
@ -47,10 +64,17 @@ public:
using llama_graph_input_attn_ptr = std::shared_ptr<llama_graph_input_attn_i>; using llama_graph_input_attn_ptr = std::shared_ptr<llama_graph_input_attn_i>;
// //
// llama_graph_result // llama_graph_result
// //
// these objects deliver the result from the graph build process back to the llama_context
// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
// specific data, by calling the set_inputs() method
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
// these are used by the llama_context to extact the relevant data, based on the compute parameters
class llama_graph_result_i { class llama_graph_result_i {
public: public:
virtual ~llama_graph_result_i() = default; virtual ~llama_graph_result_i() = default;
@ -64,9 +88,9 @@ public:
using llama_graph_result_ptr = std::unique_ptr<llama_graph_result_i>; using llama_graph_result_ptr = std::unique_ptr<llama_graph_result_i>;
class llama_graph_result : public llama_graph_result_i { class llama_graph_result : public llama_graph_result_i {
public: public:
llama_graph_result() = default;
virtual ~llama_graph_result() = default; virtual ~llama_graph_result() = default;
ggml_tensor * get_logits() override { return t_logits; } ggml_tensor * get_logits() override { return t_logits; }
@ -91,10 +115,19 @@ public:
std::vector<llama_graph_input_ptr> inputs; std::vector<llama_graph_input_ptr> inputs;
}; };
// //
// llama_graph // llama_graph
// //
// this interface defines an API for building graphs by abstracting some high-level concepts such as attention, lora, etc.
// functionality that is trivial and does not rely on the llama_context should be directly implemented in llm_build_context
// other context-specific functionality should be declared here and implemented in the llama_context variations
//
// the main goal of this interface is to separate the llama_context specifics from the graph building logic
// this allows to have cleaner model architecture definitions while being able to overload certain complex
// functionality in order to fit different use cases and/or explore new implementations and ideas
// note: keep all methods const // note: keep all methods const
// TODO: can become more granular in the future // TODO: can become more granular in the future
class llama_graph_i { class llama_graph_i {
@ -112,6 +145,10 @@ private:
public: public:
virtual int32_t get_n_outputs() const = 0; virtual int32_t get_n_outputs() const = 0;
//
// context-specific API
//
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
virtual void build_cb( virtual void build_cb(
ggml_tensor * cur, ggml_tensor * cur,
@ -141,8 +178,6 @@ public:
// rope factors based on the current context size // rope factors based on the current context size
virtual ggml_tensor * build_rope_factors(int il) const = 0; virtual ggml_tensor * build_rope_factors(int il) const = 0;
// graph build API (context-specific)
// input embeddings with optional lora // input embeddings with optional lora
virtual llama_graph_input_ptr build_inp_embd( virtual llama_graph_input_ptr build_inp_embd(
ggml_context * ctx0, ggml_context * ctx0,
@ -154,6 +189,9 @@ public:
ggml_context * ctx0, ggml_context * ctx0,
int32_t n_tokens) const = 0; int32_t n_tokens) const = 0;
virtual llama_graph_input_ptr build_inp_cross_embd(
ggml_context * ctx0) const;
// //
// attention API // attention API
// //
@ -186,9 +224,6 @@ public:
float kq_scale, float kq_scale,
int il) const; int il) const;
virtual llama_graph_input_ptr build_inp_cross_embd(
ggml_context * ctx0) const;
// //
// recurrent API // recurrent API
// //