memory : rename interface to llama_memory_context_i (#14296)

* memory : rename interface to llama_memory_context_i

ggml-ci

* cont : fix comments

* cont : use "mctx" for referencing a memory context

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-21 08:03:46 +03:00
committed by GitHub
parent b23fa0b3f4
commit 692e3cdd0a
14 changed files with 339 additions and 341 deletions

View File

@ -280,8 +280,8 @@ llama_context::llama_context(
// simulate full KV cache // simulate full KV cache
const auto mstate = memory->init_full(); const auto mctx = memory->init_full();
if (!mstate) { if (!mctx) {
throw std::runtime_error("failed to initialize KV cache"); throw std::runtime_error("failed to initialize KV cache");
} }
@ -289,7 +289,7 @@ llama_context::llama_context(
// reserve pp graph first so that buffers are only allocated once // reserve pp graph first so that buffers are only allocated once
{ {
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers"); throw std::runtime_error("failed to allocate compute pp buffers");
} }
@ -300,7 +300,7 @@ llama_context::llama_context(
// reserve with tg graph to get the number of splits and nodes // reserve with tg graph to get the number of splits and nodes
{ {
auto * gf = graph_reserve(1, 1, 1, mstate.get()); auto * gf = graph_reserve(1, 1, 1, mctx.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers"); throw std::runtime_error("failed to allocate compute tg buffers");
} }
@ -311,7 +311,7 @@ llama_context::llama_context(
// reserve again with pp graph to avoid ggml-alloc reallocations during inference // reserve again with pp graph to avoid ggml-alloc reallocations during inference
{ {
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers"); throw std::runtime_error("failed to allocate compute pp buffers");
} }
@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
optimize |= memory_force_optimize; optimize |= memory_force_optimize;
memory_force_optimize = false; memory_force_optimize = false;
const auto mstate = memory->init_update(this, optimize); const auto mctx = memory->init_update(this, optimize);
switch (mstate->get_status()) { switch (mctx->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS: case LLAMA_MEMORY_STATUS_SUCCESS:
{ {
// noop // noop
@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
} }
} }
if (!mstate->apply()) { if (!mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
} }
} }
// if the memory module did any computation, we have to reserve a new worst-case graph // if the memory module did any computation, we have to reserve a new worst-case graph
{ {
const auto mstate = memory->init_full(); const auto mctx = memory->init_full();
if (!mstate) { if (!mctx) {
throw std::runtime_error("failed to initialize memory state"); throw std::runtime_error("failed to initialize memory context");
} }
const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) { if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
} }
@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
return cvec.apply(model, data, len, n_embd, il_start, il_end); return cvec.apply(model, data, len, n_embd, il_start, il_end);
} }
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) { llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
if (mstate && !mstate->apply()) { if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__); LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
ret = GGML_STATUS_FAILED; ret = GGML_STATUS_FAILED;
return nullptr; return nullptr;
} }
@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
return nullptr; return nullptr;
} }
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate); auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
if (!res) { if (!res) {
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
ret = GGML_STATUS_FAILED; ret = GGML_STATUS_FAILED;
@ -933,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
// handle any pending defrags/shifts // handle any pending defrags/shifts
kv_self_update(false); kv_self_update(false);
llama_memory_state_ptr mstate; llama_memory_context_ptr mctx;
while (true) { while (true) {
mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all); mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
if (!mstate) { if (!mctx) {
return -2; return -2;
} }
switch (mstate->get_status()) { switch (mctx->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS: case LLAMA_MEMORY_STATUS_SUCCESS:
{ {
} break; } break;
case LLAMA_MEMORY_STATUS_NO_UPDATE: case LLAMA_MEMORY_STATUS_NO_UPDATE:
{ {
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status()); LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
return -2; return -2;
} }
@ -987,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
int64_t n_outputs_prev = 0; int64_t n_outputs_prev = 0;
do { do {
const auto & ubatch = mstate->get_ubatch(); const auto & ubatch = mctx->get_ubatch();
// count the outputs in this ubatch // count the outputs in this ubatch
{ {
@ -1009,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
ggml_status status; ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status); const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
if (!res) { if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@ -1126,7 +1126,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
} }
n_outputs_prev += n_outputs; n_outputs_prev += n_outputs;
} while (mstate->next()); } while (mctx->next());
// set to total number of outputs in the batch, for use in llama_get_logits_ith // set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all; n_outputs = n_outputs_all;
@ -1292,7 +1292,7 @@ ggml_cgraph * llama_context::graph_init() {
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
} }
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) { ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
if (n_tokens % n_seqs != 0) { if (n_tokens % n_seqs != 0) {
@ -1312,7 +1312,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
auto * gf = graph_init(); auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
this->n_outputs = save_n_outputs; this->n_outputs = save_n_outputs;
@ -1333,11 +1333,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
} }
llm_graph_result_ptr llama_context::graph_build( llm_graph_result_ptr llama_context::graph_build(
ggml_context * ctx, ggml_context * ctx,
ggml_cgraph * gf, ggml_cgraph * gf,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
llm_graph_type gtype, llm_graph_type gtype,
const llama_memory_state_i * mstate) { const llama_memory_context_i * mctx) {
return model.build_graph( return model.build_graph(
{ {
/*.ctx =*/ ctx, /*.ctx =*/ ctx,
@ -1349,7 +1349,7 @@ llm_graph_result_ptr llama_context::graph_build(
/*.backend_cpu =*/ backend_cpu, /*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec, /*.cvec =*/ &cvec,
/*.loras =*/ &loras, /*.loras =*/ &loras,
/*.mstate =*/ mstate, /*.mctx =*/ mctx,
/*.cross =*/ &cross, /*.cross =*/ &cross,
/*.n_outputs =*/ n_outputs, /*.n_outputs =*/ n_outputs,
/*.cb =*/ graph_get_cb(), /*.cb =*/ graph_get_cb(),
@ -2042,8 +2042,8 @@ void llama_context::opt_epoch_iter(
uint32_t n_outputs_all = n_tokens_all; uint32_t n_outputs_all = n_tokens_all;
auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true); auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break; break;
} }
@ -2056,17 +2056,17 @@ void llama_context::opt_epoch_iter(
uint32_t pos_batch = 0; uint32_t pos_batch = 0;
do { do {
const auto & ubatch = mstate->get_ubatch(); const auto & ubatch = mctx->get_ubatch();
n_outputs = ubatch.n_tokens; n_outputs = ubatch.n_tokens;
if (!mstate->apply()) { if (!mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
break; break;
} }
auto * gf = graph_init(); auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get()); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
struct ggml_context * ctx_compute_opt; struct ggml_context * ctx_compute_opt;
{ {
@ -2101,7 +2101,7 @@ void llama_context::opt_epoch_iter(
ggml_free(ctx_compute_opt); ggml_free(ctx_compute_opt);
pos_batch += ubatch.n_tokens; pos_batch += ubatch.n_tokens;
} while (mstate->next()); } while (mctx->next());
} }
} }

View File

@ -18,7 +18,7 @@ class llama_io_read_i;
class llama_io_write_i; class llama_io_write_i;
struct llama_memory_i; struct llama_memory_i;
struct llama_memory_state_i; struct llama_memory_context_i;
struct llama_context { struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs // init scheduler and compute buffers, reserve worst-case graphs
@ -93,14 +93,14 @@ struct llama_context {
int32_t il_end); int32_t il_end);
// process a single ubatch with a specific graph type // process a single ubatch with a specific graph type
// if memory_state is provided, it will be applied first to the context's memory // if memory_context is provided, it will be applied first to the context's memory
// ret contains the status of the graph computation // ret contains the status of the graph computation
// returns nullptr only if ret != GGML_STATUS_SUCCESS // returns nullptr only if ret != GGML_STATUS_SUCCESS
llm_graph_result_ptr process_ubatch( llm_graph_result_ptr process_ubatch(
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
llm_graph_type gtype, llm_graph_type gtype,
llama_memory_state_i * mstate, llama_memory_context_i * mctx,
ggml_status & ret); ggml_status & ret);
int encode(const llama_batch & batch_inp); int encode(const llama_batch & batch_inp);
int decode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp);
@ -197,15 +197,15 @@ public:
ggml_status graph_compute(ggml_cgraph * gf, bool batched); ggml_status graph_compute(ggml_cgraph * gf, bool batched);
// reserve a graph with a dummy ubatch of the specified size // reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate); ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
private: private:
llm_graph_result_ptr graph_build( llm_graph_result_ptr graph_build(
ggml_context * ctx, ggml_context * ctx,
ggml_cgraph * gf, ggml_cgraph * gf,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
llm_graph_type gtype, llm_graph_type gtype,
const llama_memory_state_i * mstate); const llama_memory_context_i * mctx);
llm_graph_cb graph_get_cb() const; llm_graph_cb graph_get_cb() const;

View File

@ -87,7 +87,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) { void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
if (pos_bucket) { if (pos_bucket) {
kv_state->set_input_pos_bucket(pos_bucket, ubatch); mctx->set_input_pos_bucket(pos_bucket, ubatch);
} }
} }
@ -221,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch); GGML_UNUSED(ubatch);
const int64_t n_rs = mem_state->get_n_rs(); const int64_t n_rs = mctx->get_n_rs();
if (s_copy) { if (s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@ -229,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) { for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mem_state->s_copy(i); data[i] = mctx->s_copy(i);
} }
} }
} }
@ -282,17 +282,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { if (self_kq_mask) {
kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
} }
} }
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { if (self_kq_mask) {
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
} }
if (self_kq_mask_swa) { if (self_kq_mask_swa) {
kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
} }
} }
@ -334,10 +334,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) { if (self_kq_mask) {
mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
} }
const int64_t n_rs = mem_state->get_state_recr()->get_n_rs(); const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (s_copy) { if (s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
@ -345,7 +345,7 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_rs; ++i) { for (uint32_t i = 0; i < n_rs; ++i) {
data[i] = mem_state->get_state_recr()->s_copy(i); data[i] = mctx->get_recr()->s_copy(i);
} }
} }
} }
@ -389,7 +389,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
backend_cpu (params.backend_cpu), backend_cpu (params.backend_cpu),
cvec (params.cvec), cvec (params.cvec),
loras (params.loras), loras (params.loras),
mstate (params.mstate), mctx (params.mctx),
cross (params.cross), cross (params.cross),
cb_func (params.cb), cb_func (params.cb),
res (std::make_unique<llm_graph_result>()) { res (std::make_unique<llm_graph_result>()) {
@ -950,11 +950,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
} }
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate); const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_state); auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
const auto n_kv = kv_state->get_n_kv(); const auto n_kv = mctx_cur->get_n_kv();
auto & cur = inp->pos_bucket; auto & cur = inp->pos_bucket;
@ -982,14 +982,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
} }
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate); const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state); auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
{ {
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv(); const auto n_kv = inp->mctx->get_attn()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1); //cb(inp->self_kq_mask, "KQ_mask", -1);
@ -999,7 +999,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
} }
{ {
const auto n_rs = mem_state->get_state_recr()->get_n_rs(); const auto n_rs = mctx_cur->get_recr()->get_n_rs();
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
ggml_set_input(inp->s_copy); ggml_set_input(inp->s_copy);
@ -1183,14 +1183,14 @@ ggml_tensor * llm_graph_context::build_attn(
} }
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate); const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state); auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
{ {
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
const auto n_kv = kv_state->get_n_kv(); const auto n_kv = mctx_cur->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1); //cb(inp->self_kq_mask, "KQ_mask", -1);
@ -1220,19 +1220,19 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur); ggml_build_forward_expand(gf, v_cur);
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate); const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
// store to KV cache // store to KV cache
{ {
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
} }
const auto & kq_mask = inp->get_kq_mask(); const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur; ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il); ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
@ -1270,23 +1270,23 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur); ggml_build_forward_expand(gf, v_cur);
const auto * kv_state_iswa = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate); const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
const bool is_swa = hparams.is_swa(il); const bool is_swa = hparams.is_swa(il);
const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base(); const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
// store to KV cache // store to KV cache
{ {
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
} }
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
ggml_tensor * q = q_cur; ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il); ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
@ -1379,19 +1379,19 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur); ggml_build_forward_expand(gf, v_cur);
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn(); const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
// store to KV cache // store to KV cache
{ {
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
} }
const auto & kq_mask = inp->get_kq_mask(); const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur; ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il); ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
@ -1412,12 +1412,12 @@ ggml_tensor * llm_graph_context::build_attn(
} }
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate); const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state); auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
{ {
const auto n_kv = kv_state->get_base()->get_n_kv(); const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1); //cb(inp->self_kq_mask, "KQ_mask", -1);
@ -1429,7 +1429,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
{ {
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
const auto n_kv = kv_state->get_swa()->get_n_kv(); const auto n_kv = mctx_cur->get_swa()->get_n_kv();
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
@ -1485,11 +1485,11 @@ ggml_tensor * llm_graph_context::build_rs(
} }
llm_graph_input_rs * llm_graph_context::build_rs_inp() const { llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate); const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
auto inp = std::make_unique<llm_graph_input_rs>(kv_state); auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
const auto n_rs = kv_state->get_n_rs(); const auto n_rs = mctx_cur->get_n_rs();
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
ggml_set_input(inp->s_copy); ggml_set_input(inp->s_copy);
@ -1504,9 +1504,9 @@ ggml_tensor * llm_graph_context::build_rs(
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies) const { bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate); const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies); return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
} }
ggml_tensor * llm_graph_context::build_rs( ggml_tensor * llm_graph_context::build_rs(
@ -1516,9 +1516,9 @@ ggml_tensor * llm_graph_context::build_rs(
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies) const { bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr(); const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies); return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
} }
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@ -1526,13 +1526,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_cgraph * gf, ggml_cgraph * gf,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const { int il) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate); const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto token_shift_count = hparams.token_shift_count; const auto token_shift_count = hparams.token_shift_count;
const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seqs = ubatch.n_seqs;
ggml_tensor * token_shift_all = kv_state->get_r_l(il); ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
ggml_tensor * token_shift = build_rs( ggml_tensor * token_shift = build_rs(
inp, gf, token_shift_all, inp, gf, token_shift_all,
@ -1547,19 +1547,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
ggml_tensor * token_shift, ggml_tensor * token_shift,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const { int il) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate); const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto token_shift_count = hparams.token_shift_count; const auto token_shift_count = hparams.token_shift_count;
const auto n_embd = hparams.n_embd; const auto n_embd = hparams.n_embd;
const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seqs = ubatch.n_seqs;
const auto kv_head = kv_state->get_head(); const auto kv_head = mctx_cur->get_head();
return ggml_cpy( return ggml_cpy(
ctx0, ctx0,
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il))) ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
); );
} }

