mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-15 15:17:44 +00:00
context : abstract state read/write
ggml-ci
This commit is contained in:
@ -326,6 +326,361 @@ ggml_tensor * llama_context::build_rope_factors(int il) {
|
||||
return model.layers[il].rope_short;
|
||||
}
|
||||
|
||||
//
|
||||
// state
|
||||
//
|
||||
|
||||
class llama_io_write_dummy : public llama_io_write_i {
|
||||
public:
|
||||
llama_io_write_dummy() = default;
|
||||
|
||||
void write(const void * /* src */, size_t size) override {
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
return size_written;
|
||||
}
|
||||
|
||||
size_t size_written = 0;
|
||||
};
|
||||
|
||||
class llama_io_write_buffer : public llama_io_write_i {
|
||||
public:
|
||||
llama_io_write_buffer(
|
||||
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||
|
||||
void write(const void * src, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
memcpy(ptr, src, size);
|
||||
ptr += size;
|
||||
size_written += size;
|
||||
buf_size -= size;
|
||||
}
|
||||
|
||||
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
ggml_backend_tensor_get(tensor, ptr, offset, size);
|
||||
ptr += size;
|
||||
size_written += size;
|
||||
buf_size -= size;
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
return size_written;
|
||||
}
|
||||
|
||||
uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_written = 0;
|
||||
};
|
||||
|
||||
class llama_io_read_buffer : public llama_io_read_i {
|
||||
public:
|
||||
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||
|
||||
const uint8_t * read(size_t size) override {
|
||||
const uint8_t * base_ptr = ptr;
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
ptr += size;
|
||||
size_read += size;
|
||||
buf_size -= size;
|
||||
return base_ptr;
|
||||
}
|
||||
|
||||
void read_to(void * dst, size_t size) override {
|
||||
memcpy(dst, read(size), size);
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
return size_read;
|
||||
}
|
||||
|
||||
const uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_read = 0;
|
||||
};
|
||||
|
||||
class llama_io_write_file : public llama_io_write_i {
|
||||
public:
|
||||
llama_io_write_file(llama_file * f) : file(f) {}
|
||||
|
||||
void write(const void * src, size_t size) override {
|
||||
file->write_raw(src, size);
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
temp_buffer.resize(size);
|
||||
ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
|
||||
write(temp_buffer.data(), temp_buffer.size());
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
return size_written;
|
||||
}
|
||||
|
||||
llama_file * file;
|
||||
size_t size_written = 0;
|
||||
std::vector<uint8_t> temp_buffer;
|
||||
};
|
||||
|
||||
class llama_io_read_file : public llama_io_read_i {
|
||||
public:
|
||||
llama_io_read_file(llama_file * f) : file(f) {}
|
||||
|
||||
void read_to(void * dst, size_t size) override {
|
||||
file->read_raw(dst, size);
|
||||
size_read += size;
|
||||
}
|
||||
|
||||
const uint8_t * read(size_t size) override {
|
||||
temp_buffer.resize(size);
|
||||
read_to(temp_buffer.data(), size);
|
||||
return temp_buffer.data();
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
return size_read;
|
||||
}
|
||||
|
||||
llama_file * file;
|
||||
size_t size_read = 0;
|
||||
std::vector<uint8_t> temp_buffer;
|
||||
};
|
||||
|
||||
size_t llama_context::state_get_size() {
|
||||
llama_io_write_dummy io;
|
||||
try {
|
||||
return state_get_data(io);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context::state_get_data(uint8_t * dst, size_t size) {
|
||||
llama_io_write_buffer io(dst, size);
|
||||
try {
|
||||
return state_get_data(io);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context::state_set_data(const uint8_t * src, size_t size) {
|
||||
llama_io_read_buffer io(src, size);
|
||||
try {
|
||||
return state_set_data(io);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_get_size(llama_seq_id seq_id) {
|
||||
llama_io_write_dummy io;
|
||||
try {
|
||||
return state_seq_get_data(io, seq_id);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
|
||||
llama_io_write_buffer io(dst, size);
|
||||
try {
|
||||
return state_seq_get_data(io, seq_id);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
|
||||
llama_io_read_buffer io(src, size);
|
||||
try {
|
||||
return state_seq_set_data(io, seq_id);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_context::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||
llama_file file(filepath, "rb");
|
||||
|
||||
// sanity checks
|
||||
{
|
||||
const uint32_t magic = file.read_u32();
|
||||
const uint32_t version = file.read_u32();
|
||||
|
||||
if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
|
||||
LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// load the prompt
|
||||
{
|
||||
const uint32_t n_token_count = file.read_u32();
|
||||
|
||||
if (n_token_count > n_token_capacity) {
|
||||
LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
||||
return false;
|
||||
}
|
||||
|
||||
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
||||
*n_token_count_out = n_token_count;
|
||||
}
|
||||
|
||||
// restore the context state
|
||||
{
|
||||
const size_t n_state_size_cur = file.size() - file.tell();
|
||||
|
||||
llama_io_read_file io( &file);
|
||||
const size_t n_read = state_set_data(io);
|
||||
|
||||
if (n_read != n_state_size_cur) {
|
||||
LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_context::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
||||
llama_file file(filepath, "wb");
|
||||
|
||||
file.write_u32(LLAMA_SESSION_MAGIC);
|
||||
file.write_u32(LLAMA_SESSION_VERSION);
|
||||
|
||||
// save the prompt
|
||||
file.write_u32((uint32_t) n_token_count);
|
||||
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
||||
|
||||
// save the context state using stream saving
|
||||
llama_io_write_file io(&file);
|
||||
state_get_data(io);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||
llama_file file(filepath, "rb");
|
||||
|
||||
// version checks
|
||||
{
|
||||
const uint32_t magic = file.read_u32();
|
||||
const uint32_t version = file.read_u32();
|
||||
|
||||
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
|
||||
LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// load the prompt
|
||||
{
|
||||
const uint32_t n_token_count = file.read_u32();
|
||||
|
||||
if (n_token_count > n_token_capacity) {
|
||||
LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
||||
return 0;
|
||||
}
|
||||
|
||||
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
||||
*n_token_count_out = n_token_count;
|
||||
}
|
||||
|
||||
// restore the context state
|
||||
{
|
||||
const size_t state_size = file.size() - file.tell();
|
||||
llama_io_read_file io(&file);
|
||||
const size_t nread = state_seq_set_data(io, seq_id);
|
||||
if (!nread) {
|
||||
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
||||
return 0;
|
||||
}
|
||||
GGML_ASSERT(nread <= state_size);
|
||||
GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
|
||||
}
|
||||
|
||||
return file.tell();
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
||||
llama_file file(filepath, "wb");
|
||||
|
||||
file.write_u32(LLAMA_STATE_SEQ_MAGIC);
|
||||
file.write_u32(LLAMA_STATE_SEQ_VERSION);
|
||||
|
||||
// save the prompt
|
||||
file.write_u32((uint32_t) n_token_count);
|
||||
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
||||
|
||||
// save the context state using stream saving
|
||||
llama_io_write_file io(&file);
|
||||
state_seq_get_data(io, seq_id);
|
||||
|
||||
const size_t res = file.tell();
|
||||
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
size_t llama_context::state_get_data(llama_io_write_i & io) {
|
||||
// write model info
|
||||
{
|
||||
const std::string arch_str = llm_arch_name(model.arch);
|
||||
io.write_string(arch_str);
|
||||
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context::state_set_data(llama_io_read_i & io) {
|
||||
// read model info
|
||||
{
|
||||
const std::string cur_arch_str = llm_arch_name(model.arch);
|
||||
|
||||
std::string arch_str;
|
||||
io.read_string(arch_str);
|
||||
if (cur_arch_str != arch_str) {
|
||||
throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
|
||||
}
|
||||
// TODO: add more info which needs to be identical but which is not verified otherwise
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
||||
GGML_UNUSED(seq_id);
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
GGML_UNUSED(seq_id);
|
||||
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
void llama_context::perf_reset() {
|
||||
t_start_us = ggml_time_us();
|
||||
t_eval_us = n_eval = 0;
|
||||
@ -3123,333 +3478,10 @@ ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix(
|
||||
return cur;
|
||||
}
|
||||
|
||||
//
|
||||
// state
|
||||
//
|
||||
|
||||
// TODO: this needs a big rework
|
||||
|
||||
class llama_io_write_dummy : public llama_io_write_i {
|
||||
public:
|
||||
llama_io_write_dummy() = default;
|
||||
|
||||
void write(const void * /* src */, size_t size) override {
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
return size_written;
|
||||
}
|
||||
|
||||
size_t size_written = 0;
|
||||
};
|
||||
|
||||
class llama_io_write_buffer : public llama_io_write_i {
|
||||
public:
|
||||
llama_io_write_buffer(
|
||||
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||
|
||||
void write(const void * src, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
memcpy(ptr, src, size);
|
||||
ptr += size;
|
||||
size_written += size;
|
||||
buf_size -= size;
|
||||
}
|
||||
|
||||
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
ggml_backend_tensor_get(tensor, ptr, offset, size);
|
||||
ptr += size;
|
||||
size_written += size;
|
||||
buf_size -= size;
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
return size_written;
|
||||
}
|
||||
|
||||
uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_written = 0;
|
||||
};
|
||||
|
||||
class llama_io_read_buffer : public llama_io_read_i {
|
||||
public:
|
||||
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||
|
||||
const uint8_t * read(size_t size) override {
|
||||
const uint8_t * base_ptr = ptr;
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
ptr += size;
|
||||
size_read += size;
|
||||
buf_size -= size;
|
||||
return base_ptr;
|
||||
}
|
||||
|
||||
void read_to(void * dst, size_t size) override {
|
||||
memcpy(dst, read(size), size);
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
return size_read;
|
||||
}
|
||||
|
||||
const uint8_t * ptr;
|
||||
size_t buf_size = 0;
|
||||
size_t size_read = 0;
|
||||
};
|
||||
|
||||
class llama_io_write_file : public llama_io_write_i {
|
||||
public:
|
||||
llama_io_write_file(llama_file * f) : file(f) {}
|
||||
|
||||
void write(const void * src, size_t size) override {
|
||||
file->write_raw(src, size);
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
temp_buffer.resize(size);
|
||||
ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
|
||||
write(temp_buffer.data(), temp_buffer.size());
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
return size_written;
|
||||
}
|
||||
|
||||
llama_file * file;
|
||||
size_t size_written = 0;
|
||||
std::vector<uint8_t> temp_buffer;
|
||||
};
|
||||
|
||||
class llama_io_read_file : public llama_io_read_i {
|
||||
public:
|
||||
llama_io_read_file(llama_file * f) : file(f) {}
|
||||
|
||||
void read_to(void * dst, size_t size) override {
|
||||
file->read_raw(dst, size);
|
||||
size_read += size;
|
||||
}
|
||||
|
||||
const uint8_t * read(size_t size) override {
|
||||
temp_buffer.resize(size);
|
||||
read_to(temp_buffer.data(), size);
|
||||
return temp_buffer.data();
|
||||
}
|
||||
|
||||
size_t n_bytes() override {
|
||||
return size_read;
|
||||
}
|
||||
|
||||
llama_file * file;
|
||||
size_t size_read = 0;
|
||||
std::vector<uint8_t> temp_buffer;
|
||||
};
|
||||
|
||||
size_t llama_context_kv_self::state_get_size() {
|
||||
llama_io_write_dummy io;
|
||||
try {
|
||||
return state_get_data(io);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_get_data(uint8_t * dst, size_t size) {
|
||||
llama_io_write_buffer io(dst, size);
|
||||
try {
|
||||
return state_get_data(io);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_set_data(const uint8_t * src, size_t size) {
|
||||
llama_io_read_buffer io(src, size);
|
||||
try {
|
||||
return state_set_data(io);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_get_size(llama_seq_id seq_id) {
|
||||
llama_io_write_dummy io;
|
||||
try {
|
||||
return state_seq_get_data(io, seq_id);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
|
||||
llama_io_write_buffer io(dst, size);
|
||||
try {
|
||||
return state_seq_get_data(io, seq_id);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
|
||||
llama_io_read_buffer io(src, size);
|
||||
try {
|
||||
return state_seq_set_data(io, seq_id);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_context_kv_self::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||
llama_file file(filepath, "rb");
|
||||
|
||||
// sanity checks
|
||||
{
|
||||
const uint32_t magic = file.read_u32();
|
||||
const uint32_t version = file.read_u32();
|
||||
|
||||
if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
|
||||
LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// load the prompt
|
||||
{
|
||||
const uint32_t n_token_count = file.read_u32();
|
||||
|
||||
if (n_token_count > n_token_capacity) {
|
||||
LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
||||
return false;
|
||||
}
|
||||
|
||||
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
||||
*n_token_count_out = n_token_count;
|
||||
}
|
||||
|
||||
// restore the context state
|
||||
{
|
||||
const size_t n_state_size_cur = file.size() - file.tell();
|
||||
|
||||
llama_io_read_file io( &file);
|
||||
const size_t n_read = state_set_data(io);
|
||||
|
||||
if (n_read != n_state_size_cur) {
|
||||
LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_context_kv_self::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
||||
llama_file file(filepath, "wb");
|
||||
|
||||
file.write_u32(LLAMA_SESSION_MAGIC);
|
||||
file.write_u32(LLAMA_SESSION_VERSION);
|
||||
|
||||
// save the prompt
|
||||
file.write_u32((uint32_t) n_token_count);
|
||||
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
||||
|
||||
// save the context state using stream saving
|
||||
llama_io_write_file io(&file);
|
||||
state_get_data(io);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||
llama_file file(filepath, "rb");
|
||||
|
||||
// version checks
|
||||
{
|
||||
const uint32_t magic = file.read_u32();
|
||||
const uint32_t version = file.read_u32();
|
||||
|
||||
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
|
||||
LLAMA_LOG_ERROR("%s: unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// load the prompt
|
||||
{
|
||||
const uint32_t n_token_count = file.read_u32();
|
||||
|
||||
if (n_token_count > n_token_capacity) {
|
||||
LLAMA_LOG_ERROR("%s: token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
||||
return 0;
|
||||
}
|
||||
|
||||
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
||||
*n_token_count_out = n_token_count;
|
||||
}
|
||||
|
||||
// restore the context state
|
||||
{
|
||||
const size_t state_size = file.size() - file.tell();
|
||||
llama_io_read_file io(&file);
|
||||
const size_t nread = state_seq_set_data(io, seq_id);
|
||||
if (!nread) {
|
||||
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
||||
return 0;
|
||||
}
|
||||
GGML_ASSERT(nread <= state_size);
|
||||
GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
|
||||
}
|
||||
|
||||
return file.tell();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
||||
llama_file file(filepath, "wb");
|
||||
|
||||
file.write_u32(LLAMA_STATE_SEQ_MAGIC);
|
||||
file.write_u32(LLAMA_STATE_SEQ_VERSION);
|
||||
|
||||
// save the prompt
|
||||
file.write_u32((uint32_t) n_token_count);
|
||||
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
||||
|
||||
// save the context state using stream saving
|
||||
llama_io_write_file io(&file);
|
||||
state_seq_get_data(io, seq_id);
|
||||
|
||||
const size_t res = file.tell();
|
||||
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
||||
|
||||
return res;
|
||||
}
|
||||
// state save/load
|
||||
|
||||
size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
|
||||
synchronize();
|
||||
|
||||
// write model info
|
||||
{
|
||||
const std::string arch_str = llm_arch_name(model.arch);
|
||||
io.write_string(arch_str);
|
||||
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
||||
}
|
||||
llama_context::state_get_data(io);
|
||||
|
||||
// write output ids
|
||||
{
|
||||
@ -3492,7 +3524,7 @@ size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
|
||||
}
|
||||
}
|
||||
|
||||
// write mbeddings
|
||||
// write embeddings
|
||||
{
|
||||
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
|
||||
|
||||
@ -3509,19 +3541,7 @@ size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
|
||||
synchronize();
|
||||
|
||||
// read model info
|
||||
{
|
||||
const std::string cur_arch_str = llm_arch_name(model.arch);
|
||||
|
||||
std::string arch_str;
|
||||
io.read_string(arch_str);
|
||||
if (cur_arch_str != arch_str) {
|
||||
throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
|
||||
}
|
||||
// TODO: add more info which needs to be identical but which is not verified otherwise
|
||||
}
|
||||
llama_context::state_set_data(io);
|
||||
|
||||
// read output ids
|
||||
{
|
||||
@ -3584,7 +3604,7 @@ size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
||||
synchronize();
|
||||
llama_context::state_seq_get_data(io, seq_id);
|
||||
|
||||
kv_self.state_write(io, model.hparams, seq_id);
|
||||
|
||||
@ -3592,7 +3612,7 @@ size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_se
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
synchronize();
|
||||
llama_context::state_seq_set_data(io, seq_id);
|
||||
|
||||
kv_self.state_read(io, model.hparams, seq_id);
|
||||
|
||||
@ -3937,15 +3957,21 @@ size_t llama_state_get_size(struct llama_context * ctx) {
|
||||
}
|
||||
|
||||
size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->state_get_data(dst, size);
|
||||
}
|
||||
|
||||
// Sets the state reading from the specified source address
|
||||
size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->state_set_data(src, size);
|
||||
}
|
||||
|
||||
bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||
ctx->synchronize();
|
||||
|
||||
try {
|
||||
return ctx->state_load_file(path_session, tokens_out, n_token_capacity, n_token_count_out);
|
||||
} catch (const std::exception & err) {
|
||||
@ -3955,6 +3981,8 @@ bool llama_state_load_file(struct llama_context * ctx, const char * path_session
|
||||
}
|
||||
|
||||
bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
||||
ctx->synchronize();
|
||||
|
||||
try {
|
||||
return ctx->state_save_file(path_session, tokens, n_token_count);
|
||||
} catch (const std::exception & err) {
|
||||
@ -3968,14 +3996,20 @@ size_t llama_state_seq_get_size(struct llama_context * ctx, llama_seq_id seq_id)
|
||||
}
|
||||
|
||||
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_t size, llama_seq_id seq_id) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->state_seq_get_data(seq_id, dst, size);
|
||||
}
|
||||
|
||||
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->state_seq_set_data(seq_id, src, size);
|
||||
}
|
||||
|
||||
size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
||||
ctx->synchronize();
|
||||
|
||||
try {
|
||||
return ctx->state_seq_save_file(seq_id, filepath, tokens, n_token_count);
|
||||
} catch (const std::exception & err) {
|
||||
@ -3985,6 +4019,8 @@ size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepa
|
||||
}
|
||||
|
||||
size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||
ctx->synchronize();
|
||||
|
||||
try {
|
||||
return ctx->state_seq_load_file(dest_seq_id, filepath, tokens_out, n_token_capacity, n_token_count_out);
|
||||
} catch (const std::exception & err) {
|
||||
|
@ -144,37 +144,37 @@ struct llama_context : public llama_graph_i {
|
||||
|
||||
// state save/load
|
||||
|
||||
virtual size_t state_get_size() = 0;
|
||||
virtual size_t state_get_data( uint8_t * dst, size_t size) = 0;
|
||||
virtual size_t state_set_data(const uint8_t * src, size_t size) = 0;
|
||||
virtual size_t state_get_size();
|
||||
virtual size_t state_get_data( uint8_t * dst, size_t size);
|
||||
virtual size_t state_set_data(const uint8_t * src, size_t size);
|
||||
|
||||
virtual size_t state_seq_get_size(llama_seq_id seq_id) = 0;
|
||||
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) = 0;
|
||||
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) = 0;
|
||||
virtual size_t state_seq_get_size(llama_seq_id seq_id);
|
||||
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size);
|
||||
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size);
|
||||
|
||||
virtual bool state_load_file(
|
||||
const char * filepath,
|
||||
llama_token * tokens_out,
|
||||
size_t n_token_capacity,
|
||||
size_t * n_token_count_out) = 0;
|
||||
size_t * n_token_count_out);
|
||||
|
||||
virtual bool state_save_file(
|
||||
const char * filepath,
|
||||
const llama_token * tokens,
|
||||
size_t n_token_count) = 0;
|
||||
size_t n_token_count);
|
||||
|
||||
virtual size_t state_seq_load_file(
|
||||
llama_seq_id seq_id,
|
||||
const char * filepath,
|
||||
llama_token * tokens_out,
|
||||
size_t n_token_capacity,
|
||||
size_t * n_token_count_out) = 0;
|
||||
size_t * n_token_count_out);
|
||||
|
||||
virtual size_t state_seq_save_file(
|
||||
llama_seq_id seq_id,
|
||||
const char * filepath,
|
||||
const llama_token * tokens,
|
||||
size_t n_token_count) = 0;
|
||||
size_t n_token_count);
|
||||
|
||||
// perf
|
||||
|
||||
@ -183,6 +183,14 @@ struct llama_context : public llama_graph_i {
|
||||
|
||||
protected:
|
||||
|
||||
// state save/load
|
||||
|
||||
virtual size_t state_get_data(llama_io_write_i & io);
|
||||
virtual size_t state_set_data(llama_io_read_i & io);
|
||||
|
||||
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id);
|
||||
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id);
|
||||
|
||||
// members
|
||||
|
||||
const llama_model & model;
|
||||
@ -471,46 +479,12 @@ public:
|
||||
int il,
|
||||
bool worst_case) override;
|
||||
|
||||
// state save/load
|
||||
protected:
|
||||
virtual size_t state_get_data(llama_io_write_i & io) override;
|
||||
virtual size_t state_set_data(llama_io_read_i & io) override;
|
||||
|
||||
virtual size_t state_get_size() override;
|
||||
virtual size_t state_get_data( uint8_t * dst, size_t size) override;
|
||||
virtual size_t state_set_data(const uint8_t * src, size_t size) override;
|
||||
|
||||
virtual size_t state_seq_get_size(llama_seq_id seq_id) override;
|
||||
virtual size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) override;
|
||||
virtual size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) override;
|
||||
|
||||
virtual bool state_load_file(
|
||||
const char * filepath,
|
||||
llama_token * tokens_out,
|
||||
size_t n_token_capacity,
|
||||
size_t * n_token_count_out) override;
|
||||
|
||||
virtual bool state_save_file(
|
||||
const char * filepath,
|
||||
const llama_token * tokens,
|
||||
size_t n_token_count) override;
|
||||
|
||||
virtual size_t state_seq_load_file(
|
||||
llama_seq_id seq_id,
|
||||
const char * filepath,
|
||||
llama_token * tokens_out,
|
||||
size_t n_token_capacity,
|
||||
size_t * n_token_count_out) override;
|
||||
|
||||
virtual size_t state_seq_save_file(
|
||||
llama_seq_id seq_id,
|
||||
const char * filepath,
|
||||
const llama_token * tokens,
|
||||
size_t n_token_count) override;
|
||||
|
||||
private:
|
||||
size_t state_get_data(llama_io_write_i & io);
|
||||
size_t state_set_data(llama_io_read_i & io);
|
||||
|
||||
size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id);
|
||||
size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id);
|
||||
virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
|
||||
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
|
||||
};
|
||||
|
||||
// For internal test use
|
||||
|
Reference in New Issue
Block a user