diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index e74c9ff53..1abe3b8fe 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1429,7 +1429,8 @@ ggml_tensor * llm_graph_context::build_recurrent_state( ggml_tensor * state_copy, int32_t state_size, int32_t n_seqs, - bool avoid_copies) const { + const std::function & get_state_rows) const { + const auto * kv_state = static_cast(mstate); const auto n_kv = kv_state->get_n_kv(); @@ -1445,17 +1446,11 @@ ggml_tensor * llm_graph_context::build_recurrent_state( ggml_tensor * output_states; - if (!avoid_copies) { - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // {state_size, kv_size} -> {state_size, n_seqs} - output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); - ggml_build_forward_expand(gf, output_states); - } else { - // FIXME: make the gathering operation happen before the copy below - // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) - output_states = states; - } + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); // copy extra states which won't be changed further (between n_seqs and n_kv) ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); diff --git a/src/llama-graph.h b/src/llama-graph.h index 88fb77f1d..1fcf1cde4 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -599,7 +599,8 @@ struct llm_graph_context { ggml_tensor * state_copy, int32_t state_size, int32_t n_seqs, - bool avoid_copies = false) const; + const std::function + & get_state_rows = ggml_get_rows) const; ggml_tensor * build_rwkv_token_shift_load( ggml_cgraph * gf, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2c0f7d408..2999483ad 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9024,11 +9024,8 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * conv_states_all = kv_state->get_k_l(il); ggml_tensor * ssm_states_all = kv_state->get_v_l(il); - // (ab)using the KV cache to store the states ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); - ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true); - ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size()); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9094,11 +9091,21 @@ struct llm_build_mamba : public llm_graph_context { cur = x; x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs); - ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0); - // Custom operator to optimize the parallel associative scan - // as described in the Annex D of the Mamba paper. - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size()); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, @@ -9151,11 +9158,8 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * conv_states_all = kv_state->get_k_l(il); ggml_tensor * ssm_states_all = kv_state->get_v_l(il); - // (ab)using the KV cache to store the states ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); - ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true); - ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size()); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9211,10 +9215,20 @@ struct llm_build_mamba : public llm_graph_context { // {n_head, n_seq_tokens, n_seqs} dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); - ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0); - // TODO: use semistructured matrices to implement state-space duality - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size()); + + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf,