context : add save/load for recurrent context

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-21 10:28:42 +02:00
parent 08011c2ca1
commit 2645a7d9a9
2 changed files with 44 additions and 4 deletions

View File

@ -3657,6 +3657,40 @@ ggml_tensor * llama_context_kv_self::build_inp_kq_mask_cross(
return inp_kq_mask_cross;
}
// state save/load
size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
llama_context::state_get_data(io);
kv_self.state_write(io);
return io.n_bytes();
}
size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
llama_context::state_set_data(io);
kv_self.state_read(io);
return io.n_bytes();
}
size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
llama_context::state_seq_get_data(io, seq_id);
kv_self.state_write(io, seq_id);
return io.n_bytes();
}
size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
llama_context::state_seq_set_data(io, seq_id);
kv_self.state_read(io, seq_id);
return io.n_bytes();
}
//
// llama_context_recurrent
//
@ -4527,7 +4561,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
// state save/load
size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
size_t llama_context_recurrent::state_get_data(llama_io_write_i & io) {
llama_context::state_get_data(io);
kv_self.state_write(io);
@ -4535,7 +4569,7 @@ size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
return io.n_bytes();
}
size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
size_t llama_context_recurrent::state_set_data(llama_io_read_i & io) {
llama_context::state_set_data(io);
kv_self.state_read(io);
@ -4543,7 +4577,7 @@ size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
return io.n_bytes();
}
size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
size_t llama_context_recurrent::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
llama_context::state_seq_get_data(io, seq_id);
kv_self.state_write(io, seq_id);
@ -4551,7 +4585,7 @@ size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_se
return io.n_bytes();
}
size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
llama_context::state_seq_set_data(io, seq_id);
kv_self.state_read(io, seq_id);

View File

@ -525,6 +525,12 @@ public:
bool worst_case) override;
protected:
virtual size_t state_get_data(llama_io_write_i & io) override;
virtual size_t state_set_data(llama_io_read_i & io) override;
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
virtual void input_set(const llama_ubatch & ubatch) override;
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?