recurrent : rework graph inputs + add TODOs

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-18 09:29:51 +03:00
parent faf41199c0
commit 59fee24c72
7 changed files with 227 additions and 213 deletions

View File

@ -255,11 +255,6 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
}
}
llm_graph_input_rs_hybrid_recurrent::llm_graph_input_rs_hybrid_recurrent(
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
llm_graph_input_rs(kv_state->get_state_recurrent()) {
}
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
@ -365,13 +360,6 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
}
}
llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) {
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
@ -416,6 +404,24 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
kv_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
const int64_t n_kv = kv_state->get_state_recurrent()->get_n_kv();
if (s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
int32_t * data = (int32_t *) s_copy->data;
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) {
data[i] = kv_state->get_state_recurrent()->s_copy(i);
}
}
}
//
// llm_graph_context
//
@ -1043,6 +1049,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
return pos_bias;
}
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, kv_state);
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
const auto n_kv = inp->kv_state->get_state_attn()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
{
const auto n_kv = kv_state->get_state_recurrent()->get_n_kv();
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
ggml_set_input(inp->s_copy);
}
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn_mha(
ggml_cgraph * gf,
ggml_tensor * q,
@ -1287,105 +1320,6 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(
hparams, cparams, static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
const auto n_kv = inp->kv_state->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_hybrid_recurrent * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn();
// store to KV cache
{
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
}
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}
return cur;
}
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
{
const auto n_kv = kv_state->get_base()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
{
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
const auto n_kv = kv_state->get_swa()->get_n_kv();
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
}
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
ggml_cgraph * gf,
@ -1494,20 +1428,100 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
ggml_tensor * llm_graph_context::build_recurrent_state(
const llama_kv_cache_recurrent_state * kv_state,
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn();
// store to KV cache
{
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
}
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}
return cur;
}
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
{
const auto n_kv = kv_state->get_base()->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
{
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
const auto n_kv = kv_state->get_swa()->get_n_kv();
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
ggml_set_input(inp->self_kq_mask_swa);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
}
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_rs(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
uint32_t n_kv,
uint32_t kv_head,
uint32_t kv_size,
int32_t rs_zero,
bool avoid_copies) const {
const auto n_kv = kv_state->get_n_kv();
const auto kv_head = kv_state->get_head();
const auto rs_zero = kv_state->get_rs_z();
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
// Clear a single state which will then be copied to the other cleared states.
// Note that this is a no-op when the view is zero-sized.
@ -1538,17 +1552,15 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
return output_states;
}
llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const {
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
const auto n_kv = kv_state->get_n_kv();
auto & cur = inp->s_copy;
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
ggml_set_input(cur);
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
ggml_set_input(inp->s_copy);
return (llm_graph_input_rs *) res->add_input(std::move(inp));
}
@ -1560,35 +1572,21 @@ ggml_tensor * llm_graph_context::build_rs(
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies);
}
llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const {
auto inp = std::make_unique<llm_graph_input_rs_hybrid_recurrent>(
static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
const auto n_kv = inp->kv_state->get_n_kv();
auto & cur = inp->s_copy;
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
ggml_set_input(cur);
return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp));
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_kv(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
}
ggml_tensor * llm_graph_context::build_rs(
llm_graph_input_rs_hybrid_recurrent * inp,
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent();
return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies);
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_kv(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
}
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(

View File

@ -201,12 +201,6 @@ public:
const llama_kv_cache_recurrent_state * kv_state;
};
class llm_graph_input_rs_hybrid_recurrent : public llm_graph_input_rs {
public:
llm_graph_input_rs_hybrid_recurrent(const llama_kv_cache_hybrid_recurrent_state * kv_state);
virtual ~llm_graph_input_rs_hybrid_recurrent() = default;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
public:
llm_graph_input_cross_embd(
@ -264,15 +258,6 @@ public:
const llama_kv_cache_unified_state * kv_state;
};
class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified {
public:
llm_graph_input_attn_kv_hybrid_recurrent(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_hybrid_recurrent_state * kv_state);
virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default;
};
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_unified_iswa(
@ -316,6 +301,33 @@ public:
const llama_cross * cross = nullptr;
};
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
hparams(hparams),
cparams(cparams),
kv_state(kv_state) {
}
virtual ~llm_graph_input_mem_hybrid() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size]
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
const llama_hparams & hparams;
const llama_cparams & cparams;
const llama_kv_cache_hybrid_recurrent_state * kv_state;
};
//
// llm_graph_result
//
@ -530,6 +542,8 @@ struct llm_graph_context {
ggml_tensor * build_inp_pos_bucket_dec() const;
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
//
// attention
//
@ -604,10 +618,8 @@ struct llm_graph_context {
float kq_scale,
int il) const;
llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const;
ggml_tensor * build_attn(
llm_graph_input_attn_kv_hybrid_recurrent * inp,
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
@ -622,16 +634,25 @@ struct llm_graph_context {
// recurrent
//
ggml_tensor * build_recurrent_state(
const llama_kv_cache_recurrent_state * kv_state,
// TODO: avoid notion of "kv"
// TODO: move this implementation to llama_kv_cache_recurrent.
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
// `llama_kv_cache_recurrent`
ggml_tensor * build_rs(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
uint32_t n_kv,
uint32_t kv_head,
uint32_t kv_size,
int32_t rs_zero,
bool avoid_copies = false) const;
llm_graph_input_rs * build_rs_inp_recurrent() const;
llm_graph_input_rs * build_rs_inp() const;
ggml_tensor * build_rs(
llm_graph_input_rs * inp,
@ -641,10 +662,8 @@ struct llm_graph_context {
int32_t n_seqs,
bool avoid_copies = false) const;
llm_graph_input_rs_hybrid_recurrent * build_rs_inp_hybrid_recurrent() const;
ggml_tensor * build_rs(
llm_graph_input_rs_hybrid_recurrent * inp,
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,

View File

@ -99,9 +99,7 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
}
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
static_cast<llama_kv_cache_unified_state *>( kv_attn ->init_update(lctx, optimize).release()),
static_cast<llama_kv_cache_recurrent_state *>(kv_recurrent->init_update(lctx, optimize).release()));
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this, lctx, optimize);
}
bool llama_kv_cache_hybrid_recurrent::get_can_shift() const {
@ -171,22 +169,25 @@ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() c
return kv_recurrent.get();
}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status)
: status(status),
state_attn(new llama_kv_cache_unified_state(status)),
state_recurrent(new llama_kv_cache_recurrent_state(status)) {}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) : status(status) {}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv)
: status(LLAMA_MEMORY_STATUS_SUCCESS),
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())),
state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {}
: status(LLAMA_MEMORY_STATUS_SUCCESS) {
state_attn = kv->get_kv_attn ()->init_full();
state_recurrent = kv->get_kv_recurrent()->init_full();
status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status());
}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
llama_kv_cache_unified_state * state_unified,
llama_kv_cache_recurrent_state * state_recurrent)
: status(LLAMA_MEMORY_STATUS_NO_UPDATE),
state_attn(state_unified),
state_recurrent(state_recurrent) {}
llama_kv_cache_hybrid_recurrent * kv,
llama_context * lctx,
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
state_attn = kv->get_kv_attn ()->init_update(lctx, optimize);
state_recurrent = kv->get_kv_recurrent()->init_update(lctx, optimize);
status = llama_memory_status_combine(state_attn->get_status(), state_recurrent->get_status());
}
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
llama_kv_cache_hybrid_recurrent * kv,
@ -195,11 +196,11 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
std::vector<llama_ubatch> ubatches)
: status(LLAMA_MEMORY_STATUS_SUCCESS),
sbatch(std::move(sbatch)),
ubatches(std::move(ubatches)),
ubatches(std::move(ubatches)) {
// note: here we copy the ubatches. not sure if this is ideal
state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)),
state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent(), {}, this->ubatches)) {}
state_attn .reset(new llama_kv_cache_unified_state (kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches));
state_recurrent.reset(new llama_kv_cache_recurrent_state(kv->get_kv_recurrent(), {}, this->ubatches));
}
bool llama_kv_cache_hybrid_recurrent_state::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