View File

@ -17,12 +17,12 @@ struct ggml_tensor;
struct llama_ubatch; struct llama_ubatch;
struct llama_cparams; struct llama_cparams;
struct llama_memory_state_i; struct llama_memory_context_i;
class llama_kv_cache_unified_state; class llama_kv_cache_unified_context;
class llama_kv_cache_unified_iswa_state; class llama_kv_cache_unified_iswa_context;
class llama_memory_recurrent_state; class llama_memory_recurrent_context;
class llama_memory_hybrid_state; class llama_memory_hybrid_context;
// certain models (typically multi-modal) can produce different types of graphs // certain models (typically multi-modal) can produce different types of graphs
enum llm_graph_type { enum llm_graph_type {
@ -136,7 +136,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
public: public:
llm_graph_input_pos_bucket_kv( llm_graph_input_pos_bucket_kv(
const llama_hparams & hparams, const llama_hparams & hparams,
const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {} const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
virtual ~llm_graph_input_pos_bucket_kv() = default; virtual ~llm_graph_input_pos_bucket_kv() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
@ -144,7 +144,8 @@ public:
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch] ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_kv_cache_unified_state * kv_state;
const llama_kv_cache_unified_context * mctx;
}; };
class llm_graph_input_out_ids : public llm_graph_input_i { class llm_graph_input_out_ids : public llm_graph_input_i {
@ -191,14 +192,14 @@ public:
class llm_graph_input_rs : public llm_graph_input_i { class llm_graph_input_rs : public llm_graph_input_i {
public: public:
llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {} llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
virtual ~llm_graph_input_rs() = default; virtual ~llm_graph_input_rs() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size] ggml_tensor * s_copy; // I32 [kv_size]
const llama_memory_recurrent_state * mem_state; const llama_memory_recurrent_context * mctx;
}; };
class llm_graph_input_cross_embd : public llm_graph_input_i { class llm_graph_input_cross_embd : public llm_graph_input_i {
@ -238,10 +239,10 @@ public:
llm_graph_input_attn_kv_unified( llm_graph_input_attn_kv_unified(
const llama_hparams & hparams, const llama_hparams & hparams,
const llama_cparams & cparams, const llama_cparams & cparams,
const llama_kv_cache_unified_state * kv_state) : const llama_kv_cache_unified_context * mctx) :
hparams(hparams), hparams(hparams),
cparams(cparams), cparams(cparams),
kv_state(kv_state) { mctx(mctx) {
} }
~llm_graph_input_attn_kv_unified() = default; ~llm_graph_input_attn_kv_unified() = default;
@ -255,7 +256,7 @@ public:
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
const llama_kv_cache_unified_state * kv_state; const llama_kv_cache_unified_context * mctx;
}; };
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
@ -263,10 +264,10 @@ public:
llm_graph_input_attn_kv_unified_iswa( llm_graph_input_attn_kv_unified_iswa(
const llama_hparams & hparams, const llama_hparams & hparams,
const llama_cparams & cparams, const llama_cparams & cparams,
const llama_kv_cache_unified_iswa_state * kv_state) : const llama_kv_cache_unified_iswa_context * mctx) :
hparams(hparams), hparams(hparams),
cparams(cparams), cparams(cparams),
kv_state(kv_state) { mctx(mctx) {
} }
~llm_graph_input_attn_kv_unified_iswa() = default; ~llm_graph_input_attn_kv_unified_iswa() = default;
@ -283,7 +284,7 @@ public:
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
const llama_kv_cache_unified_iswa_state * kv_state; const llama_kv_cache_unified_iswa_context * mctx;
}; };
class llm_graph_input_attn_cross : public llm_graph_input_i { class llm_graph_input_attn_cross : public llm_graph_input_i {
@ -306,10 +307,10 @@ public:
llm_graph_input_mem_hybrid( llm_graph_input_mem_hybrid(
const llama_hparams & hparams, const llama_hparams & hparams,
const llama_cparams & cparams, const llama_cparams & cparams,
const llama_memory_hybrid_state * mem_state) : const llama_memory_hybrid_context * mctx) :
hparams(hparams), hparams(hparams),
cparams(cparams), cparams(cparams),
mem_state(mem_state) { mctx(mctx) {
} }
virtual ~llm_graph_input_mem_hybrid() = default; virtual ~llm_graph_input_mem_hybrid() = default;
@ -325,7 +326,7 @@ public:
const llama_hparams & hparams; const llama_hparams & hparams;
const llama_cparams & cparams; const llama_cparams & cparams;
const llama_memory_hybrid_state * mem_state; const llama_memory_hybrid_context * mctx;
}; };
// //
@ -401,10 +402,10 @@ struct llm_graph_params {
ggml_backend_sched_t sched; ggml_backend_sched_t sched;
ggml_backend_t backend_cpu; ggml_backend_t backend_cpu;
const llama_adapter_cvec * cvec; const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras; const llama_adapter_loras * loras;
const llama_memory_state_i * mstate; const llama_memory_context_i * mctx;
const llama_cross * cross; const llama_cross * cross;
uint32_t n_outputs; uint32_t n_outputs;
@ -453,10 +454,10 @@ struct llm_graph_context {
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
const llama_adapter_cvec * cvec; const llama_adapter_cvec * cvec;
const llama_adapter_loras * loras; const llama_adapter_loras * loras;
const llama_memory_state_i * mstate; const llama_memory_context_i * mctx;
const llama_cross * cross; const llama_cross * cross;
const llm_graph_cb & cb_func; const llm_graph_cb & cb_func;

View File

@ -95,7 +95,7 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
return kv_swa->seq_pos_max(seq_id); return kv_swa->seq_pos_max(seq_id);
} }
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
GGML_UNUSED(embd_all); GGML_UNUSED(embd_all);
// first try simple split // first try simple split
@ -125,7 +125,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
assert(heads_base.size() == heads_swa.size()); assert(heads_base.size() == heads_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_state>( return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
} while (false); } while (false);
@ -156,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
assert(heads_base.size() == heads_swa.size()); assert(heads_base.size() == heads_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_state>( return std::make_unique<llama_kv_cache_unified_iswa_context>(
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches)); this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
} while (false); } while (false);
// TODO: if we fail again, we should attempt different splitting strategies // TODO: if we fail again, we should attempt different splitting strategies
// but to do that properly, we first have to refactor the batches to be more flexible // but to do that properly, we first have to refactor the batches to be more flexible
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
return std::make_unique<llama_kv_cache_unified_iswa_state>(this); return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
} }
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize); return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
} }
bool llama_kv_cache_unified_iswa::get_can_shift() const { bool llama_kv_cache_unified_iswa::get_can_shift() const {
@ -197,46 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
} }
// //
// llama_kv_cache_unified_iswa_state // llama_kv_cache_unified_iswa_context
// //
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv) : llama_kv_cache_unified_iswa * kv) :
state_base(kv->get_base()->init_full()), ctx_base(kv->get_base()->init_full()),
state_swa (kv->get_swa ()->init_full()), ctx_swa (kv->get_swa ()->init_full()),
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) { status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
} }
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
llama_context * lctx, llama_context * lctx,
bool optimize) : bool optimize) :
state_base(kv->get_base()->init_update(lctx, optimize)), ctx_base(kv->get_base()->init_update(lctx, optimize)),
state_swa (kv->get_swa ()->init_update(lctx, optimize)), ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) { status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
} }
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base, std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa, std::vector<uint32_t> heads_swa,
std::vector<llama_ubatch> ubatches) : std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)), ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)), ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa), this->ubatches)), ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) { status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
} }
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
bool llama_kv_cache_unified_iswa_state::next() { bool llama_kv_cache_unified_iswa_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
state_base->next(); ctx_base->next();
state_swa ->next(); ctx_swa ->next();
if (++i_next >= ubatches.size()) { if (++i_next >= ubatches.size()) {
return false; return false;
@ -245,35 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() {
return true; return true;
} }
bool llama_kv_cache_unified_iswa_state::apply() { bool llama_kv_cache_unified_iswa_context::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
bool res = true; bool res = true;
res = res & state_base->apply(); res = res & ctx_base->apply();
res = res & state_swa ->apply(); res = res & ctx_swa ->apply();
return res; return res;
} }
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
return status; return status;
} }
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next]; return ubatches[i_next];
} }
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_state *>(state_base.get()); return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
} }
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get()); return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
} }

