memory : migrate from llama_kv_cache to more generic llama_memory (#14006)

* memory : merge llama_kv_cache into llama_memory + new `llama_memory` API

ggml-ci

* context : fix casts

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-05 15:29:22 +03:00
committed by GitHub
parent 3a077146a4
commit 7f37b6cf1e
11 changed files with 324 additions and 220 deletions

View File

@ -61,7 +61,10 @@ extern "C" {
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
struct llama_sampler; struct llama_sampler;
struct llama_kv_cache;
typedef struct llama_memory_i * llama_memory_t;
struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
typedef int32_t llama_pos; typedef int32_t llama_pos;
typedef int32_t llama_token; typedef int32_t llama_token;
@ -493,9 +496,11 @@ extern "C" {
DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx); LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
@ -609,7 +614,78 @@ extern "C" {
int32_t il_end); int32_t il_end);
// //
// KV cache // Memory
//
// Clear the memory contents
LLAMA_API void llama_memory_clear(llama_memory_t mem);
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
// seq_id < 0 : match any sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API bool llama_memory_seq_rm(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);
// Copy all tokens that belong to the specified sequence to another sequence
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_cp(
llama_memory_t mem,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);
// Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_memory_seq_keep(
llama_memory_t mem,
llama_seq_id seq_id);
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_add(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);
// Integer division of the positions by factor of `d > 1`
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_memory_seq_div(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);
// Returns the smallest position present in the memory for the specified sequence
// This is typically non-zero only for SWA caches
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_memory_seq_pos_min(
llama_memory_t mem,
llama_seq_id seq_id);
// Returns the largest position present in the memory for the specified sequence
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_memory_seq_pos_max(
llama_memory_t mem,
llama_seq_id seq_id);
// Check if the memory supports shifting
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);
//
// KV cache for self-attention (TODO: deprecate in favor of llama_memory)
// //
// Returns the number of tokens in the KV cache (slow, use only for debug) // Returns the number of tokens in the KV cache (slow, use only for debug)
@ -623,7 +699,7 @@ extern "C" {
// Clear the KV cache - both cell info is erased and KV data is zeroed // Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_self_clear( LLAMA_API void llama_kv_self_clear(
struct llama_context * ctx); struct llama_context * ctx);
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
@ -694,14 +770,14 @@ extern "C" {
// Defragment the KV cache // Defragment the KV cache
// This will be applied: // This will be applied:
// - lazily on next llama_decode() // - lazily on next llama_decode()
LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx), DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx),
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
// Check if the context supports KV cache shifting // Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.) // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx), DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx),
"simply remove this call, updates are applied lazily on the next llama_decode()"); "simply remove this call, updates are applied lazily on the next llama_decode()");
// //
@ -709,7 +785,7 @@ extern "C" {
// //
// Returns the *actual* size in bytes of the state // Returns the *actual* size in bytes of the state
// (logits, embedding and kv_cache) // (logits, embedding and memory)
// Only use when saving the state, not when restoring it, otherwise the size may be too small. // Only use when saving the state, not when restoring it, otherwise the size may be too small.
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx); LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx), LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
@ -765,12 +841,12 @@ extern "C" {
size_t n_token_count), size_t n_token_count),
"use llama_state_save_file instead"); "use llama_state_save_file instead");
// Get the exact size needed to copy the KV cache of a single sequence // Get the exact size needed to copy the state of a single sequence
LLAMA_API size_t llama_state_seq_get_size( LLAMA_API size_t llama_state_seq_get_size(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id); llama_seq_id seq_id);
// Copy the KV cache of a single sequence into the specified buffer // Copy the state of a single sequence into the specified buffer
LLAMA_API size_t llama_state_seq_get_data( LLAMA_API size_t llama_state_seq_get_data(
struct llama_context * ctx, struct llama_context * ctx,
uint8_t * dst, uint8_t * dst,
@ -836,16 +912,16 @@ extern "C" {
// For encode-decoder contexts, processes the batch using the encoder. // For encode-decoder contexts, processes the batch using the encoder.
// Can store the encoder output internally for later use by the decoder's cross-attention layers. // Can store the encoder output internally for later use by the decoder's cross-attention layers.
// 0 - success // 0 - success
// < 0 - error. the KV cache state is restored to the state before this call // < 0 - error. the memory state is restored to the state before this call
LLAMA_API int32_t llama_encode( LLAMA_API int32_t llama_encode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch); struct llama_batch batch);
// Process a batch of tokens. // Process a batch of tokens.
// Requires KV cache. // Requires the context to have a memory.
// For encode-decoder contexts, processes the batch using the decoder. // For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning. // Positive return values does not mean a fatal error, but rather a warning.
// Upon non-zero return values, the KV cache state is restored to the state before this call // Upon non-zero return values, the memory state is restored to the state before this call
// 0 - success // 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// 2 - aborted // 2 - aborted

View File

@ -20,7 +20,6 @@ add_library(llama
llama-hparams.cpp llama-hparams.cpp
llama-impl.cpp llama-impl.cpp
llama-io.cpp llama-io.cpp
llama-kv-cache.cpp
llama-kv-cache-unified.cpp llama-kv-cache-unified.cpp
llama-kv-cache-unified-iswa.cpp llama-kv-cache-unified-iswa.cpp
llama-kv-cache-recurrent.cpp llama-kv-cache-recurrent.cpp

View File

@ -2,9 +2,9 @@
#include "llama-impl.h" #include "llama-impl.h"
#include "llama-io.h" #include "llama-io.h"
#include "llama-memory.h"
#include "llama-mmap.h" #include "llama-mmap.h"
#include "llama-model.h" #include "llama-model.h"
#include "llama-kv-cache.h"
#include <cinttypes> #include <cinttypes>
#include <cstring> #include <cstring>
@ -277,10 +277,9 @@ llama_context::llama_context(
int n_nodes_tg = -1; int n_nodes_tg = -1;
// simulate full KV cache // simulate full KV cache
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
const auto kv_state = kv_self->init_full(); const auto mstate = memory->init_full();
if (!kv_state) { if (!mstate) {
throw std::runtime_error("failed to initialize KV cache"); throw std::runtime_error("failed to initialize KV cache");
} }
@ -288,7 +287,7 @@ llama_context::llama_context(
// reserve pp graph first so that buffers are only allocated once // reserve pp graph first so that buffers are only allocated once
{ {
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers"); throw std::runtime_error("failed to allocate compute pp buffers");
} }
@ -299,7 +298,7 @@ llama_context::llama_context(
// reserve with tg graph to get the number of splits and nodes // reserve with tg graph to get the number of splits and nodes
{ {
auto * gf = graph_reserve(1, 1, 1, kv_state.get()); auto * gf = graph_reserve(1, 1, 1, mstate.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute tg buffers"); throw std::runtime_error("failed to allocate compute tg buffers");
} }
@ -310,7 +309,7 @@ llama_context::llama_context(
// reserve again with pp graph to avoid ggml-alloc reallocations during inference // reserve again with pp graph to avoid ggml-alloc reallocations during inference
{ {
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
if (!gf) { if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers"); throw std::runtime_error("failed to allocate compute pp buffers");
} }
@ -419,14 +418,8 @@ uint32_t llama_context::n_threads_batch() const {
return cparams.n_threads_batch; return cparams.n_threads_batch;
} }
llama_kv_cache * llama_context::get_kv_self() { llama_memory_t llama_context::get_memory() const {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); return memory.get();
return kv_self;
}
const llama_kv_cache * llama_context::get_kv_self() const {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
return kv_self;
} }
void llama_context::kv_self_defrag_sched() { void llama_context::kv_self_defrag_sched() {
@ -442,15 +435,13 @@ bool llama_context::kv_self_update(bool optimize) {
return false; return false;
} }
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
{ {
// TODO: remove in the future // TODO: remove in the future
optimize |= memory_force_optimize; optimize |= memory_force_optimize;
memory_force_optimize = false; memory_force_optimize = false;
const auto kv_state = kv_self->init_update(this, optimize); const auto mstate = memory->init_update(this, optimize);
switch (kv_state->get_status()) { switch (mstate->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS: case LLAMA_MEMORY_STATUS_SUCCESS:
{ {
// noop // noop
@ -468,23 +459,25 @@ bool llama_context::kv_self_update(bool optimize) {
} }
} }
if (!kv_state->apply()) { if (!mstate->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
} }
} }
// if the KV cache did any computation, we have to reserve a new worst-case graph // if the memory module did any computation, we have to reserve a new worst-case graph
const auto kv_state = kv_self->init_full(); {
if (!kv_state) { const auto mstate = memory->init_full();
throw std::runtime_error("failed to initialize memory state"); if (!mstate) {
} throw std::runtime_error("failed to initialize memory state");
}
const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get());
if (!gf) { if (!gf) {
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
}
} }
return true; return true;
@ -912,10 +905,8 @@ int llama_context::decode(llama_batch & inp_batch) {
} }
} }
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1); llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : memory->seq_pos_max(0) + 1);
const llama_batch & batch = batch_allocr.batch; const llama_batch & batch = batch_allocr.batch;
@ -977,21 +968,21 @@ int llama_context::decode(llama_batch & inp_batch) {
// handle any pending defrags/shifts // handle any pending defrags/shifts
kv_self_update(false); kv_self_update(false);
llama_memory_state_ptr kv_state; llama_memory_state_ptr mstate;
while (true) { while (true) {
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
if (!kv_state) { if (!mstate) {
return -2; return -2;
} }
switch (kv_state->get_status()) { switch (mstate->get_status()) {
case LLAMA_MEMORY_STATUS_SUCCESS: case LLAMA_MEMORY_STATUS_SUCCESS:
{ {
} break; } break;
case LLAMA_MEMORY_STATUS_NO_UPDATE: case LLAMA_MEMORY_STATUS_NO_UPDATE:
{ {
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status()); LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status());
return -2; return -2;
} }
@ -1031,7 +1022,7 @@ int llama_context::decode(llama_batch & inp_batch) {
int64_t n_outputs_prev = 0; int64_t n_outputs_prev = 0;
do { do {
const auto & ubatch = kv_state->get_ubatch(); const auto & ubatch = mstate->get_ubatch();
// count the outputs in this u_batch // count the outputs in this u_batch
{ {
@ -1054,7 +1045,7 @@ int llama_context::decode(llama_batch & inp_batch) {
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
ggml_status status; ggml_status status;
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, kv_state.get(), status); const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status);
if (!res) { if (!res) {
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
@ -1076,7 +1067,7 @@ int llama_context::decode(llama_batch & inp_batch) {
LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]);
llama_kv_self_seq_rm(this, s, pos_min[s], -1); memory->seq_rm(s, pos_min[s], -1);
} }
switch (status) { switch (status) {
@ -1170,7 +1161,7 @@ int llama_context::decode(llama_batch & inp_batch) {
} }
n_outputs_prev += n_outputs; n_outputs_prev += n_outputs;
} while (kv_state->next()); } while (mstate->next());
// set to total number of outputs in the batch, for use in llama_get_logits_ith // set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all; n_outputs = n_outputs_all;
@ -1179,7 +1170,7 @@ int llama_context::decode(llama_batch & inp_batch) {
{ {
bool sorted_output = true; bool sorted_output = true;
auto & out_ids = kv_state->out_ids(); auto & out_ids = mstate->out_ids();
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
@ -1847,11 +1838,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
} }
} }
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); if (memory != nullptr) {
if (kv_self != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
kv_self->state_write(io); memory->state_write(io);
} }
return io.n_bytes(); return io.n_bytes();
@ -1938,9 +1927,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
if (memory) { if (memory) {
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); memory->state_read(io);
kv_self->state_read(io);
} }
return io.n_bytes(); return io.n_bytes();
@ -1950,9 +1937,7 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
GGML_UNUSED(seq_id); GGML_UNUSED(seq_id);
if (memory) { if (memory) {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); memory->state_write(io, seq_id);
kv_self->state_write(io, seq_id);
} }
return io.n_bytes(); return io.n_bytes();
@ -1962,9 +1947,7 @@ size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq
GGML_UNUSED(seq_id); GGML_UNUSED(seq_id);
if (memory) { if (memory) {
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); memory->state_read(io, seq_id);
kv_self->state_read(io, seq_id);
} }
return io.n_bytes(); return io.n_bytes();
@ -2069,9 +2052,7 @@ void llama_context::opt_epoch_iter(
const uint32_t n_batch = std::min(this->n_batch(), n_ctx); const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get()); memory->clear();
kv_self->clear();
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
batch.n_tokens = n_batch; batch.n_tokens = n_batch;
@ -2094,8 +2075,8 @@ void llama_context::opt_epoch_iter(
int64_t n_outputs_all = n_tokens_all; int64_t n_outputs_all = n_tokens_all;
auto kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true);
if (!kv_state || kv_state->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break; break;
} }
@ -2108,17 +2089,17 @@ void llama_context::opt_epoch_iter(
uint32_t pos_batch = 0; uint32_t pos_batch = 0;
do { do {
const auto & ubatch = kv_state->get_ubatch(); const auto & ubatch = mstate->get_ubatch();
n_outputs = ubatch.n_tokens; n_outputs = ubatch.n_tokens;
if (!kv_state->apply()) { if (!mstate->apply()) {
LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__);
break; break;
} }
auto * gf = graph_init(); auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, kv_state.get()); auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get());
struct ggml_context * ctx_compute_opt; struct ggml_context * ctx_compute_opt;
{ {
@ -2153,7 +2134,7 @@ void llama_context::opt_epoch_iter(
ggml_free(ctx_compute_opt); ggml_free(ctx_compute_opt);
pos_batch += ubatch.n_tokens; pos_batch += ubatch.n_tokens;
} while (kv_state->next()); } while (mstate->next());
} }
} }
@ -2314,8 +2295,9 @@ const llama_model * llama_get_model(const llama_context * ctx) {
return &ctx->get_model(); return &ctx->get_model();
} }
// deprecated
llama_kv_cache * llama_get_kv_self(llama_context * ctx) { llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
return ctx->get_kv_self(); return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
} }
// deprecated // deprecated
@ -2435,13 +2417,82 @@ int32_t llama_apply_adapter_cvec(
return res ? 0 : -1; return res ? 0 : -1;
} }
//
// memory
//
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
return ctx->get_memory();
}
void llama_memory_clear(llama_memory_t mem) {
mem->clear();
}
bool llama_memory_seq_rm(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
return mem->seq_rm(seq_id, p0, p1);
}
void llama_memory_seq_cp(
llama_memory_t mem,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
mem->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_memory_seq_keep(
llama_memory_t mem,
llama_seq_id seq_id) {
mem->seq_keep(seq_id);
}
void llama_memory_seq_add(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
mem->seq_add(seq_id, p0, p1, delta);
}
void llama_memory_seq_div(
llama_memory_t mem,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
mem->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_memory_seq_pos_min(
llama_memory_t mem,
llama_seq_id seq_id) {
return mem->seq_pos_min(seq_id);
}
llama_pos llama_memory_seq_pos_max(
llama_memory_t mem,
llama_seq_id seq_id) {
return mem->seq_pos_max(seq_id);
}
bool llama_memory_can_shift(llama_memory_t mem) {
return mem->get_can_shift();
}
// //
// kv cache // kv cache
// //
// deprecated // deprecated
int32_t llama_kv_self_n_tokens(const llama_context * ctx) { int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self(); const auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return 0; return 0;
} }
@ -2463,7 +2514,7 @@ int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
// deprecated // deprecated
// note: this is the same as above - will be removed anyway, so it's ok // note: this is the same as above - will be removed anyway, so it's ok
int32_t llama_kv_self_used_cells(const llama_context * ctx) { int32_t llama_kv_self_used_cells(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self(); const auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return 0; return 0;
} }
@ -2483,12 +2534,12 @@ int32_t llama_kv_self_used_cells(const llama_context * ctx) {
} }
void llama_kv_self_clear(llama_context * ctx) { void llama_kv_self_clear(llama_context * ctx) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->clear(); llama_memory_clear(kv);
} }
bool llama_kv_self_seq_rm( bool llama_kv_self_seq_rm(
@ -2496,12 +2547,12 @@ bool llama_kv_self_seq_rm(
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1) { llama_pos p1) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return true; return true;
} }
return kv->seq_rm(seq_id, p0, p1); return llama_memory_seq_rm(kv, seq_id, p0, p1);
} }
void llama_kv_self_seq_cp( void llama_kv_self_seq_cp(
@ -2510,21 +2561,21 @@ void llama_kv_self_seq_cp(
llama_seq_id seq_id_dst, llama_seq_id seq_id_dst,
llama_pos p0, llama_pos p0,
llama_pos p1) { llama_pos p1) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1);
} }
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->seq_keep(seq_id); llama_memory_seq_keep(kv, seq_id);
} }
void llama_kv_self_seq_add( void llama_kv_self_seq_add(
@ -2533,12 +2584,12 @@ void llama_kv_self_seq_add(
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
llama_pos delta) { llama_pos delta) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->seq_add(seq_id, p0, p1, delta); llama_memory_seq_add(kv, seq_id, p0, p1, delta);
} }
void llama_kv_self_seq_div( void llama_kv_self_seq_div(
@ -2547,30 +2598,30 @@ void llama_kv_self_seq_div(
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
int d) { int d) {
auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return; return;
} }
kv->seq_div(seq_id, p0, p1, d); llama_memory_seq_div(kv, seq_id, p0, p1, d);
} }
llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return -1; return -1;
} }
return kv->seq_pos_min(seq_id); return llama_memory_seq_pos_min(kv, seq_id);
} }
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
const auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return -1; return -1;
} }
return kv->seq_pos_max(seq_id); return llama_memory_seq_pos_max(kv, seq_id);
} }
// deprecated // deprecated
@ -2580,12 +2631,12 @@ void llama_kv_self_defrag(llama_context * ctx) {
} }
bool llama_kv_self_can_shift(const llama_context * ctx) { bool llama_kv_self_can_shift(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self(); auto * kv = llama_get_memory(ctx);
if (!kv) { if (!kv) {
return false; return false;
} }
return kv->get_can_shift(); return llama_memory_can_shift(kv);
} }
// llama state API // llama state API

