From 59fee24c7236857e981a0579e095265357d0ee5c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 18 Jun 2025 09:29:51 +0300 Subject: [PATCH] recurrent : rework graph inputs + add TODOs ggml-ci --- src/llama-graph.cpp | 280 ++++++++++++------------ src/llama-graph.h | 67 ++++-- src/llama-kv-cache-hybrid-recurrent.cpp | 51 ++--- src/llama-kv-cache-hybrid-recurrent.h | 16 +- src/llama-kv-cache-recurrent.cpp | 10 +- src/llama-kv-cache-recurrent.h | 6 +- src/llama-model.cpp | 10 +- 7 files changed, 227 insertions(+), 213 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index beebcd755..f7cb3eb2e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -255,11 +255,6 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { } } -llm_graph_input_rs_hybrid_recurrent::llm_graph_input_rs_hybrid_recurrent( - const llama_kv_cache_hybrid_recurrent_state * kv_state) : - llm_graph_input_rs(kv_state->get_state_recurrent()) { -} - void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -365,13 +360,6 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { } } -llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent( - const llama_hparams & hparams, - const llama_cparams & cparams, - const llama_kv_cache_hybrid_recurrent_state * kv_state) : - llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) { -} - void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); @@ -416,6 +404,24 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } +void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { + if (self_kq_mask) { + kv_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } + + const int64_t n_kv = kv_state->get_state_recurrent()->get_n_kv(); + + if (s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); + int32_t * data = (int32_t *) s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_kv; ++i) { + data[i] = kv_state->get_state_recurrent()->s_copy(i); + } + } +} + // // llm_graph_context // @@ -1043,6 +1049,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t return pos_bias; } +llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { + const auto * kv_state = static_cast(mstate); + + auto inp = std::make_unique(hparams, cparams, kv_state); + + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); + + const auto n_kv = inp->kv_state->get_state_attn()->get_n_kv(); + + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + { + const auto n_kv = kv_state->get_state_recurrent()->get_n_kv(); + + inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + ggml_set_input(inp->s_copy); + } + + return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); +} + ggml_tensor * llm_graph_context::build_attn_mha( ggml_cgraph * gf, ggml_tensor * q, @@ -1287,105 +1320,6 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { - auto inp = std::make_unique( - hparams, cparams, static_cast(mstate)); - - { - GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); - - const auto n_kv = inp->kv_state->get_n_kv(); - - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask, "KQ_mask", -1); - ggml_set_input(inp->self_kq_mask); - - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; - } - - return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); -} - -ggml_tensor * llm_graph_context::build_attn( - llm_graph_input_attn_kv_hybrid_recurrent * inp, - ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - ggml_tensor * kq_b, - ggml_tensor * v_mla, - float kq_scale, - int il) const { - // these nodes are added to the graph together so that they are not reordered - // by doing so, the number of splits in the graph is reduced - ggml_build_forward_expand(gf, q_cur); - ggml_build_forward_expand(gf, k_cur); - ggml_build_forward_expand(gf, v_cur); - - const auto * kv_state = static_cast(mstate)->get_state_attn(); - - // store to KV cache - { - ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); - } - - const auto & kq_mask = inp->get_kq_mask(); - - ggml_tensor * q = q_cur; - ggml_tensor * k = kv_state->get_k(ctx0, il); - ggml_tensor * v = kv_state->get_v(ctx0, il); - - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); - cb(cur, "kqv_out", il); - - if (wo) { - cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4) { - // GLM4 seems to have numerical issues with half-precision accumulators - ggml_mul_mat_set_prec(cur, GGML_PREC_F32); - } - } - - if (wo_b) { - cur = ggml_add(ctx0, cur, wo_b); - } - - return cur; -} - -llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { - const auto * kv_state = static_cast(mstate); - - auto inp = std::make_unique(hparams, cparams, kv_state); - - { - const auto n_kv = kv_state->get_base()->get_n_kv(); - - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask, "KQ_mask", -1); - ggml_set_input(inp->self_kq_mask); - - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; - } - - { - GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); - - const auto n_kv = kv_state->get_swa()->get_n_kv(); - - inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); - ggml_set_input(inp->self_kq_mask_swa); - - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; - } - - return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp)); -} - ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_unified_iswa * inp, ggml_cgraph * gf, @@ -1494,20 +1428,100 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -ggml_tensor * llm_graph_context::build_recurrent_state( - const llama_kv_cache_recurrent_state * kv_state, + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_mem_hybrid * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const auto * kv_state = static_cast(mstate)->get_state_attn(); + + // store to KV cache + { + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + } + + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4) { + // GLM4 seems to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { + const auto * kv_state = static_cast(mstate); + + auto inp = std::make_unique(hparams, cparams, kv_state); + + { + const auto n_kv = kv_state->get_base()->get_n_kv(); + + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + { + GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); + + const auto n_kv = kv_state->get_swa()->get_n_kv(); + + inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); + ggml_set_input(inp->self_kq_mask_swa); + + inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + } + + return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_rs( ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, int32_t state_size, int32_t n_seqs, + uint32_t n_kv, + uint32_t kv_head, + uint32_t kv_size, + int32_t rs_zero, bool avoid_copies) const { - const auto n_kv = kv_state->get_n_kv(); - const auto kv_head = kv_state->get_head(); - const auto rs_zero = kv_state->get_rs_z(); - - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size()); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. @@ -1538,17 +1552,15 @@ ggml_tensor * llm_graph_context::build_recurrent_state( return output_states; } -llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const { +llm_graph_input_rs * llm_graph_context::build_rs_inp() const { const auto * kv_state = static_cast(mstate); auto inp = std::make_unique(kv_state); const auto n_kv = kv_state->get_n_kv(); - auto & cur = inp->s_copy; - - cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); - ggml_set_input(cur); + inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + ggml_set_input(inp->s_copy); return (llm_graph_input_rs *) res->add_input(std::move(inp)); } @@ -1560,35 +1572,21 @@ ggml_tensor * llm_graph_context::build_rs( int32_t state_size, int32_t n_seqs, bool avoid_copies) const { - const auto * kv_state = static_cast(mstate); - return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies); -} -llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const { - auto inp = std::make_unique( - static_cast(mstate)); - - const auto n_kv = inp->kv_state->get_n_kv(); - - auto & cur = inp->s_copy; - - cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); - ggml_set_input(cur); - - return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp)); + return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_kv(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies); } ggml_tensor * llm_graph_context::build_rs( - llm_graph_input_rs_hybrid_recurrent * inp, + llm_graph_input_mem_hybrid * inp, ggml_cgraph * gf, ggml_tensor * s, int32_t state_size, int32_t n_seqs, bool avoid_copies) const { - const auto * kv_state = static_cast(mstate)->get_state_recurrent(); - return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies); + + return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_kv(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies); } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( diff --git a/src/llama-graph.h b/src/llama-graph.h index f705ea81d..6598bd15f 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -201,12 +201,6 @@ public: const llama_kv_cache_recurrent_state * kv_state; }; -class llm_graph_input_rs_hybrid_recurrent : public llm_graph_input_rs { -public: - llm_graph_input_rs_hybrid_recurrent(const llama_kv_cache_hybrid_recurrent_state * kv_state); - virtual ~llm_graph_input_rs_hybrid_recurrent() = default; -}; - class llm_graph_input_cross_embd : public llm_graph_input_i { public: llm_graph_input_cross_embd( @@ -264,15 +258,6 @@ public: const llama_kv_cache_unified_state * kv_state; }; -class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified { -public: - llm_graph_input_attn_kv_hybrid_recurrent( - const llama_hparams & hparams, - const llama_cparams & cparams, - const llama_kv_cache_hybrid_recurrent_state * kv_state); - virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default; -}; - class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_unified_iswa( @@ -316,6 +301,33 @@ public: const llama_cross * cross = nullptr; }; +class llm_graph_input_mem_hybrid : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_hybrid_recurrent_state * kv_state) : + hparams(hparams), + cparams(cparams), + kv_state(kv_state) { + } + virtual ~llm_graph_input_mem_hybrid() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * s_copy; // I32 [kv_size] + + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + + const llama_hparams & hparams; + const llama_cparams & cparams; + + const llama_kv_cache_hybrid_recurrent_state * kv_state; +}; + // // llm_graph_result // @@ -530,6 +542,8 @@ struct llm_graph_context { ggml_tensor * build_inp_pos_bucket_dec() const; ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const; + llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; + // // attention // @@ -604,10 +618,8 @@ struct llm_graph_context { float kq_scale, int il) const; - llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const; - ggml_tensor * build_attn( - llm_graph_input_attn_kv_hybrid_recurrent * inp, + llm_graph_input_mem_hybrid * inp, ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, @@ -622,16 +634,25 @@ struct llm_graph_context { // recurrent // - ggml_tensor * build_recurrent_state( - const llama_kv_cache_recurrent_state * kv_state, + // TODO: avoid notion of "kv" + // TODO: move this implementation to llama_kv_cache_recurrent. + // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v + // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the + // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in + // `llama_kv_cache_recurrent` + ggml_tensor * build_rs( ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, int32_t state_size, int32_t n_seqs, + uint32_t n_kv, + uint32_t kv_head, + uint32_t kv_size, + int32_t rs_zero, bool avoid_copies = false) const; - llm_graph_input_rs * build_rs_inp_recurrent() const; + llm_graph_input_rs * build_rs_inp() const; ggml_tensor * build_rs( llm_graph_input_rs * inp, @@ -641,10 +662,8 @@ struct llm_graph_context { int32_t n_seqs, bool avoid_copies = false) const; - llm_graph_input_rs_hybrid_recurrent * build_rs_inp_hybrid_recurrent() const; - ggml_tensor * build_rs( - llm_graph_input_rs_hybrid_recurrent * inp, + llm_graph_input_mem_hybrid * inp, ggml_cgraph * gf, ggml_tensor * s, int32_t state_size, diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index 9ec205c9d..66ccc3200 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -99,9 +99,7 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() { } llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) { - return std::make_unique( - static_cast( kv_attn ->init_update(lctx, optimize).release()), - static_cast(kv_recurrent->init_update(lctx, optimize).release())); + return std::make_unique(this, lctx, optimize); } bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { @@ -171,35 +169,38 @@ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() c return kv_recurrent.get(); } -llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) - : status(status), - state_attn(new llama_kv_cache_unified_state(status)), - state_recurrent(new llama_kv_cache_recurrent_state(status)) {} +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) : status(status) {} llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv) - : status(LLAMA_MEMORY_STATUS_SUCCESS), - state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), - state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {} + : status(LLAMA_MEMORY_STATUS_SUCCESS) { + state_attn = kv->get_kv_attn ()->init_full(); + state_recurrent = kv->get_kv_recurrent()->init_full(); + + status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status()); +} llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( - llama_kv_cache_unified_state * state_unified, - llama_kv_cache_recurrent_state * state_recurrent) - : status(LLAMA_MEMORY_STATUS_NO_UPDATE), - state_attn(state_unified), - state_recurrent(state_recurrent) {} + llama_kv_cache_hybrid_recurrent * kv, + llama_context * lctx, + bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) { + state_attn = kv->get_kv_attn ()->init_update(lctx, optimize); + state_recurrent = kv->get_kv_recurrent()->init_update(lctx, optimize); + + status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status()); +} llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( - llama_kv_cache_hybrid_recurrent * kv, - llama_sbatch sbatch, - std::vector heads_attn, - std::vector ubatches) + llama_kv_cache_hybrid_recurrent * kv, + llama_sbatch sbatch, + std::vector heads_attn, + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), - sbatch(std::move(sbatch)), - ubatches(std::move(ubatches)), - // note: here we copy the ubatches. not sure if this is ideal - state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)), - state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent(), {}, this->ubatches)) {} - + sbatch(std::move(sbatch)), + ubatches(std::move(ubatches)) { + // note: here we copy the ubatches. not sure if this is ideal + state_attn .reset(new llama_kv_cache_unified_state (kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)); + state_recurrent.reset(new llama_kv_cache_recurrent_state(kv->get_kv_recurrent(), {}, this->ubatches)); +} bool llama_kv_cache_hybrid_recurrent_state::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index 17a72c613..780029bda 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -4,7 +4,6 @@ #include "llama-graph.h" #include "llama-kv-cache-recurrent.h" #include "llama-kv-cache-unified.h" -#include "llama-kv-cells.h" #include "llama-memory.h" #include @@ -12,6 +11,7 @@ // // llama_kv_cache_hybrid_recurrent +// TODO: rename to llama_memory_hybrid // // utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to @@ -93,9 +93,6 @@ private: class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { public: - using llama_kv_cache_unified_state_ptr = std::unique_ptr; - using llama_kv_cache_recurrent_state_ptr = std::unique_ptr; - // init failure explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status); @@ -104,8 +101,9 @@ public: // init update explicit llama_kv_cache_hybrid_recurrent_state( - llama_kv_cache_unified_state * state_unified, - llama_kv_cache_recurrent_state * state_recurrent); + llama_kv_cache_hybrid_recurrent * kv, + llama_context * lctx, + bool optimize); // init success llama_kv_cache_hybrid_recurrent_state( @@ -132,7 +130,7 @@ public: const llama_kv_cache_recurrent_state * get_state_recurrent() const; private: - const llama_memory_status status; + llama_memory_status status; llama_sbatch sbatch; @@ -141,6 +139,6 @@ private: std::vector ubatches; - const llama_memory_state_ptr state_attn; - const llama_memory_state_ptr state_recurrent; + llama_memory_state_ptr state_attn; + llama_memory_state_ptr state_recurrent; }; diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 802025e22..3d0608eee 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -384,11 +384,11 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); } - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches)); + return std::make_unique(this, std::move(sbatch), std::move(ubatches)); } llama_memory_state_ptr llama_kv_cache_recurrent::init_full() { - return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); + return std::make_unique(this); } llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) { @@ -1043,15 +1043,13 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {} llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( - llama_memory_status status, - llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) { + llama_kv_cache_recurrent * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), is_full(true) { } llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( - llama_memory_status status, llama_kv_cache_recurrent * kv, llama_sbatch sbatch, - std::vector ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {} + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {} llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default; diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h index b89590b7b..bb6879054 100644 --- a/src/llama-kv-cache-recurrent.h +++ b/src/llama-kv-cache-recurrent.h @@ -11,8 +11,10 @@ // llama_kv_cache_recurrent // -// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i +// TODO: extract the cache state used for graph computation into llama_kv_cache_recurrent_state_i // see the implementation of llama_kv_cache_unified_state_i for an example how to do it +// TODO: avoid the notion of "KV cache" / "KV cells", etc. +// TODO: rename to llama_recurrent_state / llama_recurrent_cache class llama_kv_cache_recurrent : public llama_memory_i { public: @@ -131,12 +133,10 @@ public: // used to create a full-cache state llama_kv_cache_recurrent_state( - llama_memory_status status, llama_kv_cache_recurrent * kv); // used to create a state from a batch llama_kv_cache_recurrent_state( - llama_memory_status status, llama_kv_cache_recurrent * kv, llama_sbatch sbatch, std::vector ubatches); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0be083979..b5428ae72 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9116,7 +9116,7 @@ struct llm_build_mamba : public llm_graph_context { // {n_embd, n_tokens} inpL = build_inp_embd(model.tok_embd); - auto * rs_inp = build_rs_inp_recurrent(); + auto * rs_inp = build_rs_inp(); for (int il = 0; il < n_layer; ++il) { // norm @@ -12092,7 +12092,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { inpL = build_inp_embd(model.tok_embd); inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - auto * rs_inp = build_rs_inp_recurrent(); + auto * rs_inp = build_rs_inp(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12187,7 +12187,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { inpL = build_inp_embd(model.tok_embd); - auto * rs_inp = build_rs_inp_recurrent(); + auto * rs_inp = build_rs_inp(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12441,7 +12441,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { inpL = build_inp_embd(model.tok_embd); inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - auto * rs_inp = build_rs_inp_recurrent(); + auto * rs_inp = build_rs_inp(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12532,7 +12532,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { inpL = build_inp_embd(model.tok_embd); - auto * rs_inp = build_rs_inp_recurrent(); + auto * rs_inp = build_rs_inp(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens;