refactor: Use a common build_recurrent_state method that is cache-agnostic

This reduces the code duplication between the different build_rs impls and
also retains a similar signature to the previous build_recurrent_state
method while standardizing on the input-dispatched build_rs implementation.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart
2025-06-16 15:17:28 -06:00
parent 5046d412ef
commit faf41199c0
2 changed files with 54 additions and 68 deletions

View File

@@ -1494,6 +1494,49 @@ ggml_tensor * llm_graph_context::build_attn(
return cur; return cur;
} }
ggml_tensor * llm_graph_context::build_recurrent_state(
const llama_kv_cache_recurrent_state * kv_state,
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto n_kv = kv_state->get_n_kv();
const auto kv_head = kv_state->get_head();
const auto rs_zero = kv_state->get_rs_z();
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
// Clear a single state which will then be copied to the other cleared states.
// Note that this is a no-op when the view is zero-sized.
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
ggml_tensor * output_states;
if (!avoid_copies) {
// copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// {state_size, kv_size} -> {state_size, n_seqs}
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
ggml_build_forward_expand(gf, output_states);
} else {
// FIXME: make the gathering operation happen before the copy below
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
output_states = states;
}
// copy extra states which won't be changed further (between n_seqs and n_kv)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
states_extra,
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
return output_states;
}
llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const { llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate); const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@@ -1519,40 +1562,7 @@ ggml_tensor * llm_graph_context::build_rs(
bool avoid_copies) const { bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate); const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies);
const auto n_kv = kv_state->get_n_kv();
const auto kv_head = kv_state->get_head();
const auto rs_zero = kv_state->get_rs_z();
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
// Clear a single state which will then be copied to the other cleared states.
// Note that this is a no-op when the view is zero-sized.
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
ggml_tensor * output_states;
if (!avoid_copies) {
// copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// {state_size, kv_size} -> {state_size, n_seqs}
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0));
ggml_build_forward_expand(gf, output_states);
} else {
// FIXME: make the gathering operation happen before the copy below
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
output_states = states;
}
// copy extra states which won't be changed further (between n_seqs and n_kv)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0]));
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
states_extra,
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
return output_states;
} }
llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const { llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const {
@@ -1578,40 +1588,7 @@ ggml_tensor * llm_graph_context::build_rs(
bool avoid_copies) const { bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent(); const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent();
return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies);
const auto n_kv = kv_state->get_n_kv();
const auto kv_head = kv_state->get_head();
const auto rs_zero = kv_state->get_rs_z();
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
// Clear a single state which will then be copied to the other cleared states.
// Note that this is a no-op when the view is zero-sized.
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
ggml_tensor * output_states;
if (!avoid_copies) {
// copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// {state_size, kv_size} -> {state_size, n_seqs}
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0));
ggml_build_forward_expand(gf, output_states);
} else {
// FIXME: make the gathering operation happen before the copy below
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
output_states = states;
}
// copy extra states which won't be changed further (between n_seqs and n_kv)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0]));
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
states_extra,
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
return output_states;
} }
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(

View File

@@ -622,6 +622,15 @@ struct llm_graph_context {
// recurrent // recurrent
// //
ggml_tensor * build_recurrent_state(
const llama_kv_cache_recurrent_state * kv_state,
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false) const;
llm_graph_input_rs * build_rs_inp_recurrent() const; llm_graph_input_rs * build_rs_inp_recurrent() const;
ggml_tensor * build_rs( ggml_tensor * build_rs(