kv-cache : avoid modifying recurrent cells when setting inputs (#13834)

* kv-cache : avoid modifying recurrent cells when setting inputs

* kv-cache : remove inp_s_mask

It was replaced with equivalent and simpler functionality
with rs_z (the first zeroed state) and the already-existing inp_s_copy.

* kv-cache : fix non-consecutive token pos warning for recurrent models

The problem was apparently caused by how the tail cells were swapped.

* graph : simplify logic for recurrent state copies

* kv-cache : use cell without src refs for rs_z in recurrent cache

* llama-graph : fix recurrent state copy

The `state_copy` shuffle assumes everything is moved at once,
which is not true when `states_extra` is copied back to the cache
before copying the range of states between `head` and `head + n_seqs`.
This is only a problem if any of the cells in [`head`, `head + n_seqs`)
have an `src` in [`head + n_seqs`, `head + n_kv`),
which does happen when `n_ubatch > 1` in the `llama-parallel` example.

Changing the order of the operations avoids the potential overwrite
before use, although when copies are avoided (like with Mamba2),
this will require further changes.

* llama-graph : rename n_state to state_size in build_recurrent_state

This naming should reduce confusion between the state size
and the number of states.
This commit is contained in:
compilade
2025-06-10 18:20:14 -04:00
committed by GitHub
parent 55f6b9fa65
commit dad5c44398
6 changed files with 117 additions and 180 deletions

View File

@ -250,22 +250,6 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
const int64_t n_kv = kv_state->get_n_kv();
if (s_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer));
float * data = (float *) s_mask->data;
// clear unused states
for (int i = 0; i < n_kv; ++i) {
data[i] = kv_state->s_mask(i);
}
}
}
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
@ -987,23 +971,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
return cur;
}
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_state);
const auto n_kv = kv_state->get_n_kv();
auto & cur = inp->s_mask;
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
ggml_set_input(cur);
res->add_input(std::move(inp));
return cur;
}
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
@ -1456,43 +1423,53 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
ggml_tensor * llm_graph_context::build_copy_mask_state(
ggml_tensor * llm_graph_context::build_recurrent_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_state,
int32_t n_seqs) const {
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto n_kv = kv_state->get_n_kv();
const auto kv_head = kv_state->get_head();
const auto rs_zero = kv_state->get_rs_z();
ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_state->get_size());
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
// copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// this shrinks the tensors's ne[1] to n_kv
states = ggml_get_rows(ctx0, states, state_copy);
// Clear a single state which will then be copied to the other cleared states.
// Note that this is a no-op when the view is zero-sized.
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
// clear states of sequences which are starting at the beginning of this batch
// FIXME: zero-out NANs?
states = ggml_mul(ctx0, states, state_mask);
ggml_tensor * output_states;
// copy states which won't be changed further (between n_seqs and n_kv)
if (!avoid_copies) {
// copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// {state_size, kv_size} -> {state_size, n_seqs}
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
ggml_build_forward_expand(gf, output_states);
} else {
// FIXME: make the gathering operation happen before the copy below
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
output_states = states;
}
// copy extra states which won't be changed further (between n_seqs and n_kv)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)),
ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s))));
states_extra,
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
// the part of the states that will be used and modified
return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0);
return output_states;
}
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@ -1503,8 +1480,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
ggml_tensor * token_shift = build_copy_mask_state(
gf, token_shift_all, state_copy, state_mask,
ggml_tensor * token_shift = build_recurrent_state(
gf, token_shift_all, state_copy,
hparams.n_embd_k_s(), n_seqs);
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);