From 5046d412ef1745e6d7c3891152a876b34710993d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 16 Jun 2025 13:48:20 -0600 Subject: [PATCH] fix: Fix initialization of child states Since initially writing this PR, the logic in the child state types changed such that using the "init full" signature and keeping the ubatches on the parent struct no longer worked. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 25 +++++++++++-------------- src/llama-kv-cache-hybrid-recurrent.h | 4 ---- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index a2afda764..9ec205c9d 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -100,7 +100,6 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() { llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) { return std::make_unique( - this, static_cast( kv_attn ->init_update(lctx, optimize).release()), static_cast(kv_recurrent->init_update(lctx, optimize).release())); } @@ -179,16 +178,13 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(lla llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), - kv(kv), state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {} llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( - llama_kv_cache_hybrid_recurrent * kv, llama_kv_cache_unified_state * state_unified, llama_kv_cache_recurrent_state * state_recurrent) : status(LLAMA_MEMORY_STATUS_NO_UPDATE), - kv(kv), state_attn(state_unified), state_recurrent(state_recurrent) {} @@ -198,20 +194,19 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( std::vector heads_attn, std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), - kv(kv), sbatch(std::move(sbatch)), - heads_attn(std::move(heads_attn)), ubatches(std::move(ubatches)), - // NOTE: these child states are only used as wrapper APIs for the - // const methods, so we use the "init full" signature since the - // actual state is not used. - state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), - state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent())) {} + // note: here we copy the ubatches. not sure if this is ideal + state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)), + state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent(), {}, this->ubatches)) {} bool llama_kv_cache_hybrid_recurrent_state::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + state_attn ->next(); + state_recurrent->next(); + if (++i_next >= ubatches.size()) { return false; } @@ -222,10 +217,12 @@ bool llama_kv_cache_hybrid_recurrent_state::next() { bool llama_kv_cache_hybrid_recurrent_state::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - kv->get_kv_attn() ->apply_ubatch(heads_attn[i_next], ubatches[i_next]); - kv->get_kv_recurrent()->find_slot(ubatches[i_next]); + bool res = true; - return true; + res = res & state_attn ->apply(); + res = res & state_recurrent->apply(); + + return res; } std::vector & llama_kv_cache_hybrid_recurrent_state::out_ids() { diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index 93bf72ec3..17a72c613 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -104,7 +104,6 @@ public: // init update explicit llama_kv_cache_hybrid_recurrent_state( - llama_kv_cache_hybrid_recurrent * kv, llama_kv_cache_unified_state * state_unified, llama_kv_cache_recurrent_state * state_recurrent); @@ -135,14 +134,11 @@ public: private: const llama_memory_status status; - llama_kv_cache_hybrid_recurrent * kv; - llama_sbatch sbatch; // the index of the next ubatch to process size_t i_next = 0; - std::vector heads_attn; std::vector ubatches; const llama_memory_state_ptr state_attn;