View File

@ -31,14 +31,14 @@ public:
// llama_memory_i // llama_memory_i
// //
llama_memory_state_ptr init_batch( llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc, llama_batch_allocr & balloc,
uint32_t n_ubatch, uint32_t n_ubatch,
bool embd_all) override; bool embd_all) override;
llama_memory_state_ptr init_full() override; llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override; bool get_can_shift() const override;
@ -72,32 +72,32 @@ private:
std::unique_ptr<llama_kv_cache_unified> kv_swa; std::unique_ptr<llama_kv_cache_unified> kv_swa;
}; };
class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
public: public:
// used for errors // used for errors
llama_kv_cache_unified_iswa_state(llama_memory_status status); llama_kv_cache_unified_iswa_context(llama_memory_status status);
// used to create a full-cache state // used to create a full-cache context
llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv); llama_kv_cache_unified_iswa * kv);
// used to create an update state // used to create an update context
llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
llama_context * lctx, llama_context * lctx,
bool optimize); bool optimize);
// used to create a state from a batch // used to create a batch processing context from a batch
llama_kv_cache_unified_iswa_state( llama_kv_cache_unified_iswa_context(
llama_kv_cache_unified_iswa * kv, llama_kv_cache_unified_iswa * kv,
std::vector<uint32_t> heads_base, std::vector<uint32_t> heads_base,
std::vector<uint32_t> heads_swa, std::vector<uint32_t> heads_swa,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_iswa_state(); virtual ~llama_kv_cache_unified_iswa_context();
// //
// llama_memory_state_i // llama_memory_context_i
// //
bool next() override; bool next() override;
@ -107,11 +107,11 @@ public:
const llama_ubatch & get_ubatch() const override; const llama_ubatch & get_ubatch() const override;
// //
// llama_kv_cache_unified_iswa_state specific API // llama_kv_cache_unified_iswa_context specific API
// //
const llama_kv_cache_unified_state * get_base() const; const llama_kv_cache_unified_context * get_base() const;
const llama_kv_cache_unified_state * get_swa() const; const llama_kv_cache_unified_context * get_swa() const;
private: private:
//llama_kv_cache_unified_iswa * kv; //llama_kv_cache_unified_iswa * kv;
@ -121,8 +121,8 @@ private:
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
const llama_memory_state_ptr state_base; const llama_memory_context_ptr ctx_base;
const llama_memory_state_ptr state_swa; const llama_memory_context_ptr ctx_swa;
const llama_memory_status status; const llama_memory_status status;
}; };

