From 62a9f34baefc657212dea8f1bad14d4fb1657da8 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 10 Jun 2025 00:19:13 -0400 Subject: [PATCH] llama-graph : fix recurrent state copy The `state_copy` shuffle assumes everything is moved at once, which is not true when `states_extra` is copied back to the cache before copying the range of states between `head` and `head + n_seqs`. This is only a problem if any of the cells in [`head`, `head + n_seqs`) have an `src` in [`head + n_seqs`, `head + n_kv`), which does happen when `n_ubatch > 1` in the `llama-parallel` example. Changing the order of the operations avoids the potential overwrite before use, although when copies are avoided (like with Mamba2), this will require further changes. * llama-graph : rename n_state to state_size in build_recurrent_state This naming should reduce confusion between the state size and the number of states. --- src/llama-graph.cpp | 31 +++++++++++++++++++------------ src/llama-graph.h | 2 +- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a41e8d4f0..56203d740 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1426,7 +1426,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state( ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, - int32_t n_state, + int32_t state_size, int32_t n_seqs, bool avoid_copies) const { const auto * kv_state = static_cast(mstate); @@ -1435,28 +1435,35 @@ ggml_tensor * llm_graph_context::build_recurrent_state( const auto kv_head = kv_state->get_head(); const auto rs_zero = kv_state->get_rs_z(); - ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size()); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size()); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. - ggml_tensor * state_zero = ggml_view_1d(ctx0, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); + ggml_tensor * output_states; + + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); + } else { + // FIXME: make the gathering operation happen before the copy below + // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) + output_states = states; + } + // copy extra states which won't be changed further (between n_seqs and n_kv) ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, states_extra, - ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); + ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); - if (!avoid_copies) { - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // this shrinks the tensors's ne[1] to n_seqs - states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); - } - - return states; + return output_states; } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( diff --git a/src/llama-graph.h b/src/llama-graph.h index a59687b71..88fb77f1d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -597,7 +597,7 @@ struct llm_graph_context { ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, - int32_t n_state, + int32_t state_size, int32_t n_seqs, bool avoid_copies = false) const;