mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-29 05:33:37 -04:00
llama : reuse compute graphs (#14482)
* llama : reuse compute graphs ggml-ci * llama-bench : add graph reuse parameter ggml-ci * cont : remove the parameter and the sched resets ggml-ci * graph : rename update() to can_reuse() ggml-ci * params : remove is_same() ggml-ci * graph : set res->params in llm_graph_context constructor ggml-ci * graph : avoid set_max_nodes in llm_graph_result ggml-ci * kv-cache : reuse llama_context's graph result instance ggml-ci * context : reset the previous graph result upon memory updates ggml-ci * batch : llama_ubatch now carries its data instead of pointing to balloc ggml-ci * merge : fix build ggml-ci * graph : fix can_reuse() checks when flash-attention is disabled * graph : move llm_graph_result impl in source file + debug env ggml-ci
This commit is contained in:
@@ -105,7 +105,7 @@ llama_context::llama_context(
|
||||
|
||||
{
|
||||
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
||||
const bool supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
|
||||
const bool supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false;
|
||||
|
||||
if (!supports_set_rows && !cparams.kv_unified) {
|
||||
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
|
||||
@@ -238,8 +238,8 @@ llama_context::llama_context(
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes);
|
||||
|
||||
// buffer used to store the computation graph and the tensor meta data
|
||||
buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
|
||||
gf_res_prev.reset(new llm_graph_result(max_nodes));
|
||||
gf_res_reserve.reset(new llm_graph_result(max_nodes));
|
||||
|
||||
// TODO: move these checks to ggml_backend_sched
|
||||
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
||||
@@ -403,10 +403,6 @@ ggml_backend_sched_t llama_context::get_sched() const {
|
||||
return sched.get();
|
||||
}
|
||||
|
||||
ggml_context * llama_context::get_ctx_compute() const {
|
||||
return ctx_compute.get();
|
||||
}
|
||||
|
||||
uint32_t llama_context::n_ctx() const {
|
||||
return cparams.n_ctx;
|
||||
}
|
||||
@@ -478,6 +474,11 @@ bool llama_context::kv_self_update(bool optimize) {
|
||||
}
|
||||
}
|
||||
|
||||
// reset the previous graph result to make sure that it won't be reused
|
||||
// TODO: change the mctx->apply() to return information if a graph reserve is needed
|
||||
// reset the graph result only if the memory module did reset the scheduler
|
||||
gf_res_prev->reset();
|
||||
|
||||
if (!mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
||||
}
|
||||
@@ -693,38 +694,59 @@ bool llama_context::apply_adapter_cvec(
|
||||
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
||||
}
|
||||
|
||||
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
||||
llm_graph_result_i * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
||||
if (mctx && !mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto * gf = graph_init();
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
auto * res = gf_res_prev.get();
|
||||
auto * gf = res->get_gf();
|
||||
|
||||
// the new graph parameters
|
||||
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
|
||||
const auto gparams = graph_params(res, ubatch, mctx, gtype);
|
||||
|
||||
if (res->can_reuse(gparams)) {
|
||||
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
|
||||
|
||||
n_reused++;
|
||||
} else {
|
||||
res->reset();
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
//const auto t_start_us = ggml_time_us();
|
||||
|
||||
gf = model.build_graph(gparams);
|
||||
|
||||
//LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
|
||||
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
||||
ret = GGML_STATUS_ALLOC_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
// set the input data for the input tensors
|
||||
{
|
||||
//const auto t_start_us = ggml_time_us();
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
|
||||
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
|
||||
}
|
||||
|
||||
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
|
||||
|
||||
if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__);
|
||||
ret = GGML_STATUS_ALLOC_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
res->set_inputs(&ubatch);
|
||||
|
||||
const auto status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
|
||||
if (status != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
|
||||
ret = status;
|
||||
@@ -785,9 +807,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
|
||||
n_outputs = n_tokens;
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
const auto causal_attn_org = cparams.causal_attn;
|
||||
|
||||
// always use non-causal attention for encoder graphs
|
||||
@@ -796,7 +815,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
cparams.causal_attn = false;
|
||||
|
||||
ggml_status status;
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
||||
|
||||
cparams.causal_attn = causal_attn_org;
|
||||
|
||||
@@ -872,10 +891,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||
}
|
||||
}
|
||||
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
// TODO: hacky solution
|
||||
if (model.arch == LLM_ARCH_T5 && t_embd) {
|
||||
//cross.t_embd = t_embd;
|
||||
@@ -1033,11 +1048,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
ggml_status status;
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
||||
|
||||
if (!res) {
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||
@@ -1218,10 +1230,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||
//synchronize();
|
||||
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -1303,20 +1311,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
// graph
|
||||
//
|
||||
|
||||
int32_t llama_context::graph_max_nodes() const {
|
||||
return std::max<int32_t>(65536, 5*model.n_tensors());
|
||||
uint32_t llama_context::graph_max_nodes() const {
|
||||
return std::max<uint32_t>(65536u, 5u*model.n_tensors());
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_context::graph_init() {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ buf_compute_meta.size(),
|
||||
/*.mem_buffer =*/ buf_compute_meta.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
ctx_compute.reset(ggml_init(params));
|
||||
|
||||
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
||||
llm_graph_result * llama_context::get_gf_res_reserve() const {
|
||||
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
|
||||
@@ -1329,6 +1329,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
||||
LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
// when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that
|
||||
gf_res_prev->reset();
|
||||
|
||||
// store the n_outputs as it is, and restore it afterwards
|
||||
// TODO: not sure if needed, might simplify in the future by removing this
|
||||
const auto save_n_outputs = this->n_outputs;
|
||||
@@ -1338,18 +1343,16 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
||||
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
||||
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
|
||||
auto * res = gf_res_reserve.get();
|
||||
|
||||
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
res->reset();
|
||||
|
||||
auto * gf = model.build_graph(gparams);
|
||||
|
||||
this->n_outputs = save_n_outputs;
|
||||
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
// initialize scheduler with the specified graph
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
@@ -1359,28 +1362,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
||||
return gf;
|
||||
}
|
||||
|
||||
llm_graph_result_ptr llama_context::graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
const llama_memory_context_i * mctx) {
|
||||
return model.build_graph(
|
||||
{
|
||||
/*.ctx =*/ ctx,
|
||||
/*.arch =*/ model.arch,
|
||||
/*.hparams =*/ model.hparams,
|
||||
/*.cparams =*/ cparams,
|
||||
/*.ubatch =*/ ubatch,
|
||||
/*.sched =*/ sched.get(),
|
||||
/*.backend_cpu =*/ backend_cpu,
|
||||
/*.cvec =*/ &cvec,
|
||||
/*.loras =*/ &loras,
|
||||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
}, gf, gtype);
|
||||
llm_graph_params llama_context::graph_params(
|
||||
llm_graph_result_i * res,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_memory_context_i * mctx,
|
||||
llm_graph_type gtype) const {
|
||||
return {
|
||||
/*.arch =*/ model.arch,
|
||||
/*.hparams =*/ model.hparams,
|
||||
/*.cparams =*/ cparams,
|
||||
/*.ubatch =*/ ubatch,
|
||||
/*.gtype =*/ gtype,
|
||||
/*.sched =*/ sched.get(),
|
||||
/*.backend_cpu =*/ backend_cpu,
|
||||
/*.cvec =*/ &cvec,
|
||||
/*.loras =*/ &loras,
|
||||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
/*.res =*/ res,
|
||||
};
|
||||
}
|
||||
|
||||
ggml_status llama_context::graph_compute(
|
||||
@@ -1958,6 +1960,7 @@ llama_perf_context_data llama_context::perf_get_data() const {
|
||||
data.t_eval_ms = 1e-3 * t_eval_us;
|
||||
data.n_p_eval = std::max(1, n_p_eval);
|
||||
data.n_eval = std::max(1, n_eval);
|
||||
data.n_reused = std::max(0, n_reused);
|
||||
|
||||
return data;
|
||||
}
|
||||
@@ -1966,6 +1969,7 @@ void llama_context::perf_reset() {
|
||||
t_start_us = ggml_time_us();
|
||||
t_eval_us = n_eval = 0;
|
||||
t_p_eval_us = n_p_eval = 0;
|
||||
n_reused = 0;
|
||||
}
|
||||
|
||||
//
|
||||
@@ -2092,8 +2096,13 @@ void llama_context::opt_epoch_iter(
|
||||
break;
|
||||
}
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
|
||||
auto * res = gf_res_prev.get();
|
||||
|
||||
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
|
||||
|
||||
res->reset();
|
||||
|
||||
auto * gf = model.build_graph(gparams);
|
||||
|
||||
struct ggml_context * ctx_compute_opt;
|
||||
{
|
||||
@@ -2836,6 +2845,7 @@ void llama_perf_context_print(const llama_context * ctx) {
|
||||
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
|
||||
LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
|
||||
LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused);
|
||||
}
|
||||
|
||||
void llama_perf_context_reset(llama_context * ctx) {
|
||||
|
Reference in New Issue
Block a user