View File

@ -307,7 +307,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
return cells.seq_pos_max(seq_id); return cells.seq_pos_max(seq_id);
} }
llama_memory_state_ptr llama_kv_cache_unified::init_batch( llama_memory_context_ptr llama_kv_cache_unified::init_batch(
llama_batch_allocr & balloc, llama_batch_allocr & balloc,
uint32_t n_ubatch, uint32_t n_ubatch,
bool embd_all) { bool embd_all) {
@ -332,18 +332,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
break; break;
} }
return std::make_unique<llama_kv_cache_unified_state>( return std::make_unique<llama_kv_cache_unified_context>(
this, std::move(heads), std::move(ubatches)); this, std::move(heads), std::move(ubatches));
} while (false); } while (false);
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
llama_memory_state_ptr llama_kv_cache_unified::init_full() { llama_memory_context_ptr llama_kv_cache_unified::init_full() {
return std::make_unique<llama_kv_cache_unified_state>(this); return std::make_unique<llama_kv_cache_unified_context>(this);
} }
llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
bool do_shift = get_has_shift(); bool do_shift = get_has_shift();
defrag_info dinfo; defrag_info dinfo;
@ -373,7 +373,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
} }
} }
return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo)); return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
} }
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) { llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
@ -1710,18 +1710,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
} }
// //
// llama_kv_cache_unified_state // llama_kv_cache_unified_context
// //
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
llama_kv_cache_unified_state::llama_kv_cache_unified_state( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
n_kv = kv->get_size(); n_kv = kv->get_size();
head = 0; head = 0;
} }
llama_kv_cache_unified_state::llama_kv_cache_unified_state( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
llama_context * lctx, llama_context * lctx,
bool do_shift, bool do_shift,
@ -1731,15 +1731,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
} }
} }
llama_kv_cache_unified_state::llama_kv_cache_unified_state( llama_kv_cache_unified_context::llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
llama_kv_cache_unified::ubatch_heads heads, llama_kv_cache_unified::ubatch_heads heads,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) { std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
} }
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
bool llama_kv_cache_unified_state::next() { bool llama_kv_cache_unified_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
if (++i_next >= ubatches.size()) { if (++i_next >= ubatches.size()) {
@ -1749,7 +1749,7 @@ bool llama_kv_cache_unified_state::next() {
return true; return true;
} }
bool llama_kv_cache_unified_state::apply() { bool llama_kv_cache_unified_context::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
// no ubatches -> this is a KV cache update // no ubatches -> this is a KV cache update
@ -1767,45 +1767,45 @@ bool llama_kv_cache_unified_state::apply() {
return true; return true;
} }
llama_memory_status llama_kv_cache_unified_state::get_status() const { llama_memory_status llama_kv_cache_unified_context::get_status() const {
return status; return status;
} }
const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const { const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next]; return ubatches[i_next];
} }
uint32_t llama_kv_cache_unified_state::get_n_kv() const { uint32_t llama_kv_cache_unified_context::get_n_kv() const {
return n_kv; return n_kv;
} }
ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
return kv->get_k(ctx, il, n_kv); return kv->get_k(ctx, il, n_kv);
} }
ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
return kv->get_v(ctx, il, n_kv); return kv->get_v(ctx, il, n_kv);
} }
ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
return kv->cpy_k(ctx, k_cur, il, head); return kv->cpy_k(ctx, k_cur, il, head);
} }
ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
return kv->cpy_v(ctx, v_cur, il, head); return kv->cpy_v(ctx, v_cur, il, head);
} }
void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const { void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
kv->set_input_k_shift(dst); kv->set_input_k_shift(dst);
} }
void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
kv->set_input_kq_mask(dst, ubatch, causal_attn); kv->set_input_kq_mask(dst, ubatch, causal_attn);
} }
void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch); kv->set_input_pos_bucket(dst, ubatch);
} }

