mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 12:35:16 +00:00
kv-cache : avoid modifying recurrent cells when setting inputs (#13834)
* kv-cache : avoid modifying recurrent cells when setting inputs * kv-cache : remove inp_s_mask It was replaced with equivalent and simpler functionality with rs_z (the first zeroed state) and the already-existing inp_s_copy. * kv-cache : fix non-consecutive token pos warning for recurrent models The problem was apparently caused by how the tail cells were swapped. * graph : simplify logic for recurrent state copies * kv-cache : use cell without src refs for rs_z in recurrent cache * llama-graph : fix recurrent state copy The `state_copy` shuffle assumes everything is moved at once, which is not true when `states_extra` is copied back to the cache before copying the range of states between `head` and `head + n_seqs`. This is only a problem if any of the cells in [`head`, `head + n_seqs`) have an `src` in [`head + n_seqs`, `head + n_kv`), which does happen when `n_ubatch > 1` in the `llama-parallel` example. Changing the order of the operations avoids the potential overwrite before use, although when copies are avoided (like with Mamba2), this will require further changes. * llama-graph : rename n_state to state_size in build_recurrent_state This naming should reduce confusion between the state size and the number of states.
This commit is contained in:
@ -200,18 +200,6 @@ public:
|
||||
const llama_kv_cache_recurrent_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_s_mask : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_mask(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
||||
virtual ~llm_graph_input_s_mask() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_mask; // F32 [1, n_kv]
|
||||
|
||||
const llama_kv_cache_recurrent_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_cross_embd(
|
||||
@ -521,7 +509,6 @@ struct llm_graph_context {
|
||||
ggml_tensor * build_inp_mean() const;
|
||||
ggml_tensor * build_inp_cls() const;
|
||||
ggml_tensor * build_inp_s_copy() const;
|
||||
ggml_tensor * build_inp_s_mask() const;
|
||||
|
||||
ggml_tensor * build_inp_cross_embd() const;
|
||||
ggml_tensor * build_inp_pos_bucket_enc() const;
|
||||
@ -606,18 +593,17 @@ struct llm_graph_context {
|
||||
// recurrent
|
||||
//
|
||||
|
||||
ggml_tensor * build_copy_mask_state(
|
||||
ggml_tensor * build_recurrent_state(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
int32_t n_state,
|
||||
int32_t n_seqs) const;
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies = false) const;
|
||||
|
||||
ggml_tensor * build_rwkv_token_shift_load(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor * state_mask,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const;
|
||||
|
||||
|
Reference in New Issue
Block a user