View File

@ -13,13 +13,12 @@
#include <vector> #include <vector>
struct llama_model; struct llama_model;
struct llama_kv_cache;
class llama_io_read_i; class llama_io_read_i;
class llama_io_write_i; class llama_io_write_i;
class llama_memory_i; struct llama_memory_i;
class llama_memory_state_i; struct llama_memory_state_i;
struct llama_context { struct llama_context {
// init scheduler and compute buffers, reserve worst-case graphs // init scheduler and compute buffers, reserve worst-case graphs
@ -47,8 +46,7 @@ struct llama_context {
uint32_t n_threads() const; uint32_t n_threads() const;
uint32_t n_threads_batch() const; uint32_t n_threads_batch() const;
llama_kv_cache * get_kv_self(); llama_memory_t get_memory() const;
const llama_kv_cache * get_kv_self() const;
// return true of the KV cache was updated // return true of the KV cache was updated
// TODO: remove // TODO: remove

View File

@ -17,7 +17,7 @@ struct ggml_tensor;
struct llama_ubatch; struct llama_ubatch;
struct llama_cparams; struct llama_cparams;
class llama_memory_state_i; struct llama_memory_state_i;
class llama_kv_cache_unified_state; class llama_kv_cache_unified_state;
class llama_kv_cache_unified_iswa_state; class llama_kv_cache_unified_iswa_state;

View File

@ -2,7 +2,7 @@
#include "llama-batch.h" #include "llama-batch.h"
#include "llama-graph.h" #include "llama-graph.h"
#include "llama-kv-cache.h" #include "llama-memory.h"
#include <set> #include <set>
#include <vector> #include <vector>
@ -13,7 +13,7 @@
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i // TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it // see the implementation of llama_kv_cache_unified_state_i for an example how to do it
class llama_kv_cache_recurrent : public llama_kv_cache { class llama_kv_cache_recurrent : public llama_memory_i {
public: public:
llama_kv_cache_recurrent( llama_kv_cache_recurrent(
const llama_model & model, const llama_model & model,
@ -29,6 +29,16 @@ public:
// llama_memory_i // llama_memory_i
// //
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
void clear() override; void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
@ -40,20 +50,6 @@ public:
llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override; llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) override;
llama_memory_state_ptr init_full() override;
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
bool prepare(const std::vector<llama_ubatch> & ubatches); bool prepare(const std::vector<llama_ubatch> & ubatches);
// find a contiguous slot of kv cells and emplace the ubatch there // find a contiguous slot of kv cells and emplace the ubatch there

View File

@ -11,7 +11,7 @@
// utilizes two instances of llama_kv_cache_unified // utilizes two instances of llama_kv_cache_unified
// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers // the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
class llama_kv_cache_unified_iswa : public llama_kv_cache { class llama_kv_cache_unified_iswa : public llama_memory_i {
public: public:
llama_kv_cache_unified_iswa( llama_kv_cache_unified_iswa(
const llama_model & model, const llama_model & model,
@ -31,21 +31,6 @@ public:
// llama_memory_i // llama_memory_i
// //
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
llama_memory_state_ptr init_batch( llama_memory_state_ptr init_batch(
const llama_batch & batch, const llama_batch & batch,
uint32_t n_ubatch, uint32_t n_ubatch,
@ -58,6 +43,17 @@ public:
bool get_can_shift() const override; bool get_can_shift() const override;
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;

View File

@ -2,8 +2,8 @@
#include "llama-batch.h" #include "llama-batch.h"
#include "llama-graph.h" #include "llama-graph.h"
#include "llama-kv-cache.h"
#include "llama-kv-cells.h" #include "llama-kv-cells.h"
#include "llama-memory.h"
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
@ -17,7 +17,7 @@ struct llama_context;
// llama_kv_cache_unified // llama_kv_cache_unified
// //
class llama_kv_cache_unified : public llama_kv_cache { class llama_kv_cache_unified : public llama_memory_i {
public: public:
static uint32_t get_padding(const llama_cparams & cparams); static uint32_t get_padding(const llama_cparams & cparams);
@ -56,21 +56,6 @@ public:
// llama_memory_i // llama_memory_i
// //
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
//
// llama_kv_cache
//
llama_memory_state_ptr init_batch( llama_memory_state_ptr init_batch(
const llama_batch & batch, const llama_batch & batch,
uint32_t n_ubatch, uint32_t n_ubatch,
@ -83,6 +68,17 @@ public:
bool get_can_shift() const override; bool get_can_shift() const override;
void clear() override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
// state write/load // state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;

View File

@ -1 +0,0 @@
#include "llama-kv-cache.h"

View File

@ -1,41 +0,0 @@
#pragma once
#include "llama.h"
#include "llama-memory.h"
class llama_io_write_i;
class llama_io_read_i;
struct llama_kv_cache : public llama_memory_i {
virtual ~llama_kv_cache() = default;
// TODO: move the init_ interfaces to llama_memory_i
// split the input batch into a set of ubatches and verify that they can fit into the cache
// return a state object containing the ubatches and KV cache state required to process them
// check the llama_memory_state_i::get_status() for the result
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;
// prepare for any pending memory updates, such as shifts, defrags, etc.
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
// getters
virtual bool get_can_shift() const = 0;
bool get_can_edit() const override { return get_can_shift(); }
//
// state write/read
//
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
};

View File

@ -7,6 +7,9 @@
struct llama_ubatch; struct llama_ubatch;
class llama_io_write_i;
class llama_io_read_i;
struct llama_memory_params { struct llama_memory_params {
// kv cache // kv cache
ggml_type type_k; ggml_type type_k;
@ -16,28 +19,6 @@ struct llama_memory_params {
bool swa_full; bool swa_full;
}; };
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
class llama_memory_i {
public:
virtual ~llama_memory_i() = default;
virtual void clear() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
virtual bool get_can_edit() const = 0;
};
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
enum llama_memory_status { enum llama_memory_status {
LLAMA_MEMORY_STATUS_SUCCESS = 0, LLAMA_MEMORY_STATUS_SUCCESS = 0,
LLAMA_MEMORY_STATUS_NO_UPDATE, LLAMA_MEMORY_STATUS_NO_UPDATE,
@ -58,8 +39,7 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me
// the only method that can mutate the memory and the memory state is llama_memory_i::apply() // the only method that can mutate the memory and the memory state is llama_memory_i::apply()
// //
// TODO: rename to llama_memory_context_i ? // TODO: rename to llama_memory_context_i ?
class llama_memory_state_i { struct llama_memory_state_i {
public:
virtual ~llama_memory_state_i() = default; virtual ~llama_memory_state_i() = default;
// consume the current ubatch from the state and proceed to the next one // consume the current ubatch from the state and proceed to the next one
@ -81,3 +61,57 @@ public:
}; };
using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>; using llama_memory_state_ptr = std::unique_ptr<llama_memory_state_i>;
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
struct llama_memory_i {
virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache
// return a state object containing the ubatches and KV cache state required to process them
// check the llama_memory_state_i::get_status() for the result
virtual llama_memory_state_ptr init_batch(
const llama_batch & batch,
uint32_t n_ubatch,
bool embd_pooled,
bool logits_all) = 0;
// simulate full cache, used for allocating worst-case compute buffers
virtual llama_memory_state_ptr init_full() = 0;
// prepare for any pending memory updates, such as shifts, defrags, etc.
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
// getters
virtual bool get_can_shift() const = 0;
//
// ops
//
virtual void clear() = 0;
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
//
// state write/read
//
virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0;
virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0;
};
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
// TODO: temporary until the llama_kv_cache is removed from the public API
struct llama_kv_cache : public llama_memory_i {
virtual ~llama_kv_cache() = default;
};