graph : fix recurrent state copies when avoiding copies

Works, but using lambda functions might not be that clean.
This commit is contained in:
Francis Couture-Harpin
2025-06-10 20:00:41 -04:00
parent 9864bfcd01
commit 2fa5f2ceb8
3 changed files with 38 additions and 28 deletions

View File

@ -1429,7 +1429,8 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
ggml_tensor * state_copy, ggml_tensor * state_copy,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies) const { const std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)> & get_state_rows) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate); const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto n_kv = kv_state->get_n_kv(); const auto n_kv = kv_state->get_n_kv();
@ -1445,17 +1446,11 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
ggml_tensor * output_states; ggml_tensor * output_states;
if (!avoid_copies) { // copy states
// copy states // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv // {state_size, kv_size} -> {state_size, n_seqs}
// {state_size, kv_size} -> {state_size, n_seqs} output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); ggml_build_forward_expand(gf, output_states);
ggml_build_forward_expand(gf, output_states);
} else {
// FIXME: make the gathering operation happen before the copy below
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
output_states = states;
}
// copy extra states which won't be changed further (between n_seqs and n_kv) // copy extra states which won't be changed further (between n_seqs and n_kv)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));

View File

@ -599,7 +599,8 @@ struct llm_graph_context {
ggml_tensor * state_copy, ggml_tensor * state_copy,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
bool avoid_copies = false) const; const std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>
& get_state_rows = ggml_get_rows) const;
ggml_tensor * build_rwkv_token_shift_load( ggml_tensor * build_rwkv_token_shift_load(
ggml_cgraph * gf, ggml_cgraph * gf,

View File

@ -9024,11 +9024,8 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * conv_states_all = kv_state->get_k_l(il); ggml_tensor * conv_states_all = kv_state->get_k_l(il);
ggml_tensor * ssm_states_all = kv_state->get_v_l(il); ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
// (ab)using the KV cache to store the states
ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs); ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs);
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true);
ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
@ -9094,11 +9091,21 @@ struct llm_build_mamba : public llm_graph_context {
cur = x; cur = x;
x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs); x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs);
ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0); ggml_tensor * A = model.layers[il].ssm_a;
// Custom operator to optimize the parallel associative scan
// as described in the Annex D of the Mamba paper. // use the states and the indices provided by build_recurrent_state
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} // (this is necessary in order to properly use the states before they are overwritten,
ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); // while avoiding to make unnecessary copies of the states)
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size());
// Custom operator to optimize the parallel associative scan
// as described in the Annex D of the Mamba paper.
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
};
ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows);
// store last states // store last states
ggml_build_forward_expand(gf, ggml_build_forward_expand(gf,
@ -9151,11 +9158,8 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * conv_states_all = kv_state->get_k_l(il); ggml_tensor * conv_states_all = kv_state->get_k_l(il);
ggml_tensor * ssm_states_all = kv_state->get_v_l(il); ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
// (ab)using the KV cache to store the states
ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs); ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs);
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true);
ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size());
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
@ -9211,10 +9215,20 @@ struct llm_build_mamba : public llm_graph_context {
// {n_head, n_seq_tokens, n_seqs} // {n_head, n_seq_tokens, n_seqs}
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0); ggml_tensor * A = model.layers[il].ssm_a;
// TODO: use semistructured matrices to implement state-space duality
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} // use the states and the indices provided by build_recurrent_state
ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); // (this is necessary in order to properly use the states before they are overwritten,
// while avoiding to make unnecessary copies of the states)
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size());
// TODO: use semistructured matrices to implement state-space duality
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
};
ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows);
// store last states // store last states
ggml_build_forward_expand(gf, ggml_build_forward_expand(gf,