mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-05 00:25:26 -04:00
fix: Fix initialization of child states
Since initially writing this PR, the logic in the child state types changed such that using the "init full" signature and keeping the ubatches on the parent struct no longer worked. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
@@ -100,7 +100,6 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
|
||||
this,
|
||||
static_cast<llama_kv_cache_unified_state *>( kv_attn ->init_update(lctx, optimize).release()),
|
||||
static_cast<llama_kv_cache_recurrent_state *>(kv_recurrent->init_update(lctx, optimize).release()));
|
||||
}
|
||||
@@ -179,16 +178,13 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(lla
|
||||
|
||||
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv)
|
||||
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
||||
kv(kv),
|
||||
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())),
|
||||
state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {}
|
||||
|
||||
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
|
||||
llama_kv_cache_hybrid_recurrent * kv,
|
||||
llama_kv_cache_unified_state * state_unified,
|
||||
llama_kv_cache_recurrent_state * state_recurrent)
|
||||
: status(LLAMA_MEMORY_STATUS_NO_UPDATE),
|
||||
kv(kv),
|
||||
state_attn(state_unified),
|
||||
state_recurrent(state_recurrent) {}
|
||||
|
||||
@@ -198,20 +194,19 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
|
||||
std::vector<uint32_t> heads_attn,
|
||||
std::vector<llama_ubatch> ubatches)
|
||||
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
||||
kv(kv),
|
||||
sbatch(std::move(sbatch)),
|
||||
heads_attn(std::move(heads_attn)),
|
||||
ubatches(std::move(ubatches)),
|
||||
// NOTE: these child states are only used as wrapper APIs for the
|
||||
// const methods, so we use the "init full" signature since the
|
||||
// actual state is not used.
|
||||
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())),
|
||||
state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent())) {}
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)),
|
||||
state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent(), {}, this->ubatches)) {}
|
||||
|
||||
|
||||
bool llama_kv_cache_hybrid_recurrent_state::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
state_attn ->next();
|
||||
state_recurrent->next();
|
||||
|
||||
if (++i_next >= ubatches.size()) {
|
||||
return false;
|
||||
}
|
||||
@@ -222,10 +217,12 @@ bool llama_kv_cache_hybrid_recurrent_state::next() {
|
||||
bool llama_kv_cache_hybrid_recurrent_state::apply() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
kv->get_kv_attn() ->apply_ubatch(heads_attn[i_next], ubatches[i_next]);
|
||||
kv->get_kv_recurrent()->find_slot(ubatches[i_next]);
|
||||
bool res = true;
|
||||
|
||||
return true;
|
||||
res = res & state_attn ->apply();
|
||||
res = res & state_recurrent->apply();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<int64_t> & llama_kv_cache_hybrid_recurrent_state::out_ids() {
|
||||
|
@@ -104,7 +104,6 @@ public:
|
||||
|
||||
// init update
|
||||
explicit llama_kv_cache_hybrid_recurrent_state(
|
||||
llama_kv_cache_hybrid_recurrent * kv,
|
||||
llama_kv_cache_unified_state * state_unified,
|
||||
llama_kv_cache_recurrent_state * state_recurrent);
|
||||
|
||||
@@ -135,14 +134,11 @@ public:
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
|
||||
llama_kv_cache_hybrid_recurrent * kv;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<uint32_t> heads_attn;
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
const llama_memory_state_ptr state_attn;
|
||||
|
Reference in New Issue
Block a user