From 7f37b6cf1e2c1b90bf0d9c8d91904b4b6c512748 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 5 Jun 2025 15:29:22 +0300 Subject: [PATCH] memory : migrate from llama_kv_cache to more generic llama_memory (#14006) * memory : merge llama_kv_cache into llama_memory + new `llama_memory` API ggml-ci * context : fix casts ggml-ci --- include/llama.h | 100 ++++++++++++-- src/CMakeLists.txt | 1 - src/llama-context.cpp | 223 ++++++++++++++++++------------ src/llama-context.h | 8 +- src/llama-graph.h | 2 +- src/llama-kv-cache-recurrent.h | 28 ++-- src/llama-kv-cache-unified-iswa.h | 28 ++-- src/llama-kv-cache-unified.h | 30 ++-- src/llama-kv-cache.cpp | 1 - src/llama-kv-cache.h | 41 ------ src/llama-memory.h | 82 +++++++---- 11 files changed, 324 insertions(+), 220 deletions(-) delete mode 100644 src/llama-kv-cache.cpp delete mode 100644 src/llama-kv-cache.h diff --git a/include/llama.h b/include/llama.h index da0f652cf..21808c881 100644 --- a/include/llama.h +++ b/include/llama.h @@ -61,7 +61,10 @@ extern "C" { struct llama_model; struct llama_context; struct llama_sampler; - struct llama_kv_cache; + + typedef struct llama_memory_i * llama_memory_t; + + struct llama_kv_cache; // DEPRECATED (use llama_memory instead) typedef int32_t llama_pos; typedef int32_t llama_token; @@ -493,9 +496,11 @@ extern "C" { DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); - LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx); + LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx); LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type + DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead"); + LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); @@ -609,7 +614,78 @@ extern "C" { int32_t il_end); // - // KV cache + // Memory + // + + // Clear the memory contents + LLAMA_API void llama_memory_clear(llama_memory_t mem); + + // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails + // seq_id < 0 : match any sequence + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API bool llama_memory_seq_rm( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1); + + // Copy all tokens that belong to the specified sequence to another sequence + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API void llama_memory_seq_cp( + llama_memory_t mem, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1); + + // Removes all tokens that do not belong to the specified sequence + LLAMA_API void llama_memory_seq_keep( + llama_memory_t mem, + llama_seq_id seq_id); + + // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API void llama_memory_seq_add( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta); + + // Integer division of the positions by factor of `d > 1` + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API void llama_memory_seq_div( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d); + + // Returns the smallest position present in the memory for the specified sequence + // This is typically non-zero only for SWA caches + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory + // Return -1 if the sequence is empty + LLAMA_API llama_pos llama_memory_seq_pos_min( + llama_memory_t mem, + llama_seq_id seq_id); + + // Returns the largest position present in the memory for the specified sequence + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory + // Return -1 if the sequence is empty + LLAMA_API llama_pos llama_memory_seq_pos_max( + llama_memory_t mem, + llama_seq_id seq_id); + + // Check if the memory supports shifting + LLAMA_API bool llama_memory_can_shift(llama_memory_t mem); + + // + // KV cache for self-attention (TODO: deprecate in favor of llama_memory) // // Returns the number of tokens in the KV cache (slow, use only for debug) @@ -623,7 +699,7 @@ extern "C" { // Clear the KV cache - both cell info is erased and KV data is zeroed LLAMA_API void llama_kv_self_clear( - struct llama_context * ctx); + struct llama_context * ctx); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails @@ -694,14 +770,14 @@ extern "C" { // Defragment the KV cache // This will be applied: // - lazily on next llama_decode() - LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx), + DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx), "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); // Check if the context supports KV cache shifting LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) - LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx), + DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx), "simply remove this call, updates are applied lazily on the next llama_decode()"); // @@ -709,7 +785,7 @@ extern "C" { // // Returns the *actual* size in bytes of the state - // (logits, embedding and kv_cache) + // (logits, embedding and memory) // Only use when saving the state, not when restoring it, otherwise the size may be too small. LLAMA_API size_t llama_state_get_size(struct llama_context * ctx); LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx), @@ -765,12 +841,12 @@ extern "C" { size_t n_token_count), "use llama_state_save_file instead"); - // Get the exact size needed to copy the KV cache of a single sequence + // Get the exact size needed to copy the state of a single sequence LLAMA_API size_t llama_state_seq_get_size( struct llama_context * ctx, llama_seq_id seq_id); - // Copy the KV cache of a single sequence into the specified buffer + // Copy the state of a single sequence into the specified buffer LLAMA_API size_t llama_state_seq_get_data( struct llama_context * ctx, uint8_t * dst, @@ -836,16 +912,16 @@ extern "C" { // For encode-decoder contexts, processes the batch using the encoder. // Can store the encoder output internally for later use by the decoder's cross-attention layers. // 0 - success - // < 0 - error. the KV cache state is restored to the state before this call + // < 0 - error. the memory state is restored to the state before this call LLAMA_API int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch); // Process a batch of tokens. - // Requires KV cache. + // Requires the context to have a memory. // For encode-decoder contexts, processes the batch using the decoder. // Positive return values does not mean a fatal error, but rather a warning. - // Upon non-zero return values, the KV cache state is restored to the state before this call + // Upon non-zero return values, the memory state is restored to the state before this call // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 2 - aborted diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d20bd4fe2..70be604e4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -20,7 +20,6 @@ add_library(llama llama-hparams.cpp llama-impl.cpp llama-io.cpp - llama-kv-cache.cpp llama-kv-cache-unified.cpp llama-kv-cache-unified-iswa.cpp llama-kv-cache-recurrent.cpp diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f1b43b9cc..c29fe7e4c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2,9 +2,9 @@ #include "llama-impl.h" #include "llama-io.h" +#include "llama-memory.h" #include "llama-mmap.h" #include "llama-model.h" -#include "llama-kv-cache.h" #include #include @@ -277,10 +277,9 @@ llama_context::llama_context( int n_nodes_tg = -1; // simulate full KV cache - llama_kv_cache * kv_self = static_cast(memory.get()); - const auto kv_state = kv_self->init_full(); - if (!kv_state) { + const auto mstate = memory->init_full(); + if (!mstate) { throw std::runtime_error("failed to initialize KV cache"); } @@ -288,7 +287,7 @@ llama_context::llama_context( // reserve pp graph first so that buffers are only allocated once { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -299,7 +298,7 @@ llama_context::llama_context( // reserve with tg graph to get the number of splits and nodes { - auto * gf = graph_reserve(1, 1, 1, kv_state.get()); + auto * gf = graph_reserve(1, 1, 1, mstate.get()); if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -310,7 +309,7 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -419,14 +418,8 @@ uint32_t llama_context::n_threads_batch() const { return cparams.n_threads_batch; } -llama_kv_cache * llama_context::get_kv_self() { - llama_kv_cache * kv_self = static_cast(memory.get()); - return kv_self; -} - -const llama_kv_cache * llama_context::get_kv_self() const { - llama_kv_cache * kv_self = static_cast(memory.get()); - return kv_self; +llama_memory_t llama_context::get_memory() const { + return memory.get(); } void llama_context::kv_self_defrag_sched() { @@ -442,15 +435,13 @@ bool llama_context::kv_self_update(bool optimize) { return false; } - llama_kv_cache * kv_self = static_cast(memory.get()); - { // TODO: remove in the future optimize |= memory_force_optimize; memory_force_optimize = false; - const auto kv_state = kv_self->init_update(this, optimize); - switch (kv_state->get_status()) { + const auto mstate = memory->init_update(this, optimize); + switch (mstate->get_status()) { case LLAMA_MEMORY_STATUS_SUCCESS: { // noop @@ -468,23 +459,25 @@ bool llama_context::kv_self_update(bool optimize) { } } - if (!kv_state->apply()) { + if (!mstate->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); } } - // if the KV cache did any computation, we have to reserve a new worst-case graph - const auto kv_state = kv_self->init_full(); - if (!kv_state) { - throw std::runtime_error("failed to initialize memory state"); - } + // if the memory module did any computation, we have to reserve a new worst-case graph + { + const auto mstate = memory->init_full(); + if (!mstate) { + throw std::runtime_error("failed to initialize memory state"); + } - const uint32_t n_seqs = cparams.n_seq_max; - const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); - if (!gf) { - LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); + } } return true; @@ -912,10 +905,8 @@ int llama_context::decode(llama_batch & inp_batch) { } } - llama_kv_cache * kv_self = static_cast(memory.get()); - // temporary allocate memory for the input batch if needed - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1); + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1); const llama_batch & batch = batch_allocr.batch; @@ -977,21 +968,21 @@ int llama_context::decode(llama_batch & inp_batch) { // handle any pending defrags/shifts kv_self_update(false); - llama_memory_state_ptr kv_state; + llama_memory_state_ptr mstate; while (true) { - kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); - if (!kv_state) { + mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); + if (!mstate) { return -2; } - switch (kv_state->get_status()) { + switch (mstate->get_status()) { case LLAMA_MEMORY_STATUS_SUCCESS: { } break; case LLAMA_MEMORY_STATUS_NO_UPDATE: { - LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status()); + LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status()); return -2; } @@ -1031,7 +1022,7 @@ int llama_context::decode(llama_batch & inp_batch) { int64_t n_outputs_prev = 0; do { - const auto & ubatch = kv_state->get_ubatch(); + const auto & ubatch = mstate->get_ubatch(); // count the outputs in this u_batch { @@ -1054,7 +1045,7 @@ int llama_context::decode(llama_batch & inp_batch) { ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_status status; - const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status); + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache @@ -1076,7 +1067,7 @@ int llama_context::decode(llama_batch & inp_batch) { LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); - llama_kv_self_seq_rm(this, s, pos_min[s], -1); + memory->seq_rm(s, pos_min[s], -1); } switch (status) { @@ -1170,7 +1161,7 @@ int llama_context::decode(llama_batch & inp_batch) { } n_outputs_prev += n_outputs; - } while (kv_state->next()); + } while (mstate->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith n_outputs = n_outputs_all; @@ -1179,7 +1170,7 @@ int llama_context::decode(llama_batch & inp_batch) { { bool sorted_output = true; - auto & out_ids = kv_state->out_ids(); + auto & out_ids = mstate->out_ids(); GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); @@ -1847,11 +1838,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } } - llama_kv_cache * kv_self = static_cast(memory.get()); - - if (kv_self != nullptr) { + if (memory != nullptr) { LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); - kv_self->state_write(io); + memory->state_write(io); } return io.n_bytes(); @@ -1938,9 +1927,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { if (memory) { LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); - llama_kv_cache * kv_self = static_cast(memory.get()); - - kv_self->state_read(io); + memory->state_read(io); } return io.n_bytes(); @@ -1950,9 +1937,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s GGML_UNUSED(seq_id); if (memory) { - llama_kv_cache * kv_self = static_cast(memory.get()); - - kv_self->state_write(io, seq_id); + memory->state_write(io, seq_id); } return io.n_bytes(); @@ -1962,9 +1947,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq GGML_UNUSED(seq_id); if (memory) { - llama_kv_cache * kv_self = static_cast(memory.get()); - - kv_self->state_read(io, seq_id); + memory->state_read(io, seq_id); } return io.n_bytes(); @@ -2069,9 +2052,7 @@ void llama_context::opt_epoch_iter( const uint32_t n_batch = std::min(this->n_batch(), n_ctx); const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); - llama_kv_cache * kv_self = static_cast(memory.get()); - - kv_self->clear(); + memory->clear(); for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { batch.n_tokens = n_batch; @@ -2094,8 +2075,8 @@ void llama_context::opt_epoch_iter( int64_t n_outputs_all = n_tokens_all; - auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); - if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { + auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); + if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); break; } @@ -2108,17 +2089,17 @@ void llama_context::opt_epoch_iter( uint32_t pos_batch = 0; do { - const auto & ubatch = kv_state->get_ubatch(); + const auto & ubatch = mstate->get_ubatch(); n_outputs = ubatch.n_tokens; - if (!kv_state->apply()) { + if (!mstate->apply()) { LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); break; } auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get()); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get()); struct ggml_context * ctx_compute_opt; { @@ -2153,7 +2134,7 @@ void llama_context::opt_epoch_iter( ggml_free(ctx_compute_opt); pos_batch += ubatch.n_tokens; - } while (kv_state->next()); + } while (mstate->next()); } } @@ -2314,8 +2295,9 @@ const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } +// deprecated llama_kv_cache * llama_get_kv_self(llama_context * ctx) { - return ctx->get_kv_self(); + return dynamic_cast(ctx->get_memory()); } // deprecated @@ -2435,13 +2417,82 @@ int32_t llama_apply_adapter_cvec( return res ? 0 : -1; } +// +// memory +// + +llama_memory_t llama_get_memory(const struct llama_context * ctx) { + return ctx->get_memory(); +} + +void llama_memory_clear(llama_memory_t mem) { + mem->clear(); +} + +bool llama_memory_seq_rm( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + return mem->seq_rm(seq_id, p0, p1); +} + +void llama_memory_seq_cp( + llama_memory_t mem, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + mem->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_memory_seq_keep( + llama_memory_t mem, + llama_seq_id seq_id) { + mem->seq_keep(seq_id); +} + +void llama_memory_seq_add( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + mem->seq_add(seq_id, p0, p1, delta); +} + +void llama_memory_seq_div( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + mem->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_memory_seq_pos_min( + llama_memory_t mem, + llama_seq_id seq_id) { + return mem->seq_pos_min(seq_id); +} + +llama_pos llama_memory_seq_pos_max( + llama_memory_t mem, + llama_seq_id seq_id) { + return mem->seq_pos_max(seq_id); +} + +bool llama_memory_can_shift(llama_memory_t mem) { + return mem->get_can_shift(); +} + // // kv cache // // deprecated int32_t llama_kv_self_n_tokens(const llama_context * ctx) { - const auto * kv = ctx->get_kv_self(); + const auto * kv = llama_get_memory(ctx); if (!kv) { return 0; } @@ -2463,7 +2514,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) { // deprecated // note: this is the same as above - will be removed anyway, so it's ok int32_t llama_kv_self_used_cells(const llama_context * ctx) { - const auto * kv = ctx->get_kv_self(); + const auto * kv = llama_get_memory(ctx); if (!kv) { return 0; } @@ -2483,12 +2534,12 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) { } void llama_kv_self_clear(llama_context * ctx) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - kv->clear(); + llama_memory_clear(kv); } bool llama_kv_self_seq_rm( @@ -2496,12 +2547,12 @@ bool llama_kv_self_seq_rm( llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return true; } - return kv->seq_rm(seq_id, p0, p1); + return llama_memory_seq_rm(kv, seq_id, p0, p1); } void llama_kv_self_seq_cp( @@ -2510,21 +2561,21 @@ void llama_kv_self_seq_cp( llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); + llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1); } void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - kv->seq_keep(seq_id); + llama_memory_seq_keep(kv, seq_id); } void llama_kv_self_seq_add( @@ -2533,12 +2584,12 @@ void llama_kv_self_seq_add( llama_pos p0, llama_pos p1, llama_pos delta) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - kv->seq_add(seq_id, p0, p1, delta); + llama_memory_seq_add(kv, seq_id, p0, p1, delta); } void llama_kv_self_seq_div( @@ -2547,30 +2598,30 @@ void llama_kv_self_seq_div( llama_pos p0, llama_pos p1, int d) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - kv->seq_div(seq_id, p0, p1, d); + llama_memory_seq_div(kv, seq_id, p0, p1, d); } llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { - const auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return -1; } - return kv->seq_pos_min(seq_id); + return llama_memory_seq_pos_min(kv, seq_id); } llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { - const auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return -1; } - return kv->seq_pos_max(seq_id); + return llama_memory_seq_pos_max(kv, seq_id); } // deprecated @@ -2580,12 +2631,12 @@ void llama_kv_self_defrag(llama_context * ctx) { } bool llama_kv_self_can_shift(const llama_context * ctx) { - const auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return false; } - return kv->get_can_shift(); + return llama_memory_can_shift(kv); } // llama state API diff --git a/src/llama-context.h b/src/llama-context.h index c1c7efb31..2e0da8c83 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -13,13 +13,12 @@ #include struct llama_model; -struct llama_kv_cache; class llama_io_read_i; class llama_io_write_i; -class llama_memory_i; -class llama_memory_state_i; +struct llama_memory_i; +struct llama_memory_state_i; struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs @@ -47,8 +46,7 @@ struct llama_context { uint32_t n_threads() const; uint32_t n_threads_batch() const; - llama_kv_cache * get_kv_self(); - const llama_kv_cache * get_kv_self() const; + llama_memory_t get_memory() const; // return true of the KV cache was updated // TODO: remove diff --git a/src/llama-graph.h b/src/llama-graph.h index d1c5dd1bf..2b1cfa5b7 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -17,7 +17,7 @@ struct ggml_tensor; struct llama_ubatch; struct llama_cparams; -class llama_memory_state_i; +struct llama_memory_state_i; class llama_kv_cache_unified_state; class llama_kv_cache_unified_iswa_state; diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h index b32f258fb..cb813dfe8 100644 --- a/src/llama-kv-cache-recurrent.h +++ b/src/llama-kv-cache-recurrent.h @@ -2,7 +2,7 @@ #include "llama-batch.h" #include "llama-graph.h" -#include "llama-kv-cache.h" +#include "llama-memory.h" #include #include @@ -13,7 +13,7 @@ // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i // see the implementation of llama_kv_cache_unified_state_i for an example how to do it -class llama_kv_cache_recurrent : public llama_kv_cache { +class llama_kv_cache_recurrent : public llama_memory_i { public: llama_kv_cache_recurrent( const llama_model & model, @@ -29,6 +29,16 @@ public: // llama_memory_i // + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + void clear() override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; @@ -40,20 +50,6 @@ public: llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override; - // - // llama_kv_cache - // - - llama_memory_state_ptr init_batch( - const llama_batch & batch, - uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) override; - - llama_memory_state_ptr init_full() override; - - llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; - bool prepare(const std::vector & ubatches); // find a contiguous slot of kv cells and emplace the ubatch there diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index cba5bbe95..3fabcd6b8 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -11,7 +11,7 @@ // utilizes two instances of llama_kv_cache_unified // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers -class llama_kv_cache_unified_iswa : public llama_kv_cache { +class llama_kv_cache_unified_iswa : public llama_memory_i { public: llama_kv_cache_unified_iswa( const llama_model & model, @@ -31,21 +31,6 @@ public: // llama_memory_i // - void clear() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // - // llama_kv_cache - // - llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, @@ -58,6 +43,17 @@ public: bool get_can_shift() const override; + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 6ff388a88..d01a9abd7 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -2,8 +2,8 @@ #include "llama-batch.h" #include "llama-graph.h" -#include "llama-kv-cache.h" #include "llama-kv-cells.h" +#include "llama-memory.h" #include #include @@ -17,7 +17,7 @@ struct llama_context; // llama_kv_cache_unified // -class llama_kv_cache_unified : public llama_kv_cache { +class llama_kv_cache_unified : public llama_memory_i { public: static uint32_t get_padding(const llama_cparams & cparams); @@ -56,21 +56,6 @@ public: // llama_memory_i // - void clear() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // - // llama_kv_cache - // - llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, @@ -83,6 +68,17 @@ public: bool get_can_shift() const override; + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp deleted file mode 100644 index aefd23e32..000000000 --- a/src/llama-kv-cache.cpp +++ /dev/null @@ -1 +0,0 @@ -#include "llama-kv-cache.h" diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h deleted file mode 100644 index 17a5e5cb8..000000000 --- a/src/llama-kv-cache.h +++ /dev/null @@ -1,41 +0,0 @@ -#pragma once - -#include "llama.h" -#include "llama-memory.h" - -class llama_io_write_i; -class llama_io_read_i; - -struct llama_kv_cache : public llama_memory_i { - virtual ~llama_kv_cache() = default; - - // TODO: move the init_ interfaces to llama_memory_i - - // split the input batch into a set of ubatches and verify that they can fit into the cache - // return a state object containing the ubatches and KV cache state required to process them - // check the llama_memory_state_i::get_status() for the result - virtual llama_memory_state_ptr init_batch( - const llama_batch & batch, - uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) = 0; - - // simulate full cache, used for allocating worst-case compute buffers - virtual llama_memory_state_ptr init_full() = 0; - - // prepare for any pending memory updates, such as shifts, defrags, etc. - // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update - virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0; - - // getters - virtual bool get_can_shift() const = 0; - - bool get_can_edit() const override { return get_can_shift(); } - - // - // state write/read - // - - virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; - virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; -}; diff --git a/src/llama-memory.h b/src/llama-memory.h index ab0d399c4..5993b59be 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -7,6 +7,9 @@ struct llama_ubatch; +class llama_io_write_i; +class llama_io_read_i; + struct llama_memory_params { // kv cache ggml_type type_k; @@ -16,28 +19,6 @@ struct llama_memory_params { bool swa_full; }; -// general concept of LLM memory -// the KV cache is a type of LLM memory, but there can be other types -class llama_memory_i { -public: - virtual ~llama_memory_i() = default; - - virtual void clear() = 0; - - virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; - virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; - virtual void seq_keep(llama_seq_id seq_id) = 0; - virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0; - virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; - - virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; - virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; - - virtual bool get_can_edit() const = 0; -}; - -using llama_memory_ptr = std::unique_ptr; - enum llama_memory_status { LLAMA_MEMORY_STATUS_SUCCESS = 0, LLAMA_MEMORY_STATUS_NO_UPDATE, @@ -58,8 +39,7 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me // the only method that can mutate the memory and the memory state is llama_memory_i::apply() // // TODO: rename to llama_memory_context_i ? -class llama_memory_state_i { -public: +struct llama_memory_state_i { virtual ~llama_memory_state_i() = default; // consume the current ubatch from the state and proceed to the next one @@ -81,3 +61,57 @@ public: }; using llama_memory_state_ptr = std::unique_ptr; + +// general concept of LLM memory +// the KV cache is a type of LLM memory, but there can be other types +struct llama_memory_i { + virtual ~llama_memory_i() = default; + + // split the input batch into a set of ubatches and verify that they can fit into the cache + // return a state object containing the ubatches and KV cache state required to process them + // check the llama_memory_state_i::get_status() for the result + virtual llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) = 0; + + // simulate full cache, used for allocating worst-case compute buffers + virtual llama_memory_state_ptr init_full() = 0; + + // prepare for any pending memory updates, such as shifts, defrags, etc. + // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update + virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0; + + // getters + virtual bool get_can_shift() const = 0; + + // + // ops + // + + virtual void clear() = 0; + + virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; + virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; + virtual void seq_keep(llama_seq_id seq_id) = 0; + virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0; + virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; + + virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; + virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; + + // + // state write/read + // + + virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; + virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; +}; + +using llama_memory_ptr = std::unique_ptr; + +// TODO: temporary until the llama_kv_cache is removed from the public API +struct llama_kv_cache : public llama_memory_i { + virtual ~llama_kv_cache() = default; +};