mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-29 05:33:37 -04:00
memory : rename interface to llama_memory_context_i (#14296)
* memory : rename interface to llama_memory_context_i ggml-ci * cont : fix comments * cont : use "mctx" for referencing a memory context ggml-ci
This commit is contained in:
@@ -280,8 +280,8 @@ llama_context::llama_context(
|
||||
|
||||
// simulate full KV cache
|
||||
|
||||
const auto mstate = memory->init_full();
|
||||
if (!mstate) {
|
||||
const auto mctx = memory->init_full();
|
||||
if (!mctx) {
|
||||
throw std::runtime_error("failed to initialize KV cache");
|
||||
}
|
||||
|
||||
@@ -289,7 +289,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve pp graph first so that buffers are only allocated once
|
||||
{
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
@@ -300,7 +300,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
{
|
||||
auto * gf = graph_reserve(1, 1, 1, mstate.get());
|
||||
auto * gf = graph_reserve(1, 1, 1, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||
}
|
||||
@@ -311,7 +311,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
{
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
@@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
|
||||
optimize |= memory_force_optimize;
|
||||
memory_force_optimize = false;
|
||||
|
||||
const auto mstate = memory->init_update(this, optimize);
|
||||
switch (mstate->get_status()) {
|
||||
const auto mctx = memory->init_update(this, optimize);
|
||||
switch (mctx->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
// noop
|
||||
@@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!mstate->apply()) {
|
||||
if (!mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// if the memory module did any computation, we have to reserve a new worst-case graph
|
||||
{
|
||||
const auto mstate = memory->init_full();
|
||||
if (!mstate) {
|
||||
throw std::runtime_error("failed to initialize memory state");
|
||||
const auto mctx = memory->init_full();
|
||||
if (!mctx) {
|
||||
throw std::runtime_error("failed to initialize memory context");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
||||
}
|
||||
@@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
|
||||
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
||||
}
|
||||
|
||||
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) {
|
||||
if (mstate && !mstate->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__);
|
||||
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
||||
if (mctx && !mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
@@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate);
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
|
||||
if (!res) {
|
||||
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
@@ -933,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
// handle any pending defrags/shifts
|
||||
kv_self_update(false);
|
||||
|
||||
llama_memory_state_ptr mstate;
|
||||
llama_memory_context_ptr mctx;
|
||||
|
||||
while (true) {
|
||||
mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
||||
if (!mstate) {
|
||||
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
||||
if (!mctx) {
|
||||
return -2;
|
||||
}
|
||||
|
||||
switch (mstate->get_status()) {
|
||||
switch (mctx->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
|
||||
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
|
||||
|
||||
return -2;
|
||||
}
|
||||
@@ -987,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
int64_t n_outputs_prev = 0;
|
||||
|
||||
do {
|
||||
const auto & ubatch = mstate->get_ubatch();
|
||||
const auto & ubatch = mctx->get_ubatch();
|
||||
|
||||
// count the outputs in this ubatch
|
||||
{
|
||||
@@ -1009,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
ggml_status status;
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
||||
|
||||
if (!res) {
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||
@@ -1126,7 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
}
|
||||
|
||||
n_outputs_prev += n_outputs;
|
||||
} while (mstate->next());
|
||||
} while (mctx->next());
|
||||
|
||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||
n_outputs = n_outputs_all;
|
||||
@@ -1292,7 +1292,7 @@ ggml_cgraph * llama_context::graph_init() {
|
||||
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) {
|
||||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
|
||||
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
|
||||
if (n_tokens % n_seqs != 0) {
|
||||
@@ -1312,7 +1312,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
||||
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
|
||||
|
||||
this->n_outputs = save_n_outputs;
|
||||
|
||||
@@ -1333,11 +1333,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
||||
}
|
||||
|
||||
llm_graph_result_ptr llama_context::graph_build(
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
const llama_memory_state_i * mstate) {
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
const llama_memory_context_i * mctx) {
|
||||
return model.build_graph(
|
||||
{
|
||||
/*.ctx =*/ ctx,
|
||||
@@ -1349,7 +1349,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
||||
/*.backend_cpu =*/ backend_cpu,
|
||||
/*.cvec =*/ &cvec,
|
||||
/*.loras =*/ &loras,
|
||||
/*.mstate =*/ mstate,
|
||||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
@@ -2042,8 +2042,8 @@ void llama_context::opt_epoch_iter(
|
||||
|
||||
uint32_t n_outputs_all = n_tokens_all;
|
||||
|
||||
auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
|
||||
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
|
||||
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
||||
break;
|
||||
}
|
||||
@@ -2056,17 +2056,17 @@ void llama_context::opt_epoch_iter(
|
||||
|
||||
uint32_t pos_batch = 0;
|
||||
do {
|
||||
const auto & ubatch = mstate->get_ubatch();
|
||||
const auto & ubatch = mctx->get_ubatch();
|
||||
|
||||
n_outputs = ubatch.n_tokens;
|
||||
|
||||
if (!mstate->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
||||
if (!mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
|
||||
|
||||
struct ggml_context * ctx_compute_opt;
|
||||
{
|
||||
@@ -2101,7 +2101,7 @@ void llama_context::opt_epoch_iter(
|
||||
ggml_free(ctx_compute_opt);
|
||||
|
||||
pos_batch += ubatch.n_tokens;
|
||||
} while (mstate->next());
|
||||
} while (mctx->next());
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user