From 2645a7d9a999de249e15ff3dae5eea1866221b57 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 21 Feb 2025 10:28:42 +0200 Subject: [PATCH] context : add save/load for recurrent context ggml-ci --- src/llama-context.cpp | 42 ++++++++++++++++++++++++++++++++++++++---- src/llama-context.h | 6 ++++++ 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 64728e8b5..4ce54b0d6 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -3657,6 +3657,40 @@ ggml_tensor * llama_context_kv_self::build_inp_kq_mask_cross( return inp_kq_mask_cross; } +// state save/load + +size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) { + llama_context::state_get_data(io); + + kv_self.state_write(io); + + return io.n_bytes(); +} + +size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) { + llama_context::state_set_data(io); + + kv_self.state_read(io); + + return io.n_bytes(); +} + +size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) { + llama_context::state_seq_get_data(io, seq_id); + + kv_self.state_write(io, seq_id); + + return io.n_bytes(); +} + +size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) { + llama_context::state_seq_set_data(io, seq_id); + + kv_self.state_read(io, seq_id); + + return io.n_bytes(); +} + // // llama_context_recurrent // @@ -4527,7 +4561,7 @@ ggml_tensor * llama_context_recurrent::build_rwkv6_time_mix( // state save/load -size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) { +size_t llama_context_recurrent::state_get_data(llama_io_write_i & io) { llama_context::state_get_data(io); kv_self.state_write(io); @@ -4535,7 +4569,7 @@ size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) { return io.n_bytes(); } -size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) { +size_t llama_context_recurrent::state_set_data(llama_io_read_i & io) { llama_context::state_set_data(io); kv_self.state_read(io); @@ -4543,7 +4577,7 @@ size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) { return io.n_bytes(); } -size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) { +size_t llama_context_recurrent::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) { llama_context::state_seq_get_data(io, seq_id); kv_self.state_write(io, seq_id); @@ -4551,7 +4585,7 @@ size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_se return io.n_bytes(); } -size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) { +size_t llama_context_recurrent::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) { llama_context::state_seq_set_data(io, seq_id); kv_self.state_read(io, seq_id); diff --git a/src/llama-context.h b/src/llama-context.h index df6acb265..9d8b70220 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -525,6 +525,12 @@ public: bool worst_case) override; protected: + virtual size_t state_get_data(llama_io_write_i & io) override; + virtual size_t state_set_data(llama_io_read_i & io) override; + + virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override; + virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override; + virtual void input_set(const llama_ubatch & ubatch) override; // TODO: change name to something more meaningful -- does "KV cache" make sense for recurrent models?