fix: Fix initialization of child states

Since initially writing this PR, the logic in the child state types changed
such that using the "init full" signature and keeping the ubatches on the
parent struct no longer worked.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-06-16 13:48:20 -06:00
parent 9db44a2a63
commit 5046d412ef
2 changed files with 11 additions and 18 deletions

View File

@@ -100,7 +100,6 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
this,
static_cast<llama_kv_cache_unified_state *>( kv_attn ->init_update(lctx, optimize).release()),
static_cast<llama_kv_cache_recurrent_state *>(kv_recurrent->init_update(lctx, optimize).release()));
}
@@ -179,16 +178,13 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(lla
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv)
: status(LLAMA_MEMORY_STATUS_SUCCESS),
kv(kv),
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())),
state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
llama_kv_cache_hybrid_recurrent * kv,
llama_kv_cache_unified_state * state_unified,
llama_kv_cache_recurrent_state * state_recurrent)
: status(LLAMA_MEMORY_STATUS_NO_UPDATE),
kv(kv),
state_attn(state_unified),
state_recurrent(state_recurrent) {}
@@ -198,20 +194,19 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
std::vector<uint32_t> heads_attn,
std::vector<llama_ubatch> ubatches)
: status(LLAMA_MEMORY_STATUS_SUCCESS),
kv(kv),
sbatch(std::move(sbatch)),
heads_attn(std::move(heads_attn)),
ubatches(std::move(ubatches)),
// NOTE: these child states are only used as wrapper APIs for the
// const methods, so we use the "init full" signature since the
// actual state is not used.
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())),
state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent())) {}
// note: here we copy the ubatches. not sure if this is ideal
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)),
state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent(), {}, this->ubatches)) {}
bool llama_kv_cache_hybrid_recurrent_state::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
state_attn ->next();
state_recurrent->next();
if (++i_next >= ubatches.size()) {
return false;
}
@@ -222,10 +217,12 @@ bool llama_kv_cache_hybrid_recurrent_state::next() {
bool llama_kv_cache_hybrid_recurrent_state::apply() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
kv->get_kv_attn() ->apply_ubatch(heads_attn[i_next], ubatches[i_next]);
kv->get_kv_recurrent()->find_slot(ubatches[i_next]);
bool res = true;
return true;
res = res & state_attn ->apply();
res = res & state_recurrent->apply();
return res;
}
std::vector<int64_t> & llama_kv_cache_hybrid_recurrent_state::out_ids() {

View File

@@ -104,7 +104,6 @@ public:
// init update
explicit llama_kv_cache_hybrid_recurrent_state(
llama_kv_cache_hybrid_recurrent * kv,
llama_kv_cache_unified_state * state_unified,
llama_kv_cache_recurrent_state * state_recurrent);
@@ -135,14 +134,11 @@ public:
private:
const llama_memory_status status;
llama_kv_cache_hybrid_recurrent * kv;
llama_sbatch sbatch;
// the index of the next ubatch to process
size_t i_next = 0;
std::vector<uint32_t> heads_attn;
std::vector<llama_ubatch> ubatches;
const llama_memory_state_ptr state_attn;