graph : reduce splits for recurrent and hybrid models (#14825)

* graph : avoid creating redundant s_copy views

* graph : comment the s_copy views
This commit is contained in:
compilade
2025-07-31 01:02:46 -04:00
committed by GitHub
parent 6e6725459a
commit 66625a59a5
2 changed files with 34 additions and 21 deletions

View File

@@ -214,7 +214,12 @@ public:
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_copy; // I32 [kv_size]
ggml_tensor * s_copy; // I32 [n_rs]
// views of s_copy, computed once per graph
// and shared across layers which use build_rs
ggml_tensor * s_copy_main; // I32 [n_seqs]
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
const llama_memory_recurrent_context * mctx;
};
@@ -730,7 +735,6 @@ struct llm_graph_context {
// recurrent
//
// TODO: avoid notion of "kv"
// TODO: move this implementation to llama_memory_recurrent.
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
@@ -738,12 +742,13 @@ struct llm_graph_context {
// `llama_memory_recurrent`
ggml_tensor * build_rs(
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_copy_main,
ggml_tensor * state_copy_extra,
int32_t state_size,
int32_t n_seqs,
uint32_t n_kv,
uint32_t kv_head,
uint32_t kv_size,
uint32_t n_rs,
uint32_t rs_head,
uint32_t rs_size,
int32_t rs_zero,
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;