mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 12:25:03 +00:00
llama : refactor kv cache guard (#12695)
* llama : refactor kv cache guard ggml-ci * cont : fix comment [no ci] * llama : fix kv_cache restore logic ggml-ci * context : simplify kv cache updates ggml-ci * cont : better name [no ci] * llama : fix llama_decode return code when could not find KV slot ggml-ci * context : change log err -> warn [no ci] * kv-cache : add comment + warning
This commit is contained in:
@ -106,6 +106,8 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
common_params params;
|
common_params params;
|
||||||
|
|
||||||
|
params.n_predict = 128;
|
||||||
|
|
||||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
|
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -1201,33 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||||||
const int64_t n_tokens_all = batch.n_tokens;
|
const int64_t n_tokens_all = batch.n_tokens;
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
|
|
||||||
// TODO: remove this stuff
|
llama_kv_cache_guard kv_guard(kv_self.get());
|
||||||
class batch_guard {
|
|
||||||
public:
|
|
||||||
batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
|
|
||||||
}
|
|
||||||
|
|
||||||
~batch_guard() {
|
|
||||||
if (!is_done) {
|
|
||||||
kv_slot_restorer.restore();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void done() {
|
|
||||||
is_done = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void save(const llama_kv_cache_slot_info & slot_info) {
|
|
||||||
kv_slot_restorer.save(slot_info);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
bool is_done = false;
|
|
||||||
|
|
||||||
llama_kv_slot_restorer kv_slot_restorer;
|
|
||||||
};
|
|
||||||
|
|
||||||
batch_guard bg(*kv_self);
|
|
||||||
|
|
||||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||||
|
|
||||||
@ -1280,6 +1254,9 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||||||
return -2;
|
return -2;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// handle any pending defrags/shifts
|
||||||
|
kv_self_update();
|
||||||
|
|
||||||
int64_t n_outputs_prev = 0;
|
int64_t n_outputs_prev = 0;
|
||||||
|
|
||||||
while (sbatch.n_tokens > 0) {
|
while (sbatch.n_tokens > 0) {
|
||||||
@ -1319,22 +1296,12 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
// find KV slot
|
// find KV slot
|
||||||
{
|
{
|
||||||
kv_self_update();
|
if (!kv_self->find_slot(ubatch)) {
|
||||||
|
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
||||||
|
|
||||||
// if we have enough unused cells before the current head ->
|
return 1;
|
||||||
// better to start searching from the beginning of the cache, hoping to fill it
|
|
||||||
if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
|
|
||||||
kv_self->head = 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto slot_info = kv_self->find_slot(ubatch);
|
|
||||||
if (!slot_info) {
|
|
||||||
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
|
||||||
return -3;
|
|
||||||
}
|
|
||||||
|
|
||||||
bg.save(slot_info);
|
|
||||||
|
|
||||||
if (!kv_self->recurrent) {
|
if (!kv_self->recurrent) {
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
// after enough generations, the benefit from this heuristic disappears
|
// after enough generations, the benefit from this heuristic disappears
|
||||||
@ -1371,16 +1338,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// update the kv ring buffer
|
|
||||||
{
|
|
||||||
kv_self->head += ubatch.n_tokens;
|
|
||||||
|
|
||||||
// Ensure kv cache head points to a valid index.
|
|
||||||
if (kv_self->head >= kv_self->size) {
|
|
||||||
kv_self->head = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// plot the computation graph in dot format (for debugging purposes)
|
// plot the computation graph in dot format (for debugging purposes)
|
||||||
//if (n_past%100 == 0) {
|
//if (n_past%100 == 0) {
|
||||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||||
@ -1467,7 +1424,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// finalize the batch processing
|
// finalize the batch processing
|
||||||
bg.done();
|
kv_guard.commit();
|
||||||
|
|
||||||
// set output mappings
|
// set output mappings
|
||||||
{
|
{
|
||||||
|
@ -11,8 +11,6 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
|
||||||
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
|
|
||||||
|
|
||||||
llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
|
llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -206,6 +204,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
@ -446,16 +446,66 @@ void llama_kv_cache_unified::defrag() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified::restore() {
|
||||||
|
if (pending.ranges.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: tmp - move to llama_kv_cache_recurrent
|
||||||
|
if (recurrent) {
|
||||||
|
seq_rm(-1, -1, -1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t new_head = size;
|
||||||
|
|
||||||
|
for (auto & range : pending.ranges) {
|
||||||
|
for (uint32_t i = range.c0; i < range.c1; ++i) {
|
||||||
|
cells[i].seq_id.clear();
|
||||||
|
|
||||||
|
// keep count of the number of used cells
|
||||||
|
if (cells[i].pos >= 0) {
|
||||||
|
used--;
|
||||||
|
}
|
||||||
|
|
||||||
|
cells[i].pos = -1;
|
||||||
|
cells[i].src = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
new_head = std::min(new_head, range.c0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (new_head != size && new_head < head) {
|
||||||
|
head = new_head;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_unified::commit() {
|
||||||
|
if (pending.ranges.empty()) {
|
||||||
|
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
|
||||||
|
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
pending.ranges.clear();
|
||||||
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::get_can_shift() const {
|
bool llama_kv_cache_unified::get_can_shift() const {
|
||||||
return can_shift;
|
return can_shift;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
|
bool llama_kv_cache_unified::find_slot(
|
||||||
const llama_ubatch & ubatch) {
|
const llama_ubatch & ubatch) {
|
||||||
const uint32_t n_tokens = ubatch.n_tokens;
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
const uint32_t n_seqs = ubatch.n_seqs;
|
const uint32_t n_seqs = ubatch.n_seqs;
|
||||||
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||||
|
|
||||||
|
// if we have enough unused cells before the current head ->
|
||||||
|
// better to start searching from the beginning of the cache, hoping to fill it
|
||||||
|
if (head > used + 2*ubatch.n_tokens) {
|
||||||
|
head = 0;
|
||||||
|
}
|
||||||
|
|
||||||
if (recurrent) {
|
if (recurrent) {
|
||||||
// For recurrent state architectures (like Mamba or RWKV),
|
// For recurrent state architectures (like Mamba or RWKV),
|
||||||
// each cache cell can store the state for a whole sequence.
|
// each cache cell can store the state for a whole sequence.
|
||||||
@ -477,7 +527,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
|
|||||||
// too big seq_id
|
// too big seq_id
|
||||||
// TODO: would it be possible to resize the cache instead?
|
// TODO: would it be possible to resize the cache instead?
|
||||||
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
|
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
|
||||||
return llama_kv_cache_slot_info_failed;
|
return false;
|
||||||
}
|
}
|
||||||
if (j > 0) {
|
if (j > 0) {
|
||||||
llama_kv_cell & seq = cells[seq_id];
|
llama_kv_cell & seq = cells[seq_id];
|
||||||
@ -616,14 +666,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
|
|||||||
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
|
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
|
||||||
|
|
||||||
// sanity check
|
// sanity check
|
||||||
return llama_kv_cache_slot_info(n >= n_seqs);
|
return n >= n_seqs;
|
||||||
}
|
}
|
||||||
|
|
||||||
// otherwise, one cell per token.
|
// otherwise, one cell per token.
|
||||||
|
|
||||||
if (n_tokens > size) {
|
if (n_tokens > size) {
|
||||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
|
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
|
||||||
return llama_kv_cache_slot_info_failed;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t n_tested = 0;
|
uint32_t n_tested = 0;
|
||||||
@ -651,7 +701,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
|
|||||||
|
|
||||||
if (n_tested >= size) {
|
if (n_tested >= size) {
|
||||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||||
return llama_kv_cache_slot_info_failed;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -668,7 +718,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
|
|||||||
|
|
||||||
used += n_tokens;
|
used += n_tokens;
|
||||||
|
|
||||||
return llama_kv_cache_slot_info(head, head + n_tokens);
|
pending.ranges.push_back({head, head + n_tokens});
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
|
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
|
||||||
@ -1033,6 +1085,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
commit();
|
||||||
|
|
||||||
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||||
// Assume that this is one contiguous block of cells
|
// Assume that this is one contiguous block of cells
|
||||||
|
@ -17,6 +17,9 @@ struct llama_ubatch;
|
|||||||
struct llama_kv_cache : public llama_memory_i {
|
struct llama_kv_cache : public llama_memory_i {
|
||||||
using llama_memory_i::llama_memory_i;
|
using llama_memory_i::llama_memory_i;
|
||||||
|
|
||||||
|
virtual void restore() = 0; // call if batch processing fails - restores the cache state
|
||||||
|
virtual void commit() = 0; // call after successful batch processing - clears any pending state
|
||||||
|
|
||||||
virtual int32_t get_n_tokens() const = 0;
|
virtual int32_t get_n_tokens() const = 0;
|
||||||
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
|
||||||
|
|
||||||
@ -25,9 +28,24 @@ struct llama_kv_cache : public llama_memory_i {
|
|||||||
bool get_can_edit() const override { return get_can_shift(); }
|
bool get_can_edit() const override { return get_can_shift(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct llama_kv_cache_guard {
|
||||||
|
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
|
||||||
|
|
||||||
|
~llama_kv_cache_guard() {
|
||||||
|
kv->restore();
|
||||||
|
}
|
||||||
|
|
||||||
|
void commit() {
|
||||||
|
kv->commit();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
llama_kv_cache * kv;
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_kv_cell {
|
struct llama_kv_cell {
|
||||||
llama_pos pos = -1;
|
llama_pos pos = -1;
|
||||||
llama_pos delta = 0;
|
llama_pos delta = 0;
|
||||||
int32_t src = -1; // used by recurrent state models to copy states
|
int32_t src = -1; // used by recurrent state models to copy states
|
||||||
int32_t tail = -1;
|
int32_t tail = -1;
|
||||||
|
|
||||||
@ -46,17 +64,6 @@ struct llama_kv_cell {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// a structure holds information about the slot found in llama_kv_cache_find_slot
|
|
||||||
struct llama_kv_cache_slot_info {
|
|
||||||
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
|
|
||||||
bool found = false; // the slot was found
|
|
||||||
|
|
||||||
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
|
|
||||||
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
|
|
||||||
|
|
||||||
operator bool() const { return found; }
|
|
||||||
};
|
|
||||||
|
|
||||||
// ring-buffer of cached KV data
|
// ring-buffer of cached KV data
|
||||||
// TODO: pimpl
|
// TODO: pimpl
|
||||||
// TODO: add notion of max sequences
|
// TODO: add notion of max sequences
|
||||||
@ -93,6 +100,9 @@ public:
|
|||||||
void clear() override;
|
void clear() override;
|
||||||
void defrag() override;
|
void defrag() override;
|
||||||
|
|
||||||
|
virtual void restore() override;
|
||||||
|
virtual void commit() override;
|
||||||
|
|
||||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
void seq_keep(llama_seq_id seq_id) override;
|
void seq_keep(llama_seq_id seq_id) override;
|
||||||
@ -105,10 +115,9 @@ public:
|
|||||||
|
|
||||||
// find an empty slot of size "n_tokens" in the cache
|
// find an empty slot of size "n_tokens" in the cache
|
||||||
// updates the cache head
|
// updates the cache head
|
||||||
// returns a structure holding information about the slot found
|
|
||||||
// Note: On success, it's important that cache.head points
|
// Note: On success, it's important that cache.head points
|
||||||
// to the first cell of the slot.
|
// to the first cell of the slot.
|
||||||
llama_kv_cache_slot_info find_slot(const llama_ubatch & batch);
|
bool find_slot(const llama_ubatch & batch);
|
||||||
|
|
||||||
// TODO: maybe not needed
|
// TODO: maybe not needed
|
||||||
uint32_t get_padding(const llama_cparams & cparams) const;
|
uint32_t get_padding(const llama_cparams & cparams) const;
|
||||||
@ -128,7 +137,19 @@ public:
|
|||||||
// return true if cells have been moved
|
// return true if cells have been moved
|
||||||
bool defrag_prepare(int32_t n_max_nodes);
|
bool defrag_prepare(int32_t n_max_nodes);
|
||||||
|
|
||||||
// state save/load
|
// commit/restore cache
|
||||||
|
|
||||||
|
struct slot_range {
|
||||||
|
uint32_t c0 = 0; // note: these are cell indices, not sequence positions
|
||||||
|
uint32_t c1 = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// pending cell updates that are not yet committed
|
||||||
|
struct {
|
||||||
|
std::vector<slot_range> ranges;
|
||||||
|
} pending;
|
||||||
|
|
||||||
|
// state write/load
|
||||||
|
|
||||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
|
||||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
|
||||||
@ -183,59 +204,6 @@ private:
|
|||||||
// using llama_kv_cache_unified::llama_kv_cache_unified;
|
// using llama_kv_cache_unified::llama_kv_cache_unified;
|
||||||
//};
|
//};
|
||||||
|
|
||||||
//
|
|
||||||
// kv cache restore
|
|
||||||
//
|
|
||||||
|
|
||||||
// saves the kv_cache state for future recovery.
|
|
||||||
// used to rollback llama_kv_cache_find_slot changes.
|
|
||||||
struct llama_kv_slot_restorer {
|
|
||||||
struct llama_kv_cache_state {
|
|
||||||
uint32_t head = 0;
|
|
||||||
uint32_t n = 0;
|
|
||||||
} old_state;
|
|
||||||
|
|
||||||
// for non-recurrent models only
|
|
||||||
// list of slots to restore
|
|
||||||
std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
|
|
||||||
|
|
||||||
bool do_restore = false;
|
|
||||||
|
|
||||||
llama_kv_cache_unified & cache;
|
|
||||||
|
|
||||||
explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
|
|
||||||
old_state.head = cache.head;
|
|
||||||
old_state.n = cache.n;
|
|
||||||
}
|
|
||||||
|
|
||||||
// saves a slot information for future restoration
|
|
||||||
void save(const llama_kv_cache_slot_info & slot) {
|
|
||||||
if (slot) {
|
|
||||||
do_restore = true;
|
|
||||||
if (slot.boundaries.first != slot.boundaries.second) {
|
|
||||||
slot_boundaries.push_back(slot.boundaries);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// must be explicitly called to restore the kv_cache state
|
|
||||||
// and rollback changes from all llama_kv_cache_find_slot calls
|
|
||||||
void restore() {
|
|
||||||
if (do_restore) {
|
|
||||||
cache.head = old_state.head;
|
|
||||||
cache.n = old_state.n;
|
|
||||||
|
|
||||||
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
|
|
||||||
cache.seq_rm(-1, -1, -1);
|
|
||||||
} else {
|
|
||||||
for (auto & slot : slot_boundaries) {
|
|
||||||
cache.seq_rm(-1, slot.first, slot.second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO: maybe become part of the public llama_kv_cache in the future
|
// TODO: maybe become part of the public llama_kv_cache in the future
|
||||||
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
|
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user