mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-19 00:57:41 +00:00
context : add save/load for recurrent context
ggml-ci
This commit is contained in:
@ -3657,6 +3657,40 @@ ggml_tensor * llama_context_kv_self::build_inp_kq_mask_cross(
|
||||
return inp_kq_mask_cross;
|
||||
}
|
||||
|
||||
// state save/load
|
||||
|
||||
size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
|
||||
llama_context::state_get_data(io);
|
||||
|
||||
kv_self.state_write(io);
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
|
||||
llama_context::state_set_data(io);
|
||||
|
||||
kv_self.state_read(io);
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
||||
llama_context::state_seq_get_data(io, seq_id);
|
||||
|
||||
kv_self.state_write(io, seq_id);
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
llama_context::state_seq_set_data(io, seq_id);
|
||||
|
||||
kv_self.state_read(io, seq_id);
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
//
|
||||
// llama_context_recurrent
|
||||
//
|
||||
@ -4527,7 +4561,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix(
|
||||
|
||||
// state save/load
|
||||
|
||||
size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
|
||||
size_t llama_context_recurrent::state_get_data(llama_io_write_i & io) {
|
||||
llama_context::state_get_data(io);
|
||||
|
||||
kv_self.state_write(io);
|
||||
@ -4535,7 +4569,7 @@ size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
|
||||
size_t llama_context_recurrent::state_set_data(llama_io_read_i & io) {
|
||||
llama_context::state_set_data(io);
|
||||
|
||||
kv_self.state_read(io);
|
||||
@ -4543,7 +4577,7 @@ size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
||||
size_t llama_context_recurrent::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
||||
llama_context::state_seq_get_data(io, seq_id);
|
||||
|
||||
kv_self.state_write(io, seq_id);
|
||||
@ -4551,7 +4585,7 @@ size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_se
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
llama_context::state_seq_set_data(io, seq_id);
|
||||
|
||||
kv_self.state_read(io, seq_id);
|
||||
|
@ -525,6 +525,12 @@ public:
|
||||
bool worst_case) override;
|
||||
|
||||
protected:
|
||||
virtual size_t state_get_data(llama_io_write_i & io) override;
|
||||
virtual size_t state_set_data(llama_io_read_i & io) override;
|
||||
|
||||
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
|
||||
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
|
||||
|
||||
virtual void input_set(const llama_ubatch & ubatch) override;
|
||||
|
||||
// TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?
|
||||
|
Reference in New Issue
Block a user