context : abstract state read/write

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-13 12:37:28 +02:00
parent 3a504d9a0b
commit f7c7757bab
2 changed files with 400 additions and 390 deletions

View File

@ -326,6 +326,361 @@ ggml_tensor * llama_context::build_rope_factors(int il) {
return model.layers[il].rope_short;
}
//
// state
//
class llama_io_write_dummy : public llama_io_write_i {
public:
llama_io_write_dummy() = default;
void write(const void * /* src */, size_t size) override {
size_written += size;
}
void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
size_written += size;
}
size_t n_bytes() override {
return size_written;
}
size_t size_written = 0;
};
class llama_io_write_buffer : public llama_io_write_i {
public:
llama_io_write_buffer(
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
void write(const void * src, size_t size) override {
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
}
memcpy(ptr, src, size);
ptr += size;
size_written += size;
buf_size -= size;
}
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
}
ggml_backend_tensor_get(tensor, ptr, offset, size);
ptr += size;
size_written += size;
buf_size -= size;
}
size_t n_bytes() override {
return size_written;
}
uint8_t * ptr;
size_t buf_size = 0;
size_t size_written = 0;
};
class llama_io_read_buffer : public llama_io_read_i {
public:
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
const uint8_t * read(size_t size) override {
const uint8_t * base_ptr = ptr;
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
}
ptr += size;
size_read += size;
buf_size -= size;
return base_ptr;
}
void read_to(void * dst, size_t size) override {
memcpy(dst, read(size), size);
}
size_t n_bytes() override {
return size_read;
}
const uint8_t * ptr;
size_t buf_size = 0;
size_t size_read = 0;
};
class llama_io_write_file : public llama_io_write_i {
public:
llama_io_write_file(llama_file * f) : file(f) {}
void write(const void * src, size_t size) override {
file->write_raw(src, size);
size_written += size;
}
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
temp_buffer.resize(size);
ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
write(temp_buffer.data(), temp_buffer.size());
}
size_t n_bytes() override {
return size_written;
}
llama_file * file;
size_t size_written = 0;
std::vector<uint8_t> temp_buffer;
};
class llama_io_read_file : public llama_io_read_i {
public:
llama_io_read_file(llama_file * f) : file(f) {}
void read_to(void * dst, size_t size) override {
file->read_raw(dst, size);
size_read += size;
}
const uint8_t * read(size_t size) override {
temp_buffer.resize(size);
read_to(temp_buffer.data(), size);
return temp_buffer.data();
}
size_t n_bytes() override {
return size_read;
}
llama_file * file;
size_t size_read = 0;
std::vector<uint8_t> temp_buffer;
};
size_t llama_context::state_get_size() {
llama_io_write_dummy io;
try {
return state_get_data(io);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
return 0;
}
}
size_t llama_context::state_get_data(uint8_t * dst, size_t size) {
llama_io_write_buffer io(dst, size);
try {
return state_get_data(io);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
return 0;
}
}
size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
llama_io_read_buffer io(src, size);
try {
return state_set_data(io);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
return 0;
}
}
size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
llama_io_write_dummy io;
try {
return state_seq_get_data(io, seq_id);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
return 0;
}
}
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
llama_io_write_buffer io(dst, size);
try {
return state_seq_get_data(io, seq_id);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
return 0;
}
}
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
llama_io_read_buffer io(src, size);
try {
return state_seq_set_data(io, seq_id);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
return 0;
}
}
bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
llama_file file(filepath, "rb");
// sanity checks
{
const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();
if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
return false;
}
}
// load the prompt
{
const uint32_t n_token_count = file.read_u32();
if (n_token_count > n_token_capacity) {
LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
return false;
}
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
*n_token_count_out = n_token_count;
}
// restore the context state
{
const size_t n_state_size_cur = file.size() - file.tell();
llama_io_read_file io( &file);
const size_t n_read = state_set_data(io);
if (n_read != n_state_size_cur) {
LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
return false;
}
}
return true;
}
bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
llama_file file(filepath, "wb");
file.write_u32(LLAMA_SESSION_MAGIC);
file.write_u32(LLAMA_SESSION_VERSION);
// save the prompt
file.write_u32((uint32_t) n_token_count);
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
// save the context state using stream saving
llama_io_write_file io(&file);
state_get_data(io);
return true;
}
size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
llama_file file(filepath, "rb");
// version checks
{
const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
return 0;
}
}
// load the prompt
{
const uint32_t n_token_count = file.read_u32();
if (n_token_count > n_token_capacity) {
LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
return 0;
}
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
*n_token_count_out = n_token_count;
}
// restore the context state
{
const size_t state_size = file.size() - file.tell();
llama_io_read_file io(&file);
const size_t nread = state_seq_set_data(io, seq_id);
if (!nread) {
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
return 0;
}
GGML_ASSERT(nread <= state_size);
GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
}
return file.tell();
}
size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
llama_file file(filepath, "wb");
file.write_u32(LLAMA_STATE_SEQ_MAGIC);
file.write_u32(LLAMA_STATE_SEQ_VERSION);
// save the prompt
file.write_u32((uint32_t) n_token_count);
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
// save the context state using stream saving
llama_io_write_file io(&file);
state_seq_get_data(io, seq_id);
const size_t res = file.tell();
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
return res;
}
size_t llama_context::state_get_data(llama_io_write_i & io) {
// write model info
{
const std::string arch_str = llm_arch_name(model.arch);
io.write_string(arch_str);
// TODO: add more model-specific info which should prevent loading the session file if not identical
}
return io.n_bytes();
}
size_t llama_context::state_set_data(llama_io_read_i & io) {
// read model info
{
const std::string cur_arch_str = llm_arch_name(model.arch);
std::string arch_str;
io.read_string(arch_str);
if (cur_arch_str != arch_str) {
throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
}
// TODO: add more info which needs to be identical but which is not verified otherwise
}
return io.n_bytes();
}
size_t llama_context::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
GGML_UNUSED(seq_id);
return io.n_bytes();
}
size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
GGML_UNUSED(seq_id);
return io.n_bytes();
}
void llama_context::perf_reset() {
t_start_us = ggml_time_us();
t_eval_us = n_eval = 0;
@ -3123,333 +3478,10 @@ ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix(
return cur;
}
//
// state
//
// TODO: this needs a big rework
class llama_io_write_dummy : public llama_io_write_i {
public:
llama_io_write_dummy() = default;
void write(const void * /* src */, size_t size) override {
size_written += size;
}
void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
size_written += size;
}
size_t n_bytes() override {
return size_written;
}
size_t size_written = 0;
};
class llama_io_write_buffer : public llama_io_write_i {
public:
llama_io_write_buffer(
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
void write(const void * src, size_t size) override {
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
}
memcpy(ptr, src, size);
ptr += size;
size_written += size;
buf_size -= size;
}
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
}
ggml_backend_tensor_get(tensor, ptr, offset, size);
ptr += size;
size_written += size;
buf_size -= size;
}
size_t n_bytes() override {
return size_written;
}
uint8_t * ptr;
size_t buf_size = 0;
size_t size_written = 0;
};
class llama_io_read_buffer : public llama_io_read_i {
public:
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
const uint8_t * read(size_t size) override {
const uint8_t * base_ptr = ptr;
if (size > buf_size) {
throw std::runtime_error("unexpectedly reached end of buffer");
}
ptr += size;
size_read += size;
buf_size -= size;
return base_ptr;
}
void read_to(void * dst, size_t size) override {
memcpy(dst, read(size), size);
}
size_t n_bytes() override {
return size_read;
}
const uint8_t * ptr;
size_t buf_size = 0;
size_t size_read = 0;
};
class llama_io_write_file : public llama_io_write_i {
public:
llama_io_write_file(llama_file * f) : file(f) {}
void write(const void * src, size_t size) override {
file->write_raw(src, size);
size_written += size;
}
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
temp_buffer.resize(size);
ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
write(temp_buffer.data(), temp_buffer.size());
}
size_t n_bytes() override {
return size_written;
}
llama_file * file;
size_t size_written = 0;
std::vector<uint8_t> temp_buffer;
};
class llama_io_read_file : public llama_io_read_i {
public:
llama_io_read_file(llama_file * f) : file(f) {}
void read_to(void * dst, size_t size) override {
file->read_raw(dst, size);
size_read += size;
}
const uint8_t * read(size_t size) override {
temp_buffer.resize(size);
read_to(temp_buffer.data(), size);
return temp_buffer.data();
}
size_t n_bytes() override {
return size_read;
}
llama_file * file;
size_t size_read = 0;
std::vector<uint8_t> temp_buffer;
};
size_t llama_context_kv_self::state_get_size() {
llama_io_write_dummy io;
try {
return state_get_data(io);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
return 0;
}
}
size_t llama_context_kv_self::state_get_data(uint8_t * dst, size_t size) {
llama_io_write_buffer io(dst, size);
try {
return state_get_data(io);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
return 0;
}
}
size_t llama_context_kv_self::state_set_data(const uint8_t * src, size_t size) {
llama_io_read_buffer io(src, size);
try {
return state_set_data(io);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
return 0;
}
}
size_t llama_context_kv_self::state_seq_get_size(llama_seq_id seq_id) {
llama_io_write_dummy io;
try {
return state_seq_get_data(io, seq_id);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
return 0;
}
}
size_t llama_context_kv_self::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
llama_io_write_buffer io(dst, size);
try {
return state_seq_get_data(io, seq_id);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
return 0;
}
}
size_t llama_context_kv_self::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
llama_io_read_buffer io(src, size);
try {
return state_seq_set_data(io, seq_id);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
return 0;
}
}
bool llama_context_kv_self::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
llama_file file(filepath, "rb");
// sanity checks
{
const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();
if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
return false;
}
}
// load the prompt
{
const uint32_t n_token_count = file.read_u32();
if (n_token_count > n_token_capacity) {
LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
return false;
}
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
*n_token_count_out = n_token_count;
}
// restore the context state
{
const size_t n_state_size_cur = file.size() - file.tell();
llama_io_read_file io( &file);
const size_t n_read = state_set_data(io);
if (n_read != n_state_size_cur) {
LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
return false;
}
}
return true;
}
bool llama_context_kv_self::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
llama_file file(filepath, "wb");
file.write_u32(LLAMA_SESSION_MAGIC);
file.write_u32(LLAMA_SESSION_VERSION);
// save the prompt
file.write_u32((uint32_t) n_token_count);
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
// save the context state using stream saving
llama_io_write_file io(&file);
state_get_data(io);
return true;
}
size_t llama_context_kv_self::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
llama_file file(filepath, "rb");
// version checks
{
const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
return 0;
}
}
// load the prompt
{
const uint32_t n_token_count = file.read_u32();
if (n_token_count > n_token_capacity) {
LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
return 0;
}
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
*n_token_count_out = n_token_count;
}
// restore the context state
{
const size_t state_size = file.size() - file.tell();
llama_io_read_file io(&file);
const size_t nread = state_seq_set_data(io, seq_id);
if (!nread) {
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
return 0;
}
GGML_ASSERT(nread <= state_size);
GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
}
return file.tell();
}
size_t llama_context_kv_self::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
llama_file file(filepath, "wb");
file.write_u32(LLAMA_STATE_SEQ_MAGIC);
file.write_u32(LLAMA_STATE_SEQ_VERSION);
// save the prompt
file.write_u32((uint32_t) n_token_count);
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
// save the context state using stream saving
llama_io_write_file io(&file);
state_seq_get_data(io, seq_id);
const size_t res = file.tell();
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
return res;
}
// state save/load
size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
synchronize();
// write model info
{
const std::string arch_str = llm_arch_name(model.arch);
io.write_string(arch_str);
// TODO: add more model-specific info which should prevent loading the session file if not identical
}
llama_context::state_get_data(io);
// write output ids
{
@ -3492,7 +3524,7 @@ size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
}
}
// write mbeddings
// write embeddings
{
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
@ -3509,19 +3541,7 @@ size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
}
size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
synchronize();
// read model info
{
const std::string cur_arch_str = llm_arch_name(model.arch);
std::string arch_str;
io.read_string(arch_str);
if (cur_arch_str != arch_str) {
throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
}
// TODO: add more info which needs to be identical but which is not verified otherwise
}
llama_context::state_set_data(io);
// read output ids
{
@ -3584,7 +3604,7 @@ size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
}
size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
synchronize();
llama_context::state_seq_get_data(io, seq_id);
kv_self.state_write(io, model.hparams, seq_id);
@ -3592,7 +3612,7 @@ size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_se
}
size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
synchronize();
llama_context::state_seq_set_data(io, seq_id);
kv_self.state_read(io, model.hparams, seq_id);
@ -3937,15 +3957,21 @@ size_t llama_state_get_size(struct llama_context * ctx) {
}
size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
ctx->synchronize();
return ctx->state_get_data(dst, size);
}
// Sets the state reading from the specified source address
size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
ctx->synchronize();
return ctx->state_set_data(src, size);
}
bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
ctx->synchronize();
try {
return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out);
} catch (const std::exception & err) {
@ -3955,6 +3981,8 @@ bool llama_state_load_file(struct llama_context * ctx, const char * path_session
}
bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
ctx->synchronize();
try {
return ctx->state_save_file(path_session, tokens, n_token_count);
} catch (const std::exception & err) {
@ -3968,14 +3996,20 @@ size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id)
}
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
ctx->synchronize();
return ctx->state_seq_get_data(seq_id, dst, size);
}
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
ctx->synchronize();
return ctx->state_seq_set_data(seq_id, src, size);
}
size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
ctx->synchronize();
try {
return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count);
} catch (const std::exception & err) {
@ -3985,6 +4019,8 @@ size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepa
}
size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
ctx->synchronize();
try {
return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out);
} catch (const std::exception & err) {

View File

@ -144,37 +144,37 @@ struct llama_context : public llama_graph_i {
// state save/load
virtual size_t state_get_size() = 0;
virtual size_t state_get_data( uint8_t * dst, size_t size) = 0;
virtual size_t state_set_data(const uint8_t * src, size_t size) = 0;
virtual size_t state_get_size();
virtual size_t state_get_data( uint8_t * dst, size_t size);
virtual size_t state_set_data(const uint8_t * src, size_t size);
virtual size_t state_seq_get_size(llama_seq_id seq_id) = 0;
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) = 0;
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) = 0;
virtual size_t state_seq_get_size(llama_seq_id seq_id);
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
virtual bool state_load_file(
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) = 0;
size_t * n_token_count_out);
virtual bool state_save_file(
const char * filepath,
const llama_token * tokens,
size_t n_token_count) = 0;
size_t n_token_count);
virtual size_t state_seq_load_file(
llama_seq_id seq_id,
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) = 0;
size_t * n_token_count_out);
virtual size_t state_seq_save_file(
llama_seq_id seq_id,
const char * filepath,
const llama_token * tokens,
size_t n_token_count) = 0;
size_t n_token_count);
// perf
@ -183,6 +183,14 @@ struct llama_context : public llama_graph_i {
protected:
// state save/load
virtual size_t state_get_data(llama_io_write_i & io);
virtual size_t state_set_data(llama_io_read_i & io);
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id);
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id);
// members
const llama_model & model;
@ -471,46 +479,12 @@ public:
int il,
bool worst_case) override;
// state save/load
protected:
virtual size_t state_get_data(llama_io_write_i & io) override;
virtual size_t state_set_data(llama_io_read_i & io) override;
virtual size_t state_get_size() override;
virtual size_t state_get_data( uint8_t * dst, size_t size) override;
virtual size_t state_set_data(const uint8_t * src, size_t size) override;
virtual size_t state_seq_get_size(llama_seq_id seq_id) override;
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override;
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override;
virtual bool state_load_file(
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) override;
virtual bool state_save_file(
const char * filepath,
const llama_token * tokens,
size_t n_token_count) override;
virtual size_t state_seq_load_file(
llama_seq_id seq_id,
const char * filepath,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out) override;
virtual size_t state_seq_save_file(
llama_seq_id seq_id,
const char * filepath,
const llama_token * tokens,
size_t n_token_count) override;
private:
size_t state_get_data(llama_io_write_i & io);
size_t state_set_data(llama_io_read_i & io);
size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id);
size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id);
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
};
// For internal test use