diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index a2afda764..9ec205c9d 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -100,7 +100,6 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() { llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) { return std::make_unique( - this, static_cast( kv_attn ->init_update(lctx, optimize).release()), static_cast(kv_recurrent->init_update(lctx, optimize).release())); } @@ -179,16 +178,13 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(lla llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), - kv(kv), state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {} llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( - llama_kv_cache_hybrid_recurrent * kv, llama_kv_cache_unified_state * state_unified, llama_kv_cache_recurrent_state * state_recurrent) : status(LLAMA_MEMORY_STATUS_NO_UPDATE), - kv(kv), state_attn(state_unified), state_recurrent(state_recurrent) {} @@ -198,20 +194,19 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( std::vector heads_attn, std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), - kv(kv), sbatch(std::move(sbatch)), - heads_attn(std::move(heads_attn)), ubatches(std::move(ubatches)), - // NOTE: these child states are only used as wrapper APIs for the - // const methods, so we use the "init full" signature since the - // actual state is not used. - state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), - state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent())) {} + // note: here we copy the ubatches. not sure if this is ideal + state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)), + state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent(), {}, this->ubatches)) {} bool llama_kv_cache_hybrid_recurrent_state::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + state_attn ->next(); + state_recurrent->next(); + if (++i_next >= ubatches.size()) { return false; } @@ -222,10 +217,12 @@ bool llama_kv_cache_hybrid_recurrent_state::next() { bool llama_kv_cache_hybrid_recurrent_state::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - kv->get_kv_attn() ->apply_ubatch(heads_attn[i_next], ubatches[i_next]); - kv->get_kv_recurrent()->find_slot(ubatches[i_next]); + bool res = true; - return true; + res = res & state_attn ->apply(); + res = res & state_recurrent->apply(); + + return res; } std::vector & llama_kv_cache_hybrid_recurrent_state::out_ids() { diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index 93bf72ec3..17a72c613 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -104,7 +104,6 @@ public: // init update explicit llama_kv_cache_hybrid_recurrent_state( - llama_kv_cache_hybrid_recurrent * kv, llama_kv_cache_unified_state * state_unified, llama_kv_cache_recurrent_state * state_recurrent); @@ -135,14 +134,11 @@ public: private: const llama_memory_status status; - llama_kv_cache_hybrid_recurrent * kv; - llama_sbatch sbatch; // the index of the next ubatch to process size_t i_next = 0; - std::vector heads_attn; std::vector ubatches; const llama_memory_state_ptr state_attn;