kv-cache : rework kv_cell (#13706)

* kv-cache : rework kv_cell

ggml-ci

* kv-cells : use "shift" instead of "delta" consistently

ggml-ci

* llama : add llama_max_parallel_sequences()

ggml-ci

* kv-cells : update comments [no ci]

* context : fail upon construction if sequences exceed max value

ggml-ci

* kv-cells : get_pos() -> pos_get() + comments

ggml-ci

* kv-cells : fix tracking of "used" cells

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-05-25 16:34:36 +03:00
committed by GitHub
parent c508256db2
commit de2ef53a4b
8 changed files with 470 additions and 253 deletions

View File

@ -471,6 +471,7 @@ extern "C" {
LLAMA_API int64_t llama_time_us(void); LLAMA_API int64_t llama_time_us(void);
LLAMA_API size_t llama_max_devices(void); LLAMA_API size_t llama_max_devices(void);
LLAMA_API size_t llama_max_parallel_sequences(void);
LLAMA_API bool llama_supports_mmap (void); LLAMA_API bool llama_supports_mmap (void);
LLAMA_API bool llama_supports_mlock (void); LLAMA_API bool llama_supports_mlock (void);

View File

@ -25,7 +25,11 @@ llama_context::llama_context(
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
cparams.n_seq_max = std::max(1u, params.n_seq_max); cparams.n_seq_max = std::max(1u, params.n_seq_max);
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
}
cparams.n_threads = params.n_threads; cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch; cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor; cparams.yarn_ext_factor = params.yarn_ext_factor;

View File

@ -1 +1,5 @@
#include "llama-cparams.h" #include "llama-cparams.h"
size_t llama_max_parallel_sequences(void) {
return LLAMA_MAX_PARALLEL_SEQUENCES;
}

View File

@ -4,6 +4,8 @@
#include <cstdint> #include <cstdint>
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
struct llama_cparams { struct llama_cparams {
uint32_t n_ctx; // context size used during inference uint32_t n_ctx; // context size used during inference
uint32_t n_batch; uint32_t n_batch;

View File

@ -65,8 +65,6 @@ llama_kv_cache_unified::llama_kv_cache_unified(
}; };
head = 0; head = 0;
size = kv_size;
used = 0;
cells.resize(kv_size); cells.resize(kv_size);
@ -138,13 +136,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
} }
void llama_kv_cache_unified::clear() { void llama_kv_cache_unified::clear() {
for (uint32_t i = 0; i < size; ++i) { cells.reset();
cells[i].pos = -1;
cells[i].seq_id.clear();
}
head = 0; head = 0;
used = 0;
for (auto & buf : bufs) { for (auto & buf : bufs) {
ggml_backend_buffer_clear(buf.get(), 0); ggml_backend_buffer_clear(buf.get(), 0);
@ -152,7 +146,7 @@ void llama_kv_cache_unified::clear() {
} }
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
uint32_t new_head = size; uint32_t new_head = cells.size();
if (p0 < 0) { if (p0 < 0) {
p0 = 0; p0 = 0;
@ -162,33 +156,20 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
p1 = std::numeric_limits<llama_pos>::max(); p1 = std::numeric_limits<llama_pos>::max();
} }
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < cells.size(); ++i) {
if (cells[i].pos >= p0 && cells[i].pos < p1) { if (!cells.pos_in(i, p0, p1)) {
if (seq_id < 0) { continue;
cells[i].seq_id.clear(); }
} else if (cells[i].has_seq_id(seq_id)) {
cells[i].seq_id.erase(seq_id);
} else {
continue;
}
if (cells[i].is_empty()) { if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
// keep count of the number of used cells if (new_head == cells.size()) {
if (cells[i].pos >= 0) { new_head = i;
used--;
}
cells[i].pos = -1;
if (new_head == size) {
new_head = i;
}
} }
} }
} }
// If we freed up a slot, set head to it so searching can start there. // If we freed up a slot, set head to it so searching can start there.
if (new_head != size && new_head < head) { if (new_head != cells.size() && new_head < head) {
head = new_head; head = new_head;
} }
@ -208,49 +189,40 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
p1 = std::numeric_limits<llama_pos>::max(); p1 = std::numeric_limits<llama_pos>::max();
} }
// otherwise, this is the KV of a Transformer-like model for (uint32_t i = 0; i < cells.size(); ++i) {
head = 0; if (!cells.pos_in(i, p0, p1)) {
continue;
}
for (uint32_t i = 0; i < size; ++i) { if (cells.seq_has(i, seq_id_src)) {
if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) { cells.seq_add(i, seq_id_dst);
cells[i].seq_id.insert(seq_id_dst);
} }
} }
} }
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
uint32_t new_head = size; uint32_t new_head = cells.size();
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < cells.size(); ++i) {
if (!cells[i].has_seq_id(seq_id)) { if (cells.seq_keep(i, seq_id)) {
if (cells[i].pos >= 0) { if (new_head == cells.size()) {
used--;
}
cells[i].pos = -1;
cells[i].seq_id.clear();
if (new_head == size){
new_head = i; new_head = i;
} }
} else {
cells[i].seq_id.clear();
cells[i].seq_id.insert(seq_id);
} }
} }
// If we freed up a slot, set head to it so searching can start there. // If we freed up a slot, set head to it so searching can start there.
if (new_head != size && new_head < head) { if (new_head != cells.size() && new_head < head) {
head = new_head; head = new_head;
} }
} }
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
if (delta == 0) { if (shift == 0) {
return; return;
} }
uint32_t new_head = size; uint32_t new_head = cells.size();
if (p0 < 0) { if (p0 < 0) {
p0 = 0; p0 = 0;
@ -260,25 +232,19 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
p1 = std::numeric_limits<llama_pos>::max(); p1 = std::numeric_limits<llama_pos>::max();
} }
// If there is no range then return early to avoid looping over the // If there is no range then return early to avoid looping over all cells.
if (p0 == p1) { if (p0 == p1) {
return; return;
} }
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < cells.size(); ++i) {
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { if (!cells.pos_in(i, p0, p1)) {
has_shift = true; continue;
}
cells[i].pos += delta; if (cells.seq_has(i, seq_id)) {
cells[i].delta += delta; if (cells.pos_add(i, shift)) {
if (new_head == cells.size()) {
if (cells[i].pos < 0) {
if (!cells[i].is_empty()) {
used--;
}
cells[i].pos = -1;
cells[i].seq_id.clear();
if (new_head == size) {
new_head = i; new_head = i;
} }
} }
@ -287,7 +253,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
// If we freed up a slot, set head to it so searching can start there. // If we freed up a slot, set head to it so searching can start there.
// Otherwise we just start the next search from the beginning. // Otherwise we just start the next search from the beginning.
head = new_head != size ? new_head : 0; head = new_head != cells.size() ? new_head : 0;
} }
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
@ -308,15 +274,13 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
return; return;
} }
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < cells.size(); ++i) {
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { if (!cells.pos_in(i, p0, p1)) {
has_shift = true; continue;
}
{ if (cells.seq_has(i, seq_id)) {
llama_pos p_old = cells[i].pos; cells.pos_div(i, d);
cells[i].pos /= d;
cells[i].delta += cells[i].pos - p_old;
}
} }
} }
} }
@ -324,9 +288,9 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
llama_pos result = std::numeric_limits<llama_pos>::max(); llama_pos result = std::numeric_limits<llama_pos>::max();
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < cells.size(); ++i) {
if (cells[i].has_seq_id(seq_id)) { if (cells.seq_has(i, seq_id)) {
result = std::min(result, cells[i].pos); result = std::min(result, cells.pos_get(i));
} }
} }
@ -340,9 +304,9 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_pos result = -1; llama_pos result = -1;
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < cells.size(); ++i) {
if (cells[i].has_seq_id(seq_id)) { if (cells.seq_has(i, seq_id)) {
result = std::max(result, cells[i].pos); result = std::max(result, cells.pos_get(i));
} }
} }
@ -350,25 +314,15 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
} }
void llama_kv_cache_unified::restore() { void llama_kv_cache_unified::restore() {
for (const auto & [id, cell] : recovery.cells) { for (auto & state : recovery.states) {
// TODO: move to new `struct kv_cells` cells.set(state.i, state.cells);
const bool is_empty0 = cells[id].is_empty();
const bool is_empty1 = cell.is_empty();
if (!is_empty0 && is_empty1) {
used--;
} else if (is_empty0 && !is_empty1) {
used++;
}
cells[id] = cell;
} }
recovery.clear(); recovery.clear();
} }
void llama_kv_cache_unified::commit() { void llama_kv_cache_unified::commit() {
if (recovery.cells.empty()) { if (recovery.states.empty()) {
LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n", LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194"); __func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
return; return;
@ -382,7 +336,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
auto * sched = lctx.get_sched(); auto * sched = lctx.get_sched();
if (has_shift) { if (cells.get_has_shift()) {
if (!get_can_shift()) { if (!get_can_shift()) {
GGML_ABORT("The current KV cache / model configuration does not support K-shift"); GGML_ABORT("The current KV cache / model configuration does not support K-shift");
} }
@ -406,13 +360,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
need_reserve = true; need_reserve = true;
} }
{ cells.reset_shift();
has_shift = false;
for (uint32_t i = 0; i < size; ++i) {
cells[i].delta = 0;
}
}
} }
if (do_defrag) { if (do_defrag) {
@ -443,7 +391,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
void llama_kv_cache_unified::defrag_sched(float thold) { void llama_kv_cache_unified::defrag_sched(float thold) {
// - do not defrag small contexts (i.e. < 2048 tokens) // - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens // - count the padding towards the number of used tokens
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f; const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update // queue defragmentation for next llama_kv_cache_update
if (fragmentation > thold) { if (fragmentation > thold) {
@ -454,7 +402,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
} }
void llama_kv_cache_unified::set_full() { void llama_kv_cache_unified::set_full() {
n = size; n = cells.size();
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not // when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views. // affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
@ -478,14 +426,14 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
// if we have enough unused cells before the current head -> // if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it // better to start searching from the beginning of the cache, hoping to fill it
if (head > used + 2*ubatch.n_tokens) { if (head > cells.get_used() + 2*ubatch.n_tokens) {
head = 0; head = 0;
} }
// otherwise, one cell per token. // otherwise, one cell per token.
if (n_tokens > size) { if (n_tokens > cells.size()) {
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
return false; return false;
} }
@ -498,10 +446,10 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
std::string ss; std::string ss;
if (n_swa > 0) { if (n_swa > 0) {
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < size; ++i) {
if (cells[i].pos == -1) { if (cells.is_empty(i)) {
ss += '.'; ss += '.';
} else { } else {
ss += std::to_string(*cells[i].seq_id.begin()); ss += 'x';
} }
if (i%256 == 255) { if (i%256 == 255) {
ss += '\n'; ss += '\n';
@ -515,15 +463,16 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
uint32_t n_tested = 0; uint32_t n_tested = 0;
while (true) { while (true) {
if (head + n_tokens > size) { if (head + n_tokens > cells.size()) {
n_tested += size - head; n_tested += cells.size() - head;
head = 0; head = 0;
continue; continue;
} }
bool found = true; bool found = true;
for (uint32_t i = 0; i < n_tokens; i++) { for (uint32_t i = 0; i < n_tokens; i++) {
if (cells[head + i].pos >= 0) { // TODO: improve to accept cells that are masked by the SWA
if (!cells.is_empty(head + i)) {
found = false; found = false;
head += i + 1; head += i + 1;
n_tested += i + 1; n_tested += i + 1;
@ -535,31 +484,27 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
break; break;
} }
if (n_tested >= size) { if (n_tested >= cells.size()) {
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
return false; return false;
} }
} }
for (uint32_t i = 0; i < n_tokens; ++i) { // store the old state of the cells in the recovery stack
// remember the original state recovery.states.push_back({head, cells.cp(head, n_tokens)});
if (recovery.cells.find(head + i) == recovery.cells.end()) {
recovery.cells[head + i] = cells[head + i];
}
cells[head + i].pos = ubatch.pos[i]; for (uint32_t i = 0; i < n_tokens; ++i) {
cells.pos_set(head + i, ubatch.pos[i]);
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
cells[head + i].seq_id.insert(ubatch.seq_id[i][j]); cells.seq_add(head + i, ubatch.seq_id[i][j]);
} }
} }
used += n_tokens;
// a heuristic, to avoid attending the full cache if it is not yet utilized // a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears // after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important // if we start defragmenting the cache, the benefit from this will be more important
n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad))); n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
#ifdef FIND_SLOT_DEBUG #ifdef FIND_SLOT_DEBUG
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa); LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
@ -577,7 +522,7 @@ uint32_t llama_kv_cache_unified::get_n() const {
} }
uint32_t llama_kv_cache_unified::get_size() const { uint32_t llama_kv_cache_unified::get_size() const {
return size; return cells.size();
} }
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const { ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
@ -661,30 +606,19 @@ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llam
int n_attended = 0; int n_attended = 0;
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < cells.size(); ++i) {
const llama_pos p0 = cells[i].pos; if (!cells.seq_has(i, seq_id)) {
continue;
}
const llama_pos p0 = cells.pos_get(i);
if (p0 <= pmin && !is_masked_swa(p0, pmin)) { if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
n_attended++; n_attended++;
} }
if (is_masked_swa(p0, pmax)) { if (is_masked_swa(p0, pmax)) {
if (seq_id < 0) { cells.seq_rm(i, seq_id);
cells[i].seq_id.clear();
} else if (cells[i].has_seq_id(seq_id)) {
cells[i].seq_id.erase(seq_id);
} else {
continue;
}
if (cells[i].is_empty()) {
// keep count of the number of used cells
if (cells[i].pos >= 0) {
used--;
}
cells[i].pos = -1;
}
} }
} }
@ -723,25 +657,31 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
const llama_pos p0 = cells[i].pos; float f = 0.0f;
bool masked = false; bool masked = false;
// mask the token if not the same sequence if (cells.is_empty(i)) {
masked = masked || (!cells[i].has_seq_id(seq_id)); masked = true;
} else {
const llama_pos p0 = cells.pos_get(i);
// mask future tokens // mask the token if not the same sequence
masked = masked || (causal_attn && p0 > p1); masked = masked || (!cells.seq_has(i, seq_id));
// apply SWA if any // mask future tokens
masked = masked || (is_masked_swa(p0, p1)); masked = masked || (causal_attn && p0 > p1);
float f = 0.0f; // apply SWA if any
masked = masked || (is_masked_swa(p0, p1));
if (!masked && hparams.use_alibi) {
f = -std::abs(p0 - p1);
}
}
if (masked) { if (masked) {
f = -INFINITY; f = -INFINITY;
} else if (hparams.use_alibi) {
f = -std::abs(p0 - p1);
} }
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
@ -765,8 +705,8 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
int32_t * data = (int32_t *) dst->data; int32_t * data = (int32_t *) dst->data;
for (uint32_t i = 0; i < size; ++i) { for (uint32_t i = 0; i < cells.size(); ++i) {
data[i] = cells[i].delta; data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
} }
} }
@ -783,7 +723,10 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
for (int h = 0; h < 1; ++h) { for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) { for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_kv; ++i) { for (int i = 0; i < n_kv; ++i) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false); // the position when the cells is empty is irrelevant - it will be masked out later in the attention
const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
} }
} }
} }
@ -910,7 +853,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
ggml_tensor * k = ggml_tensor * k =
ggml_view_3d(ctx, layer.k, ggml_view_3d(ctx, layer.k,
n_embd_head_k, n_head_kv, size, n_embd_head_k, n_head_kv, cells.size(),
ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_head_k),
ggml_row_size(layer.k->type, n_embd_k_gqa), ggml_row_size(layer.k->type, n_embd_k_gqa),
0); 0);
@ -1050,12 +993,12 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
} else { } else {
view_v_src = ggml_view_2d(ctx, layer.v, view_v_src = ggml_view_2d(ctx, layer.v,
nm, n_embd_v_gqa, nm, n_embd_v_gqa,
ggml_row_size(layer.v->type, size), ggml_row_size(layer.v->type, cells.size()),
ggml_row_size(layer.v->type, i)); ggml_row_size(layer.v->type, i));
view_v_dst = ggml_view_2d(ctx, layer.v, view_v_dst = ggml_view_2d(ctx, layer.v,
nm, n_embd_v_gqa, nm, n_embd_v_gqa,
ggml_row_size(layer.v->type, size), ggml_row_size(layer.v->type, cells.size()),
ggml_row_size(layer.v->type, id)); ggml_row_size(layer.v->type, id));
} }
@ -1076,7 +1019,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
const uint32_t n_layer = layers.size(); const uint32_t n_layer = layers.size();
const uint32_t n_kv = cell_max(); const uint32_t n_kv = cell_max();
const uint32_t n_used = used; const uint32_t n_used = cells.get_used();
assert(n_used <= n_kv); assert(n_used <= n_kv);
@ -1104,9 +1047,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
ids.resize(n_kv, n_kv); ids.resize(n_kv, n_kv);
for (uint32_t i0 = 0; i0 < n_used; ++i0) { for (uint32_t i0 = 0; i0 < n_used; ++i0) {
const auto & cell0 = cells[i0]; if (!cells.is_empty(i0)) {
if (!cell0.is_empty()) {
ids[i0] = i0; ids[i0] = i0;
continue; continue;
@ -1117,7 +1058,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
uint32_t nh = 1; uint32_t nh = 1;
// determine the size of the hole // determine the size of the hole
while (i0 + nh < n_used && cells[i0 + nh].is_empty()) { while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
nh++; nh++;
} }
@ -1126,9 +1067,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
// starting from the end, find nh non-empty cells // starting from the end, find nh non-empty cells
for (; is > i0; --is) { for (; is > i0; --is) {
const auto & cell1 = cells[is]; if (cells.is_empty(is) || ids[is] != n_kv) {
if (cell1.is_empty() || ids[is] != n_kv) {
continue; continue;
} }
@ -1155,9 +1094,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
// go back and move the nf cells to the hole // go back and move the nf cells to the hole
for (; i1 < n_kv; ++i1) { for (; i1 < n_kv; ++i1) {
auto & cell1 = cells[i1]; if (cells.is_empty(i1) || ids[i1] != n_kv) {
if (cell1.is_empty() || ids[i1] != n_kv) {
if (n_moves == max_moves) { if (n_moves == max_moves) {
stop = true; stop = true;
break; break;
@ -1171,10 +1108,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
ids[i1] = i0 + nf; ids[i1] = i0 + nf;
// move the cell meta data // move the cell meta data
cells[i0 + nf] = cell1; cells.mv(i1, i0 + nf);
// clear the old cell and move the head there
cell1 = kv_cell();
head = n_used; head = n_used;
if (!cont) { if (!cont) {
@ -1210,10 +1145,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
} }
uint32_t llama_kv_cache_unified::cell_max() const { uint32_t llama_kv_cache_unified::cell_max() const {
for (uint32_t i = size; i > 0; --i) { for (uint32_t i = cells.size(); i > 0; --i) {
const kv_cell & cell = cells[i - 1]; if (!cells.is_empty(i - 1)) {
if (cell.pos >= 0 && !cell.is_empty()) {
return i; return i;
} }
} }
@ -1222,9 +1155,7 @@ uint32_t llama_kv_cache_unified::cell_max() const {
} }
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
if (p0 < 0) { assert(p0 >= 0 && p1 >= 0);
return true;
}
switch (swa_type) { switch (swa_type) {
case LLAMA_SWA_TYPE_NONE: case LLAMA_SWA_TYPE_NONE:
@ -1255,23 +1186,24 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
// Count the number of cells with the specified seq_id // Count the number of cells with the specified seq_id
// Find all the ranges of cells with this seq id (or all, when -1) // Find all the ranges of cells with this seq id (or all, when -1)
uint32_t cell_range_begin = size; uint32_t cell_range_begin = cells.size();
for (uint32_t i = 0; i < size; ++i) {
const auto & cell = cells[i]; for (uint32_t i = 0; i < cells.size(); ++i) {
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
++cell_count; ++cell_count;
if (cell_range_begin == size) { if (cell_range_begin == cells.size()) {
cell_range_begin = i; cell_range_begin = i;
} }
} else { } else {
if (cell_range_begin != size) { if (cell_range_begin != cells.size()) {
cell_ranges.emplace_back(cell_range_begin, i); cell_ranges.emplace_back(cell_range_begin, i);
cell_range_begin = size; cell_range_begin = cells.size();
} }
} }
} }
if (cell_range_begin != size) {
cell_ranges.emplace_back(cell_range_begin, size); if (cell_range_begin != cells.size()) {
cell_ranges.emplace_back(cell_range_begin, cells.size());
} }
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
@ -1308,17 +1240,24 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const { void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
for (const auto & range : cell_ranges) { for (const auto & range : cell_ranges) {
for (uint32_t i = range.first; i < range.second; ++i) { for (uint32_t i = range.first; i < range.second; ++i) {
const auto & cell = cells[i]; std::vector<llama_seq_id> seq_ids;
const llama_pos pos = cell.pos;
const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
if (cur == seq_id || seq_id == -1) {
if (cells.seq_has(i, cur)) {
seq_ids.push_back(cur);
}
}
}
const llama_pos pos = cells.pos_get(i);
const uint32_t n_seq_id = seq_ids.size();
io.write(&pos, sizeof(pos)); io.write(&pos, sizeof(pos));
io.write(&n_seq_id, sizeof(n_seq_id)); io.write(&n_seq_id, sizeof(n_seq_id));
if (n_seq_id) { for (const auto & seq_id : seq_ids) {
for (auto seq_id : cell.seq_id) { io.write(&seq_id, sizeof(seq_id));
io.write(&seq_id, sizeof(seq_id));
}
} }
} }
} }
@ -1379,7 +1318,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
} }
} else { } else {
// When v is transposed, we also need the element size and get the element ranges from each row // When v is transposed, we also need the element size and get the element ranges from each row
const uint32_t kv_size = size; const uint32_t kv_size = cells.size();
for (const auto & layer : layers) { for (const auto & layer : layers) {
const uint32_t il = layer.il; const uint32_t il = layer.il;
@ -1429,14 +1368,20 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
io.read_to(&pos, sizeof(pos)); io.read_to(&pos, sizeof(pos));
io.read_to(&n_seq_id, sizeof(n_seq_id)); io.read_to(&n_seq_id, sizeof(n_seq_id));
if (n_seq_id != 0) { if (n_seq_id != 1) {
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
return false; return false;
} }
batch.pos[i] = pos; // read the sequence id, but directly discard it - we will use dest_seq_id instead
batch.n_seq_id[i] = 1; {
batch.seq_id[i] = &dest_seq_id; llama_seq_id seq_id;
io.read_to(&seq_id, sizeof(seq_id));
}
batch.pos[i] = pos;
batch.n_seq_id[i] = n_seq_id;
batch.seq_id[i] = &dest_seq_id;
} }
if (!find_slot(batch)) { if (!find_slot(batch)) {
@ -1448,15 +1393,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells // Assume that this is one contiguous block of cells
GGML_ASSERT(head + cell_count <= size); GGML_ASSERT(head + cell_count <= cells.size());
GGML_ASSERT(cells[head].pos == batch.pos[0]); GGML_ASSERT(cells.pos_get(head) == batch.pos[0]);
GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); GGML_ASSERT(cells.seq_has(head, dest_seq_id));
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
} else { } else {
// whole KV cache restore // whole KV cache restore
if (cell_count > size) { if (cell_count > cells.size()) {
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
return false; return false;
} }
@ -1464,15 +1409,13 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
clear(); clear();
for (uint32_t i = 0; i < cell_count; ++i) { for (uint32_t i = 0; i < cell_count; ++i) {
kv_cell & cell = cells[i];
llama_pos pos; llama_pos pos;
uint32_t n_seq_id; uint32_t n_seq_id;
io.read_to(&pos, sizeof(pos)); io.read_to(&pos, sizeof(pos));
io.read_to(&n_seq_id, sizeof(n_seq_id)); io.read_to(&n_seq_id, sizeof(n_seq_id));
cell.pos = pos; cells.pos_set(i, pos);
for (uint32_t j = 0; j < n_seq_id; ++j) { for (uint32_t j = 0; j < n_seq_id; ++j) {
llama_seq_id seq_id; llama_seq_id seq_id;
@ -1483,12 +1426,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
return false; return false;
} }
cell.seq_id.insert(seq_id); cells.seq_add(i, seq_id);
} }
} }
head = 0; head = 0;
used = cell_count;
} }
return true; return true;
@ -1505,8 +1447,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
return false; return false;
} }
if (cell_count > size) { if (cell_count > cells.size()) {
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
return false; return false;
} }
if (this->v_trans != (bool) v_trans) { if (this->v_trans != (bool) v_trans) {
@ -1609,7 +1551,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
if (cell_count) { if (cell_count) {
// For each row in the transposed matrix, read the values for the whole cell range // For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (head + j * size) * v_size_el; const size_t dst_offset = (head + j * cells.size()) * v_size_el;
ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
} }
} }
@ -1689,9 +1631,9 @@ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
kv_swa ->seq_keep(seq_id); kv_swa ->seq_keep(seq_id);
} }
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
kv_base->seq_add(seq_id, p0, p1, delta); kv_base->seq_add(seq_id, p0, p1, shift);
kv_swa ->seq_add(seq_id, p0, p1, delta); kv_swa ->seq_add(seq_id, p0, p1, shift);
} }
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
@ -2063,8 +2005,8 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
} }
} }
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
if (delta == 0) { if (shift == 0) {
return; return;
} }
@ -2087,7 +2029,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
if (tail_id >= 0) { if (tail_id >= 0) {
kv_cell & cell = cells[tail_id]; kv_cell & cell = cells[tail_id];
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
cell.pos += delta; cell.pos += shift;
} }
} }
} }

