diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 7ba86a2a7..8963b85ca 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -101,6 +101,7 @@ void llama_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { } } +// note: this does not depend on the context and can technically be moved to llama-model.cpp class llama_graph_input_attn_base : public llama_graph_input_attn_i { public: llama_graph_input_attn_base(const llama_hparams & hparams, const llama_cparams & cparams) : diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 89e311a91..119f1a56f 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -19,6 +19,14 @@ ggml_tensor * llama_graph_input_attn_i::get_kq_mask_cross() { llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {} +llama_graph_input_ptr llama_graph_i::build_inp_cross_embd( + ggml_context * ctx0) const { + GGML_UNUSED(ctx0); + + LLAMA_LOG_ERROR("%s: not implemented\n", __func__); + return nullptr; +} + ggml_tensor * llama_graph_i::build_attn( llama_graph_input_attn_i * inp, ggml_context * ctx0, @@ -67,14 +75,6 @@ ggml_tensor * llama_graph_i::build_attn_cross( return nullptr; } -llama_graph_input_ptr llama_graph_i::build_inp_cross_embd( - ggml_context * ctx0) const { - GGML_UNUSED(ctx0); - - LLAMA_LOG_ERROR("%s: not implemented\n", __func__); - return nullptr; -} - llama_graph_input_ptr llama_graph_i::build_inp_s_copy ( ggml_context * ctx0) const { GGML_UNUSED(ctx0); diff --git a/src/llama-graph.h b/src/llama-graph.h index 343d4a077..2d62c674f 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -10,32 +10,49 @@ struct ggml_cgraph; struct ggml_context; struct ggml_tensor; -struct ggml_backend_buffer; struct llama_ubatch; +// certain models (typically multi-modal) can produce different types of graphs +// the llama_context specifies which type of graph it needs through the llama_graph_i::type member enum llama_graph_type { LLAMA_GRAPH_TYPE_DEFAULT, LLAMA_GRAPH_TYPE_ENCODER, LLAMA_GRAPH_TYPE_DECODER, }; + // // llama_graph_input // +// denotes an input to the graph +// typically, the data of these objects is populated based on the contents of the current llama_ubatch: +// +// - llama_graph_input_pos +// - llama_graph_input_out_ids +// - etc. +// +// some inputs require context-specific data (e.g. KV cache) - such inputs are defined for the specific llama_context: +// +// - llama_graph_input_embd (can apply lora) +// - llama_graph_input_attn_kv_self (requires KV cache instance) +// - etc. +// + class llama_graph_input_i { public: virtual ~llama_graph_input_i() = default; virtual void set_input(const llama_ubatch * ubatch) = 0; - // by default, we produce a single input tensor, but some children could produce more + // by default, we produce a single input tensor, but some implementations could produce more ggml_tensor * cur = nullptr; }; using llama_graph_input_ptr = std::shared_ptr; + class llama_graph_input_attn_i : public llama_graph_input_i { public: virtual ~llama_graph_input_attn_i() = default; @@ -47,10 +64,17 @@ public: using llama_graph_input_attn_ptr = std::shared_ptr; + // // llama_graph_result // +// these objects deliver the result from the graph build process back to the llama_context +// note that the input tensors created for the graph are referenced here - the goal is to be able to populate their +// specific data, by calling the set_inputs() method +// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc. +// these are used by the llama_context to extact the relevant data, based on the compute parameters + class llama_graph_result_i { public: virtual ~llama_graph_result_i() = default; @@ -64,9 +88,9 @@ public: using llama_graph_result_ptr = std::unique_ptr; + class llama_graph_result : public llama_graph_result_i { public: - llama_graph_result() = default; virtual ~llama_graph_result() = default; ggml_tensor * get_logits() override { return t_logits; } @@ -91,10 +115,19 @@ public: std::vector inputs; }; + // // llama_graph // +// this interface defines an API for building graphs by abstracting some high-level concepts such as attention, lora, etc. +// functionality that is trivial and does not rely on the llama_context should be directly implemented in llm_build_context +// other context-specific functionality should be declared here and implemented in the llama_context variations +// +// the main goal of this interface is to separate the llama_context specifics from the graph building logic +// this allows to have cleaner model architecture definitions while being able to overload certain complex +// functionality in order to fit different use cases and/or explore new implementations and ideas + // note: keep all methods const // TODO: can become more granular in the future class llama_graph_i { @@ -112,6 +145,10 @@ private: public: virtual int32_t get_n_outputs() const = 0; + // + // context-specific API + // + // callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) virtual void build_cb( ggml_tensor * cur, @@ -141,8 +178,6 @@ public: // rope factors based on the current context size virtual ggml_tensor * build_rope_factors(int il) const = 0; - // graph build API (context-specific) - // input embeddings with optional lora virtual llama_graph_input_ptr build_inp_embd( ggml_context * ctx0, @@ -154,6 +189,9 @@ public: ggml_context * ctx0, int32_t n_tokens) const = 0; + virtual llama_graph_input_ptr build_inp_cross_embd( + ggml_context * ctx0) const; + // // attention API // @@ -186,9 +224,6 @@ public: float kq_scale, int il) const; - virtual llama_graph_input_ptr build_inp_cross_embd( - ggml_context * ctx0) const; - // // recurrent API //