View File

@ -56,14 +56,14 @@ public:
// llama_memory_i // llama_memory_i
// //
llama_memory_state_ptr init_batch( llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc, llama_batch_allocr & balloc,
uint32_t n_ubatch, uint32_t n_ubatch,
bool embd_all) override; bool embd_all) override;
llama_memory_state_ptr init_full() override; llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override; bool get_can_shift() const override;
@ -208,36 +208,36 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count); bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
}; };
class llama_kv_cache_unified_state : public llama_memory_state_i { class llama_kv_cache_unified_context : public llama_memory_context_i {
public: public:
// some shorthands // some shorthands
using ubatch_heads = llama_kv_cache_unified::ubatch_heads; using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
using defrag_info = llama_kv_cache_unified::defrag_info; using defrag_info = llama_kv_cache_unified::defrag_info;
// used for errors // used for errors
llama_kv_cache_unified_state(llama_memory_status status); llama_kv_cache_unified_context(llama_memory_status status);
// used to create a full-cache state // used to create a full-cache context
llama_kv_cache_unified_state( llama_kv_cache_unified_context(
llama_kv_cache_unified * kv); llama_kv_cache_unified * kv);
// used to create an update state // used to create an update context
llama_kv_cache_unified_state( llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
llama_context * lctx, llama_context * lctx,
bool do_shift, bool do_shift,
defrag_info dinfo); defrag_info dinfo);
// used to create a decode state from a batch // used to create a batch procesing context from a batch
llama_kv_cache_unified_state( llama_kv_cache_unified_context(
llama_kv_cache_unified * kv, llama_kv_cache_unified * kv,
ubatch_heads heads, ubatch_heads heads,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_unified_state(); virtual ~llama_kv_cache_unified_context();
// //
// llama_memory_state_i // llama_memory_context_i
// //
bool next() override; bool next() override;
@ -247,7 +247,7 @@ public:
const llama_ubatch & get_ubatch() const override; const llama_ubatch & get_ubatch() const override;
// //
// llama_kv_cache_unified_state specific API // llama_kv_cache_unified_context specific API
// //
uint32_t get_n_kv() const; uint32_t get_n_kv() const;
@ -272,7 +272,7 @@ private:
llama_context * lctx; llama_context * lctx;
// //
// update state // update context
// //
bool do_shift = false; bool do_shift = false;
@ -280,7 +280,7 @@ private:
defrag_info dinfo; defrag_info dinfo;
// //
// batch processing state // batch processing context
// //
// the index of the next ubatch to process // the index of the next ubatch to process

View File

@ -56,7 +56,7 @@ llama_memory_hybrid::llama_memory_hybrid(
n_seq_max n_seq_max
)) {} )) {}
llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
do { do {
balloc.split_reset(); balloc.split_reset();
@ -82,31 +82,31 @@ llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ball
// prepare the recurrent batches first // prepare the recurrent batches first
if (!mem_recr->prepare(ubatches)) { if (!mem_recr->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined state at this point? // TODO: will the recurrent cache be in an undefined context at this point?
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
// prepare the attention cache // prepare the attention cache
auto heads_attn = mem_attn->prepare(ubatches); auto heads_attn = mem_attn->prepare(ubatches);
if (heads_attn.empty()) { if (heads_attn.empty()) {
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
return std::make_unique<llama_memory_hybrid_state>( return std::make_unique<llama_memory_hybrid_context>(
this, std::move(heads_attn), std::move(ubatches)); this, std::move(heads_attn), std::move(ubatches));
} while(false); } while(false);
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
llama_memory_state_ptr llama_memory_hybrid::init_full() { llama_memory_context_ptr llama_memory_hybrid::init_full() {
return std::make_unique<llama_memory_hybrid_state>(this); return std::make_unique<llama_memory_hybrid_context>(this);
} }
llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) { llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize); return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
} }
bool llama_memory_hybrid::get_can_shift() const { bool llama_memory_hybrid::get_can_shift() const {
@ -176,39 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
return mem_recr.get(); return mem_recr.get();
} }
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {} llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) : llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
state_attn(mem->get_mem_attn()->init_full()), ctx_attn(mem->get_mem_attn()->init_full()),
state_recr(mem->get_mem_recr()->init_full()), ctx_recr(mem->get_mem_recr()->init_full()),
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
} }
llama_memory_hybrid_state::llama_memory_hybrid_state( llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
llama_context * lctx, llama_context * lctx,
bool optimize) : bool optimize) :
state_attn(mem->get_mem_attn()->init_update(lctx, optimize)), ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
state_recr(mem->get_mem_recr()->init_update(lctx, optimize)), ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
} }
llama_memory_hybrid_state::llama_memory_hybrid_state( llama_memory_hybrid_context::llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn, std::vector<uint32_t> heads_attn,
std::vector<llama_ubatch> ubatches) : std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)), ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal // note: here we copy the ubatches. not sure if this is ideal
state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)), ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), this->ubatches)), ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) { status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
} }
bool llama_memory_hybrid_state::next() { bool llama_memory_hybrid_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
state_attn->next(); ctx_attn->next();
state_recr->next(); ctx_recr->next();
if (++i_next >= ubatches.size()) { if (++i_next >= ubatches.size()) {
return false; return false;
@ -217,30 +217,30 @@ bool llama_memory_hybrid_state::next() {
return true; return true;
} }
bool llama_memory_hybrid_state::apply() { bool llama_memory_hybrid_context::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
bool res = true; bool res = true;
res = res & state_attn->apply(); res = res & ctx_attn->apply();
res = res & state_recr->apply(); res = res & ctx_recr->apply();
return res; return res;
} }
llama_memory_status llama_memory_hybrid_state::get_status() const { llama_memory_status llama_memory_hybrid_context::get_status() const {
return status; return status;
} }
const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const { const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next]; return ubatches[i_next];
} }
const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const { const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
return static_cast<const llama_kv_cache_unified_state *>(state_attn.get()); return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
} }
const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const { const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
return static_cast<const llama_memory_recurrent_state *>(state_recr.get()); return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
} }