View File

@ -4,6 +4,7 @@
#include "llama-io.h" #include "llama-io.h"
#include "llama-graph.h" #include "llama-graph.h"
#include "llama-memory.h" #include "llama-memory.h"
#include "llama-kv-cells.h"
#include "ggml-cpp.h" #include "ggml-cpp.h"
@ -35,6 +36,7 @@ struct llama_kv_cache : public llama_memory_i {
virtual void defrag_sched(float thold) = 0; virtual void defrag_sched(float thold) = 0;
// simulate full cache, used for allocating worst-case compute buffers // simulate full cache, used for allocating worst-case compute buffers
// TODO: remove
virtual void set_full() = 0; virtual void set_full() = 0;
// //
@ -42,7 +44,7 @@ struct llama_kv_cache : public llama_memory_i {
// //
// ============================================================================================================= // =============================================================================================================
// TODO: refactor and simplify this // TODO: refactor and simplify this [TAG: KV_API]
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
@ -121,7 +123,7 @@ public:
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override; void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_min(llama_seq_id seq_id) const override;
@ -159,7 +161,7 @@ public:
// llama_kv_cache_unified specific API // llama_kv_cache_unified specific API
// //
uint32_t get_n() const; uint32_t get_n() const;
uint32_t get_size() const; uint32_t get_size() const;
// get views of the current state of the cache // get views of the current state of the cache
@ -180,26 +182,6 @@ private:
const llama_model & model; const llama_model & model;
const llama_hparams & hparams; const llama_hparams & hparams;
struct kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
// TODO: replace with bitset uint64_t
std::set<llama_seq_id> seq_id;
bool has_seq_id(const llama_seq_id & id) const {
return seq_id.find(id) != seq_id.end();
}
bool is_empty() const {
return seq_id.empty();
}
bool is_same_seq(const kv_cell & other) const {
return seq_id == other.seq_id;
}
};
struct kv_layer { struct kv_layer {
// layer index in the model // layer index in the model
// note: can be different from the layer index in the KV cache // note: can be different from the layer index in the KV cache
@ -209,15 +191,13 @@ private:
ggml_tensor * v; ggml_tensor * v;
}; };
bool has_shift = false;
bool do_defrag = false; bool do_defrag = false;
bool v_trans = true; // the value tensor is transposed bool v_trans = true; // the value tensor is transposed
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt)
// computed before each graph build // computed before each graph build
// TODO: cells should start to maintain this value dynamically based on the edits
uint32_t n = 0; uint32_t n = 0;
const uint32_t n_seq_max = 1; const uint32_t n_seq_max = 1;
@ -233,19 +213,29 @@ private:
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;
std::vector<kv_cell> cells; // TODO: replace with `struct kv_cells` llama_kv_cells_unified cells;
std::vector<kv_layer> layers; std::vector<kv_layer> layers;
// model layer id -> KV cache layer id // model layer id -> KV cache layer id
std::unordered_map<int32_t, int32_t> map_layer_ids; std::unordered_map<int32_t, int32_t> map_layer_ids;
// recovery information used to restore the KV cells to their original state in case of a failure // recovery information used to restore the KV cells to their original state in case of a failure
// TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
// to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
struct { struct {
void clear() { void clear() {
cells.clear(); states.clear();
} }
std::unordered_map<uint32_t, kv_cell> cells; struct state {
uint32_t i;
llama_kv_cells_unified cells;
};
// stack with the partial states before each ubatch
std::vector<state> states;
} recovery; } recovery;
// defrag // defrag
@ -257,6 +247,7 @@ private:
bool defrag_prepare(int32_t n_max_nodes); bool defrag_prepare(int32_t n_max_nodes);
// find how many cells are currently in use // find how many cells are currently in use
// TODO: optimize
uint32_t cell_max() const; uint32_t cell_max() const;
size_t total_size() const; size_t total_size() const;
@ -325,7 +316,7 @@ public:
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override; void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_min(llama_seq_id seq_id) const override;
@ -431,7 +422,7 @@ public:
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override; void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override; llama_pos seq_pos_min(llama_seq_id seq_id) const override;