View File

@ -4,7 +4,6 @@
#include "llama-graph.h"
#include "llama-kv-cache-recurrent.h"
#include "llama-kv-cache-unified.h"
#include "llama-kv-cells.h"
#include "llama-memory.h"
#include <memory>
@ -12,6 +11,7 @@
//
// llama_kv_cache_hybrid_recurrent
// TODO: rename to llama_memory_hybrid
//
// utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to
@ -93,9 +93,6 @@ private:
class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i {
public:
using llama_kv_cache_unified_state_ptr = std::unique_ptr<llama_kv_cache_unified_state>;
using llama_kv_cache_recurrent_state_ptr = std::unique_ptr<llama_kv_cache_recurrent_state>;
// init failure
explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status);
@ -104,8 +101,9 @@ public:
// init update
explicit llama_kv_cache_hybrid_recurrent_state(
llama_kv_cache_unified_state * state_unified,
llama_kv_cache_recurrent_state * state_recurrent);
llama_kv_cache_hybrid_recurrent * kv,
llama_context * lctx,
bool optimize);
// init success
llama_kv_cache_hybrid_recurrent_state(
@ -132,7 +130,7 @@ public:
const llama_kv_cache_recurrent_state * get_state_recurrent() const;
private:
const llama_memory_status status;
llama_memory_status status;
llama_sbatch sbatch;
@ -141,6 +139,6 @@ private:
std::vector<llama_ubatch> ubatches;
const llama_memory_state_ptr state_attn;
const llama_memory_state_ptr state_recurrent;
llama_memory_state_ptr state_attn;
llama_memory_state_ptr state_recurrent;
};

