|
|
|
@ -332,10 +332,10 @@ void llama_context::perf_reset() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
// llama_context_unified
|
|
|
|
|
// llama_context_kv_self
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
llama_context_unified::llama_context_unified(
|
|
|
|
|
llama_context_kv_self::llama_context_kv_self(
|
|
|
|
|
const llama_model & model,
|
|
|
|
|
const llama_context_params & params) : llama_context(model) {
|
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
@ -636,29 +636,29 @@ llama_context_unified::llama_context_unified(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llama_context_unified::~llama_context_unified() = default;
|
|
|
|
|
llama_context_kv_self::~llama_context_kv_self() = default;
|
|
|
|
|
|
|
|
|
|
uint32_t llama_context_unified::n_seq_max() const {
|
|
|
|
|
uint32_t llama_context_kv_self::n_seq_max() const {
|
|
|
|
|
// TODO: add notion of n_seq_max to llama_kv_cache and use it here
|
|
|
|
|
return kv_self.size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llama_kv_cache * llama_context_unified::get_kv_self() {
|
|
|
|
|
llama_kv_cache * llama_context_kv_self::get_kv_self() {
|
|
|
|
|
return &kv_self;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const llama_kv_cache * llama_context_unified::get_kv_self() const {
|
|
|
|
|
const llama_kv_cache * llama_context_kv_self::get_kv_self() const {
|
|
|
|
|
return &kv_self;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float * llama_context_unified::get_logits() {
|
|
|
|
|
float * llama_context_kv_self::get_logits() {
|
|
|
|
|
// reorder logits for backward compatibility
|
|
|
|
|
reorder_outputs();
|
|
|
|
|
|
|
|
|
|
return logits;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float * llama_context_unified::get_logits_ith(int32_t i) {
|
|
|
|
|
float * llama_context_kv_self::get_logits_ith(int32_t i) {
|
|
|
|
|
int32_t j = -1;
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
@ -696,14 +696,14 @@ float * llama_context_unified::get_logits_ith(int32_t i) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float * llama_context_unified::get_embeddings() {
|
|
|
|
|
float * llama_context_kv_self::get_embeddings() {
|
|
|
|
|
// reorder embeddings for backward compatibility
|
|
|
|
|
reorder_outputs();
|
|
|
|
|
|
|
|
|
|
return embd;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float * llama_context_unified::get_embeddings_ith(int32_t i) {
|
|
|
|
|
float * llama_context_kv_self::get_embeddings_ith(int32_t i) {
|
|
|
|
|
int32_t j = -1;
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
@ -741,7 +741,7 @@ float * llama_context_unified::get_embeddings_ith(int32_t i) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float * llama_context_unified::get_embeddings_seq(llama_seq_id seq_id) {
|
|
|
|
|
float * llama_context_kv_self::get_embeddings_seq(llama_seq_id seq_id) {
|
|
|
|
|
auto it = embd_seq.find(seq_id);
|
|
|
|
|
if (it == embd_seq.end()) {
|
|
|
|
|
return nullptr;
|
|
|
|
@ -750,7 +750,7 @@ float * llama_context_unified::get_embeddings_seq(llama_seq_id seq_id) {
|
|
|
|
|
return it->second.data();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_context_ptr llama_context_unified::init() {
|
|
|
|
|
ggml_context_ptr llama_context_kv_self::init() {
|
|
|
|
|
inp_tokens = nullptr;
|
|
|
|
|
inp_embd = nullptr;
|
|
|
|
|
inp_pos = nullptr;
|
|
|
|
@ -771,8 +771,8 @@ ggml_context_ptr llama_context_unified::init() {
|
|
|
|
|
return llama_context::init();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct llama_context_unified::batch_manager {
|
|
|
|
|
batch_manager(llama_context_unified & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
|
|
|
|
|
struct llama_context_kv_self::batch_manager {
|
|
|
|
|
batch_manager(llama_context_kv_self & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
|
|
|
|
|
const auto & model = lctx.model;
|
|
|
|
|
const auto & cparams = lctx.cparams;
|
|
|
|
|
const auto & hparams = lctx.model.hparams;
|
|
|
|
@ -982,18 +982,18 @@ struct llama_context_unified::batch_manager {
|
|
|
|
|
|
|
|
|
|
int64_t n_outputs_all = 0;
|
|
|
|
|
|
|
|
|
|
llama_context_unified & lctx;
|
|
|
|
|
llama_context_kv_self & lctx;
|
|
|
|
|
|
|
|
|
|
const llama_batch & batch;
|
|
|
|
|
|
|
|
|
|
llama_kv_slot_restorer kv_slot_restorer;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<llama_context_unified::batch_manager> llama_context_unified::prepare_batch(const llama_batch & batch) {
|
|
|
|
|
std::unique_ptr<llama_context_kv_self::batch_manager> llama_context_kv_self::prepare_batch(const llama_batch & batch) {
|
|
|
|
|
return std::make_unique<batch_manager>(*this, batch);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int llama_context_unified::decode(llama_batch & inp_batch) {
|
|
|
|
|
int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|
|
|
|
is_encoding = false;
|
|
|
|
|
|
|
|
|
|
if (inp_batch.n_tokens == 0) {
|
|
|
|
@ -1198,7 +1198,7 @@ int llama_context_unified::decode(llama_batch & inp_batch) {
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int llama_context_unified::encode(llama_batch & inp_batch) {
|
|
|
|
|
int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
|
|
|
|
is_encoding = true;
|
|
|
|
|
|
|
|
|
|
if (inp_batch.n_tokens == 0) {
|
|
|
|
@ -1375,7 +1375,7 @@ int llama_context_unified::encode(llama_batch & inp_batch) {
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
enum ggml_status llama_context_unified::compute_graph(
|
|
|
|
|
enum ggml_status llama_context_kv_self::compute_graph(
|
|
|
|
|
ggml_cgraph * graph,
|
|
|
|
|
bool batched) {
|
|
|
|
|
int n_threads = batched ? cparams.n_threads_batch : cparams.n_threads;
|
|
|
|
@ -1402,23 +1402,23 @@ enum ggml_status llama_context_unified::compute_graph(
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llama_pos llama_context_unified::pos_max() const {
|
|
|
|
|
llama_pos llama_context_kv_self::pos_max() const {
|
|
|
|
|
return kv_self.pos_max();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint32_t llama_context_unified::get_ctx_padding(const llama_cparams & cparams) const {
|
|
|
|
|
uint32_t llama_context_kv_self::get_ctx_padding(const llama_cparams & cparams) const {
|
|
|
|
|
return kv_self.get_padding(cparams);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void llama_context_unified::prepare_k_shift() {
|
|
|
|
|
void llama_context_kv_self::prepare_k_shift() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void llama_context_unified::prepare_defrag() {
|
|
|
|
|
void llama_context_kv_self::prepare_defrag() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// llama input
|
|
|
|
|
|
|
|
|
|
void llama_context_unified::set_inputs(const llama_ubatch & ubatch) {
|
|
|
|
|
void llama_context_kv_self::set_inputs(const llama_ubatch & ubatch) {
|
|
|
|
|
const llama_hparams & hparams = model.hparams;
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
@ -1837,7 +1837,7 @@ void llama_context_unified::set_inputs(const llama_ubatch & ubatch) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void llama_context_unified::reorder_outputs() {
|
|
|
|
|
void llama_context_kv_self::reorder_outputs() {
|
|
|
|
|
std::vector<size_t> & out_ids = sbatch.out_ids;
|
|
|
|
|
if (!out_ids.empty()) {
|
|
|
|
|
const uint32_t n_vocab = model.vocab.n_tokens();
|
|
|
|
@ -1875,7 +1875,7 @@ void llama_context_unified::reorder_outputs() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t llama_context_unified::reserve_outputs(size_t n_outputs) {
|
|
|
|
|
size_t llama_context_kv_self::reserve_outputs(size_t n_outputs) {
|
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
|
const auto & vocab = model.vocab;
|
|
|
|
|
|
|
|
|
@ -1944,7 +1944,7 @@ size_t llama_context_unified::reserve_outputs(size_t n_outputs) {
|
|
|
|
|
return n_outputs_max;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void llama_context_unified::kv_self_update() {
|
|
|
|
|
void llama_context_kv_self::kv_self_update() {
|
|
|
|
|
auto & kv = kv_self;
|
|
|
|
|
|
|
|
|
|
if (kv.has_shift) {
|
|
|
|
@ -2009,7 +2009,7 @@ void llama_context_unified::kv_self_update() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void llama_context_unified::build_attn_inp(
|
|
|
|
|
void llama_context_kv_self::build_attn_inp(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
int32_t n_tokens,
|
|
|
|
|
bool causal,
|
|
|
|
@ -2040,7 +2040,7 @@ void llama_context_unified::build_attn_inp(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void llama_context_unified::build_attn_kv_store(
|
|
|
|
|
void llama_context_kv_self::build_attn_kv_store(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
ggml_cgraph * graph,
|
|
|
|
|
ggml_tensor * k_cur,
|
|
|
|
@ -2084,7 +2084,7 @@ void llama_context_unified::build_attn_kv_store(
|
|
|
|
|
ggml_build_forward_expand(graph, ggml_cpy(ctx0, v_cur, v_cache_view));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_attn_qkv(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_attn_qkv(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
ggml_cgraph * graph,
|
|
|
|
|
ggml_tensor * wo,
|
|
|
|
@ -2236,7 +2236,7 @@ ggml_tensor * llama_context_unified::build_attn_qkv(
|
|
|
|
|
return cur;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_soft_max_ext(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_soft_max_ext(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
ggml_tensor * kq,
|
|
|
|
|
float kq_scale) {
|
|
|
|
@ -2245,7 +2245,7 @@ ggml_tensor * llama_context_unified::build_soft_max_ext(
|
|
|
|
|
return ggml_soft_max_ext(ctx0, kq, inp_KQ_mask_cnv, kq_scale, hparams.f_max_alibi_bias);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_inp_embd(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_inp_embd(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
ggml_tensor * tok_embd,
|
|
|
|
|
const llama_ubatch & ubatch) {
|
|
|
|
@ -2295,7 +2295,7 @@ ggml_tensor * llama_context_unified::build_inp_embd(
|
|
|
|
|
return inpL;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_inp_pos(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_inp_pos(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
int32_t n_tokens) {
|
|
|
|
|
inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
|
|
|
|
@ -2304,7 +2304,7 @@ ggml_tensor * llama_context_unified::build_inp_pos(
|
|
|
|
|
return inp_pos;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_inp_out_ids(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_inp_out_ids(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
int32_t n_tokens,
|
|
|
|
|
bool worst_case) {
|
|
|
|
@ -2316,7 +2316,7 @@ ggml_tensor * llama_context_unified::build_inp_out_ids(
|
|
|
|
|
return inp_out_ids;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_inp_mean(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_inp_mean(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
int32_t n_tokens) {
|
|
|
|
|
inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
|
|
|
|
@ -2325,7 +2325,7 @@ ggml_tensor * llama_context_unified::build_inp_mean(
|
|
|
|
|
return inp_mean;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_inp_cls(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_inp_cls(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
int32_t n_tokens) {
|
|
|
|
|
inp_cls = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
|
|
|
|
@ -2334,7 +2334,7 @@ ggml_tensor * llama_context_unified::build_inp_cls(
|
|
|
|
|
return inp_cls;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void llama_context_unified::build_k_shift(
|
|
|
|
|
void llama_context_kv_self::build_k_shift(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
ggml_cgraph * graph) {
|
|
|
|
|
const auto & n_ctx = cparams.n_ctx;
|
|
|
|
@ -2406,7 +2406,7 @@ void llama_context_unified::build_k_shift(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void llama_context_unified::build_defrag(
|
|
|
|
|
void llama_context_kv_self::build_defrag(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
ggml_cgraph * graph) {
|
|
|
|
|
const auto & hparams = model.hparams;
|
|
|
|
@ -2676,7 +2676,7 @@ void llama_context_unified::build_defrag(
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_inp_embd_enc(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_inp_embd_enc(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
int32_t n_tokens,
|
|
|
|
|
bool worst_case) {
|
|
|
|
@ -2692,7 +2692,7 @@ ggml_tensor * llama_context_unified::build_inp_embd_enc(
|
|
|
|
|
return inp_embd_enc;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_inp_KQ_mask_cross(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
int32_t n_tokens,
|
|
|
|
|
bool worst_case) {
|
|
|
|
@ -2708,7 +2708,7 @@ ggml_tensor * llama_context_unified::build_inp_KQ_mask_cross(
|
|
|
|
|
return inp_KQ_mask_cross;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_inp_s_copy(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_inp_s_copy(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
bool worst_case) {
|
|
|
|
|
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
|
|
|
|
@ -2719,7 +2719,7 @@ ggml_tensor * llama_context_unified::build_inp_s_copy(
|
|
|
|
|
return inp_s_copy;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_inp_s_mask(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_inp_s_mask(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
bool worst_case) {
|
|
|
|
|
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
|
|
|
|
@ -2729,7 +2729,7 @@ ggml_tensor * llama_context_unified::build_inp_s_mask(
|
|
|
|
|
return inp_s_mask;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_copy_mask_state(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_copy_mask_state(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
ggml_cgraph * graph,
|
|
|
|
|
ggml_tensor * s,
|
|
|
|
@ -2764,7 +2764,7 @@ ggml_tensor * llama_context_unified::build_copy_mask_state(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: split
|
|
|
|
|
ggml_tensor * llama_context_unified::build_mamba_layer(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_mamba_layer(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
ggml_cgraph * graph,
|
|
|
|
|
ggml_tensor * cur,
|
|
|
|
@ -2900,7 +2900,7 @@ ggml_tensor * llama_context_unified::build_mamba_layer(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_rwkv_token_shift_load(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
ggml_cgraph * graph,
|
|
|
|
|
ggml_tensor * state_copy,
|
|
|
|
@ -2927,7 +2927,7 @@ ggml_tensor * llama_context_unified::build_rwkv_token_shift_load(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_rwkv_token_shift_store(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
ggml_tensor * token_shift,
|
|
|
|
|
const llama_ubatch & ubatch,
|
|
|
|
@ -2951,7 +2951,7 @@ ggml_tensor * llama_context_unified::build_rwkv_token_shift_store(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ggml_tensor * llama_context_unified::build_rwkv6_time_mix(
|
|
|
|
|
ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix(
|
|
|
|
|
ggml_context * ctx0,
|
|
|
|
|
ggml_cgraph * graph,
|
|
|
|
|
ggml_tensor * cur,
|
|
|
|
@ -3130,7 +3130,7 @@ ggml_tensor * llama_context_unified::build_rwkv6_time_mix(
|
|
|
|
|
|
|
|
|
|
// TODO: replace all non-fatal assertions with returned errors or exceptions
|
|
|
|
|
struct llama_data_write {
|
|
|
|
|
llama_data_write(llama_context_unified * ctx) : ctx(ctx) {}
|
|
|
|
|
llama_data_write(llama_context_kv_self * ctx) : ctx(ctx) {}
|
|
|
|
|
virtual ~llama_data_write() = default;
|
|
|
|
|
|
|
|
|
|
virtual void write(const void * src, size_t size) = 0;
|
|
|
|
@ -3215,11 +3215,11 @@ struct llama_data_write {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llama_context_unified * ctx;
|
|
|
|
|
llama_context_kv_self * ctx;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct llama_data_read {
|
|
|
|
|
llama_data_read(llama_context_unified * ctx) : ctx(ctx) {}
|
|
|
|
|
llama_data_read(llama_context_kv_self * ctx) : ctx(ctx) {}
|
|
|
|
|
virtual ~llama_data_read() = default;
|
|
|
|
|
|
|
|
|
|
virtual const uint8_t * read(size_t size) = 0;
|
|
|
|
@ -3311,11 +3311,11 @@ struct llama_data_read {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
llama_context_unified * ctx;
|
|
|
|
|
llama_context_kv_self * ctx;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct llama_data_write_dummy : llama_data_write {
|
|
|
|
|
llama_data_write_dummy(llama_context_unified * ctx) : llama_data_write(ctx) {}
|
|
|
|
|
llama_data_write_dummy(llama_context_kv_self * ctx) : llama_data_write(ctx) {}
|
|
|
|
|
|
|
|
|
|
void write(const void * /* src */, size_t size) override {
|
|
|
|
|
size_written += size;
|
|
|
|
@ -3334,7 +3334,7 @@ struct llama_data_write_dummy : llama_data_write {
|
|
|
|
|
|
|
|
|
|
struct llama_data_write_buffer : llama_data_write {
|
|
|
|
|
llama_data_write_buffer(
|
|
|
|
|
llama_context_unified * ctx,
|
|
|
|
|
llama_context_kv_self * ctx,
|
|
|
|
|
uint8_t * p, size_t len) : llama_data_write(ctx), ptr(p), buf_size(len) {}
|
|
|
|
|
|
|
|
|
|
void write(const void * src, size_t size) override {
|
|
|
|
@ -3368,7 +3368,7 @@ struct llama_data_write_buffer : llama_data_write {
|
|
|
|
|
|
|
|
|
|
struct llama_data_read_buffer : llama_data_read {
|
|
|
|
|
llama_data_read_buffer(
|
|
|
|
|
llama_context_unified * ctx,
|
|
|
|
|
llama_context_kv_self * ctx,
|
|
|
|
|
const uint8_t * p, size_t len) : llama_data_read(ctx), ptr(p), buf_size(len) {}
|
|
|
|
|
|
|
|
|
|
const uint8_t * read(size_t size) override {
|
|
|
|
@ -3397,7 +3397,7 @@ struct llama_data_read_buffer : llama_data_read {
|
|
|
|
|
|
|
|
|
|
struct llama_data_write_file : llama_data_write {
|
|
|
|
|
llama_data_write_file(
|
|
|
|
|
llama_context_unified * ctx,
|
|
|
|
|
llama_context_kv_self * ctx,
|
|
|
|
|
llama_file * f) : llama_data_write(ctx), file(f) {}
|
|
|
|
|
|
|
|
|
|
void write(const void * src, size_t size) override {
|
|
|
|
@ -3422,7 +3422,7 @@ struct llama_data_write_file : llama_data_write {
|
|
|
|
|
|
|
|
|
|
struct llama_data_read_file : llama_data_read {
|
|
|
|
|
llama_data_read_file(
|
|
|
|
|
llama_context_unified * ctx,
|
|
|
|
|
llama_context_kv_self * ctx,
|
|
|
|
|
llama_file * f) : llama_data_read(ctx), file(f) {}
|
|
|
|
|
|
|
|
|
|
void read_to(void * dst, size_t size) override {
|
|
|
|
@ -3445,7 +3445,7 @@ struct llama_data_read_file : llama_data_read {
|
|
|
|
|
std::vector<uint8_t> temp_buffer;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
size_t llama_context_unified::state_get_size() {
|
|
|
|
|
size_t llama_context_kv_self::state_get_size() {
|
|
|
|
|
llama_data_write_dummy data_ctx(this);
|
|
|
|
|
try {
|
|
|
|
|
return state_get_data(data_ctx);
|
|
|
|
@ -3455,7 +3455,7 @@ size_t llama_context_unified::state_get_size() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t llama_context_unified::state_get_data(uint8_t * dst, size_t size) {
|
|
|
|
|
size_t llama_context_kv_self::state_get_data(uint8_t * dst, size_t size) {
|
|
|
|
|
llama_data_write_buffer data_ctx(this, dst, size);
|
|
|
|
|
try {
|
|
|
|
|
return state_get_data(data_ctx);
|
|
|
|
@ -3465,7 +3465,7 @@ size_t llama_context_unified::state_get_data(uint8_t * dst, size_t size) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t llama_context_unified::state_set_data(const uint8_t * src, size_t size) {
|
|
|
|
|
size_t llama_context_kv_self::state_set_data(const uint8_t * src, size_t size) {
|
|
|
|
|
llama_data_read_buffer data_ctx(this, src, size);
|
|
|
|
|
try {
|
|
|
|
|
return state_set_data(data_ctx);
|
|
|
|
@ -3475,7 +3475,7 @@ size_t llama_context_unified::state_set_data(const uint8_t * src, size_t size) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t llama_context_unified::state_seq_get_size(llama_seq_id seq_id) {
|
|
|
|
|
size_t llama_context_kv_self::state_seq_get_size(llama_seq_id seq_id) {
|
|
|
|
|
llama_data_write_dummy data_ctx(this);
|
|
|
|
|
try {
|
|
|
|
|
return state_seq_get_data(data_ctx, seq_id);
|
|
|
|
@ -3485,7 +3485,7 @@ size_t llama_context_unified::state_seq_get_size(llama_seq_id seq_id) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t llama_context_unified::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
|
|
|
|
|
size_t llama_context_kv_self::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size) {
|
|
|
|
|
llama_data_write_buffer data_ctx(this, dst, size);
|
|
|
|
|
try {
|
|
|
|
|
return state_seq_get_data(data_ctx, seq_id);
|
|
|
|
@ -3495,7 +3495,7 @@ size_t llama_context_unified::state_seq_get_data(llama_seq_id seq_id, uint8_t *
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t llama_context_unified::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
|
|
|
|
|
size_t llama_context_kv_self::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size) {
|
|
|
|
|
llama_data_read_buffer data_ctx(this, src, size);
|
|
|
|
|
try {
|
|
|
|
|
return state_seq_set_data(data_ctx, seq_id);
|
|
|
|
@ -3505,7 +3505,7 @@ size_t llama_context_unified::state_seq_set_data(llama_seq_id seq_id, const uint
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool llama_context_unified::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
|
|
|
bool llama_context_kv_self::state_load_file(const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
|
|
|
llama_file file(filepath, "rb");
|
|
|
|
|
|
|
|
|
|
// sanity checks
|
|
|
|
@ -3548,7 +3548,7 @@ bool llama_context_unified::state_load_file(const char * filepath, llama_token *
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool llama_context_unified::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
|
|
|
|
bool llama_context_kv_self::state_save_file(const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
|
|
|
|
llama_file file(filepath, "wb");
|
|
|
|
|
|
|
|
|
|
file.write_u32(LLAMA_SESSION_MAGIC);
|
|
|
|
@ -3565,7 +3565,7 @@ bool llama_context_unified::state_save_file(const char * filepath, const llama_t
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t llama_context_unified::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
|
|
|
size_t llama_context_kv_self::state_seq_load_file(llama_seq_id seq_id, const char * filepath, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
|
|
|
|
llama_file file(filepath, "rb");
|
|
|
|
|
|
|
|
|
|
// version checks
|
|
|
|
@ -3608,7 +3608,7 @@ size_t llama_context_unified::state_seq_load_file(llama_seq_id seq_id, const cha
|
|
|
|
|
return file.tell();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t llama_context_unified::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
|
|
|
|
size_t llama_context_kv_self::state_seq_save_file(llama_seq_id seq_id, const char * filepath, const llama_token * tokens, size_t n_token_count) {
|
|
|
|
|
llama_file file(filepath, "wb");
|
|
|
|
|
|
|
|
|
|
file.write_u32(LLAMA_STATE_SEQ_MAGIC);
|
|
|
|
@ -3641,7 +3641,7 @@ size_t llama_context_unified::state_seq_save_file(llama_seq_id seq_id, const cha
|
|
|
|
|
* llama_state_get_data_internal(ctx, data_ctx);
|
|
|
|
|
*
|
|
|
|
|
*/
|
|
|
|
|
size_t llama_context_unified::state_get_data(llama_data_write & data_ctx) {
|
|
|
|
|
size_t llama_context_kv_self::state_get_data(llama_data_write & data_ctx) {
|
|
|
|
|
synchronize();
|
|
|
|
|
|
|
|
|
|
data_ctx.write_model_info();
|
|
|
|
@ -3667,7 +3667,7 @@ size_t llama_context_unified::state_get_data(llama_data_write & data_ctx) {
|
|
|
|
|
return data_ctx.get_size_written();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t llama_context_unified::state_set_data(llama_data_read & data_ctx) {
|
|
|
|
|
size_t llama_context_kv_self::state_set_data(llama_data_read & data_ctx) {
|
|
|
|
|
synchronize();
|
|
|
|
|
|
|
|
|
|
data_ctx.read_model_info();
|
|
|
|
@ -3693,7 +3693,7 @@ size_t llama_context_unified::state_set_data(llama_data_read & data_ctx) {
|
|
|
|
|
return data_ctx.get_size_read();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t llama_context_unified::state_seq_get_data(llama_data_write & data_ctx, llama_seq_id seq_id) {
|
|
|
|
|
size_t llama_context_kv_self::state_seq_get_data(llama_data_write & data_ctx, llama_seq_id seq_id) {
|
|
|
|
|
synchronize();
|
|
|
|
|
|
|
|
|
|
llama_kv_cache::io io = {
|
|
|
|
@ -3712,7 +3712,7 @@ size_t llama_context_unified::state_seq_get_data(llama_data_write & data_ctx, ll
|
|
|
|
|
return data_ctx.get_size_written();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t llama_context_unified::state_seq_set_data(llama_data_read & data_ctx, llama_seq_id seq_id) {
|
|
|
|
|
size_t llama_context_kv_self::state_seq_set_data(llama_data_read & data_ctx, llama_seq_id seq_id) {
|
|
|
|
|
synchronize();
|
|
|
|
|
|
|
|
|
|
llama_kv_cache::io io = {
|
|
|
|
|