mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-18 08:37:43 +00:00
llama : introduce llama_io interfaces
ggml-ci
This commit is contained in:
@ -18,6 +18,7 @@ add_library(llama
|
||||
llama-graph.cpp
|
||||
llama-hparams.cpp
|
||||
llama-impl.cpp
|
||||
llama-io.cpp
|
||||
llama-kv-cache.cpp
|
||||
llama-mmap.cpp
|
||||
llama-model-loader.cpp
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-mmap.h"
|
||||
#include "llama-io.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
@ -3128,214 +3129,29 @@ ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix(
|
||||
|
||||
// TODO: this needs a big rework
|
||||
|
||||
// TODO: replace all non-fatal assertions with returned errors or exceptions
|
||||
struct llama_data_write {
|
||||
llama_data_write(llama_context_kv_self * ctx) : ctx(ctx) {}
|
||||
virtual ~llama_data_write() = default;
|
||||
|
||||
virtual void write(const void * src, size_t size) = 0;
|
||||
virtual void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) = 0;
|
||||
virtual size_t get_size_written() = 0;
|
||||
|
||||
void write_string(const std::string & str) {
|
||||
uint32_t str_size = str.size();
|
||||
|
||||
write(&str_size, sizeof(str_size));
|
||||
write(str.data(), str_size);
|
||||
}
|
||||
|
||||
void write_model_info() {
|
||||
const auto & model = ctx->get_model();
|
||||
const std::string arch_str = llm_arch_name(model.arch);
|
||||
write_string(arch_str);
|
||||
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
||||
}
|
||||
|
||||
//void write_rng(const std::mt19937 & rng) {
|
||||
// std::ostringstream rng_ss;
|
||||
// rng_ss << rng;
|
||||
|
||||
// const std::string & rng_str = rng_ss.str();
|
||||
|
||||
// write_string(rng_str);
|
||||
//}
|
||||
|
||||
void write_output_ids() {
|
||||
ctx->reorder_outputs();
|
||||
|
||||
const uint32_t n_outputs = ctx->n_outputs;
|
||||
|
||||
std::vector<int32_t> output_pos;
|
||||
|
||||
const size_t n_batch = ctx->n_batch();
|
||||
const auto & output_ids = ctx->output_ids;
|
||||
|
||||
GGML_ASSERT(n_outputs <= ctx->output_size);
|
||||
|
||||
output_pos.resize(n_outputs);
|
||||
|
||||
// build a more compact representation of the output ids
|
||||
for (size_t i = 0; i < n_batch; ++i) {
|
||||
// map an output id to a position in the batch
|
||||
int32_t pos = output_ids[i];
|
||||
if (pos >= 0) {
|
||||
GGML_ASSERT((uint32_t) pos < n_outputs);
|
||||
output_pos[pos] = i;
|
||||
}
|
||||
}
|
||||
|
||||
write(&n_outputs, sizeof(n_outputs));
|
||||
|
||||
if (n_outputs) {
|
||||
write(output_pos.data(), n_outputs * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
void write_logits() {
|
||||
const auto & model = ctx->get_model();
|
||||
|
||||
const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * model.vocab.n_tokens());
|
||||
|
||||
write(&logits_size, sizeof(logits_size));
|
||||
|
||||
if (logits_size) {
|
||||
write(ctx->logits, logits_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
void write_embeddings() {
|
||||
const auto & model = ctx->get_model();
|
||||
|
||||
const uint64_t embeddings_size = std::min((uint64_t) ctx->embd_size, (uint64_t) ctx->n_outputs * model.hparams.n_embd);
|
||||
|
||||
write(&embeddings_size, sizeof(embeddings_size));
|
||||
|
||||
if (embeddings_size) {
|
||||
write(ctx->embd, embeddings_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
llama_context_kv_self * ctx;
|
||||
};
|
||||
|
||||
struct llama_data_read {
|
||||
llama_data_read(llama_context_kv_self * ctx) : ctx(ctx) {}
|
||||
virtual ~llama_data_read() = default;
|
||||
|
||||
virtual const uint8_t * read(size_t size) = 0;
|
||||
virtual void read_to(void * dst, size_t size) = 0;
|
||||
virtual size_t get_size_read() = 0;
|
||||
|
||||
void read_string(std::string & str) {
|
||||
uint32_t str_size;
|
||||
read_to(&str_size, sizeof(str_size));
|
||||
|
||||
str.assign((const char *) read(str_size), str_size);
|
||||
}
|
||||
|
||||
// validate model information
|
||||
void read_model_info() {
|
||||
const auto & model = ctx->get_model();
|
||||
|
||||
const std::string cur_arch_str = llm_arch_name(model.arch);
|
||||
|
||||
std::string arch_str;
|
||||
read_string(arch_str);
|
||||
if (cur_arch_str != arch_str) {
|
||||
throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
|
||||
}
|
||||
// TODO: add more info which needs to be identical but which is not verified otherwise
|
||||
}
|
||||
|
||||
//void read_rng(std::mt19937 & rng) {
|
||||
// std::string rng_str;
|
||||
// read_string(rng_str);
|
||||
|
||||
// std::istringstream rng_ss(rng_str);
|
||||
// rng_ss >> rng;
|
||||
|
||||
// if (rng_ss.fail()) {
|
||||
// throw std::runtime_error("failed to load RNG state");
|
||||
// }
|
||||
//}
|
||||
|
||||
void read_output_ids() {
|
||||
std::vector<int32_t> output_pos;
|
||||
|
||||
uint32_t n_outputs;
|
||||
read_to(&n_outputs, sizeof(n_outputs));
|
||||
|
||||
if (n_outputs > ctx->reserve_outputs(n_outputs)) {
|
||||
throw std::runtime_error("could not reserve outputs");
|
||||
}
|
||||
|
||||
if (n_outputs) {
|
||||
output_pos.resize(n_outputs);
|
||||
read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
|
||||
int32_t id = output_pos[i];
|
||||
if ((uint32_t) id >= ctx->n_batch()) {
|
||||
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->n_batch()));
|
||||
}
|
||||
ctx->output_ids[id] = i;
|
||||
}
|
||||
|
||||
ctx->n_outputs = n_outputs;
|
||||
}
|
||||
}
|
||||
|
||||
void read_logits() {
|
||||
uint64_t logits_size;
|
||||
read_to(&logits_size, sizeof(logits_size));
|
||||
|
||||
if (ctx->logits_size < logits_size) {
|
||||
throw std::runtime_error("logits buffer too small");
|
||||
}
|
||||
|
||||
if (logits_size) {
|
||||
read_to(ctx->logits, logits_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
void read_embeddings() {
|
||||
uint64_t embeddings_size;
|
||||
read_to(&embeddings_size, sizeof(embeddings_size));
|
||||
|
||||
if (ctx->embd_size < embeddings_size) {
|
||||
throw std::runtime_error("embeddings buffer too small");
|
||||
}
|
||||
|
||||
if (embeddings_size) {
|
||||
read_to(ctx->embd, embeddings_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
llama_context_kv_self * ctx;
|
||||
};
|
||||
|
||||
struct llama_data_write_dummy : llama_data_write {
|
||||
llama_data_write_dummy(llama_context_kv_self * ctx) : llama_data_write(ctx) {}
|
||||
class llama_io_write_dummy : public llama_io_write_i {
|
||||
public:
|
||||
llama_io_write_dummy() = default;
|
||||
|
||||
void write(const void * /* src */, size_t size) override {
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
void write_tensor_data(const struct ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
|
||||
void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override {
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
size_t get_size_written() override {
|
||||
size_t n_bytes() override {
|
||||
return size_written;
|
||||
}
|
||||
|
||||
size_t size_written = 0;
|
||||
};
|
||||
|
||||
struct llama_data_write_buffer : llama_data_write {
|
||||
llama_data_write_buffer(
|
||||
llama_context_kv_self * ctx,
|
||||
uint8_t * p, size_t len) : llama_data_write(ctx), ptr(p), buf_size(len) {}
|
||||
class llama_io_write_buffer : public llama_io_write_i {
|
||||
public:
|
||||
llama_io_write_buffer(
|
||||
uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||
|
||||
void write(const void * src, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
@ -3347,7 +3163,7 @@ struct llama_data_write_buffer : llama_data_write {
|
||||
buf_size -= size;
|
||||
}
|
||||
|
||||
void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
if (size > buf_size) {
|
||||
throw std::runtime_error("unexpectedly reached end of buffer");
|
||||
}
|
||||
@ -3357,7 +3173,7 @@ struct llama_data_write_buffer : llama_data_write {
|
||||
buf_size -= size;
|
||||
}
|
||||
|
||||
size_t get_size_written() override {
|
||||
size_t n_bytes() override {
|
||||
return size_written;
|
||||
}
|
||||
|
||||
@ -3366,10 +3182,9 @@ struct llama_data_write_buffer : llama_data_write {
|
||||
size_t size_written = 0;
|
||||
};
|
||||
|
||||
struct llama_data_read_buffer : llama_data_read {
|
||||
llama_data_read_buffer(
|
||||
llama_context_kv_self * ctx,
|
||||
const uint8_t * p, size_t len) : llama_data_read(ctx), ptr(p), buf_size(len) {}
|
||||
class llama_io_read_buffer : public llama_io_read_i {
|
||||
public:
|
||||
llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {}
|
||||
|
||||
const uint8_t * read(size_t size) override {
|
||||
const uint8_t * base_ptr = ptr;
|
||||
@ -3386,7 +3201,7 @@ struct llama_data_read_buffer : llama_data_read {
|
||||
memcpy(dst, read(size), size);
|
||||
}
|
||||
|
||||
size_t get_size_read() override {
|
||||
size_t n_bytes() override {
|
||||
return size_read;
|
||||
}
|
||||
|
||||
@ -3395,23 +3210,22 @@ struct llama_data_read_buffer : llama_data_read {
|
||||
size_t size_read = 0;
|
||||
};
|
||||
|
||||
struct llama_data_write_file : llama_data_write {
|
||||
llama_data_write_file(
|
||||
llama_context_kv_self * ctx,
|
||||
llama_file * f) : llama_data_write(ctx), file(f) {}
|
||||
class llama_io_write_file : public llama_io_write_i {
|
||||
public:
|
||||
llama_io_write_file(llama_file * f) : file(f) {}
|
||||
|
||||
void write(const void * src, size_t size) override {
|
||||
file->write_raw(src, size);
|
||||
size_written += size;
|
||||
}
|
||||
|
||||
void write_tensor_data(const struct ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override {
|
||||
temp_buffer.resize(size);
|
||||
ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size);
|
||||
write(temp_buffer.data(), temp_buffer.size());
|
||||
}
|
||||
|
||||
size_t get_size_written() override {
|
||||
size_t n_bytes() override {
|
||||
return size_written;
|
||||
}
|
||||
|
||||
@ -3420,10 +3234,9 @@ struct llama_data_write_file : llama_data_write {
|
||||
std::vector<uint8_t> temp_buffer;
|
||||
};
|
||||
|
||||
struct llama_data_read_file : llama_data_read {
|
||||
llama_data_read_file(
|
||||
llama_context_kv_self * ctx,
|
||||
llama_file * f) : llama_data_read(ctx), file(f) {}
|
||||
class llama_io_read_file : public llama_io_read_i {
|
||||
public:
|
||||
llama_io_read_file(llama_file * f) : file(f) {}
|
||||
|
||||
void read_to(void * dst, size_t size) override {
|
||||
file->read_raw(dst, size);
|
||||
@ -3436,7 +3249,7 @@ struct llama_data_read_file : llama_data_read {
|
||||
return temp_buffer.data();
|
||||
}
|
||||
|
||||
size_t get_size_read() override {
|
||||
size_t n_bytes() override {
|
||||
return size_read;
|
||||
}
|
||||
|
||||
@ -3446,9 +3259,9 @@ struct llama_data_read_file : llama_data_read {
|
||||
};
|
||||
|
||||
size_t llama_context_kv_self::state_get_size() {
|
||||
llama_data_write_dummy data_ctx(this);
|
||||
llama_io_write_dummy io;
|
||||
try {
|
||||
return state_get_data(data_ctx);
|
||||
return state_get_data(io);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
@ -3456,9 +3269,9 @@ size_t llama_context_kv_self::state_get_size() {
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_get_data(uint8_t * dst, size_t size) {
|
||||
llama_data_write_buffer data_ctx(this, dst, size);
|
||||
llama_io_write_buffer io(dst, size);
|
||||
try {
|
||||
return state_get_data(data_ctx);
|
||||
return state_get_data(io);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
@ -3466,9 +3279,9 @@ size_t llama_context_kv_self::state_get_data(uint8_t * dst, size_t size) {
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_set_data(const uint8_t * src, size_t size) {
|
||||
llama_data_read_buffer data_ctx(this, src, size);
|
||||
llama_io_read_buffer io(src, size);
|
||||
try {
|
||||
return state_set_data(data_ctx);
|
||||
return state_set_data(io);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
@ -3476,9 +3289,9 @@ size_t llama_context_kv_self::state_set_data(const uint8_t * src, size_t size) {
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_get_size(llama_seq_id seq_id) {
|
||||
llama_data_write_dummy data_ctx(this);
|
||||
llama_io_write_dummy io;
|
||||
try {
|
||||
return state_seq_get_data(data_ctx, seq_id);
|
||||
return state_seq_get_data(io, seq_id);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
@ -3486,9 +3299,9 @@ size_t llama_context_kv_self::state_seq_get_size(llama_seq_id seq_id) {
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
|
||||
llama_data_write_buffer data_ctx(this, dst, size);
|
||||
llama_io_write_buffer io(dst, size);
|
||||
try {
|
||||
return state_seq_get_data(data_ctx, seq_id);
|
||||
return state_seq_get_data(io, seq_id);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
@ -3496,9 +3309,9 @@ size_t llama_context_kv_self::state_seq_get_data(llama_seq_id seq_id, uint8_t *
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
|
||||
llama_data_read_buffer data_ctx(this, src, size);
|
||||
llama_io_read_buffer io(src, size);
|
||||
try {
|
||||
return state_seq_set_data(data_ctx, seq_id);
|
||||
return state_seq_set_data(io, seq_id);
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
|
||||
return 0;
|
||||
@ -3536,8 +3349,8 @@ bool llama_context_kv_self::state_load_file(const char * filepath, llama_token *
|
||||
{
|
||||
const size_t n_state_size_cur = file.size() - file.tell();
|
||||
|
||||
llama_data_read_file data_ctx(this, &file);
|
||||
const size_t n_read = state_set_data(data_ctx);
|
||||
llama_io_read_file io( &file);
|
||||
const size_t n_read = state_set_data(io);
|
||||
|
||||
if (n_read != n_state_size_cur) {
|
||||
LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read);
|
||||
@ -3559,8 +3372,8 @@ bool llama_context_kv_self::state_save_file(const char * filepath, const llama_t
|
||||
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
||||
|
||||
// save the context state using stream saving
|
||||
llama_data_write_file data_ctx(this, &file);
|
||||
state_get_data(data_ctx);
|
||||
llama_io_write_file io(&file);
|
||||
state_get_data(io);
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -3595,8 +3408,8 @@ size_t llama_context_kv_self::state_seq_load_file(llama_seq_id seq_id, const cha
|
||||
// restore the context state
|
||||
{
|
||||
const size_t state_size = file.size() - file.tell();
|
||||
llama_data_read_file data_ctx(this, &file);
|
||||
const size_t nread = state_seq_set_data(data_ctx, seq_id);
|
||||
llama_io_read_file io(&file);
|
||||
const size_t nread = state_seq_set_data(io, seq_id);
|
||||
if (!nread) {
|
||||
LLAMA_LOG_ERROR("%s: failed to restore sequence state\n", __func__);
|
||||
return 0;
|
||||
@ -3619,116 +3432,171 @@ size_t llama_context_kv_self::state_seq_save_file(llama_seq_id seq_id, const cha
|
||||
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
||||
|
||||
// save the context state using stream saving
|
||||
llama_data_write_file data_ctx(this, &file);
|
||||
state_seq_get_data(data_ctx, seq_id);
|
||||
llama_io_write_file io(&file);
|
||||
state_seq_get_data(io, seq_id);
|
||||
|
||||
const size_t res = file.tell();
|
||||
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
|
||||
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + io.n_bytes());
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/** copy state data into either a buffer or file depending on the passed in context
|
||||
*
|
||||
* file context:
|
||||
* llama_file file("/path", "wb");
|
||||
* llama_data_write_file data_ctx(&file);
|
||||
* llama_state_get_data_internal(ctx, data_ctx);
|
||||
*
|
||||
* buffer context:
|
||||
* std::vector<uint8_t> buf(max_size, 0);
|
||||
* llama_data_write_buffer data_ctx(buf.data(), max_size);
|
||||
* llama_state_get_data_internal(ctx, data_ctx);
|
||||
*
|
||||
*/
|
||||
size_t llama_context_kv_self::state_get_data(llama_data_write & data_ctx) {
|
||||
size_t llama_context_kv_self::state_get_data(llama_io_write_i & io) {
|
||||
synchronize();
|
||||
|
||||
data_ctx.write_model_info();
|
||||
// write model info
|
||||
{
|
||||
const std::string arch_str = llm_arch_name(model.arch);
|
||||
io.write_string(arch_str);
|
||||
// TODO: add more model-specific info which should prevent loading the session file if not identical
|
||||
}
|
||||
|
||||
// copy outputs
|
||||
data_ctx.write_output_ids();
|
||||
data_ctx.write_logits();
|
||||
data_ctx.write_embeddings();
|
||||
// write output ids
|
||||
{
|
||||
reorder_outputs();
|
||||
|
||||
llama_kv_cache::io io = {
|
||||
/* .write = */ [&](const void * src, size_t size) {
|
||||
data_ctx.write(src, size);
|
||||
},
|
||||
/* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
|
||||
data_ctx.write_tensor_data(tensor, offset, size);
|
||||
},
|
||||
/* .read = */ nullptr,
|
||||
/* .read_to = */ nullptr,
|
||||
};
|
||||
const uint32_t n_outputs = this->n_outputs;
|
||||
const auto & output_ids = this->output_ids;
|
||||
|
||||
std::vector<int32_t> w_output_pos;
|
||||
|
||||
GGML_ASSERT(n_outputs <= output_size);
|
||||
|
||||
w_output_pos.resize(n_outputs);
|
||||
|
||||
// build a more compact representation of the output ids
|
||||
for (size_t i = 0; i < n_batch(); ++i) {
|
||||
// map an output id to a position in the batch
|
||||
int32_t pos = output_ids[i];
|
||||
if (pos >= 0) {
|
||||
GGML_ASSERT((uint32_t) pos < n_outputs);
|
||||
w_output_pos[pos] = i;
|
||||
}
|
||||
}
|
||||
|
||||
io.write(&n_outputs, sizeof(n_outputs));
|
||||
|
||||
if (n_outputs) {
|
||||
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
// write logits
|
||||
{
|
||||
const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
|
||||
|
||||
io.write(&logits_size, sizeof(logits_size));
|
||||
|
||||
if (logits_size) {
|
||||
io.write(logits, logits_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// write mbeddings
|
||||
{
|
||||
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
|
||||
|
||||
io.write(&embd_size, sizeof(embd_size));
|
||||
|
||||
if (embd_size) {
|
||||
io.write(embd, embd_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
kv_self.state_write(io, model.hparams);
|
||||
|
||||
return data_ctx.get_size_written();
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_set_data(llama_data_read & data_ctx) {
|
||||
size_t llama_context_kv_self::state_set_data(llama_io_read_i & io) {
|
||||
synchronize();
|
||||
|
||||
data_ctx.read_model_info();
|
||||
// read model info
|
||||
{
|
||||
const std::string cur_arch_str = llm_arch_name(model.arch);
|
||||
|
||||
// set outputs
|
||||
data_ctx.read_output_ids();
|
||||
data_ctx.read_logits();
|
||||
data_ctx.read_embeddings();
|
||||
std::string arch_str;
|
||||
io.read_string(arch_str);
|
||||
if (cur_arch_str != arch_str) {
|
||||
throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
|
||||
}
|
||||
// TODO: add more info which needs to be identical but which is not verified otherwise
|
||||
}
|
||||
|
||||
llama_kv_cache::io io = {
|
||||
/* .write = */ nullptr,
|
||||
/* .write_tensor_data = */ nullptr,
|
||||
/* .read = */ [&](size_t size) {
|
||||
return data_ctx.read(size);
|
||||
},
|
||||
/* .read_to = */ [&](void * dst, size_t size) {
|
||||
data_ctx.read_to(dst, size);
|
||||
},
|
||||
};
|
||||
// read output ids
|
||||
{
|
||||
std::vector<int32_t> output_pos;
|
||||
|
||||
uint32_t n_outputs;
|
||||
io.read_to(&n_outputs, sizeof(n_outputs));
|
||||
|
||||
if (n_outputs > reserve_outputs(n_outputs)) {
|
||||
throw std::runtime_error("could not reserve outputs");
|
||||
}
|
||||
|
||||
if (n_outputs) {
|
||||
output_pos.resize(n_outputs);
|
||||
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
|
||||
int32_t id = output_pos[i];
|
||||
if ((uint32_t) id >= n_batch()) {
|
||||
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
|
||||
}
|
||||
this->output_ids[id] = i;
|
||||
}
|
||||
|
||||
this->n_outputs = n_outputs;
|
||||
}
|
||||
}
|
||||
|
||||
// read logits
|
||||
{
|
||||
uint64_t logits_size;
|
||||
io.read_to(&logits_size, sizeof(logits_size));
|
||||
|
||||
if (this->logits_size < logits_size) {
|
||||
throw std::runtime_error("logits buffer too small");
|
||||
}
|
||||
|
||||
if (logits_size) {
|
||||
io.read_to(this->logits, logits_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// read embeddings
|
||||
{
|
||||
uint64_t embd_size;
|
||||
io.read_to(&embd_size, sizeof(embd_size));
|
||||
|
||||
if (this->embd_size < embd_size) {
|
||||
throw std::runtime_error("embeddings buffer too small");
|
||||
}
|
||||
|
||||
if (embd_size) {
|
||||
io.read_to(this->embd, embd_size * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
kv_self.state_read(io, model.hparams);
|
||||
|
||||
return data_ctx.get_size_read();
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_get_data(llama_data_write & data_ctx, llama_seq_id seq_id) {
|
||||
size_t llama_context_kv_self::state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
||||
synchronize();
|
||||
|
||||
llama_kv_cache::io io = {
|
||||
/* .write = */ [&](const void * src, size_t size) {
|
||||
data_ctx.write(src, size);
|
||||
},
|
||||
/* .write_tensor_data = */ [&](const struct ggml_tensor * tensor, size_t offset, size_t size) {
|
||||
data_ctx.write_tensor_data(tensor, offset, size);
|
||||
},
|
||||
/* .read = */ nullptr,
|
||||
/* .read_to = */ nullptr,
|
||||
};
|
||||
|
||||
kv_self.state_write(io, model.hparams, seq_id);
|
||||
|
||||
return data_ctx.get_size_written();
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
size_t llama_context_kv_self::state_seq_set_data(llama_data_read & data_ctx, llama_seq_id seq_id) {
|
||||
size_t llama_context_kv_self::state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
synchronize();
|
||||
|
||||
llama_kv_cache::io io = {
|
||||
/* .write = */ nullptr,
|
||||
/* .write_tensor_data = */ nullptr,
|
||||
/* .read = */ [&](size_t size) {
|
||||
return data_ctx.read(size);
|
||||
},
|
||||
/* .read_to = */ [&](void * dst, size_t size) {
|
||||
data_ctx.read_to(dst, size);
|
||||
},
|
||||
};
|
||||
|
||||
kv_self.state_read(io, model.hparams, seq_id);
|
||||
|
||||
return data_ctx.get_size_read();
|
||||
return io.n_bytes();
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -15,6 +15,9 @@
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
class llama_io_read_i;
|
||||
class llama_io_write_i;
|
||||
|
||||
using llama_loras = std::unordered_map<struct llama_adapter_lora *, float>;
|
||||
|
||||
struct llama_context : public llama_graph_i {
|
||||
@ -178,9 +181,10 @@ struct llama_context : public llama_graph_i {
|
||||
virtual llama_perf_context_data perf_get_data() const;
|
||||
virtual void perf_reset();
|
||||
|
||||
protected:
|
||||
|
||||
// members
|
||||
|
||||
protected:
|
||||
const llama_model & model;
|
||||
|
||||
llama_cparams cparams;
|
||||
@ -502,11 +506,11 @@ public:
|
||||
size_t n_token_count) override;
|
||||
|
||||
private:
|
||||
size_t state_get_data(struct llama_data_write & data_ctx);
|
||||
size_t state_set_data(struct llama_data_read & data_ctx);
|
||||
size_t state_get_data(llama_io_write_i & io);
|
||||
size_t state_set_data(llama_io_read_i & io);
|
||||
|
||||
size_t state_seq_get_data(struct llama_data_write & data_ctx, llama_seq_id seq_id);
|
||||
size_t state_seq_set_data(struct llama_data_read & data_ctx, llama_seq_id seq_id);
|
||||
size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id);
|
||||
size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id);
|
||||
};
|
||||
|
||||
// For internal test use
|
||||
|
15
src/llama-io.cpp
Normal file
15
src/llama-io.cpp
Normal file
@ -0,0 +1,15 @@
|
||||
#include "llama-io.h"
|
||||
|
||||
void llama_io_write_i::write_string(const std::string & str) {
|
||||
uint32_t str_size = str.size();
|
||||
|
||||
write(&str_size, sizeof(str_size));
|
||||
write(str.data(), str_size);
|
||||
}
|
||||
|
||||
void llama_io_read_i::read_string(std::string & str) {
|
||||
uint32_t str_size;
|
||||
read_to(&str_size, sizeof(str_size));
|
||||
|
||||
str.assign((const char *) read(str_size), str_size);
|
||||
}
|
35
src/llama-io.h
Normal file
35
src/llama-io.h
Normal file
@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
struct ggml_tensor;
|
||||
|
||||
class llama_io_write_i {
|
||||
public:
|
||||
llama_io_write_i() = default;
|
||||
virtual ~llama_io_write_i() = default;
|
||||
|
||||
virtual void write(const void * src, size_t size) = 0;
|
||||
virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0;
|
||||
|
||||
// bytes written so far
|
||||
virtual size_t n_bytes() = 0;
|
||||
|
||||
void write_string(const std::string & str);
|
||||
};
|
||||
|
||||
class llama_io_read_i {
|
||||
public:
|
||||
llama_io_read_i() = default;
|
||||
virtual ~llama_io_read_i() = default;
|
||||
|
||||
virtual const uint8_t * read(size_t size) = 0;
|
||||
virtual void read_to(void * dst, size_t size) = 0;
|
||||
|
||||
// bytes read so far
|
||||
virtual size_t n_bytes() = 0;
|
||||
|
||||
void read_string(std::string & str);
|
||||
};
|
@ -698,7 +698,7 @@ size_t llama_kv_cache::size_v_bytes() const {
|
||||
return size_v_bytes;
|
||||
}
|
||||
|
||||
void llama_kv_cache::state_write(const io & io, const llama_hparams & hparams, llama_seq_id seq_id) const {
|
||||
void llama_kv_cache::state_write(llama_io_write_i & io, const llama_hparams & hparams, llama_seq_id seq_id) const {
|
||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||
uint32_t cell_count = 0;
|
||||
|
||||
@ -736,7 +736,7 @@ void llama_kv_cache::state_write(const io & io, const llama_hparams & hparams, l
|
||||
state_write_data(io, cell_ranges, hparams);
|
||||
}
|
||||
|
||||
void llama_kv_cache::state_read(const io & io, const llama_hparams & hparams, llama_seq_id seq_id) {
|
||||
void llama_kv_cache::state_read(llama_io_read_i & io, const llama_hparams & hparams, llama_seq_id seq_id) {
|
||||
uint32_t cell_count;
|
||||
io.read_to(&cell_count, sizeof(cell_count));
|
||||
|
||||
@ -754,7 +754,7 @@ void llama_kv_cache::state_read(const io & io, const llama_hparams & hparams, ll
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache::state_write_meta(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||
void llama_kv_cache::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||
for (const auto & range : cell_ranges) {
|
||||
for (uint32_t i = range.first; i < range.second; ++i) {
|
||||
const auto & cell = cells[i];
|
||||
@ -773,7 +773,7 @@ void llama_kv_cache::state_write_meta(const io & io, const std::vector<std::pair
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache::state_write_data(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const {
|
||||
void llama_kv_cache::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const {
|
||||
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
||||
@ -799,7 +799,7 @@ void llama_kv_cache::state_write_data(const io & io, const std::vector<std::pair
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t buf_size = range_size * k_size_row;
|
||||
io.write_tensor_data(k_l[il], range.first * k_size_row, buf_size);
|
||||
io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
|
||||
}
|
||||
}
|
||||
|
||||
@ -819,7 +819,7 @@ void llama_kv_cache::state_write_data(const io & io, const std::vector<std::pair
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t buf_size = range_size * v_size_row;
|
||||
io.write_tensor_data(v_l[il], range.first * v_size_row, buf_size);
|
||||
io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -846,14 +846,14 @@ void llama_kv_cache::state_write_data(const io & io, const std::vector<std::pair
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
||||
const size_t buf_size = range_size * v_size_el;
|
||||
io.write_tensor_data(v_l[il], src_offset, buf_size);
|
||||
io.write_tensor(v_l[il], src_offset, buf_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache::state_read_meta(const io & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
if (dest_seq_id != -1) {
|
||||
// single sequence
|
||||
|
||||
@ -955,7 +955,7 @@ bool llama_kv_cache::state_read_meta(const io & io, uint32_t cell_count, llama_s
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache::state_read_data(const io & io, const llama_hparams & hparams, uint32_t cell_count) {
|
||||
bool llama_kv_cache::state_read_data(llama_io_read_i & io, const llama_hparams & hparams, uint32_t cell_count) {
|
||||
uint32_t v_trans;
|
||||
uint32_t n_layer;
|
||||
io.read_to(&v_trans, sizeof(v_trans));
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "llama.h"
|
||||
#include "llama-io.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
@ -114,16 +115,8 @@ struct llama_kv_cache {
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
|
||||
struct io {
|
||||
std::function<void(const void * src, size_t size)> write;
|
||||
std::function<void(const struct ggml_tensor * tensor, size_t offset, size_t size)> write_tensor_data;
|
||||
|
||||
std::function<const uint8_t * (size_t size)> read;
|
||||
std::function<void(void * dst, size_t size)> read_to;
|
||||
};
|
||||
|
||||
void state_write(const io & io, const llama_hparams & hparams, llama_seq_id seq_id = -1) const;
|
||||
void state_read (const io & io, const llama_hparams & hparams, llama_seq_id seq_id = -1);
|
||||
void state_write(llama_io_write_i & io, const llama_hparams & hparams, llama_seq_id seq_id = -1) const;
|
||||
void state_read (llama_io_read_i & io, const llama_hparams & hparams, llama_seq_id seq_id = -1);
|
||||
|
||||
private:
|
||||
ggml_type type_k = GGML_TYPE_F16;
|
||||
@ -132,11 +125,11 @@ private:
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
void state_write_meta(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(const io & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const;
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, const llama_hparams & hparams) const;
|
||||
|
||||
bool state_read_meta(const io & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(const io & io, const llama_hparams & hparams, uint32_t cell_count);
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, const llama_hparams & hparams, uint32_t cell_count);
|
||||
};
|
||||
|
||||
//
|
||||
|
Reference in New Issue
Block a user