refactor: Use llama_memory_state_ptr for child states in hybrid memory state

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-06-12 14:30:21 -06:00
parent 7ba463b38c
commit 4ec4e6a801
2 changed files with 4 additions and 4 deletions

View File

@@ -244,9 +244,9 @@ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const {
}
const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const {
return state_attn.get();
return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
}
const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const {
return state_recurrent.get();
return static_cast<const llama_kv_cache_recurrent_state *>(state_recurrent.get());
}

View File

@@ -145,6 +145,6 @@ private:
std::vector<uint32_t> heads_attn;
std::vector<llama_ubatch> ubatches;
const llama_kv_cache_unified_state_ptr state_attn;
const llama_kv_cache_recurrent_state_ptr state_recurrent;
const llama_memory_state_ptr state_attn;
const llama_memory_state_ptr state_recurrent;
};