diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 039718c04..beebcd755 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1494,6 +1494,49 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +ggml_tensor * llm_graph_context::build_recurrent_state( + const llama_kv_cache_recurrent_state * kv_state, + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies) const { + + const auto n_kv = kv_state->get_n_kv(); + const auto kv_head = kv_state->get_head(); + const auto rs_zero = kv_state->get_rs_z(); + + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size()); + + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); + + ggml_tensor * output_states; + + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); + } else { + // FIXME: make the gathering operation happen before the copy below + // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) + output_states = states; + } + + // copy extra states which won't be changed further (between n_seqs and n_kv) + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + states_extra, + ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); + + return output_states; +} llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const { const auto * kv_state = static_cast(mstate); @@ -1519,40 +1562,7 @@ ggml_tensor * llm_graph_context::build_rs( bool avoid_copies) const { const auto * kv_state = static_cast(mstate); - - const auto n_kv = kv_state->get_n_kv(); - const auto kv_head = kv_state->get_head(); - const auto rs_zero = kv_state->get_rs_z(); - - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size()); - - // Clear a single state which will then be copied to the other cleared states. - // Note that this is a no-op when the view is zero-sized. - ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); - ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); - - ggml_tensor * output_states; - - if (!avoid_copies) { - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // {state_size, kv_size} -> {state_size, n_seqs} - output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0)); - ggml_build_forward_expand(gf, output_states); - } else { - // FIXME: make the gathering operation happen before the copy below - // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) - output_states = states; - } - - // copy extra states which won't be changed further (between n_seqs and n_kv) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0])); - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - states_extra, - ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); - - return output_states; + return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies); } llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const { @@ -1578,40 +1588,7 @@ ggml_tensor * llm_graph_context::build_rs( bool avoid_copies) const { const auto * kv_state = static_cast(mstate)->get_state_recurrent(); - - const auto n_kv = kv_state->get_n_kv(); - const auto kv_head = kv_state->get_head(); - const auto rs_zero = kv_state->get_rs_z(); - - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size()); - - // Clear a single state which will then be copied to the other cleared states. - // Note that this is a no-op when the view is zero-sized. - ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); - ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); - - ggml_tensor * output_states; - - if (!avoid_copies) { - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // {state_size, kv_size} -> {state_size, n_seqs} - output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0)); - ggml_build_forward_expand(gf, output_states); - } else { - // FIXME: make the gathering operation happen before the copy below - // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) - output_states = states; - } - - // copy extra states which won't be changed further (between n_seqs and n_kv) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0])); - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - states_extra, - ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); - - return output_states; + return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies); } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( diff --git a/src/llama-graph.h b/src/llama-graph.h index 77f19a673..f705ea81d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -622,6 +622,15 @@ struct llm_graph_context { // recurrent // + ggml_tensor * build_recurrent_state( + const llama_kv_cache_recurrent_state * kv_state, + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies = false) const; + llm_graph_input_rs * build_rs_inp_recurrent() const; ggml_tensor * build_rs(