From 3a504d9a0bd7d952d22cd2d707446de2316ec955 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 13 Feb 2025 12:18:44 +0200 Subject: [PATCH] llama : introduce llama_io interfaces ggml-ci --- src/CMakeLists.txt | 1 + src/llama-context.cpp | 480 +++++++++++++++-------------------------- src/llama-context.h | 14 +- src/llama-io.cpp | 15 ++ src/llama-io.h | 35 +++ src/llama-kv-cache.cpp | 18 +- src/llama-kv-cache.h | 21 +- 7 files changed, 250 insertions(+), 334 deletions(-) create mode 100644 src/llama-io.cpp create mode 100644 src/llama-io.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f1f5d41d4..7f919c90e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -18,6 +18,7 @@ add_library(llama llama-graph.cpp llama-hparams.cpp llama-impl.cpp + llama-io.cpp llama-kv-cache.cpp llama-mmap.cpp llama-model-loader.cpp diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 665a144d7..d6618f143 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2,6 +2,7 @@ #include "llama-impl.h" #include "llama-mmap.h" +#include "llama-io.h" #include #include @@ -3128,214 +3129,29 @@ ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix( // TODO: this needs a big rework -// TODO: replace all non-fatal assertions with returned errors or exceptions -struct llama_data_write { - llama_data_write(llama_context_kv_self * ctx) : ctx(ctx) {} - virtual ~llama_data_write() = default; - - virtual void write(const void * src, size_t size) = 0; - virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0; - virtual size_t get_size_written() = 0; - - void write_string(const std::string & str) { - uint32_t str_size = str.size(); - - write(&str_size, sizeof(str_size)); - write(str.data(), str_size); - } - - void write_model_info() { - const auto & model = ctx->get_model(); - const std::string arch_str = llm_arch_name(model.arch); - write_string(arch_str); - // TODO: add more model-specific info which should prevent loading the session file if not identical - } - - //void write_rng(const std::mt19937 & rng) { - // std::ostringstream rng_ss; - // rng_ss << rng; - - // const std::string & rng_str = rng_ss.str(); - - // write_string(rng_str); - //} - - void write_output_ids() { - ctx->reorder_outputs(); - - const uint32_t n_outputs = ctx->n_outputs; - - std::vector output_pos; - - const size_t n_batch = ctx->n_batch(); - const auto & output_ids = ctx->output_ids; - - GGML_ASSERT(n_outputs <= ctx->output_size); - - output_pos.resize(n_outputs); - - // build a more compact representation of the output ids - for (size_t i = 0; i < n_batch; ++i) { - // map an output id to a position in the batch - int32_t pos = output_ids[i]; - if (pos >= 0) { - GGML_ASSERT((uint32_t) pos < n_outputs); - output_pos[pos] = i; - } - } - - write(&n_outputs, sizeof(n_outputs)); - - if (n_outputs) { - write(output_pos.data(), n_outputs * sizeof(int32_t)); - } - } - - void write_logits() { - const auto & model = ctx->get_model(); - - const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * model.vocab.n_tokens()); - - write(&logits_size, sizeof(logits_size)); - - if (logits_size) { - write(ctx->logits, logits_size * sizeof(float)); - } - } - - void write_embeddings() { - const auto & model = ctx->get_model(); - - const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * model.hparams.n_embd); - - write(&embeddings_size, sizeof(embeddings_size)); - - if (embeddings_size) { - write(ctx->embd, embeddings_size * sizeof(float)); - } - } - - llama_context_kv_self * ctx; -}; - -struct llama_data_read { - llama_data_read(llama_context_kv_self * ctx) : ctx(ctx) {} - virtual ~llama_data_read() = default; - - virtual const uint8_t * read(size_t size) = 0; - virtual void read_to(void * dst, size_t size) = 0; - virtual size_t get_size_read() = 0; - - void read_string(std::string & str) { - uint32_t str_size; - read_to(&str_size, sizeof(str_size)); - - str.assign((const char *) read(str_size), str_size); - } - - // validate model information - void read_model_info() { - const auto & model = ctx->get_model(); - - const std::string cur_arch_str = llm_arch_name(model.arch); - - std::string arch_str; - read_string(arch_str); - if (cur_arch_str != arch_str) { - throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); - } - // TODO: add more info which needs to be identical but which is not verified otherwise - } - - //void read_rng(std::mt19937 & rng) { - // std::string rng_str; - // read_string(rng_str); - - // std::istringstream rng_ss(rng_str); - // rng_ss >> rng; - - // if (rng_ss.fail()) { - // throw std::runtime_error("failed to load RNG state"); - // } - //} - - void read_output_ids() { - std::vector output_pos; - - uint32_t n_outputs; - read_to(&n_outputs, sizeof(n_outputs)); - - if (n_outputs > ctx->reserve_outputs(n_outputs)) { - throw std::runtime_error("could not reserve outputs"); - } - - if (n_outputs) { - output_pos.resize(n_outputs); - read_to(output_pos.data(), n_outputs * sizeof(int32_t)); - - for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { - int32_t id = output_pos[i]; - if ((uint32_t) id >= ctx->n_batch()) { - throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->n_batch())); - } - ctx->output_ids[id] = i; - } - - ctx->n_outputs = n_outputs; - } - } - - void read_logits() { - uint64_t logits_size; - read_to(&logits_size, sizeof(logits_size)); - - if (ctx->logits_size < logits_size) { - throw std::runtime_error("logits buffer too small"); - } - - if (logits_size) { - read_to(ctx->logits, logits_size * sizeof(float)); - } - } - - void read_embeddings() { - uint64_t embeddings_size; - read_to(&embeddings_size, sizeof(embeddings_size)); - - if (ctx->embd_size < embeddings_size) { - throw std::runtime_error("embeddings buffer too small"); - } - - if (embeddings_size) { - read_to(ctx->embd, embeddings_size * sizeof(float)); - } - } - - llama_context_kv_self * ctx; -}; - -struct llama_data_write_dummy : llama_data_write { - llama_data_write_dummy(llama_context_kv_self * ctx) : llama_data_write(ctx) {} +class llama_io_write_dummy : public llama_io_write_i { +public: + llama_io_write_dummy() = default; void write(const void * /* src */, size_t size) override { size_written += size; } - void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { + void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { size_written += size; } - size_t get_size_written() override { + size_t n_bytes() override { return size_written; } size_t size_written = 0; }; -struct llama_data_write_buffer : llama_data_write { - llama_data_write_buffer( - llama_context_kv_self * ctx, - uint8_t * p, size_t len) : llama_data_write(ctx), ptr(p), buf_size(len) {} +class llama_io_write_buffer : public llama_io_write_i { +public: + llama_io_write_buffer( + uint8_t * p, size_t len) : ptr(p), buf_size(len) {} void write(const void * src, size_t size) override { if (size > buf_size) { @@ -3347,7 +3163,7 @@ struct llama_data_write_buffer : llama_data_write { buf_size -= size; } - void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { + void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { if (size > buf_size) { throw std::runtime_error("unexpectedly reached end of buffer"); } @@ -3357,7 +3173,7 @@ struct llama_data_write_buffer : llama_data_write { buf_size -= size; } - size_t get_size_written() override { + size_t n_bytes() override { return size_written; } @@ -3366,10 +3182,9 @@ struct llama_data_write_buffer : llama_data_write { size_t size_written = 0; }; -struct llama_data_read_buffer : llama_data_read { - llama_data_read_buffer( - llama_context_kv_self * ctx, - const uint8_t * p, size_t len) : llama_data_read(ctx), ptr(p), buf_size(len) {} +class llama_io_read_buffer : public llama_io_read_i { +public: + llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} const uint8_t * read(size_t size) override { const uint8_t * base_ptr = ptr; @@ -3386,7 +3201,7 @@ struct llama_data_read_buffer : llama_data_read { memcpy(dst, read(size), size); } - size_t get_size_read() override { + size_t n_bytes() override { return size_read; } @@ -3395,23 +3210,22 @@ struct llama_data_read_buffer : llama_data_read { size_t size_read = 0; }; -struct llama_data_write_file : llama_data_write { - llama_data_write_file( - llama_context_kv_self * ctx, - llama_file * f) : llama_data_write(ctx), file(f) {} +class llama_io_write_file : public llama_io_write_i { +public: + llama_io_write_file(llama_file * f) : file(f) {} void write(const void * src, size_t size) override { file->write_raw(src, size); size_written += size; } - void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override { + void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { temp_buffer.resize(size); ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); write(temp_buffer.data(), temp_buffer.size()); } - size_t get_size_written() override { + size_t n_bytes() override { return size_written; } @@ -3420,10 +3234,9 @@ struct llama_data_write_file : llama_data_write { std::vector temp_buffer; }; -struct llama_data_read_file : llama_data_read { - llama_data_read_file( - llama_context_kv_self * ctx, - llama_file * f) : llama_data_read(ctx), file(f) {} +class llama_io_read_file : public llama_io_read_i { +public: + llama_io_read_file(llama_file * f) : file(f) {} void read_to(void * dst, size_t size) override { file->read_raw(dst, size); @@ -3436,7 +3249,7 @@ struct llama_data_read_file : llama_data_read { return temp_buffer.data(); } - size_t get_size_read() override { + size_t n_bytes() override { return size_read; } @@ -3446,9 +3259,9 @@ struct llama_data_read_file : llama_data_read { }; size_t llama_context_kv_self::state_get_size() { - llama_data_write_dummy data_ctx(this); + llama_io_write_dummy io; try { - return state_get_data(data_ctx); + return state_get_data(io); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); return 0; @@ -3456,9 +3269,9 @@ size_t llama_context_kv_self::state_get_size() { } size_t llama_context_kv_self::state_get_data(uint8_t * dst, size_t size) { - llama_data_write_buffer data_ctx(this, dst, size); + llama_io_write_buffer io(dst, size); try { - return state_get_data(data_ctx); + return state_get_data(io); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); return 0; @@ -3466,9 +3279,9 @@ size_t llama_context_kv_self::state_get_data(uint8_t * dst, size_t size) { } size_t llama_context_kv_self::state_set_data(const uint8_t * src, size_t size) { - llama_data_read_buffer data_ctx(this, src, size); + llama_io_read_buffer io(src, size); try { - return state_set_data(data_ctx); + return state_set_data(io); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); return 0; @@ -3476,9 +3289,9 @@ size_t llama_context_kv_self::state_set_data(const uint8_t * src, size_t size) { } size_t llama_context_kv_self::state_seq_get_size(llama_seq_id seq_id) { - llama_data_write_dummy data_ctx(this); + llama_io_write_dummy io; try { - return state_seq_get_data(data_ctx, seq_id); + return state_seq_get_data(io, seq_id); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); return 0; @@ -3486,9 +3299,9 @@ size_t llama_context_kv_self::state_seq_get_size(llama_seq_id seq_id) { } size_t llama_context_kv_self::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) { - llama_data_write_buffer data_ctx(this, dst, size); + llama_io_write_buffer io(dst, size); try { - return state_seq_get_data(data_ctx, seq_id); + return state_seq_get_data(io, seq_id); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); return 0; @@ -3496,9 +3309,9 @@ size_t llama_context_kv_self::state_seq_get_data(llama_seq_id seq_id, uint8_t * } size_t llama_context_kv_self::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) { - llama_data_read_buffer data_ctx(this, src, size); + llama_io_read_buffer io(src, size); try { - return state_seq_set_data(data_ctx, seq_id); + return state_seq_set_data(io, seq_id); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); return 0; @@ -3536,8 +3349,8 @@ bool llama_context_kv_self::state_load_file(const char * filepath, llama_token * { const size_t n_state_size_cur = file.size() - file.tell(); - llama_data_read_file data_ctx(this, &file); - const size_t n_read = state_set_data(data_ctx); + llama_io_read_file io( &file); + const size_t n_read = state_set_data(io); if (n_read != n_state_size_cur) { LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); @@ -3559,8 +3372,8 @@ bool llama_context_kv_self::state_save_file(const char * filepath, const llama_t file.write_raw(tokens, sizeof(llama_token) * n_token_count); // save the context state using stream saving - llama_data_write_file data_ctx(this, &file); - state_get_data(data_ctx); + llama_io_write_file io(&file); + state_get_data(io); return true; } @@ -3595,8 +3408,8 @@ size_t llama_context_kv_self::state_seq_load_file(llama_seq_id seq_id, const cha // restore the context state { const size_t state_size = file.size() - file.tell(); - llama_data_read_file data_ctx(this, &file); - const size_t nread = state_seq_set_data(data_ctx, seq_id); + llama_io_read_file io(&file); + const size_t nread = state_seq_set_data(io, seq_id); if (!nread) { LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__); return 0; @@ -3619,116 +3432,171 @@ size_t llama_context_kv_self::state_seq_save_file(llama_seq_id seq_id, const cha file.write_raw(tokens, sizeof(llama_token) * n_token_count); // save the context state using stream saving - llama_data_write_file data_ctx(this, &file); - state_seq_get_data(data_ctx, seq_id); + llama_io_write_file io(&file); + state_seq_get_data(io, seq_id); const size_t res = file.tell(); - GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written()); + GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes()); return res; } -/** copy state data into either a buffer or file depending on the passed in context - * - * file context: - * llama_file file("/path", "wb"); - * llama_data_write_file data_ctx(&file); - * llama_state_get_data_internal(ctx, data_ctx); - * - * buffer context: - * std::vector buf(max_size, 0); - * llama_data_write_buffer data_ctx(buf.data(), max_size); - * llama_state_get_data_internal(ctx, data_ctx); - * -*/ -size_t llama_context_kv_self::state_get_data(llama_data_write & data_ctx) { +size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) { synchronize(); - data_ctx.write_model_info(); + // write model info + { + const std::string arch_str = llm_arch_name(model.arch); + io.write_string(arch_str); + // TODO: add more model-specific info which should prevent loading the session file if not identical + } - // copy outputs - data_ctx.write_output_ids(); - data_ctx.write_logits(); - data_ctx.write_embeddings(); + // write output ids + { + reorder_outputs(); - llama_kv_cache::io io = { - /* .write = */ [&](const void * src, size_t size) { - data_ctx.write(src, size); - }, - /* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) { - data_ctx.write_tensor_data(tensor, offset, size); - }, - /* .read = */ nullptr, - /* .read_to = */ nullptr, - }; + const uint32_t n_outputs = this->n_outputs; + const auto & output_ids = this->output_ids; + + std::vector w_output_pos; + + GGML_ASSERT(n_outputs <= output_size); + + w_output_pos.resize(n_outputs); + + // build a more compact representation of the output ids + for (size_t i = 0; i < n_batch(); ++i) { + // map an output id to a position in the batch + int32_t pos = output_ids[i]; + if (pos >= 0) { + GGML_ASSERT((uint32_t) pos < n_outputs); + w_output_pos[pos] = i; + } + } + + io.write(&n_outputs, sizeof(n_outputs)); + + if (n_outputs) { + io.write(w_output_pos.data(), n_outputs * sizeof(int32_t)); + } + } + + // write logits + { + const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens()); + + io.write(&logits_size, sizeof(logits_size)); + + if (logits_size) { + io.write(logits, logits_size * sizeof(float)); + } + } + + // write mbeddings + { + const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); + + io.write(&embd_size, sizeof(embd_size)); + + if (embd_size) { + io.write(embd, embd_size * sizeof(float)); + } + } kv_self.state_write(io, model.hparams); - return data_ctx.get_size_written(); + return io.n_bytes(); } -size_t llama_context_kv_self::state_set_data(llama_data_read & data_ctx) { +size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) { synchronize(); - data_ctx.read_model_info(); + // read model info + { + const std::string cur_arch_str = llm_arch_name(model.arch); - // set outputs - data_ctx.read_output_ids(); - data_ctx.read_logits(); - data_ctx.read_embeddings(); + std::string arch_str; + io.read_string(arch_str); + if (cur_arch_str != arch_str) { + throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); + } + // TODO: add more info which needs to be identical but which is not verified otherwise + } - llama_kv_cache::io io = { - /* .write = */ nullptr, - /* .write_tensor_data = */ nullptr, - /* .read = */ [&](size_t size) { - return data_ctx.read(size); - }, - /* .read_to = */ [&](void * dst, size_t size) { - data_ctx.read_to(dst, size); - }, - }; + // read output ids + { + std::vector output_pos; + + uint32_t n_outputs; + io.read_to(&n_outputs, sizeof(n_outputs)); + + if (n_outputs > reserve_outputs(n_outputs)) { + throw std::runtime_error("could not reserve outputs"); + } + + if (n_outputs) { + output_pos.resize(n_outputs); + io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); + + for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { + int32_t id = output_pos[i]; + if ((uint32_t) id >= n_batch()) { + throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch())); + } + this->output_ids[id] = i; + } + + this->n_outputs = n_outputs; + } + } + + // read logits + { + uint64_t logits_size; + io.read_to(&logits_size, sizeof(logits_size)); + + if (this->logits_size < logits_size) { + throw std::runtime_error("logits buffer too small"); + } + + if (logits_size) { + io.read_to(this->logits, logits_size * sizeof(float)); + } + } + + // read embeddings + { + uint64_t embd_size; + io.read_to(&embd_size, sizeof(embd_size)); + + if (this->embd_size < embd_size) { + throw std::runtime_error("embeddings buffer too small"); + } + + if (embd_size) { + io.read_to(this->embd, embd_size * sizeof(float)); + } + } kv_self.state_read(io, model.hparams); - return data_ctx.get_size_read(); + return io.n_bytes(); } -size_t llama_context_kv_self::state_seq_get_data(llama_data_write & data_ctx, llama_seq_id seq_id) { +size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) { synchronize(); - llama_kv_cache::io io = { - /* .write = */ [&](const void * src, size_t size) { - data_ctx.write(src, size); - }, - /* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) { - data_ctx.write_tensor_data(tensor, offset, size); - }, - /* .read = */ nullptr, - /* .read_to = */ nullptr, - }; - kv_self.state_write(io, model.hparams, seq_id); - return data_ctx.get_size_written(); + return io.n_bytes(); } -size_t llama_context_kv_self::state_seq_set_data(llama_data_read & data_ctx, llama_seq_id seq_id) { +size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) { synchronize(); - llama_kv_cache::io io = { - /* .write = */ nullptr, - /* .write_tensor_data = */ nullptr, - /* .read = */ [&](size_t size) { - return data_ctx.read(size); - }, - /* .read_to = */ [&](void * dst, size_t size) { - data_ctx.read_to(dst, size); - }, - }; - kv_self.state_read(io, model.hparams, seq_id); - return data_ctx.get_size_read(); + return io.n_bytes(); } // diff --git a/src/llama-context.h b/src/llama-context.h index 648a41045..204793d75 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -15,6 +15,9 @@ #include #include +class llama_io_read_i; +class llama_io_write_i; + using llama_loras = std::unordered_map; struct llama_context : public llama_graph_i { @@ -178,9 +181,10 @@ struct llama_context : public llama_graph_i { virtual llama_perf_context_data perf_get_data() const; virtual void perf_reset(); +protected: + // members -protected: const llama_model & model; llama_cparams cparams; @@ -502,11 +506,11 @@ public: size_t n_token_count) override; private: - size_t state_get_data(struct llama_data_write & data_ctx); - size_t state_set_data(struct llama_data_read & data_ctx); + size_t state_get_data(llama_io_write_i & io); + size_t state_set_data(llama_io_read_i & io); - size_t state_seq_get_data(struct llama_data_write & data_ctx, llama_seq_id seq_id); - size_t state_seq_set_data(struct llama_data_read & data_ctx, llama_seq_id seq_id); + size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id); + size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id); }; // For internal test use diff --git a/src/llama-io.cpp b/src/llama-io.cpp new file mode 100644 index 000000000..7ad70d163 --- /dev/null +++ b/src/llama-io.cpp @@ -0,0 +1,15 @@ +#include "llama-io.h" + +void llama_io_write_i::write_string(const std::string & str) { + uint32_t str_size = str.size(); + + write(&str_size, sizeof(str_size)); + write(str.data(), str_size); +} + +void llama_io_read_i::read_string(std::string & str) { + uint32_t str_size; + read_to(&str_size, sizeof(str_size)); + + str.assign((const char *) read(str_size), str_size); +} diff --git a/src/llama-io.h b/src/llama-io.h new file mode 100644 index 000000000..ce9216b83 --- /dev/null +++ b/src/llama-io.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include + +struct ggml_tensor; + +class llama_io_write_i { +public: + llama_io_write_i() = default; + virtual ~llama_io_write_i() = default; + + virtual void write(const void * src, size_t size) = 0; + virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0; + + // bytes written so far + virtual size_t n_bytes() = 0; + + void write_string(const std::string & str); +}; + +class llama_io_read_i { +public: + llama_io_read_i() = default; + virtual ~llama_io_read_i() = default; + + virtual const uint8_t * read(size_t size) = 0; + virtual void read_to(void * dst, size_t size) = 0; + + // bytes read so far + virtual size_t n_bytes() = 0; + + void read_string(std::string & str); +}; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index b79c2ff93..c93410f0a 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -698,7 +698,7 @@ size_t llama_kv_cache::size_v_bytes() const { return size_v_bytes; } -void llama_kv_cache::state_write(const io & io, const llama_hparams & hparams, llama_seq_id seq_id) const { +void llama_kv_cache::state_write(llama_io_write_i & io, const llama_hparams & hparams, llama_seq_id seq_id) const { std::vector> cell_ranges; // ranges, from inclusive, to exclusive uint32_t cell_count = 0; @@ -736,7 +736,7 @@ void llama_kv_cache::state_write(const io & io, const llama_hparams & hparams, l state_write_data(io, cell_ranges, hparams); } -void llama_kv_cache::state_read(const io & io, const llama_hparams & hparams, llama_seq_id seq_id) { +void llama_kv_cache::state_read(llama_io_read_i & io, const llama_hparams & hparams, llama_seq_id seq_id) { uint32_t cell_count; io.read_to(&cell_count, sizeof(cell_count)); @@ -754,7 +754,7 @@ void llama_kv_cache::state_read(const io & io, const llama_hparams & hparams, ll } } -void llama_kv_cache::state_write_meta(const io & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { +void llama_kv_cache::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { for (const auto & range : cell_ranges) { for (uint32_t i = range.first; i < range.second; ++i) { const auto & cell = cells[i]; @@ -773,7 +773,7 @@ void llama_kv_cache::state_write_meta(const io & io, const std::vector> & cell_ranges, const llama_hparams & hparams) const { +void llama_kv_cache::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges, const llama_hparams & hparams) const { const uint32_t v_trans = this->v_trans ? 1 : 0; const uint32_t n_layer = hparams.n_layer; @@ -799,7 +799,7 @@ void llama_kv_cache::state_write_data(const io & io, const std::vector write; - std::function write_tensor_data; - - std::function read; - std::function read_to; - }; - - void state_write(const io & io, const llama_hparams & hparams, llama_seq_id seq_id = -1) const; - void state_read (const io & io, const llama_hparams & hparams, llama_seq_id seq_id = -1); + void state_write(llama_io_write_i & io, const llama_hparams & hparams, llama_seq_id seq_id = -1) const; + void state_read (llama_io_read_i & io, const llama_hparams & hparams, llama_seq_id seq_id = -1); private: ggml_type type_k = GGML_TYPE_F16; @@ -132,11 +125,11 @@ private: std::vector ctxs; std::vector bufs; - void state_write_meta(const io & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; - void state_write_data(const io & io, const std::vector> & cell_ranges, const llama_hparams & hparams) const; + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges, const llama_hparams & hparams) const; - bool state_read_meta(const io & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(const io & io, const llama_hparams & hparams, uint32_t cell_count); + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, const llama_hparams & hparams, uint32_t cell_count); }; //