llama : reuse compute graphs (#14482)

* llama : reuse compute graphs

ggml-ci

* llama-bench : add graph reuse parameter

ggml-ci

* cont : remove the parameter and the sched resets

ggml-ci

* graph : rename update() to can_reuse()

ggml-ci

* params : remove is_same()

ggml-ci

* graph : set res->params in llm_graph_context constructor

ggml-ci

* graph : avoid set_max_nodes in llm_graph_result

ggml-ci

* kv-cache : reuse llama_context's graph result instance

ggml-ci

* context : reset the previous graph result upon memory updates

ggml-ci

* batch : llama_ubatch now carries its data instead of pointing to balloc

ggml-ci

* merge : fix build

ggml-ci

* graph : fix can_reuse() checks when flash-attention is disabled

* graph : move llm_graph_result impl in source file + debug env

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-07-17 19:08:33 +03:00
committed by GitHub
parent 086cf81e88
commit 01612b7409
12 changed files with 548 additions and 289 deletions

View File

@@ -1,6 +1,7 @@
#pragma once
#include "llama-arch.h"
#include "llama-batch.h"
#include "llama-hparams.h"
#include "llama-adapter.h"
@@ -14,7 +15,6 @@ struct ggml_cgraph;
struct ggml_context;
struct ggml_tensor;
struct llama_ubatch;
struct llama_cparams;
struct llama_memory_context_i;
@@ -69,6 +69,8 @@ struct llama_cross {
std::vector<std::set<llama_seq_id>> seq_ids_enc;
};
struct llm_graph_params;
//
// llm_graph_input
//
@@ -78,11 +80,19 @@ public:
virtual ~llm_graph_input_i() = default;
virtual void set_input(const llama_ubatch * ubatch) = 0;
// return true if the resulting input tensors using the provided graph parameters would be
// the same as the previous input tensors that we have currently stored in the object
virtual bool can_reuse(const llm_graph_params & params) {
// returning false here by default will prevent from reusing the graph if the check
// for the input type has not been implemented yet
GGML_UNUSED(params);
return false;
}
};
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
class llm_graph_input_embd : public llm_graph_input_i {
public:
llm_graph_input_embd() = default;
@@ -90,6 +100,8 @@ public:
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * tokens = nullptr; // I32 [n_batch]
ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch]
};
@@ -101,6 +113,8 @@ public:
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * pos = nullptr; // I32 [n_batch]
const uint32_t n_pos_per_embd = 1;
@@ -154,17 +168,19 @@ public:
llm_graph_input_out_ids(
const llama_hparams & hparams,
const llama_cparams & cparams,
int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {}
virtual ~llm_graph_input_out_ids() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * out_ids; // I32 [n_outputs]
const llama_hparams & hparams;
const llama_cparams & cparams;
const int32_t n_outputs;
const uint32_t n_outputs;
};
class llm_graph_input_mean : public llm_graph_input_i {
@@ -249,6 +265,8 @@ public:
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
@@ -280,6 +298,8 @@ public:
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
@@ -351,65 +371,40 @@ public:
// along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc.
// these are used by the llama_context to extact the relevant data, based on the compute parameters
// TODO: this interface seems redundant - remove it
class llm_graph_result_i {
public:
virtual ~llm_graph_result_i() = default;
virtual ggml_tensor * get_tokens() = 0;
virtual ggml_tensor * get_logits() = 0;
virtual ggml_tensor * get_embd() = 0;
virtual ggml_tensor * get_embd_pooled() = 0;
virtual ggml_tensor * get_tokens() const = 0;
virtual ggml_tensor * get_logits() const = 0;
virtual ggml_tensor * get_embd() const = 0;
virtual ggml_tensor * get_embd_pooled() const = 0;
virtual ggml_cgraph * get_gf() = 0;
virtual ggml_context * get_ctx() = 0;
virtual void reset() = 0;
virtual void set_inputs(const llama_ubatch * ubatch) = 0;
virtual bool can_reuse(const llm_graph_params & params) = 0;
};
using llm_graph_result_ptr = std::unique_ptr<llm_graph_result_i>;
class llm_graph_result : public llm_graph_result_i {
public:
virtual ~llm_graph_result() = default;
ggml_tensor * get_tokens() override { return t_tokens; }
ggml_tensor * get_logits() override { return t_logits; }
ggml_tensor * get_embd() override { return t_embd; }
ggml_tensor * get_embd_pooled() override { return t_embd_pooled; }
void set_inputs(const llama_ubatch * ubatch) override {
for (auto & input : inputs) {
input->set_input(ubatch);
}
}
llm_graph_input_i * add_input(llm_graph_input_ptr input) {
inputs.emplace_back(std::move(input));
return inputs.back().get();
}
// important graph nodes
ggml_tensor * t_tokens = nullptr;
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
std::vector<llm_graph_input_ptr> inputs;
};
//
// llm_graph_context
//
// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
using llm_graph_cb = std::function<void(const llama_ubatch & ubatch, ggml_tensor * cur, const char * name, int il)>;
struct llm_graph_params {
ggml_context * ctx;
llm_arch arch = LLM_ARCH_UNKNOWN;
const llm_arch arch;
llama_hparams hparams;
llama_cparams cparams;
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_ubatch & ubatch;
llama_ubatch ubatch; // note: intentionally make a copy
llm_graph_type gtype;
ggml_backend_sched_t sched;
ggml_backend_t backend_cpu;
@@ -421,9 +416,113 @@ struct llm_graph_params {
uint32_t n_outputs;
const llm_graph_cb & cb;
llm_graph_cb cb;
// TODO: temporary
llm_graph_result_i * res;
// return true if the "other" params would result in a graph with the same topology as with the current params
// having the same topology allows us to reuse the graph in some cases
bool allow_reuse(const llm_graph_params & other) const {
// first check the ubatch
bool can_reuse_ubatch =
ubatch.equal_seqs() == other.ubatch.equal_seqs() &&
ubatch.n_tokens == other.ubatch.n_tokens &&
ubatch.n_seq_tokens == other.ubatch.n_seq_tokens &&
ubatch.n_seqs == other.ubatch.n_seqs &&
ubatch.n_seqs_unq == other.ubatch.n_seqs_unq &&
(
(!ubatch.token && !other.ubatch.token) ||
(!ubatch.embd && !other.ubatch.embd)
);
if (can_reuse_ubatch && !ubatch.equal_seqs()) {
if (!ubatch.data) {
// if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and
// therefore we cannot perform the sequence id check. normally should never happen
can_reuse_ubatch = false;
} else {
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s];
}
}
}
if (!can_reuse_ubatch) {
return false;
}
return
cparams.embeddings == other.cparams.embeddings &&
cparams.causal_attn == other.cparams.causal_attn &&
arch == other.arch &&
gtype == other.gtype &&
cvec == other.cvec &&
loras == other.loras &&
cross == other.cross &&
n_outputs == other.n_outputs;
}
};
class llm_graph_result : public llm_graph_result_i {
public:
llm_graph_result(int64_t max_nodes);
virtual ~llm_graph_result() = default;
ggml_tensor * get_tokens() const override { return t_tokens; }
ggml_tensor * get_logits() const override { return t_logits; }
ggml_tensor * get_embd() const override { return t_embd; }
ggml_tensor * get_embd_pooled() const override { return t_embd_pooled; }
ggml_cgraph * get_gf() override { return gf; }
ggml_context * get_ctx() override { return ctx_compute.get(); }
int64_t get_max_nodes() const;
void reset() override;
void set_inputs(const llama_ubatch * ubatch) override;
// try to update the existing graph result using the new graph parameters in order to reuse it
// this can only be done if we determine that the resulting graph using the new graph parameters
// would be identical to the existing graph. in that case, we simply have to update the memory
// contexts of the input tensors of the graph and we can reuse it for another computation
// return true if the graph was updated and can be reused
bool can_reuse(const llm_graph_params & params) override;
llm_graph_input_i * add_input(llm_graph_input_ptr input);
// important graph nodes
ggml_tensor * t_tokens = nullptr;
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
std::vector<llm_graph_input_ptr> inputs;
ggml_context_ptr ctx_compute;
// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
ggml_cgraph * gf;
int64_t max_nodes;
// keep a copy of the previous graph parameters
// we will use this to determine whether the graph can be reused by comparing them with the new parameters
// note: these are updated after constructing the new graph
llm_graph_params params;
// env: LLAMA_GRAPH_RESULT_DEBUG
int debug = 0;
};
//
// llm_graph_context
//
// used in build_rs to properly order writes and avoid unnecessary copies
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
@@ -463,8 +562,6 @@ struct llm_graph_context {
const enum llama_pooling_type pooling_type;
const enum llama_rope_type rope_type;
ggml_context * ctx0 = nullptr;
ggml_backend_sched_t sched;
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
@@ -476,7 +573,9 @@ struct llm_graph_context {
const llm_graph_cb & cb_func;
std::unique_ptr<llm_graph_result> res;
llm_graph_result * res;
ggml_context * ctx0 = nullptr;
llm_graph_context(const llm_graph_params & params);
virtual ~llm_graph_context() = default;