mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-29 05:33:37 -04:00
memory : migrate from llama_kv_cache to more generic llama_memory (#14006)
* memory : merge llama_kv_cache into llama_memory + new `llama_memory` API ggml-ci * context : fix casts ggml-ci
This commit is contained in:
@@ -2,9 +2,9 @@
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-io.h"
|
||||
#include "llama-memory.h"
|
||||
#include "llama-mmap.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstring>
|
||||
@@ -277,10 +277,9 @@ llama_context::llama_context(
|
||||
int n_nodes_tg = -1;
|
||||
|
||||
// simulate full KV cache
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
const auto kv_state = kv_self->init_full();
|
||||
if (!kv_state) {
|
||||
const auto mstate = memory->init_full();
|
||||
if (!mstate) {
|
||||
throw std::runtime_error("failed to initialize KV cache");
|
||||
}
|
||||
|
||||
@@ -288,7 +287,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve pp graph first so that buffers are only allocated once
|
||||
{
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
@@ -299,7 +298,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
{
|
||||
auto * gf = graph_reserve(1, 1, 1, kv_state.get());
|
||||
auto * gf = graph_reserve(1, 1, 1, mstate.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||
}
|
||||
@@ -310,7 +309,7 @@ llama_context::llama_context(
|
||||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
{
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
}
|
||||
@@ -419,14 +418,8 @@ uint32_t llama_context::n_threads_batch() const {
|
||||
return cparams.n_threads_batch;
|
||||
}
|
||||
|
||||
llama_kv_cache * llama_context::get_kv_self() {
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
return kv_self;
|
||||
}
|
||||
|
||||
const llama_kv_cache * llama_context::get_kv_self() const {
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
return kv_self;
|
||||
llama_memory_t llama_context::get_memory() const {
|
||||
return memory.get();
|
||||
}
|
||||
|
||||
void llama_context::kv_self_defrag_sched() {
|
||||
@@ -442,15 +435,13 @@ bool llama_context::kv_self_update(bool optimize) {
|
||||
return false;
|
||||
}
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
{
|
||||
// TODO: remove in the future
|
||||
optimize |= memory_force_optimize;
|
||||
memory_force_optimize = false;
|
||||
|
||||
const auto kv_state = kv_self->init_update(this, optimize);
|
||||
switch (kv_state->get_status()) {
|
||||
const auto mstate = memory->init_update(this, optimize);
|
||||
switch (mstate->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
// noop
|
||||
@@ -468,23 +459,25 @@ bool llama_context::kv_self_update(bool optimize) {
|
||||
}
|
||||
}
|
||||
|
||||
if (!kv_state->apply()) {
|
||||
if (!mstate->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// if the KV cache did any computation, we have to reserve a new worst-case graph
|
||||
const auto kv_state = kv_self->init_full();
|
||||
if (!kv_state) {
|
||||
throw std::runtime_error("failed to initialize memory state");
|
||||
}
|
||||
// if the memory module did any computation, we have to reserve a new worst-case graph
|
||||
{
|
||||
const auto mstate = memory->init_full();
|
||||
if (!mstate) {
|
||||
throw std::runtime_error("failed to initialize memory state");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -912,10 +905,8 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
}
|
||||
}
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
// temporary allocate memory for the input batch if needed
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
|
||||
@@ -977,21 +968,21 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
// handle any pending defrags/shifts
|
||||
kv_self_update(false);
|
||||
|
||||
llama_memory_state_ptr kv_state;
|
||||
llama_memory_state_ptr mstate;
|
||||
|
||||
while (true) {
|
||||
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
||||
if (!kv_state) {
|
||||
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
||||
if (!mstate) {
|
||||
return -2;
|
||||
}
|
||||
|
||||
switch (kv_state->get_status()) {
|
||||
switch (mstate->get_status()) {
|
||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||
{
|
||||
} break;
|
||||
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||
{
|
||||
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status());
|
||||
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
|
||||
|
||||
return -2;
|
||||
}
|
||||
@@ -1031,7 +1022,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
int64_t n_outputs_prev = 0;
|
||||
|
||||
do {
|
||||
const auto & ubatch = kv_state->get_ubatch();
|
||||
const auto & ubatch = mstate->get_ubatch();
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
@@ -1054,7 +1045,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
ggml_status status;
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status);
|
||||
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
|
||||
|
||||
if (!res) {
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||
@@ -1076,7 +1067,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
|
||||
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
|
||||
|
||||
llama_kv_self_seq_rm(this, s, pos_min[s], -1);
|
||||
memory->seq_rm(s, pos_min[s], -1);
|
||||
}
|
||||
|
||||
switch (status) {
|
||||
@@ -1170,7 +1161,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
}
|
||||
|
||||
n_outputs_prev += n_outputs;
|
||||
} while (kv_state->next());
|
||||
} while (mstate->next());
|
||||
|
||||
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
||||
n_outputs = n_outputs_all;
|
||||
@@ -1179,7 +1170,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
{
|
||||
bool sorted_output = true;
|
||||
|
||||
auto & out_ids = kv_state->out_ids();
|
||||
auto & out_ids = mstate->out_ids();
|
||||
|
||||
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
|
||||
|
||||
@@ -1847,11 +1838,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
||||
}
|
||||
}
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
if (kv_self != nullptr) {
|
||||
if (memory != nullptr) {
|
||||
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
||||
kv_self->state_write(io);
|
||||
memory->state_write(io);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
@@ -1938,9 +1927,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
||||
if (memory) {
|
||||
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->state_read(io);
|
||||
memory->state_read(io);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
@@ -1950,9 +1937,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
|
||||
GGML_UNUSED(seq_id);
|
||||
|
||||
if (memory) {
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->state_write(io, seq_id);
|
||||
memory->state_write(io, seq_id);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
@@ -1962,9 +1947,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
|
||||
GGML_UNUSED(seq_id);
|
||||
|
||||
if (memory) {
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->state_read(io, seq_id);
|
||||
memory->state_read(io, seq_id);
|
||||
}
|
||||
|
||||
return io.n_bytes();
|
||||
@@ -2069,9 +2052,7 @@ void llama_context::opt_epoch_iter(
|
||||
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
||||
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
||||
|
||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||
|
||||
kv_self->clear();
|
||||
memory->clear();
|
||||
|
||||
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
||||
batch.n_tokens = n_batch;
|
||||
@@ -2094,8 +2075,8 @@ void llama_context::opt_epoch_iter(
|
||||
|
||||
int64_t n_outputs_all = n_tokens_all;
|
||||
|
||||
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
|
||||
if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
|
||||
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
||||
break;
|
||||
}
|
||||
@@ -2108,17 +2089,17 @@ void llama_context::opt_epoch_iter(
|
||||
|
||||
uint32_t pos_batch = 0;
|
||||
do {
|
||||
const auto & ubatch = kv_state->get_ubatch();
|
||||
const auto & ubatch = mstate->get_ubatch();
|
||||
|
||||
n_outputs = ubatch.n_tokens;
|
||||
|
||||
if (!kv_state->apply()) {
|
||||
if (!mstate->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
|
||||
break;
|
||||
}
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get());
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
|
||||
|
||||
struct ggml_context * ctx_compute_opt;
|
||||
{
|
||||
@@ -2153,7 +2134,7 @@ void llama_context::opt_epoch_iter(
|
||||
ggml_free(ctx_compute_opt);
|
||||
|
||||
pos_batch += ubatch.n_tokens;
|
||||
} while (kv_state->next());
|
||||
} while (mstate->next());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2314,8 +2295,9 @@ const llama_model * llama_get_model(const llama_context * ctx) {
|
||||
return &ctx->get_model();
|
||||
}
|
||||
|
||||
// deprecated
|
||||
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
||||
return ctx->get_kv_self();
|
||||
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
|
||||
}
|
||||
|
||||
// deprecated
|
||||
@@ -2435,13 +2417,82 @@ int32_t llama_apply_adapter_cvec(
|
||||
return res ? 0 : -1;
|
||||
}
|
||||
|
||||
//
|
||||
// memory
|
||||
//
|
||||
|
||||
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
|
||||
return ctx->get_memory();
|
||||
}
|
||||
|
||||
void llama_memory_clear(llama_memory_t mem) {
|
||||
mem->clear();
|
||||
}
|
||||
|
||||
bool llama_memory_seq_rm(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
return mem->seq_rm(seq_id, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_seq_cp(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_seq_keep(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id) {
|
||||
mem->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_memory_seq_add(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
mem->seq_add(seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
void llama_memory_seq_div(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
mem->seq_div(seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_memory_seq_pos_min(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id) {
|
||||
return mem->seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_memory_seq_pos_max(
|
||||
llama_memory_t mem,
|
||||
llama_seq_id seq_id) {
|
||||
return mem->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
bool llama_memory_can_shift(llama_memory_t mem) {
|
||||
return mem->get_can_shift();
|
||||
}
|
||||
|
||||
//
|
||||
// kv cache
|
||||
//
|
||||
|
||||
// deprecated
|
||||
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
const auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return 0;
|
||||
}
|
||||
@@ -2463,7 +2514,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
||||
// deprecated
|
||||
// note: this is the same as above - will be removed anyway, so it's ok
|
||||
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
const auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return 0;
|
||||
}
|
||||
@@ -2483,12 +2534,12 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
||||
}
|
||||
|
||||
void llama_kv_self_clear(llama_context * ctx) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
kv->clear();
|
||||
llama_memory_clear(kv);
|
||||
}
|
||||
|
||||
bool llama_kv_self_seq_rm(
|
||||
@@ -2496,12 +2547,12 @@ bool llama_kv_self_seq_rm(
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return kv->seq_rm(seq_id, p0, p1);
|
||||
return llama_memory_seq_rm(kv, seq_id, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_cp(
|
||||
@@ -2510,21 +2561,21 @@ void llama_kv_self_seq_cp(
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
kv->seq_keep(seq_id);
|
||||
llama_memory_seq_keep(kv, seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_add(
|
||||
@@ -2533,12 +2584,12 @@ void llama_kv_self_seq_add(
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
kv->seq_add(seq_id, p0, p1, delta);
|
||||
llama_memory_seq_add(kv, seq_id, p0, p1, delta);
|
||||
}
|
||||
|
||||
void llama_kv_self_seq_div(
|
||||
@@ -2547,30 +2598,30 @@ void llama_kv_self_seq_div(
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d) {
|
||||
auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return;
|
||||
}
|
||||
|
||||
kv->seq_div(seq_id, p0, p1, d);
|
||||
llama_memory_seq_div(kv, seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return kv->seq_pos_min(seq_id);
|
||||
return llama_memory_seq_pos_min(kv, seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return kv->seq_pos_max(seq_id);
|
||||
return llama_memory_seq_pos_max(kv, seq_id);
|
||||
}
|
||||
|
||||
// deprecated
|
||||
@@ -2580,12 +2631,12 @@ void llama_kv_self_defrag(llama_context * ctx) {
|
||||
}
|
||||
|
||||
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||
const auto * kv = ctx->get_kv_self();
|
||||
auto * kv = llama_get_memory(ctx);
|
||||
if (!kv) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return kv->get_can_shift();
|
||||
return llama_memory_can_shift(kv);
|
||||
}
|
||||
|
||||
// llama state API
|
||||
|
Reference in New Issue
Block a user