diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index 49a7c35ab..a2afda764 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -244,9 +244,9 @@ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const { } const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const { - return state_attn.get(); + return static_cast(state_attn.get()); } const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const { - return state_recurrent.get(); + return static_cast(state_recurrent.get()); } diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index d6678eb21..93bf72ec3 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -145,6 +145,6 @@ private: std::vector heads_attn; std::vector ubatches; - const llama_kv_cache_unified_state_ptr state_attn; - const llama_kv_cache_recurrent_state_ptr state_recurrent; + const llama_memory_state_ptr state_attn; + const llama_memory_state_ptr state_recurrent; };