View File

@ -384,11 +384,11 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch &
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
return std::make_unique<llama_kv_cache_recurrent_state>(this, std::move(sbatch), std::move(ubatches));
}
llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
return std::make_unique<llama_kv_cache_recurrent_state>(this);
}
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
@ -1043,15 +1043,13 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
llama_kv_cache_recurrent * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), is_full(true) {
}
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv,
llama_sbatch sbatch,
std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;

View File

@ -11,8 +11,10 @@
// llama_kv_cache_recurrent
//
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
// TODO: extract the cache state used for graph computation into llama_kv_cache_recurrent_state_i
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
// TODO: avoid the notion of "KV cache" / "KV cells", etc.
// TODO: rename to llama_recurrent_state / llama_recurrent_cache
class llama_kv_cache_recurrent : public llama_memory_i {
public:
@ -131,12 +133,10 @@ public:
// used to create a full-cache state
llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv);
// used to create a state from a batch
llama_kv_cache_recurrent_state(
llama_memory_status status,
llama_kv_cache_recurrent * kv,
llama_sbatch sbatch,
std::vector<llama_ubatch> ubatches);

View File

@ -9116,7 +9116,7 @@ struct llm_build_mamba : public llm_graph_context {
// {n_embd, n_tokens}
inpL = build_inp_embd(model.tok_embd);
auto * rs_inp = build_rs_inp_recurrent();
auto * rs_inp = build_rs_inp();
for (int il = 0; il < n_layer; ++il) {
// norm
@ -12092,7 +12092,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
inpL = build_inp_embd(model.tok_embd);
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
auto * rs_inp = build_rs_inp_recurrent();
auto * rs_inp = build_rs_inp();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
@ -12187,7 +12187,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
inpL = build_inp_embd(model.tok_embd);
auto * rs_inp = build_rs_inp_recurrent();
auto * rs_inp = build_rs_inp();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
@ -12441,7 +12441,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
inpL = build_inp_embd(model.tok_embd);
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
auto * rs_inp = build_rs_inp_recurrent();
auto * rs_inp = build_rs_inp();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
@ -12532,7 +12532,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
inpL = build_inp_embd(model.tok_embd);
auto * rs_inp = build_rs_inp_recurrent();
auto * rs_inp = build_rs_inp();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;