View File

@ -49,14 +49,14 @@ public:
// llama_memory_i // llama_memory_i
// //
llama_memory_state_ptr init_batch( llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc, llama_batch_allocr & balloc,
uint32_t n_ubatch, uint32_t n_ubatch,
bool embd_all) override; bool embd_all) override;
llama_memory_state_ptr init_full() override; llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override; bool get_can_shift() const override;
@ -90,27 +90,27 @@ private:
const std::unique_ptr<llama_memory_recurrent> mem_recr; const std::unique_ptr<llama_memory_recurrent> mem_recr;
}; };
class llama_memory_hybrid_state : public llama_memory_state_i { class llama_memory_hybrid_context : public llama_memory_context_i {
public: public:
// init failure // init failure
explicit llama_memory_hybrid_state(llama_memory_status status); explicit llama_memory_hybrid_context(llama_memory_status status);
// init full // init full
explicit llama_memory_hybrid_state(llama_memory_hybrid * mem); explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
// init update // init update
explicit llama_memory_hybrid_state( explicit llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
llama_context * lctx, llama_context * lctx,
bool optimize); bool optimize);
// init success // init success
llama_memory_hybrid_state( llama_memory_hybrid_context(
llama_memory_hybrid * mem, llama_memory_hybrid * mem,
std::vector<uint32_t> heads_attn, std::vector<uint32_t> heads_attn,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
~llama_memory_hybrid_state() = default; ~llama_memory_hybrid_context() = default;
bool next() override; bool next() override;
bool apply() override; bool apply() override;
@ -119,11 +119,11 @@ public:
const llama_ubatch & get_ubatch() const override; const llama_ubatch & get_ubatch() const override;
// //
// llama_memory_hybrid_state // llama_memory_hybrid_context
// //
const llama_kv_cache_unified_state * get_state_attn() const; const llama_kv_cache_unified_context * get_attn() const;
const llama_memory_recurrent_state * get_state_recr() const; const llama_memory_recurrent_context * get_recr() const;
private: private:
// the index of the next ubatch to process // the index of the next ubatch to process
@ -131,8 +131,8 @@ private:
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
const llama_memory_state_ptr state_attn; const llama_memory_context_ptr ctx_attn;
const llama_memory_state_ptr state_recr; const llama_memory_context_ptr ctx_recr;
const llama_memory_status status; const llama_memory_status status;
}; };

View File

@ -362,7 +362,7 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result; return result;
} }
llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
while (true) { while (true) {
@ -383,21 +383,21 @@ llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & b
} }
if (!prepare(ubatches)) { if (!prepare(ubatches)) {
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
return std::make_unique<llama_memory_recurrent_state>(this, std::move(ubatches)); return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
} }
llama_memory_state_ptr llama_memory_recurrent::init_full() { llama_memory_context_ptr llama_memory_recurrent::init_full() {
return std::make_unique<llama_memory_recurrent_state>(this); return std::make_unique<llama_memory_recurrent_context>(this);
} }
llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) { llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
GGML_UNUSED(lctx); GGML_UNUSED(lctx);
GGML_UNUSED(optimize); GGML_UNUSED(optimize);
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE); return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
} }
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) { bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
@ -1040,22 +1040,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
} }
// //
// llama_memory_recurrent_state // llama_memory_recurrent_context
// //
llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {} llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
llama_memory_recurrent_state::llama_memory_recurrent_state( llama_memory_recurrent_context::llama_memory_recurrent_context(
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) { llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
} }
llama_memory_recurrent_state::llama_memory_recurrent_state( llama_memory_recurrent_context::llama_memory_recurrent_context(
llama_memory_recurrent * mem, llama_memory_recurrent * mem,
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {} std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
llama_memory_recurrent_state::~llama_memory_recurrent_state() = default; llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
bool llama_memory_recurrent_state::next() { bool llama_memory_recurrent_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
if (++i_next >= ubatches.size()) { if (++i_next >= ubatches.size()) {
@ -1065,7 +1065,7 @@ bool llama_memory_recurrent_state::next() {
return true; return true;
} }
bool llama_memory_recurrent_state::apply() { bool llama_memory_recurrent_context::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
mem->find_slot(ubatches[i_next]); mem->find_slot(ubatches[i_next]);
@ -1073,40 +1073,40 @@ bool llama_memory_recurrent_state::apply() {
return true; return true;
} }
llama_memory_status llama_memory_recurrent_state::get_status() const { llama_memory_status llama_memory_recurrent_context::get_status() const {
return status; return status;
} }
const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const { const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS); assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next]; return ubatches[i_next];
} }
uint32_t llama_memory_recurrent_state::get_n_rs() const { uint32_t llama_memory_recurrent_context::get_n_rs() const {
return is_full ? mem->size : mem->n; return is_full ? mem->size : mem->n;
} }
uint32_t llama_memory_recurrent_state::get_head() const { uint32_t llama_memory_recurrent_context::get_head() const {
return is_full ? 0 : mem->head; return is_full ? 0 : mem->head;
} }
int32_t llama_memory_recurrent_state::get_rs_z() const { int32_t llama_memory_recurrent_context::get_rs_z() const {
return is_full ? 0 : mem->rs_z; return is_full ? 0 : mem->rs_z;
} }
uint32_t llama_memory_recurrent_state::get_size() const { uint32_t llama_memory_recurrent_context::get_size() const {
return mem->size; return mem->size;
} }
ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const { ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
return mem->r_l[il]; return mem->r_l[il];
} }
ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const { ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
return mem->s_l[il]; return mem->s_l[il];
} }
int32_t llama_memory_recurrent_state::s_copy(int i) const { int32_t llama_memory_recurrent_context::s_copy(int i) const {
return mem->cells[i + mem->head].src0; return mem->cells[i + mem->head].src0;
} }

View File

@ -11,8 +11,8 @@
// llama_memory_recurrent // llama_memory_recurrent
// //
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it // see the implementation of llama_kv_cache_unified_context_i for an example how to do it
class llama_memory_recurrent : public llama_memory_i { class llama_memory_recurrent : public llama_memory_i {
public: public:
@ -34,14 +34,14 @@ public:
// llama_memory_i // llama_memory_i
// //
llama_memory_state_ptr init_batch( llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc, llama_batch_allocr & balloc,
uint32_t n_ubatch, uint32_t n_ubatch,
bool embd_all) override; bool embd_all) override;
llama_memory_state_ptr init_full() override; llama_memory_context_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
void clear(bool data) override; void clear(bool data) override;
@ -125,24 +125,24 @@ private:
bool state_read_data(llama_io_read_i & io, uint32_t cell_count); bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
}; };
class llama_memory_recurrent_state : public llama_memory_state_i { class llama_memory_recurrent_context : public llama_memory_context_i {
public: public:
// used for errors // used for errors
llama_memory_recurrent_state(llama_memory_status status); llama_memory_recurrent_context(llama_memory_status status);
// used to create a full-cache state // used to create a full-cache or update context
llama_memory_recurrent_state( llama_memory_recurrent_context(
llama_memory_recurrent * mem); llama_memory_recurrent * mem);
// used to create a state from a batch // used to create a batch processing context from a batch
llama_memory_recurrent_state( llama_memory_recurrent_context(
llama_memory_recurrent * mem, llama_memory_recurrent * mem,
std::vector<llama_ubatch> ubatches); std::vector<llama_ubatch> ubatches);
virtual ~llama_memory_recurrent_state(); virtual ~llama_memory_recurrent_context();
// //
// llama_memory_state_i // llama_memory_context_i
// //
bool next() override; bool next() override;
@ -152,7 +152,7 @@ public:
const llama_ubatch & get_ubatch() const override; const llama_ubatch & get_ubatch() const override;
// //
// llama_memory_recurrent_state specific API // llama_memory_recurrent_context specific API
// //
uint32_t get_n_rs() const; uint32_t get_n_rs() const;

View File

@ -3,7 +3,6 @@
#include "llama.h" #include "llama.h"
#include <memory> #include <memory>
#include <vector>
struct llama_ubatch; struct llama_ubatch;
@ -28,23 +27,21 @@ enum llama_memory_status {
LLAMA_MEMORY_STATUS_FAILED_COMPUTE, LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
}; };
// helper function for combining the status of two memory states // helper function for combining the status of two memory contexts
// useful for implementing hybrid memory types (e.g. iSWA) // useful for implementing hybrid memory types (e.g. iSWA)
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
// the interface for managing the memory state during batch processing // the interface for managing the memory context during batch processing
// this interface is implemented per memory type. see: // this interface is implemented per memory type. see:
// - llama_kv_cache_unified_state // - llama_kv_cache_unified_context
// - llama_kv_cache_unified_iswa_state // - llama_kv_cache_unified_iswa_context
// ... // ...
// //
// the only method that can mutate the memory and the memory state is llama_memory_i::apply() // the only method that should mutate the memory and the memory context is llama_memory_i::apply()
// struct llama_memory_context_i {
// TODO: rename to llama_memory_context_i ? virtual ~llama_memory_context_i() = default;
struct llama_memory_state_i {
virtual ~llama_memory_state_i() = default;
// consume the current ubatch from the state and proceed to the next one // consume the current ubatch from the context and proceed to the next one
// return false if we are done // return false if we are done
virtual bool next() = 0; virtual bool next() = 0;
@ -55,11 +52,11 @@ struct llama_memory_state_i {
// get the current ubatch // get the current ubatch
virtual const llama_ubatch & get_ubatch() const = 0; virtual const llama_ubatch & get_ubatch() const = 0;
// get the status of the memory state - used for error handling and checking if any updates would be applied // get the status of the memory context - used for error handling and checking if any updates would be applied
virtual llama_memory_status get_status() const = 0; virtual llama_memory_status get_status() const = 0;
}; };
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>; using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
// general concept of LLM memory // general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types // the KV cache is a type of LLM memory, but there can be other types
@ -67,19 +64,19 @@ struct llama_memory_i {
virtual ~llama_memory_i() = default; virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache // split the input batch into a set of ubatches and verify that they can fit into the cache
// return a state object containing the ubatches and KV cache state required to process them // return a context object containing the ubatches and memory state required to process them
// check the llama_memory_state_i::get_status() for the result // check the llama_memory_context_i::get_status() for the result
virtual llama_memory_state_ptr init_batch( virtual llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc, llama_batch_allocr & balloc,
uint32_t n_ubatch, uint32_t n_ubatch,
bool embd_all) = 0; bool embd_all) = 0;
// simulate full cache, used for allocating worst-case compute buffers // simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0; virtual llama_memory_context_ptr init_full() = 0;
// prepare for any pending memory updates, such as shifts, defrags, etc. // prepare for any pending memory updates, such as shifts, defrags, etc.
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0; virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
// getters // getters
virtual bool get_can_shift() const = 0; virtual bool get_can_shift() const = 0;

View File

@ -9171,9 +9171,9 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * cur, ggml_tensor * cur,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const { int il) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate); const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto kv_head = kv_state->get_head(); const auto kv_head = mctx_cur->get_head();
const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_inner = hparams.ssm_d_inner;
@ -9191,8 +9191,8 @@ struct llm_build_mamba : public llm_graph_context {
GGML_ASSERT(ubatch.equal_seqs); GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
ggml_tensor * conv_states_all = kv_state->get_r_l(il); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
ggml_tensor * ssm_states_all = kv_state->get_s_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
// (ab)using the KV cache to store the states // (ab)using the KV cache to store the states
ggml_tensor * conv = build_rs( ggml_tensor * conv = build_rs(
@ -11916,7 +11916,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
ggml_tensor * x_prev, ggml_tensor * x_prev,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const { int il) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate); const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto n_tokens = ubatch.n_tokens; const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs; const auto n_seqs = ubatch.n_seqs;
@ -11926,7 +11926,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
const auto n_head = n_embd / head_size; const auto n_head = n_embd / head_size;
const auto n_head_kv = hparams.n_head_kv(il); const auto n_head_kv = hparams.n_head_kv(il);
const auto kv_head = kv_state->get_head(); const auto kv_head = mctx_cur->get_head();
const auto & layer = model.layers[il]; const auto & layer = model.layers[il];
@ -12038,7 +12038,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
} }
ggml_tensor * wkv_state = build_rs( ggml_tensor * wkv_state = build_rs(
inp, gf, kv_state->get_s_l(il), inp, gf, mctx_cur->get_s_l(il),
hparams.n_embd_s(), n_seqs); hparams.n_embd_s(), n_seqs);
ggml_tensor * wkv_output; ggml_tensor * wkv_output;
@ -12057,9 +12057,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
wkv_state, wkv_state,
ggml_view_1d( ggml_view_1d(
ctx0, ctx0,
kv_state->get_s_l(il), mctx_cur->get_s_l(il),
hparams.n_embd_s() * n_seqs, hparams.n_embd_s() * n_seqs,
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il)) hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
) )
) )
); );
@ -12313,7 +12313,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
ggml_tensor *& first_layer_value, ggml_tensor *& first_layer_value,
const llama_ubatch & ubatch, const llama_ubatch & ubatch,
int il) const { int il) const {
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate); const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
const auto n_tokens = ubatch.n_tokens; const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs; const auto n_seqs = ubatch.n_seqs;
@ -12322,7 +12322,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
const auto head_count = n_embd / head_size; const auto head_count = n_embd / head_size;
const auto n_seq_tokens = ubatch.n_seq_tokens; const auto n_seq_tokens = ubatch.n_seq_tokens;
const auto kv_head = kv_state->get_head(); const auto kv_head = mctx_cur->get_head();
const auto & layer = model.layers[il]; const auto & layer = model.layers[il];
@ -12393,7 +12393,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
ggml_tensor * wkv_state = build_rs( ggml_tensor * wkv_state = build_rs(
inp, gf, kv_state->get_s_l(il), inp, gf, mctx_cur->get_s_l(il),
hparams.n_embd_s(), n_seqs); hparams.n_embd_s(), n_seqs);
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@ -12407,9 +12407,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
wkv_state, wkv_state,
ggml_view_1d( ggml_view_1d(
ctx0, ctx0,
kv_state->get_s_l(il), mctx_cur->get_s_l(il),
hparams.n_embd_s() * n_seqs, hparams.n_embd_s() * n_seqs,
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il)) hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
) )
) )
); );