diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d6618f143..bde665953 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -326,6 +326,361 @@ ggml_tensor * llama_context::build_rope_factors(int il) { return model.layers[il].rope_short; } +// +// state +// + +class llama_io_write_dummy : public llama_io_write_i { +public: + llama_io_write_dummy() = default; + + void write(const void * /* src */, size_t size) override { + size_written += size; + } + + void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { + size_written += size; + } + + size_t n_bytes() override { + return size_written; + } + + size_t size_written = 0; +}; + +class llama_io_write_buffer : public llama_io_write_i { +public: + llama_io_write_buffer( + uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + + void write(const void * src, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + memcpy(ptr, src, size); + ptr += size; + size_written += size; + buf_size -= size; + } + + void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + ggml_backend_tensor_get(tensor, ptr, offset, size); + ptr += size; + size_written += size; + buf_size -= size; + } + + size_t n_bytes() override { + return size_written; + } + + uint8_t * ptr; + size_t buf_size = 0; + size_t size_written = 0; +}; + +class llama_io_read_buffer : public llama_io_read_i { +public: + llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + + const uint8_t * read(size_t size) override { + const uint8_t * base_ptr = ptr; + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + ptr += size; + size_read += size; + buf_size -= size; + return base_ptr; + } + + void read_to(void * dst, size_t size) override { + memcpy(dst, read(size), size); + } + + size_t n_bytes() override { + return size_read; + } + + const uint8_t * ptr; + size_t buf_size = 0; + size_t size_read = 0; +}; + +class llama_io_write_file : public llama_io_write_i { +public: + llama_io_write_file(llama_file * f) : file(f) {} + + void write(const void * src, size_t size) override { + file->write_raw(src, size); + size_written += size; + } + + void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { + temp_buffer.resize(size); + ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); + write(temp_buffer.data(), temp_buffer.size()); + } + + size_t n_bytes() override { + return size_written; + } + + llama_file * file; + size_t size_written = 0; + std::vector temp_buffer; +}; + +class llama_io_read_file : public llama_io_read_i { +public: + llama_io_read_file(llama_file * f) : file(f) {} + + void read_to(void * dst, size_t size) override { + file->read_raw(dst, size); + size_read += size; + } + + const uint8_t * read(size_t size) override { + temp_buffer.resize(size); + read_to(temp_buffer.data(), size); + return temp_buffer.data(); + } + + size_t n_bytes() override { + return size_read; + } + + llama_file * file; + size_t size_read = 0; + std::vector temp_buffer; +}; + +size_t llama_context::state_get_size() { + llama_io_write_dummy io; + try { + return state_get_data(io); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context::state_get_data(uint8_t * dst, size_t size) { + llama_io_write_buffer io(dst, size); + try { + return state_get_data(io); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context::state_set_data(const uint8_t * src, size_t size) { + llama_io_read_buffer io(src, size); + try { + return state_set_data(io); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context::state_seq_get_size(llama_seq_id seq_id) { + llama_io_write_dummy io; + try { + return state_seq_get_data(io, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { + llama_io_write_buffer io(dst, size); + try { + return state_seq_get_data(io, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); + return 0; + } +} + +size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { + llama_io_read_buffer io(src, size); + try { + return state_seq_set_data(io, seq_id); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); + return 0; + } +} + +bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(filepath, "rb"); + + // sanity checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { + LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); + return false; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return false; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t n_state_size_cur = file.size() - file.tell(); + + llama_io_read_file io( &file); + const size_t n_read = state_set_data(io); + + if (n_read != n_state_size_cur) { + LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); + return false; + } + } + + return true; +} + +bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { + llama_file file(filepath, "wb"); + + file.write_u32(LLAMA_SESSION_MAGIC); + file.write_u32(LLAMA_SESSION_VERSION); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_io_write_file io(&file); + state_get_data(io); + + return true; +} + +size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + llama_file file(filepath, "rb"); + + // version checks + { + const uint32_t magic = file.read_u32(); + const uint32_t version = file.read_u32(); + + if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { + LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); + return 0; + } + } + + // load the prompt + { + const uint32_t n_token_count = file.read_u32(); + + if (n_token_count > n_token_capacity) { + LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return 0; + } + + file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); + *n_token_count_out = n_token_count; + } + + // restore the context state + { + const size_t state_size = file.size() - file.tell(); + llama_io_read_file io(&file); + const size_t nread = state_seq_set_data(io, seq_id); + if (!nread) { + LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); + return 0; + } + GGML_ASSERT(nread <= state_size); + GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); + } + + return file.tell(); +} + +size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { + llama_file file(filepath, "wb"); + + file.write_u32(LLAMA_STATE_SEQ_MAGIC); + file.write_u32(LLAMA_STATE_SEQ_VERSION); + + // save the prompt + file.write_u32((uint32_t) n_token_count); + file.write_raw(tokens, sizeof(llama_token) * n_token_count); + + // save the context state using stream saving + llama_io_write_file io(&file); + state_seq_get_data(io, seq_id); + + const size_t res = file.tell(); + GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); + + return res; +} + +size_t llama_context::state_get_data(llama_io_write_i & io) { + // write model info + { + const std::string arch_str = llm_arch_name(model.arch); + io.write_string(arch_str); + // TODO: add more model-specific info which should prevent loading the session file if not identical + } + + return io.n_bytes(); +} + +size_t llama_context::state_set_data(llama_io_read_i & io) { + // read model info + { + const std::string cur_arch_str = llm_arch_name(model.arch); + + std::string arch_str; + io.read_string(arch_str); + if (cur_arch_str != arch_str) { + throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); + } + // TODO: add more info which needs to be identical but which is not verified otherwise + } + + return io.n_bytes(); +} + +size_t llama_context::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) { + GGML_UNUSED(seq_id); + + return io.n_bytes(); +} + +size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) { + GGML_UNUSED(seq_id); + + return io.n_bytes(); +} + void llama_context::perf_reset() { t_start_us = ggml_time_us(); t_eval_us = n_eval = 0; @@ -3123,333 +3478,10 @@ ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix( return cur; } -// -// state -// - -// TODO: this needs a big rework - -class llama_io_write_dummy : public llama_io_write_i { -public: - llama_io_write_dummy() = default; - - void write(const void * /* src */, size_t size) override { - size_written += size; - } - - void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { - size_written += size; - } - - size_t n_bytes() override { - return size_written; - } - - size_t size_written = 0; -}; - -class llama_io_write_buffer : public llama_io_write_i { -public: - llama_io_write_buffer( - uint8_t * p, size_t len) : ptr(p), buf_size(len) {} - - void write(const void * src, size_t size) override { - if (size > buf_size) { - throw std::runtime_error("unexpectedly reached end of buffer"); - } - memcpy(ptr, src, size); - ptr += size; - size_written += size; - buf_size -= size; - } - - void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { - if (size > buf_size) { - throw std::runtime_error("unexpectedly reached end of buffer"); - } - ggml_backend_tensor_get(tensor, ptr, offset, size); - ptr += size; - size_written += size; - buf_size -= size; - } - - size_t n_bytes() override { - return size_written; - } - - uint8_t * ptr; - size_t buf_size = 0; - size_t size_written = 0; -}; - -class llama_io_read_buffer : public llama_io_read_i { -public: - llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} - - const uint8_t * read(size_t size) override { - const uint8_t * base_ptr = ptr; - if (size > buf_size) { - throw std::runtime_error("unexpectedly reached end of buffer"); - } - ptr += size; - size_read += size; - buf_size -= size; - return base_ptr; - } - - void read_to(void * dst, size_t size) override { - memcpy(dst, read(size), size); - } - - size_t n_bytes() override { - return size_read; - } - - const uint8_t * ptr; - size_t buf_size = 0; - size_t size_read = 0; -}; - -class llama_io_write_file : public llama_io_write_i { -public: - llama_io_write_file(llama_file * f) : file(f) {} - - void write(const void * src, size_t size) override { - file->write_raw(src, size); - size_written += size; - } - - void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { - temp_buffer.resize(size); - ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); - write(temp_buffer.data(), temp_buffer.size()); - } - - size_t n_bytes() override { - return size_written; - } - - llama_file * file; - size_t size_written = 0; - std::vector temp_buffer; -}; - -class llama_io_read_file : public llama_io_read_i { -public: - llama_io_read_file(llama_file * f) : file(f) {} - - void read_to(void * dst, size_t size) override { - file->read_raw(dst, size); - size_read += size; - } - - const uint8_t * read(size_t size) override { - temp_buffer.resize(size); - read_to(temp_buffer.data(), size); - return temp_buffer.data(); - } - - size_t n_bytes() override { - return size_read; - } - - llama_file * file; - size_t size_read = 0; - std::vector temp_buffer; -}; - -size_t llama_context_kv_self::state_get_size() { - llama_io_write_dummy io; - try { - return state_get_data(io); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); - return 0; - } -} - -size_t llama_context_kv_self::state_get_data(uint8_t * dst, size_t size) { - llama_io_write_buffer io(dst, size); - try { - return state_get_data(io); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); - return 0; - } -} - -size_t llama_context_kv_self::state_set_data(const uint8_t * src, size_t size) { - llama_io_read_buffer io(src, size); - try { - return state_set_data(io); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); - return 0; - } -} - -size_t llama_context_kv_self::state_seq_get_size(llama_seq_id seq_id) { - llama_io_write_dummy io; - try { - return state_seq_get_data(io, seq_id); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); - return 0; - } -} - -size_t llama_context_kv_self::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { - llama_io_write_buffer io(dst, size); - try { - return state_seq_get_data(io, seq_id); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); - return 0; - } -} - -size_t llama_context_kv_self::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { - llama_io_read_buffer io(src, size); - try { - return state_seq_set_data(io, seq_id); - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); - return 0; - } -} - -bool llama_context_kv_self::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { - llama_file file(filepath, "rb"); - - // sanity checks - { - const uint32_t magic = file.read_u32(); - const uint32_t version = file.read_u32(); - - if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { - LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); - return false; - } - } - - // load the prompt - { - const uint32_t n_token_count = file.read_u32(); - - if (n_token_count > n_token_capacity) { - LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); - return false; - } - - file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); - *n_token_count_out = n_token_count; - } - - // restore the context state - { - const size_t n_state_size_cur = file.size() - file.tell(); - - llama_io_read_file io( &file); - const size_t n_read = state_set_data(io); - - if (n_read != n_state_size_cur) { - LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); - return false; - } - } - - return true; -} - -bool llama_context_kv_self::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) { - llama_file file(filepath, "wb"); - - file.write_u32(LLAMA_SESSION_MAGIC); - file.write_u32(LLAMA_SESSION_VERSION); - - // save the prompt - file.write_u32((uint32_t) n_token_count); - file.write_raw(tokens, sizeof(llama_token) * n_token_count); - - // save the context state using stream saving - llama_io_write_file io(&file); - state_get_data(io); - - return true; -} - -size_t llama_context_kv_self::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { - llama_file file(filepath, "rb"); - - // version checks - { - const uint32_t magic = file.read_u32(); - const uint32_t version = file.read_u32(); - - if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) { - LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version); - return 0; - } - } - - // load the prompt - { - const uint32_t n_token_count = file.read_u32(); - - if (n_token_count > n_token_capacity) { - LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); - return 0; - } - - file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); - *n_token_count_out = n_token_count; - } - - // restore the context state - { - const size_t state_size = file.size() - file.tell(); - llama_io_read_file io(&file); - const size_t nread = state_seq_set_data(io, seq_id); - if (!nread) { - LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); - return 0; - } - GGML_ASSERT(nread <= state_size); - GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell()); - } - - return file.tell(); -} - -size_t llama_context_kv_self::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) { - llama_file file(filepath, "wb"); - - file.write_u32(LLAMA_STATE_SEQ_MAGIC); - file.write_u32(LLAMA_STATE_SEQ_VERSION); - - // save the prompt - file.write_u32((uint32_t) n_token_count); - file.write_raw(tokens, sizeof(llama_token) * n_token_count); - - // save the context state using stream saving - llama_io_write_file io(&file); - state_seq_get_data(io, seq_id); - - const size_t res = file.tell(); - GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); - - return res; -} +// state save/load size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) { - synchronize(); - - // write model info - { - const std::string arch_str = llm_arch_name(model.arch); - io.write_string(arch_str); - // TODO: add more model-specific info which should prevent loading the session file if not identical - } + llama_context::state_get_data(io); // write output ids { @@ -3492,7 +3524,7 @@ size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) { } } - // write mbeddings + // write embeddings { const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); @@ -3509,19 +3541,7 @@ size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) { } size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) { - synchronize(); - - // read model info - { - const std::string cur_arch_str = llm_arch_name(model.arch); - - std::string arch_str; - io.read_string(arch_str); - if (cur_arch_str != arch_str) { - throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); - } - // TODO: add more info which needs to be identical but which is not verified otherwise - } + llama_context::state_set_data(io); // read output ids { @@ -3584,7 +3604,7 @@ size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) { } size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) { - synchronize(); + llama_context::state_seq_get_data(io, seq_id); kv_self.state_write(io, model.hparams, seq_id); @@ -3592,7 +3612,7 @@ size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_se } size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) { - synchronize(); + llama_context::state_seq_set_data(io, seq_id); kv_self.state_read(io, model.hparams, seq_id); @@ -3937,15 +3957,21 @@ size_t llama_state_get_size(struct llama_context * ctx) { } size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) { + ctx->synchronize(); + return ctx->state_get_data(dst, size); } // Sets the state reading from the specified source address size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) { + ctx->synchronize(); + return ctx->state_set_data(src, size); } bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + ctx->synchronize(); + try { return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out); } catch (const std::exception & err) { @@ -3955,6 +3981,8 @@ bool llama_state_load_file(struct llama_context * ctx, const char * path_session } bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + ctx->synchronize(); + try { return ctx->state_save_file(path_session, tokens, n_token_count); } catch (const std::exception & err) { @@ -3968,14 +3996,20 @@ size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id) } size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) { + ctx->synchronize(); + return ctx->state_seq_get_data(seq_id, dst, size); } size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) { + ctx->synchronize(); + return ctx->state_seq_set_data(seq_id, src, size); } size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { + ctx->synchronize(); + try { return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count); } catch (const std::exception & err) { @@ -3985,6 +4019,8 @@ size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepa } size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + ctx->synchronize(); + try { return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out); } catch (const std::exception & err) { diff --git a/src/llama-context.h b/src/llama-context.h index 204793d75..235fcfee4 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -144,37 +144,37 @@ struct llama_context : public llama_graph_i { // state save/load - virtual size_t state_get_size() = 0; - virtual size_t state_get_data( uint8_t * dst, size_t size) = 0; - virtual size_t state_set_data(const uint8_t * src, size_t size) = 0; + virtual size_t state_get_size(); + virtual size_t state_get_data( uint8_t * dst, size_t size); + virtual size_t state_set_data(const uint8_t * src, size_t size); - virtual size_t state_seq_get_size(llama_seq_id seq_id) = 0; - virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) = 0; - virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) = 0; + virtual size_t state_seq_get_size(llama_seq_id seq_id); + virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size); + virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size); virtual bool state_load_file( const char * filepath, llama_token * tokens_out, size_t n_token_capacity, - size_t * n_token_count_out) = 0; + size_t * n_token_count_out); virtual bool state_save_file( const char * filepath, const llama_token * tokens, - size_t n_token_count) = 0; + size_t n_token_count); virtual size_t state_seq_load_file( llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, - size_t * n_token_count_out) = 0; + size_t * n_token_count_out); virtual size_t state_seq_save_file( llama_seq_id seq_id, const char * filepath, const llama_token * tokens, - size_t n_token_count) = 0; + size_t n_token_count); // perf @@ -183,6 +183,14 @@ struct llama_context : public llama_graph_i { protected: + // state save/load + + virtual size_t state_get_data(llama_io_write_i & io); + virtual size_t state_set_data(llama_io_read_i & io); + + virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id); + virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id); + // members const llama_model & model; @@ -471,46 +479,12 @@ public: int il, bool worst_case) override; - // state save/load +protected: + virtual size_t state_get_data(llama_io_write_i & io) override; + virtual size_t state_set_data(llama_io_read_i & io) override; - virtual size_t state_get_size() override; - virtual size_t state_get_data( uint8_t * dst, size_t size) override; - virtual size_t state_set_data(const uint8_t * src, size_t size) override; - - virtual size_t state_seq_get_size(llama_seq_id seq_id) override; - virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override; - virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override; - - virtual bool state_load_file( - const char * filepath, - llama_token * tokens_out, - size_t n_token_capacity, - size_t * n_token_count_out) override; - - virtual bool state_save_file( - const char * filepath, - const llama_token * tokens, - size_t n_token_count) override; - - virtual size_t state_seq_load_file( - llama_seq_id seq_id, - const char * filepath, - llama_token * tokens_out, - size_t n_token_capacity, - size_t * n_token_count_out) override; - - virtual size_t state_seq_save_file( - llama_seq_id seq_id, - const char * filepath, - const llama_token * tokens, - size_t n_token_count) override; - -private: - size_t state_get_data(llama_io_write_i & io); - size_t state_set_data(llama_io_read_i & io); - - size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id); - size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id); + virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override; + virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override; }; // For internal test use