kv-cache : avoid modifying recurrent cells when setting inputs (#13834)

* kv-cache : avoid modifying recurrent cells when setting inputs

* kv-cache : remove inp_s_mask

It was replaced with equivalent and simpler functionality
with rs_z (the first zeroed state) and the already-existing inp_s_copy.

* kv-cache : fix non-consecutive token pos warning for recurrent models

The problem was apparently caused by how the tail cells were swapped.

* graph : simplify logic for recurrent state copies

* kv-cache : use cell without src refs for rs_z in recurrent cache

* llama-graph : fix recurrent state copy

The `state_copy` shuffle assumes everything is moved at once,
which is not true when `states_extra` is copied back to the cache
before copying the range of states between `head` and `head + n_seqs`.
This is only a problem if any of the cells in [`head`, `head + n_seqs`)
have an `src` in [`head + n_seqs`, `head + n_kv`),
which does happen when `n_ubatch > 1` in the `llama-parallel` example.

Changing the order of the operations avoids the potential overwrite
before use, although when copies are avoided (like with Mamba2),
this will require further changes.

* llama-graph : rename n_state to state_size in build_recurrent_state

This naming should reduce confusion between the state size
and the number of states.
This commit is contained in:
compilade
2025-06-10 18:20:14 -04:00
committed by GitHub
parent 55f6b9fa65
commit dad5c44398
6 changed files with 117 additions and 180 deletions

View File

@ -200,18 +200,6 @@ public:
const llama_kv_cache_recurrent_state * kv_state;
};
class llm_graph_input_s_mask : public llm_graph_input_i {
public:
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
virtual ~llm_graph_input_s_mask() = default;
void set_input(const llama_ubatch * ubatch) override;
ggml_tensor * s_mask; // F32 [1, n_kv]
const llama_kv_cache_recurrent_state * kv_state;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
public:
llm_graph_input_cross_embd(
@ -521,7 +509,6 @@ struct llm_graph_context {
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy() const;
ggml_tensor * build_inp_s_mask() const;
ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const;
@ -606,18 +593,17 @@ struct llm_graph_context {
// recurrent
//
ggml_tensor * build_copy_mask_state(
ggml_tensor * build_recurrent_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const;
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false) const;
ggml_tensor * build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const;