graph : add comments

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-28 21:13:08 +02:00
parent 0f7daa9d1b
commit 624f7bd03b
3 changed files with 52 additions and 16 deletions

View File

@ -101,6 +101,7 @@ void llama_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
}
}
// note: this does not depend on the context and can technically be moved to llama-model.cpp
class llama_graph_input_attn_base : public llama_graph_input_attn_i {
public:
llama_graph_input_attn_base(const llama_hparams & hparams, const llama_cparams & cparams) :

View File

@ -19,6 +19,14 @@ ggml_tensor * llama_graph_input_attn_i::get_kq_mask_cross() {
llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {}
llama_graph_input_ptr llama_graph_i::build_inp_cross_embd(
ggml_context * ctx0) const {
GGML_UNUSED(ctx0);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr;
}
ggml_tensor * llama_graph_i::build_attn(
llama_graph_input_attn_i * inp,
ggml_context * ctx0,
@ -67,14 +75,6 @@ ggml_tensor * llama_graph_i::build_attn_cross(
return nullptr;
}
llama_graph_input_ptr llama_graph_i::build_inp_cross_embd(
ggml_context * ctx0) const {
GGML_UNUSED(ctx0);
LLAMA_LOG_ERROR("%s: not implemented\n", __func__);
return nullptr;
}
llama_graph_input_ptr llama_graph_i::build_inp_s_copy (
ggml_context * ctx0) const {
GGML_UNUSED(ctx0);

View File

@ -10,32 +10,49 @@
struct ggml_cgraph;
struct ggml_context;
struct ggml_tensor;
struct ggml_backend_buffer;
struct llama_ubatch;
// certain models (typically multi-modal) can produce different types of graphs
// the llama_context specifies which type of graph it needs through the llama_graph_i::type member
enum llama_graph_type {
LLAMA_GRAPH_TYPE_DEFAULT,
LLAMA_GRAPH_TYPE_ENCODER,
LLAMA_GRAPH_TYPE_DECODER,
};
//
// llama_graph_input
//
// denotes an input to the graph
// typically, the data of these objects is populated based on the contents of the current llama_ubatch:
//
// - llama_graph_input_pos
// - llama_graph_input_out_ids
// - etc.
//
// some inputs require context-specific data (e.g. KV cache) - such inputs are defined for the specific llama_context:
//
// - llama_graph_input_embd (can apply lora)
// - llama_graph_input_attn_kv_self (requires KV cache instance)
// - etc.
//
class llama_graph_input_i {
public:
virtual ~llama_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0;
// by default, we produce a single input tensor, but some children could produce more
// by default, we produce a single input tensor, but some implementations could produce more
ggml_tensor * cur = nullptr;
};
using llama_graph_input_ptr = std::shared_ptr<llama_graph_input_i>;
class llama_graph_input_attn_i : public llama_graph_input_i {
public:
virtual ~llama_graph_input_attn_i() = default;
@ -47,10 +64,17 @@ public:
using llama_graph_input_attn_ptr = std::shared_ptr<llama_graph_input_attn_i>;
//
// llama_graph_result
//
// these objects deliver the result from the graph build process back to the llama_context
// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their
// specific data, by calling the set_inputs() method
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
// these are used by the llama_context to extact the relevant data, based on the compute parameters
class llama_graph_result_i {
public:
virtual ~llama_graph_result_i() = default;
@ -64,9 +88,9 @@ public:
using llama_graph_result_ptr = std::unique_ptr<llama_graph_result_i>;
class llama_graph_result : public llama_graph_result_i {
public:
llama_graph_result() = default;
virtual ~llama_graph_result() = default;
ggml_tensor * get_logits() override { return t_logits; }
@ -91,10 +115,19 @@ public:
std::vector<llama_graph_input_ptr> inputs;
};
//
// llama_graph
//
// this interface defines an API for building graphs by abstracting some high-level concepts such as attention, lora, etc.
// functionality that is trivial and does not rely on the llama_context should be directly implemented in llm_build_context
// other context-specific functionality should be declared here and implemented in the llama_context variations
//
// the main goal of this interface is to separate the llama_context specifics from the graph building logic
// this allows to have cleaner model architecture definitions while being able to overload certain complex
// functionality in order to fit different use cases and/or explore new implementations and ideas
// note: keep all methods const
// TODO: can become more granular in the future
class llama_graph_i {
@ -112,6 +145,10 @@ private:
public:
virtual int32_t get_n_outputs() const = 0;
//
// context-specific API
//
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
virtual void build_cb(
ggml_tensor * cur,
@ -141,8 +178,6 @@ public:
// rope factors based on the current context size
virtual ggml_tensor * build_rope_factors(int il) const = 0;
// graph build API (context-specific)
// input embeddings with optional lora
virtual llama_graph_input_ptr build_inp_embd(
ggml_context * ctx0,
@ -154,6 +189,9 @@ public:
ggml_context * ctx0,
int32_t n_tokens) const = 0;
virtual llama_graph_input_ptr build_inp_cross_embd(
ggml_context * ctx0) const;
//
// attention API
//
@ -186,9 +224,6 @@ public:
float kq_scale,
int il) const;
virtual llama_graph_input_ptr build_inp_cross_embd(
ggml_context * ctx0) const;
//
// recurrent API
//