llama : reuse compute graphs (#14482)

* llama : reuse compute graphs

ggml-ci

* llama-bench : add graph reuse parameter

ggml-ci

* cont : remove the parameter and the sched resets

ggml-ci

* graph : rename update() to can_reuse()

ggml-ci

* params : remove is_same()

ggml-ci

* graph : set res->params in llm_graph_context constructor

ggml-ci

* graph : avoid set_max_nodes in llm_graph_result

ggml-ci

* kv-cache : reuse llama_context's graph result instance

ggml-ci

* context : reset the previous graph result upon memory updates

ggml-ci

* batch : llama_ubatch now carries its data instead of pointing to balloc

ggml-ci

* merge : fix build

ggml-ci

* graph : fix can_reuse() checks when flash-attention is disabled

* graph : move llm_graph_result impl in source file + debug env

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-07-17 19:08:33 +03:00
committed by GitHub
parent 086cf81e88
commit 01612b7409
12 changed files with 548 additions and 289 deletions

View File

@@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
}
}
bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
bool res = true;
res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens);
res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens);
return res;
}
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
if (ubatch->pos && pos) {
const int64_t n_tokens = ubatch->n_tokens;
@@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
}
}
bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) {
bool res = true;
res &= pos->ne[0] == params.ubatch.n_tokens;
return res;
}
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
if (ubatch->pos && attn_scale) {
const int64_t n_tokens = ubatch->n_tokens;
@@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
int32_t * data = (int32_t *) pos_bucket->data;
@@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
}
}
bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) {
bool res = true;
res &= n_outputs == params.n_outputs;
return res;
}
void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
const int64_t n_tokens = ubatch->n_tokens;
@@ -287,6 +312,24 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
this->mctx = mctx;
bool res = true;
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
res &= mctx->get_supports_set_rows(); // TODO: tmp
return res;
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
@@ -299,6 +342,30 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}
bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
this->mctx = mctx;
bool res = true;
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD);
res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp
return res;
}
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
GGML_ASSERT(cross_kq_mask);
@@ -306,7 +373,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
const int64_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing
float * data = (float *) cross_kq_mask->data;
@@ -340,6 +407,83 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
inp_rs->set_input(ubatch);
}
//
// llm_graph_result
//
llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) {
reset();
const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG");
debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0;
}
int64_t llm_graph_result::get_max_nodes() const {
return max_nodes;
}
void llm_graph_result::reset() {
t_tokens = nullptr;
t_logits = nullptr;
t_embd = nullptr;
t_embd_pooled = nullptr;
inputs.clear();
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
ggml_init_params params = {
/*.mem_size =*/ buf_compute_meta.size(),
/*.mem_buffer =*/ buf_compute_meta.data(),
/*.no_alloc =*/ true,
};
ctx_compute.reset(ggml_init(params));
gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false);
}
void llm_graph_result::set_inputs(const llama_ubatch * ubatch) {
for (auto & input : inputs) {
input->set_input(ubatch);
}
}
bool llm_graph_result::can_reuse(const llm_graph_params & params) {
if (!this->params.allow_reuse(params)) {
if (debug > 1) {
LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__);
}
return false;
}
if (debug > 1) {
LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size());
}
bool res = true;
for (auto & input : inputs) {
const bool cur = input->can_reuse(params);
LLAMA_LOG_DEBUG(" %s: can_reuse = %d\n", "placeholder", cur);
res = res && cur;
}
if (debug > 0) {
LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res);
}
return res;
}
llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) {
inputs.emplace_back(std::move(input));
return inputs.back().get();
}
//
// llm_graph_context
//
@@ -374,7 +518,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
n_ctx_orig (cparams.n_ctx_orig_yarn),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
ctx0 (params.ctx),
sched (params.sched),
backend_cpu (params.backend_cpu),
cvec (params.cvec),
@@ -382,7 +525,9 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
mctx (params.mctx),
cross (params.cross),
cb_func (params.cb),
res (std::make_unique<llm_graph_result>()) {
res (static_cast<llm_graph_result *>(params.res)),
ctx0 (res->get_ctx()) {
res->params = params;
}
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
@@ -1127,8 +1272,8 @@ ggml_tensor * llm_graph_context::build_attn(
const auto & kq_mask = inp->get_kq_mask();
// [TAG_NO_CACHE_PAD]
// TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
assert(ubatch.equal_seqs == false);
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
assert(!ubatch.equal_seqs());
ggml_tensor * q = q_cur;
ggml_tensor * k = k_cur;