diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index a6468482d..ea228a834 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -49,6 +49,59 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( n_seq_max )) {} +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { + + // since this includes a recurrent cache, we cannot use split_simple + auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); + + // follow the recurrent pattern for creating the ubatch splits + std::vector ubatches; + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch; + + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = sbatch.split_seq(n_ubatch); + } else { + ubatch = sbatch.split_equal(n_ubatch); + } + + ubatches.push_back(ubatch); + } + + // prepare the recurrent batches first + if (!kv_recurrent->prepare(ubatches)) { + // TODO: will the recurrent cache be in an undefined state at this point? + LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + // prepare the attention cache + auto heads_attn = kv_attn->prepare(ubatches); + if (heads_attn.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, std::move(sbatch), std::move(heads_attn), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() { + return std::make_unique(this); +} + +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) { + return std::make_unique( + this, + static_cast( kv_attn ->init_update(lctx, optimize).release()), + static_cast(kv_recurrent->init_update(lctx, optimize).release())); +} + +bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { + // Shifting is trivially supported for recurrent + return kv_attn->get_can_shift(); +} void llama_kv_cache_hybrid_recurrent::clear() { kv_attn ->clear(); kv_recurrent->clear(); @@ -93,67 +146,6 @@ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) cons return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id)); } -llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { - - // since this includes a recurrent cache, we cannot use split_simple - auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); - - // follow the recurrent pattern for creating the ubatch splits - std::vector ubatches; - while (sbatch.n_tokens > 0) { - llama_ubatch ubatch; - - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - ubatch = sbatch.split_seq(n_ubatch); - } else { - ubatch = sbatch.split_equal(n_ubatch); - } - - ubatches.push_back(ubatch); - } - - // prepare the recurrent batches first - if (!kv_recurrent->prepare(ubatches)) { - // TODO: will the recurrent cache be in an undefined state at this point? - LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); - } - - // prepare the attention cache - auto heads_attn = kv_attn->prepare(ubatches); - if (heads_attn.empty()) { - LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); - } - - return std::make_unique( - this, std::move(sbatch), std::move(heads_attn), std::move(ubatches)); -} - -llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() { - return std::make_unique(this); -} - -bool llama_kv_cache_hybrid_recurrent::update(llama_context & lctx) { - bool res = false; - - res = res | kv_attn ->update(lctx); - res = res | kv_recurrent->update(lctx); - - return res; -} - -void llama_kv_cache_hybrid_recurrent::defrag_sched(float thold) { - kv_attn ->defrag_sched(thold); - kv_recurrent->defrag_sched(thold); -} - -bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { - // Shifting is trivially supported for recurrent - return kv_attn->get_can_shift(); -} - void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { kv_attn ->state_write(io, seq_id); kv_recurrent->state_write(io, seq_id); @@ -173,13 +165,24 @@ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() c } llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) - : status(status), state_attn(status), state_recurrent(status) {} + : status(status), + state_attn(new llama_kv_cache_unified_state(status)), + state_recurrent(new llama_kv_cache_recurrent_state(status)) {} llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), - state_attn(status, kv->get_kv_attn()), - state_recurrent(status, kv->get_kv_recurrent()) {} + state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), + state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {} + +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_hybrid_recurrent * kv, + llama_kv_cache_unified_state * state_unified, + llama_kv_cache_recurrent_state * state_recurrent) + : status(LLAMA_MEMORY_STATUS_SUCCESS), + kv(kv), + state_attn(state_unified), + state_recurrent(state_recurrent) {} llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( llama_kv_cache_hybrid_recurrent * kv, @@ -194,8 +197,8 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( // NOTE: these child states are only used as wrapper APIs for the // const methods, so we use the "init full" signature since the // actual state is not used. - state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn()), - state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent()) {} + state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), + state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent())) {} bool llama_kv_cache_hybrid_recurrent_state::next() { @@ -232,10 +235,10 @@ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const { return ubatches[i_next]; } -const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn () const { - return &state_attn; +const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const { + return state_attn.get(); } const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const { - return &state_recurrent; + return state_recurrent.get(); } diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index 692079b65..e504631e4 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -2,9 +2,10 @@ #include "llama-batch.h" #include "llama-graph.h" -#include "llama-kv-cache.h" #include "llama-kv-cache-recurrent.h" #include "llama-kv-cache-unified.h" +#include "llama-kv-cells.h" +#include "llama-memory.h" #include #include @@ -16,7 +17,7 @@ // utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to // support models where each layer may be either attention-based or recurrent -class llama_kv_cache_hybrid_recurrent : public llama_kv_cache { +class llama_kv_cache_hybrid_recurrent : public llama_memory_i { public: llama_kv_cache_hybrid_recurrent( const llama_model & model, @@ -42,6 +43,18 @@ public: // llama_memory_i // + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + void clear() override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; @@ -53,24 +66,6 @@ public: llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override; - // - // llama_kv_cache - // - - llama_memory_state_ptr init_batch( - const llama_batch & batch, - uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) override; - - llama_memory_state_ptr init_full() override; - - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; - - bool get_can_shift() const override; - // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; @@ -92,12 +87,21 @@ private: class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { public: + using llama_kv_cache_unified_state_ptr = std::unique_ptr; + using llama_kv_cache_recurrent_state_ptr = std::unique_ptr; + // init failure explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status); // init full explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv); + // init update + explicit llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_hybrid_recurrent * kv, + llama_kv_cache_unified_state * state_unified, + llama_kv_cache_recurrent_state * state_recurrent); + // init success llama_kv_cache_hybrid_recurrent_state( llama_kv_cache_hybrid_recurrent * kv, @@ -116,7 +120,7 @@ public: const llama_ubatch & get_ubatch() const override; // - // llama_kv_cache_hybrid_recurrent_state_i + // llama_kv_cache_hybrid_recurrent_state // const llama_kv_cache_unified_state * get_state_attn () const; @@ -135,6 +139,6 @@ private: std::vector heads_attn; std::vector ubatches; - const llama_kv_cache_unified_state state_attn; - const llama_kv_cache_recurrent_state state_recurrent; + const llama_kv_cache_unified_state_ptr state_attn; + const llama_kv_cache_recurrent_state_ptr state_recurrent; };