graph : reduce splits for recurrent and hybrid models (#14825)

* graph : avoid creating redundant s_copy views

* graph : comment the s_copy views
This commit is contained in:
compilade
2025-07-31 01:02:46 -04:00
committed by GitHub
parent 6e6725459a
commit 66625a59a5
2 changed files with 34 additions and 21 deletions

View File

@@ -1644,16 +1644,17 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
ggml_tensor * llm_graph_context::build_rs( ggml_tensor * llm_graph_context::build_rs(
ggml_tensor * s, ggml_tensor * s,
ggml_tensor * state_copy, ggml_tensor * state_copy_main,
ggml_tensor * state_copy_extra,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
uint32_t n_kv, uint32_t n_rs,
uint32_t kv_head, uint32_t rs_head,
uint32_t kv_size, uint32_t rs_size,
int32_t rs_zero, int32_t rs_zero,
const llm_graph_get_rows_fn & get_state_rows) const { const llm_graph_get_rows_fn & get_state_rows) const {
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size); ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size);
// Clear a single state which will then be copied to the other cleared states. // Clear a single state which will then be copied to the other cleared states.
// Note that this is a no-op when the view is zero-sized. // Note that this is a no-op when the view is zero-sized.
@@ -1661,39 +1662,44 @@ ggml_tensor * llm_graph_context::build_rs(
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
// copy states // copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
// {state_size, kv_size} -> {state_size, n_seqs} // {state_size, rs_size} -> {state_size, n_seqs}
ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); ggml_tensor * output_states = get_state_rows(ctx0, states, state_copy_main);
ggml_build_forward_expand(gf, output_states); ggml_build_forward_expand(gf, output_states);
// copy extra states which won't be changed further (between n_seqs and n_kv) // copy extra states which won't be changed further (between n_seqs and n_rs)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_tensor * states_extra = ggml_get_rows(ctx0, states, state_copy_extra);
ggml_build_forward_expand(gf, ggml_build_forward_expand(gf,
ggml_cpy(ctx0, ggml_cpy(ctx0,
states_extra, states_extra,
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s))));
return output_states; return output_states;
} }
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl( static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
ggml_context * ctx0, ggml_context * ctx0,
const llama_ubatch & ubatch,
const llama_memory_recurrent_context * mctx_cur) { const llama_memory_recurrent_context * mctx_cur) {
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur); auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
const auto n_rs = mctx_cur->get_n_rs(); const int64_t n_rs = mctx_cur->get_n_rs();
const int64_t n_seqs = ubatch.n_seqs;
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
ggml_set_input(inp->s_copy); ggml_set_input(inp->s_copy);
inp->s_copy_main = ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0);
inp->s_copy_extra = ggml_view_1d(ctx0, inp->s_copy, n_rs - n_seqs, n_seqs * inp->s_copy->nb[0]);
return inp; return inp;
} }
llm_graph_input_rs * llm_graph_context::build_rs_inp() const { llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx); const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
auto inp = build_rs_inp_impl(ctx0, mctx_cur); auto inp = build_rs_inp_impl(ctx0, ubatch, mctx_cur);
return (llm_graph_input_rs *) res->add_input(std::move(inp)); return (llm_graph_input_rs *) res->add_input(std::move(inp));
} }
@@ -1706,7 +1712,9 @@ ggml_tensor * llm_graph_context::build_rs(
const llm_graph_get_rows_fn & get_state_rows) const { const llm_graph_get_rows_fn & get_state_rows) const {
const auto * kv_state = inp->mctx; const auto * kv_state = inp->mctx;
return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); return build_rs(s, inp->s_copy_main, inp->s_copy_extra, state_size, n_seqs,
kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(),
get_state_rows);
} }
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -1753,7 +1761,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx); const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr()); auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur); auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);

View File

@@ -214,7 +214,12 @@ public:
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size] ggml_tensor * s_copy; // I32 [n_rs]
// views of s_copy, computed once per graph
// and shared across layers which use build_rs
ggml_tensor * s_copy_main; // I32 [n_seqs]
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
const llama_memory_recurrent_context * mctx; const llama_memory_recurrent_context * mctx;
}; };
@@ -730,7 +735,6 @@ struct llm_graph_context {
// recurrent // recurrent
// //
// TODO: avoid notion of "kv"
// TODO: move this implementation to llama_memory_recurrent. // TODO: move this implementation to llama_memory_recurrent.
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v // this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the // when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
@@ -738,12 +742,13 @@ struct llm_graph_context {
// `llama_memory_recurrent` // `llama_memory_recurrent`
ggml_tensor * build_rs( ggml_tensor * build_rs(
ggml_tensor * s, ggml_tensor * s,
ggml_tensor * state_copy, ggml_tensor * state_copy_main,
ggml_tensor * state_copy_extra,
int32_t state_size, int32_t state_size,
int32_t n_seqs, int32_t n_seqs,
uint32_t n_kv, uint32_t n_rs,
uint32_t kv_head, uint32_t rs_head,
uint32_t kv_size, uint32_t rs_size,
int32_t rs_zero, int32_t rs_zero,
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const; const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;