273
src/llama-kv-cells.h Normal file
View File

@ -0,0 +1,273 @@
#pragma once
#include "llama.h"
#include "llama-cparams.h"
#include <bitset>
#include <cassert>
#include <vector>
// meta information about KV cells that can be part of multiple sequences at the same time
// TODO: add unit tests
class llama_kv_cells_unified {
public:
void reset() {
for (uint32_t i = 0; i < pos.size(); ++i) {
pos[i] = -1;
shift[i] = 0;
seq[i].reset();
}
used = 0;
has_shift = false;
}
void reset_shift() {
has_shift = false;
for (uint32_t i = 0; i < shift.size(); ++i) {
shift[i] = 0;
}
}
uint32_t size() const {
return pos.size();
}
void resize(uint32_t n) {
pos.resize(n);
shift.resize(n);
seq.resize(n);
reset();
}
bool is_empty(uint32_t i) const {
assert(i < pos.size());
assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
return pos[i] == -1;
}
uint32_t get_used() const {
return used;
}
bool get_has_shift() const {
return has_shift;
}
// move cell isrc to idst (used during defrag)
void mv(uint32_t isrc, uint32_t idst) {
assert(isrc < pos.size());
assert(idst < pos.size());
pos [idst] = pos [isrc];
shift[idst] = shift[isrc];
seq [idst] = seq [isrc];
pos [isrc] = -1;
shift[isrc] = 0;
seq [isrc].reset();
}
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
assert(i + n <= pos.size());
llama_kv_cells_unified res;
res.resize(n);
for (uint32_t j = 0; j < n; ++j) {
res.pos[j] = pos[i + j];
res.seq[j] = seq[i + j];
assert(shift[i + j] == 0);
}
return res;
}
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
void set(uint32_t i, const llama_kv_cells_unified & other) {
assert(i + other.pos.size() <= pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) {
if (pos[i + j] == -1 && other.pos[j] != -1) {
used++;
}
if (pos[i + j] != -1 && other.pos[j] == -1) {
used--;
}
pos[i + j] = other.pos[j];
seq[i + j] = other.seq[j];
assert(shift[i + j] == 0);
}
}
// note: call only if the cell has seq_id
// return true if the cell becomes empty
bool seq_rm(uint32_t i, llama_seq_id seq_id) {
assert(i < pos.size());
assert(seq[i].test(seq_id));
assert(pos[i] != -1);
assert(seq_id >= 0);
seq[i].reset(seq_id);
if (seq[i].none()) {
pos[i] = -1;
used--;
return true;
}
return false;
}
// return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
bool seq_keep(uint32_t i, llama_seq_id seq_id) {
assert(i < pos.size());
if (seq[i].test(seq_id)) {
seq[i].reset();
seq[i].set(seq_id);
return false;
}
if (seq[i].any()) {
seq[i].reset();
pos[i] = -1;
used--;
return true;
}
assert(pos[i] == -1);
return false;
}
bool seq_has(uint32_t i, llama_seq_id seq_id) const {
assert(i < pos.size());
assert(seq_id >= 0);
return seq[i].test(seq_id);
}
// note: call only if the cell is not empty and the seq_id is not in the cell
void seq_add(uint32_t i, llama_seq_id seq_id) {
assert(i < pos.size());
assert(pos[i] != -1);
assert(!seq[i].test(seq_id));
seq[i].set(seq_id);
}
// note: call only if the cell is not empty
llama_pos pos_get(uint32_t i) const {
assert(i < pos.size());
assert(pos[i] != -1);
return pos[i];
}
// note: call only if the cell is not empty
llama_pos get_shift(uint32_t i) const {
assert(i < pos.size());
assert(pos[i] != -1);
return shift[i];
}
// check if a cell is not empty and its position is within [p0, p1)
bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
assert(i < pos.size());
return pos[i] >= p0 && pos[i] < p1;
}
// set the position of an empty cell
// does not modify "has_shift"
// note: call only if the cell is empty
void pos_set(uint32_t i, llama_pos p) {
assert(i < pos.size());
assert(pos[i] == -1);
pos[i] = p;
used++;
}
// pos[i] = pos[i] + d
// sets "has_shift" to true
// note: call only if the cell is not empty
bool pos_add(uint32_t i, llama_pos d) {
assert(i < pos.size());
assert(pos[i] != -1);
pos[i] += d;
shift[i] += d;
has_shift = true;
if (pos[i] < 0) {
pos[i] = -1;
seq[i].reset();
used--;
return true;
}
return false;
}
// pos[i] = pos[i] / d
// sets "has_shift" to true
// note: call only if the cell is not empty
void pos_div(uint32_t i, int d) {
assert(i < pos.size());
assert(pos[i] != -1);
const llama_pos p_old = pos[i];
pos[i] /= d;
shift[i] += p_old - pos[i];
has_shift = true;
}
private:
uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
bool has_shift = false;
std::vector<llama_pos> pos;
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
//
// cells.pos_add(x, shift_x);
// cells.pos_div(y, shift_y);
// ...
//
// if (cells.has_shift()) {
// for (int i = 0; i < n; ++i) {
// auto shift_i = cells.get_shift(i);
// ...
// }
// cells.reset_shift();
// }
//
std::vector<llama_pos> shift;
std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
};

View File

@ -22,7 +22,7 @@ public:
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
virtual void seq_keep(llama_seq_id seq_id) = 0; virtual void seq_keep(llama_seq_id seq_id) = 0;
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0; virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;