mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 03:55:20 +00:00
kv-cache : rework kv_cell (#13706)
* kv-cache : rework kv_cell ggml-ci * kv-cells : use "shift" instead of "delta" consistently ggml-ci * llama : add llama_max_parallel_sequences() ggml-ci * kv-cells : update comments [no ci] * context : fail upon construction if sequences exceed max value ggml-ci * kv-cells : get_pos() -> pos_get() + comments ggml-ci * kv-cells : fix tracking of "used" cells ggml-ci
This commit is contained in:
@ -471,6 +471,7 @@ extern "C" {
|
|||||||
LLAMA_API int64_t llama_time_us(void);
|
LLAMA_API int64_t llama_time_us(void);
|
||||||
|
|
||||||
LLAMA_API size_t llama_max_devices(void);
|
LLAMA_API size_t llama_max_devices(void);
|
||||||
|
LLAMA_API size_t llama_max_parallel_sequences(void);
|
||||||
|
|
||||||
LLAMA_API bool llama_supports_mmap (void);
|
LLAMA_API bool llama_supports_mmap (void);
|
||||||
LLAMA_API bool llama_supports_mlock (void);
|
LLAMA_API bool llama_supports_mlock (void);
|
||||||
|
@ -26,6 +26,10 @@ llama_context::llama_context(
|
|||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
cparams.n_seq_max = std::max(1u, params.n_seq_max);
|
||||||
|
if (cparams.n_seq_max > LLAMA_MAX_PARALLEL_SEQUENCES) {
|
||||||
|
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_PARALLEL_SEQUENCES));
|
||||||
|
}
|
||||||
|
|
||||||
cparams.n_threads = params.n_threads;
|
cparams.n_threads = params.n_threads;
|
||||||
cparams.n_threads_batch = params.n_threads_batch;
|
cparams.n_threads_batch = params.n_threads_batch;
|
||||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||||
|
@ -1 +1,5 @@
|
|||||||
#include "llama-cparams.h"
|
#include "llama-cparams.h"
|
||||||
|
|
||||||
|
size_t llama_max_parallel_sequences(void) {
|
||||||
|
return LLAMA_MAX_PARALLEL_SEQUENCES;
|
||||||
|
}
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
|
#define LLAMA_MAX_PARALLEL_SEQUENCES 64
|
||||||
|
|
||||||
struct llama_cparams {
|
struct llama_cparams {
|
||||||
uint32_t n_ctx; // context size used during inference
|
uint32_t n_ctx; // context size used during inference
|
||||||
uint32_t n_batch;
|
uint32_t n_batch;
|
||||||
|
@ -65,8 +65,6 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||||||
};
|
};
|
||||||
|
|
||||||
head = 0;
|
head = 0;
|
||||||
size = kv_size;
|
|
||||||
used = 0;
|
|
||||||
|
|
||||||
cells.resize(kv_size);
|
cells.resize(kv_size);
|
||||||
|
|
||||||
@ -138,13 +136,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::clear() {
|
void llama_kv_cache_unified::clear() {
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
cells.reset();
|
||||||
cells[i].pos = -1;
|
|
||||||
cells[i].seq_id.clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
head = 0;
|
head = 0;
|
||||||
used = 0;
|
|
||||||
|
|
||||||
for (auto & buf : bufs) {
|
for (auto & buf : bufs) {
|
||||||
ggml_backend_buffer_clear(buf.get(), 0);
|
ggml_backend_buffer_clear(buf.get(), 0);
|
||||||
@ -152,7 +146,7 @@ void llama_kv_cache_unified::clear() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||||
uint32_t new_head = size;
|
uint32_t new_head = cells.size();
|
||||||
|
|
||||||
if (p0 < 0) {
|
if (p0 < 0) {
|
||||||
p0 = 0;
|
p0 = 0;
|
||||||
@ -162,33 +156,20 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||||||
p1 = std::numeric_limits<llama_pos>::max();
|
p1 = std::numeric_limits<llama_pos>::max();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
if (cells[i].pos >= p0 && cells[i].pos < p1) {
|
if (!cells.pos_in(i, p0, p1)) {
|
||||||
if (seq_id < 0) {
|
|
||||||
cells[i].seq_id.clear();
|
|
||||||
} else if (cells[i].has_seq_id(seq_id)) {
|
|
||||||
cells[i].seq_id.erase(seq_id);
|
|
||||||
} else {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cells[i].is_empty()) {
|
if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) {
|
||||||
// keep count of the number of used cells
|
if (new_head == cells.size()) {
|
||||||
if (cells[i].pos >= 0) {
|
|
||||||
used--;
|
|
||||||
}
|
|
||||||
|
|
||||||
cells[i].pos = -1;
|
|
||||||
|
|
||||||
if (new_head == size) {
|
|
||||||
new_head = i;
|
new_head = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// If we freed up a slot, set head to it so searching can start there.
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
if (new_head != size && new_head < head) {
|
if (new_head != cells.size() && new_head < head) {
|
||||||
head = new_head;
|
head = new_head;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -208,49 +189,40 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
|
|||||||
p1 = std::numeric_limits<llama_pos>::max();
|
p1 = std::numeric_limits<llama_pos>::max();
|
||||||
}
|
}
|
||||||
|
|
||||||
// otherwise, this is the KV of a Transformer-like model
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
head = 0;
|
if (!cells.pos_in(i, p0, p1)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
if (cells.seq_has(i, seq_id_src)) {
|
||||||
if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
cells.seq_add(i, seq_id_dst);
|
||||||
cells[i].seq_id.insert(seq_id_dst);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
||||||
uint32_t new_head = size;
|
uint32_t new_head = cells.size();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
if (!cells[i].has_seq_id(seq_id)) {
|
if (cells.seq_keep(i, seq_id)) {
|
||||||
if (cells[i].pos >= 0) {
|
if (new_head == cells.size()) {
|
||||||
used--;
|
|
||||||
}
|
|
||||||
|
|
||||||
cells[i].pos = -1;
|
|
||||||
cells[i].seq_id.clear();
|
|
||||||
|
|
||||||
if (new_head == size){
|
|
||||||
new_head = i;
|
new_head = i;
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
cells[i].seq_id.clear();
|
|
||||||
cells[i].seq_id.insert(seq_id);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we freed up a slot, set head to it so searching can start there.
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
if (new_head != size && new_head < head) {
|
if (new_head != cells.size() && new_head < head) {
|
||||||
head = new_head;
|
head = new_head;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||||
if (delta == 0) {
|
if (shift == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t new_head = size;
|
uint32_t new_head = cells.size();
|
||||||
|
|
||||||
if (p0 < 0) {
|
if (p0 < 0) {
|
||||||
p0 = 0;
|
p0 = 0;
|
||||||
@ -260,25 +232,19 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||||||
p1 = std::numeric_limits<llama_pos>::max();
|
p1 = std::numeric_limits<llama_pos>::max();
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there is no range then return early to avoid looping over the
|
// If there is no range then return early to avoid looping over all cells.
|
||||||
if (p0 == p1) {
|
if (p0 == p1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
if (!cells.pos_in(i, p0, p1)) {
|
||||||
has_shift = true;
|
continue;
|
||||||
|
|
||||||
cells[i].pos += delta;
|
|
||||||
cells[i].delta += delta;
|
|
||||||
|
|
||||||
if (cells[i].pos < 0) {
|
|
||||||
if (!cells[i].is_empty()) {
|
|
||||||
used--;
|
|
||||||
}
|
}
|
||||||
cells[i].pos = -1;
|
|
||||||
cells[i].seq_id.clear();
|
if (cells.seq_has(i, seq_id)) {
|
||||||
if (new_head == size) {
|
if (cells.pos_add(i, shift)) {
|
||||||
|
if (new_head == cells.size()) {
|
||||||
new_head = i;
|
new_head = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -287,7 +253,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||||||
|
|
||||||
// If we freed up a slot, set head to it so searching can start there.
|
// If we freed up a slot, set head to it so searching can start there.
|
||||||
// Otherwise we just start the next search from the beginning.
|
// Otherwise we just start the next search from the beginning.
|
||||||
head = new_head != size ? new_head : 0;
|
head = new_head != cells.size() ? new_head : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||||
@ -308,15 +274,13 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) {
|
if (!cells.pos_in(i, p0, p1)) {
|
||||||
has_shift = true;
|
continue;
|
||||||
|
|
||||||
{
|
|
||||||
llama_pos p_old = cells[i].pos;
|
|
||||||
cells[i].pos /= d;
|
|
||||||
cells[i].delta += cells[i].pos - p_old;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (cells.seq_has(i, seq_id)) {
|
||||||
|
cells.pos_div(i, d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -324,9 +288,9 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||||||
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
||||||
llama_pos result = std::numeric_limits<llama_pos>::max();
|
llama_pos result = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
if (cells[i].has_seq_id(seq_id)) {
|
if (cells.seq_has(i, seq_id)) {
|
||||||
result = std::min(result, cells[i].pos);
|
result = std::min(result, cells.pos_get(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -340,9 +304,9 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
|||||||
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
llama_pos result = -1;
|
llama_pos result = -1;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
if (cells[i].has_seq_id(seq_id)) {
|
if (cells.seq_has(i, seq_id)) {
|
||||||
result = std::max(result, cells[i].pos);
|
result = std::max(result, cells.pos_get(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -350,25 +314,15 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::restore() {
|
void llama_kv_cache_unified::restore() {
|
||||||
for (const auto & [id, cell] : recovery.cells) {
|
for (auto & state : recovery.states) {
|
||||||
// TODO: move to new `struct kv_cells`
|
cells.set(state.i, state.cells);
|
||||||
const bool is_empty0 = cells[id].is_empty();
|
|
||||||
const bool is_empty1 = cell.is_empty();
|
|
||||||
|
|
||||||
if (!is_empty0 && is_empty1) {
|
|
||||||
used--;
|
|
||||||
} else if (is_empty0 && !is_empty1) {
|
|
||||||
used++;
|
|
||||||
}
|
|
||||||
|
|
||||||
cells[id] = cell;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
recovery.clear();
|
recovery.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::commit() {
|
void llama_kv_cache_unified::commit() {
|
||||||
if (recovery.cells.empty()) {
|
if (recovery.states.empty()) {
|
||||||
LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
|
LLAMA_LOG_WARN("%s: the recovery information upon a commit was empty - might indicate a bug (ref: %s)\n",
|
||||||
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
|
__func__, "https://github.com/ggml-org/llama.cpp/pull/13194");
|
||||||
return;
|
return;
|
||||||
@ -382,7 +336,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|||||||
|
|
||||||
auto * sched = lctx.get_sched();
|
auto * sched = lctx.get_sched();
|
||||||
|
|
||||||
if (has_shift) {
|
if (cells.get_has_shift()) {
|
||||||
if (!get_can_shift()) {
|
if (!get_can_shift()) {
|
||||||
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
||||||
}
|
}
|
||||||
@ -406,13 +360,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|||||||
need_reserve = true;
|
need_reserve = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
cells.reset_shift();
|
||||||
has_shift = false;
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
|
||||||
cells[i].delta = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (do_defrag) {
|
if (do_defrag) {
|
||||||
@ -443,7 +391,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|||||||
void llama_kv_cache_unified::defrag_sched(float thold) {
|
void llama_kv_cache_unified::defrag_sched(float thold) {
|
||||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||||
// - count the padding towards the number of used tokens
|
// - count the padding towards the number of used tokens
|
||||||
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f;
|
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n)) : 0.0f;
|
||||||
|
|
||||||
// queue defragmentation for next llama_kv_cache_update
|
// queue defragmentation for next llama_kv_cache_update
|
||||||
if (fragmentation > thold) {
|
if (fragmentation > thold) {
|
||||||
@ -454,7 +402,7 @@ void llama_kv_cache_unified::defrag_sched(float thold) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::set_full() {
|
void llama_kv_cache_unified::set_full() {
|
||||||
n = size;
|
n = cells.size();
|
||||||
|
|
||||||
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
|
// when simulating a full KV cache, the specific value of the "head" pointer is not important because it does not
|
||||||
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
|
// affect the shapes of the tensors in the compute graph - it only affects the offsets of the K/V views.
|
||||||
@ -478,14 +426,14 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
|||||||
|
|
||||||
// if we have enough unused cells before the current head ->
|
// if we have enough unused cells before the current head ->
|
||||||
// better to start searching from the beginning of the cache, hoping to fill it
|
// better to start searching from the beginning of the cache, hoping to fill it
|
||||||
if (head > used + 2*ubatch.n_tokens) {
|
if (head > cells.get_used() + 2*ubatch.n_tokens) {
|
||||||
head = 0;
|
head = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// otherwise, one cell per token.
|
// otherwise, one cell per token.
|
||||||
|
|
||||||
if (n_tokens > size) {
|
if (n_tokens > cells.size()) {
|
||||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
|
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -498,10 +446,10 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
|||||||
std::string ss;
|
std::string ss;
|
||||||
if (n_swa > 0) {
|
if (n_swa > 0) {
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < size; ++i) {
|
||||||
if (cells[i].pos == -1) {
|
if (cells.is_empty(i)) {
|
||||||
ss += '.';
|
ss += '.';
|
||||||
} else {
|
} else {
|
||||||
ss += std::to_string(*cells[i].seq_id.begin());
|
ss += 'x';
|
||||||
}
|
}
|
||||||
if (i%256 == 255) {
|
if (i%256 == 255) {
|
||||||
ss += '\n';
|
ss += '\n';
|
||||||
@ -515,15 +463,16 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
|||||||
uint32_t n_tested = 0;
|
uint32_t n_tested = 0;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
if (head + n_tokens > size) {
|
if (head + n_tokens > cells.size()) {
|
||||||
n_tested += size - head;
|
n_tested += cells.size() - head;
|
||||||
head = 0;
|
head = 0;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool found = true;
|
bool found = true;
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
if (cells[head + i].pos >= 0) {
|
// TODO: improve to accept cells that are masked by the SWA
|
||||||
|
if (!cells.is_empty(head + i)) {
|
||||||
found = false;
|
found = false;
|
||||||
head += i + 1;
|
head += i + 1;
|
||||||
n_tested += i + 1;
|
n_tested += i + 1;
|
||||||
@ -535,31 +484,27 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_tested >= size) {
|
if (n_tested >= cells.size()) {
|
||||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
// store the old state of the cells in the recovery stack
|
||||||
// remember the original state
|
recovery.states.push_back({head, cells.cp(head, n_tokens)});
|
||||||
if (recovery.cells.find(head + i) == recovery.cells.end()) {
|
|
||||||
recovery.cells[head + i] = cells[head + i];
|
|
||||||
}
|
|
||||||
|
|
||||||
cells[head + i].pos = ubatch.pos[i];
|
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||||
|
cells.pos_set(head + i, ubatch.pos[i]);
|
||||||
|
|
||||||
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
|
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) {
|
||||||
cells[head + i].seq_id.insert(ubatch.seq_id[i][j]);
|
cells.seq_add(head + i, ubatch.seq_id[i][j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
used += n_tokens;
|
|
||||||
|
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
// after enough generations, the benefit from this heuristic disappears
|
// after enough generations, the benefit from this heuristic disappears
|
||||||
// if we start defragmenting the cache, the benefit from this will be more important
|
// if we start defragmenting the cache, the benefit from this will be more important
|
||||||
n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
|
n = std::min(cells.size(), std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
|
||||||
|
|
||||||
#ifdef FIND_SLOT_DEBUG
|
#ifdef FIND_SLOT_DEBUG
|
||||||
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
|
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
|
||||||
@ -577,7 +522,7 @@ uint32_t llama_kv_cache_unified::get_n() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_unified::get_size() const {
|
uint32_t llama_kv_cache_unified::get_size() const {
|
||||||
return size;
|
return cells.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
|
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il) const {
|
||||||
@ -661,30 +606,19 @@ void llama_kv_cache_unified::prune_swa(llama_seq_id seq_id, llama_pos pmin, llam
|
|||||||
|
|
||||||
int n_attended = 0;
|
int n_attended = 0;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
const llama_pos p0 = cells[i].pos;
|
if (!cells.seq_has(i, seq_id)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_pos p0 = cells.pos_get(i);
|
||||||
|
|
||||||
if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
|
if (p0 <= pmin && !is_masked_swa(p0, pmin)) {
|
||||||
n_attended++;
|
n_attended++;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_masked_swa(p0, pmax)) {
|
if (is_masked_swa(p0, pmax)) {
|
||||||
if (seq_id < 0) {
|
cells.seq_rm(i, seq_id);
|
||||||
cells[i].seq_id.clear();
|
|
||||||
} else if (cells[i].has_seq_id(seq_id)) {
|
|
||||||
cells[i].seq_id.erase(seq_id);
|
|
||||||
} else {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cells[i].is_empty()) {
|
|
||||||
// keep count of the number of used cells
|
|
||||||
if (cells[i].pos >= 0) {
|
|
||||||
used--;
|
|
||||||
}
|
|
||||||
|
|
||||||
cells[i].pos = -1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -723,12 +657,17 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||||||
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
|
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j];
|
||||||
|
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
const llama_pos p0 = cells[i].pos;
|
float f = 0.0f;
|
||||||
|
|
||||||
bool masked = false;
|
bool masked = false;
|
||||||
|
|
||||||
|
if (cells.is_empty(i)) {
|
||||||
|
masked = true;
|
||||||
|
} else {
|
||||||
|
const llama_pos p0 = cells.pos_get(i);
|
||||||
|
|
||||||
// mask the token if not the same sequence
|
// mask the token if not the same sequence
|
||||||
masked = masked || (!cells[i].has_seq_id(seq_id));
|
masked = masked || (!cells.seq_has(i, seq_id));
|
||||||
|
|
||||||
// mask future tokens
|
// mask future tokens
|
||||||
masked = masked || (causal_attn && p0 > p1);
|
masked = masked || (causal_attn && p0 > p1);
|
||||||
@ -736,12 +675,13 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||||||
// apply SWA if any
|
// apply SWA if any
|
||||||
masked = masked || (is_masked_swa(p0, p1));
|
masked = masked || (is_masked_swa(p0, p1));
|
||||||
|
|
||||||
float f = 0.0f;
|
if (!masked && hparams.use_alibi) {
|
||||||
|
f = -std::abs(p0 - p1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (masked) {
|
if (masked) {
|
||||||
f = -INFINITY;
|
f = -INFINITY;
|
||||||
} else if (hparams.use_alibi) {
|
|
||||||
f = -std::abs(p0 - p1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
||||||
@ -765,8 +705,8 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
|||||||
|
|
||||||
int32_t * data = (int32_t *) dst->data;
|
int32_t * data = (int32_t *) dst->data;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
data[i] = cells[i].delta;
|
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -783,7 +723,10 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
|
|||||||
for (int h = 0; h < 1; ++h) {
|
for (int h = 0; h < 1; ++h) {
|
||||||
for (int j = 0; j < n_tokens; ++j) {
|
for (int j = 0; j < n_tokens; ++j) {
|
||||||
for (int i = 0; i < n_kv; ++i) {
|
for (int i = 0; i < n_kv; ++i) {
|
||||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
// the position when the cells is empty is irrelevant - it will be masked out later in the attention
|
||||||
|
const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i);
|
||||||
|
|
||||||
|
data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -910,7 +853,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|||||||
|
|
||||||
ggml_tensor * k =
|
ggml_tensor * k =
|
||||||
ggml_view_3d(ctx, layer.k,
|
ggml_view_3d(ctx, layer.k,
|
||||||
n_embd_head_k, n_head_kv, size,
|
n_embd_head_k, n_head_kv, cells.size(),
|
||||||
ggml_row_size(layer.k->type, n_embd_head_k),
|
ggml_row_size(layer.k->type, n_embd_head_k),
|
||||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||||
0);
|
0);
|
||||||
@ -1050,12 +993,12 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|||||||
} else {
|
} else {
|
||||||
view_v_src = ggml_view_2d(ctx, layer.v,
|
view_v_src = ggml_view_2d(ctx, layer.v,
|
||||||
nm, n_embd_v_gqa,
|
nm, n_embd_v_gqa,
|
||||||
ggml_row_size(layer.v->type, size),
|
ggml_row_size(layer.v->type, cells.size()),
|
||||||
ggml_row_size(layer.v->type, i));
|
ggml_row_size(layer.v->type, i));
|
||||||
|
|
||||||
view_v_dst = ggml_view_2d(ctx, layer.v,
|
view_v_dst = ggml_view_2d(ctx, layer.v,
|
||||||
nm, n_embd_v_gqa,
|
nm, n_embd_v_gqa,
|
||||||
ggml_row_size(layer.v->type, size),
|
ggml_row_size(layer.v->type, cells.size()),
|
||||||
ggml_row_size(layer.v->type, id));
|
ggml_row_size(layer.v->type, id));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1076,7 +1019,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||||||
const uint32_t n_layer = layers.size();
|
const uint32_t n_layer = layers.size();
|
||||||
|
|
||||||
const uint32_t n_kv = cell_max();
|
const uint32_t n_kv = cell_max();
|
||||||
const uint32_t n_used = used;
|
const uint32_t n_used = cells.get_used();
|
||||||
|
|
||||||
assert(n_used <= n_kv);
|
assert(n_used <= n_kv);
|
||||||
|
|
||||||
@ -1104,9 +1047,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||||||
ids.resize(n_kv, n_kv);
|
ids.resize(n_kv, n_kv);
|
||||||
|
|
||||||
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
||||||
const auto & cell0 = cells[i0];
|
if (!cells.is_empty(i0)) {
|
||||||
|
|
||||||
if (!cell0.is_empty()) {
|
|
||||||
ids[i0] = i0;
|
ids[i0] = i0;
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
@ -1117,7 +1058,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||||||
uint32_t nh = 1;
|
uint32_t nh = 1;
|
||||||
|
|
||||||
// determine the size of the hole
|
// determine the size of the hole
|
||||||
while (i0 + nh < n_used && cells[i0 + nh].is_empty()) {
|
while (i0 + nh < n_used && cells.is_empty(i0 + nh)) {
|
||||||
nh++;
|
nh++;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1126,9 +1067,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||||||
|
|
||||||
// starting from the end, find nh non-empty cells
|
// starting from the end, find nh non-empty cells
|
||||||
for (; is > i0; --is) {
|
for (; is > i0; --is) {
|
||||||
const auto & cell1 = cells[is];
|
if (cells.is_empty(is) || ids[is] != n_kv) {
|
||||||
|
|
||||||
if (cell1.is_empty() || ids[is] != n_kv) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1155,9 +1094,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||||||
|
|
||||||
// go back and move the nf cells to the hole
|
// go back and move the nf cells to the hole
|
||||||
for (; i1 < n_kv; ++i1) {
|
for (; i1 < n_kv; ++i1) {
|
||||||
auto & cell1 = cells[i1];
|
if (cells.is_empty(i1) || ids[i1] != n_kv) {
|
||||||
|
|
||||||
if (cell1.is_empty() || ids[i1] != n_kv) {
|
|
||||||
if (n_moves == max_moves) {
|
if (n_moves == max_moves) {
|
||||||
stop = true;
|
stop = true;
|
||||||
break;
|
break;
|
||||||
@ -1171,10 +1108,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||||||
ids[i1] = i0 + nf;
|
ids[i1] = i0 + nf;
|
||||||
|
|
||||||
// move the cell meta data
|
// move the cell meta data
|
||||||
cells[i0 + nf] = cell1;
|
cells.mv(i1, i0 + nf);
|
||||||
|
|
||||||
// clear the old cell and move the head there
|
|
||||||
cell1 = kv_cell();
|
|
||||||
head = n_used;
|
head = n_used;
|
||||||
|
|
||||||
if (!cont) {
|
if (!cont) {
|
||||||
@ -1210,10 +1145,8 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_unified::cell_max() const {
|
uint32_t llama_kv_cache_unified::cell_max() const {
|
||||||
for (uint32_t i = size; i > 0; --i) {
|
for (uint32_t i = cells.size(); i > 0; --i) {
|
||||||
const kv_cell & cell = cells[i - 1];
|
if (!cells.is_empty(i - 1)) {
|
||||||
|
|
||||||
if (cell.pos >= 0 && !cell.is_empty()) {
|
|
||||||
return i;
|
return i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1222,9 +1155,7 @@ uint32_t llama_kv_cache_unified::cell_max() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||||
if (p0 < 0) {
|
assert(p0 >= 0 && p1 >= 0);
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (swa_type) {
|
switch (swa_type) {
|
||||||
case LLAMA_SWA_TYPE_NONE:
|
case LLAMA_SWA_TYPE_NONE:
|
||||||
@ -1255,23 +1186,24 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
|
|||||||
|
|
||||||
// Count the number of cells with the specified seq_id
|
// Count the number of cells with the specified seq_id
|
||||||
// Find all the ranges of cells with this seq id (or all, when -1)
|
// Find all the ranges of cells with this seq id (or all, when -1)
|
||||||
uint32_t cell_range_begin = size;
|
uint32_t cell_range_begin = cells.size();
|
||||||
for (uint32_t i = 0; i < size; ++i) {
|
|
||||||
const auto & cell = cells[i];
|
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||||
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
|
if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
|
||||||
++cell_count;
|
++cell_count;
|
||||||
if (cell_range_begin == size) {
|
if (cell_range_begin == cells.size()) {
|
||||||
cell_range_begin = i;
|
cell_range_begin = i;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (cell_range_begin != size) {
|
if (cell_range_begin != cells.size()) {
|
||||||
cell_ranges.emplace_back(cell_range_begin, i);
|
cell_ranges.emplace_back(cell_range_begin, i);
|
||||||
cell_range_begin = size;
|
cell_range_begin = cells.size();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (cell_range_begin != size) {
|
|
||||||
cell_ranges.emplace_back(cell_range_begin, size);
|
if (cell_range_begin != cells.size()) {
|
||||||
|
cell_ranges.emplace_back(cell_range_begin, cells.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
||||||
@ -1308,20 +1240,27 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
|||||||
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||||
for (const auto & range : cell_ranges) {
|
for (const auto & range : cell_ranges) {
|
||||||
for (uint32_t i = range.first; i < range.second; ++i) {
|
for (uint32_t i = range.first; i < range.second; ++i) {
|
||||||
const auto & cell = cells[i];
|
std::vector<llama_seq_id> seq_ids;
|
||||||
const llama_pos pos = cell.pos;
|
|
||||||
const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0;
|
for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) {
|
||||||
|
if (cur == seq_id || seq_id == -1) {
|
||||||
|
if (cells.seq_has(i, cur)) {
|
||||||
|
seq_ids.push_back(cur);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_pos pos = cells.pos_get(i);
|
||||||
|
const uint32_t n_seq_id = seq_ids.size();
|
||||||
|
|
||||||
io.write(&pos, sizeof(pos));
|
io.write(&pos, sizeof(pos));
|
||||||
io.write(&n_seq_id, sizeof(n_seq_id));
|
io.write(&n_seq_id, sizeof(n_seq_id));
|
||||||
|
|
||||||
if (n_seq_id) {
|
for (const auto & seq_id : seq_ids) {
|
||||||
for (auto seq_id : cell.seq_id) {
|
|
||||||
io.write(&seq_id, sizeof(seq_id));
|
io.write(&seq_id, sizeof(seq_id));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
||||||
@ -1379,7 +1318,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// When v is transposed, we also need the element size and get the element ranges from each row
|
// When v is transposed, we also need the element size and get the element ranges from each row
|
||||||
const uint32_t kv_size = size;
|
const uint32_t kv_size = cells.size();
|
||||||
|
|
||||||
for (const auto & layer : layers) {
|
for (const auto & layer : layers) {
|
||||||
const uint32_t il = layer.il;
|
const uint32_t il = layer.il;
|
||||||
@ -1429,13 +1368,19 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||||||
io.read_to(&pos, sizeof(pos));
|
io.read_to(&pos, sizeof(pos));
|
||||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||||
|
|
||||||
if (n_seq_id != 0) {
|
if (n_seq_id != 1) {
|
||||||
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// read the sequence id, but directly discard it - we will use dest_seq_id instead
|
||||||
|
{
|
||||||
|
llama_seq_id seq_id;
|
||||||
|
io.read_to(&seq_id, sizeof(seq_id));
|
||||||
|
}
|
||||||
|
|
||||||
batch.pos[i] = pos;
|
batch.pos[i] = pos;
|
||||||
batch.n_seq_id[i] = 1;
|
batch.n_seq_id[i] = n_seq_id;
|
||||||
batch.seq_id[i] = &dest_seq_id;
|
batch.seq_id[i] = &dest_seq_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1448,15 +1393,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||||||
|
|
||||||
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||||
// Assume that this is one contiguous block of cells
|
// Assume that this is one contiguous block of cells
|
||||||
GGML_ASSERT(head + cell_count <= size);
|
GGML_ASSERT(head + cell_count <= cells.size());
|
||||||
GGML_ASSERT(cells[head].pos == batch.pos[0]);
|
GGML_ASSERT(cells.pos_get(head) == batch.pos[0]);
|
||||||
GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
GGML_ASSERT(cells.pos_get(head + cell_count - 1) == batch.pos[cell_count - 1]);
|
||||||
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
|
GGML_ASSERT(cells.seq_has(head, dest_seq_id));
|
||||||
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
|
GGML_ASSERT(cells.seq_has(head + cell_count - 1, dest_seq_id));
|
||||||
} else {
|
} else {
|
||||||
// whole KV cache restore
|
// whole KV cache restore
|
||||||
|
|
||||||
if (cell_count > size) {
|
if (cell_count > cells.size()) {
|
||||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -1464,15 +1409,13 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||||||
clear();
|
clear();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||||
kv_cell & cell = cells[i];
|
|
||||||
|
|
||||||
llama_pos pos;
|
llama_pos pos;
|
||||||
uint32_t n_seq_id;
|
uint32_t n_seq_id;
|
||||||
|
|
||||||
io.read_to(&pos, sizeof(pos));
|
io.read_to(&pos, sizeof(pos));
|
||||||
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
io.read_to(&n_seq_id, sizeof(n_seq_id));
|
||||||
|
|
||||||
cell.pos = pos;
|
cells.pos_set(i, pos);
|
||||||
|
|
||||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||||
llama_seq_id seq_id;
|
llama_seq_id seq_id;
|
||||||
@ -1483,12 +1426,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
cell.seq_id.insert(seq_id);
|
cells.seq_add(i, seq_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
head = 0;
|
head = 0;
|
||||||
used = cell_count;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -1505,8 +1447,8 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||||||
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
|
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (cell_count > size) {
|
if (cell_count > cells.size()) {
|
||||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (this->v_trans != (bool) v_trans) {
|
if (this->v_trans != (bool) v_trans) {
|
||||||
@ -1609,7 +1551,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||||||
if (cell_count) {
|
if (cell_count) {
|
||||||
// For each row in the transposed matrix, read the values for the whole cell range
|
// For each row in the transposed matrix, read the values for the whole cell range
|
||||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||||
const size_t dst_offset = (head + j * size) * v_size_el;
|
const size_t dst_offset = (head + j * cells.size()) * v_size_el;
|
||||||
ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1689,9 +1631,9 @@ void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
|
|||||||
kv_swa ->seq_keep(seq_id);
|
kv_swa ->seq_keep(seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||||
kv_base->seq_add(seq_id, p0, p1, delta);
|
kv_base->seq_add(seq_id, p0, p1, shift);
|
||||||
kv_swa ->seq_add(seq_id, p0, p1, delta);
|
kv_swa ->seq_add(seq_id, p0, p1, shift);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||||
@ -2063,8 +2005,8 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
|
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||||
if (delta == 0) {
|
if (shift == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2087,7 +2029,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
|||||||
if (tail_id >= 0) {
|
if (tail_id >= 0) {
|
||||||
kv_cell & cell = cells[tail_id];
|
kv_cell & cell = cells[tail_id];
|
||||||
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||||
cell.pos += delta;
|
cell.pos += shift;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
#include "llama-io.h"
|
#include "llama-io.h"
|
||||||
#include "llama-graph.h"
|
#include "llama-graph.h"
|
||||||
#include "llama-memory.h"
|
#include "llama-memory.h"
|
||||||
|
#include "llama-kv-cells.h"
|
||||||
|
|
||||||
#include "ggml-cpp.h"
|
#include "ggml-cpp.h"
|
||||||
|
|
||||||
@ -35,6 +36,7 @@ struct llama_kv_cache : public llama_memory_i {
|
|||||||
virtual void defrag_sched(float thold) = 0;
|
virtual void defrag_sched(float thold) = 0;
|
||||||
|
|
||||||
// simulate full cache, used for allocating worst-case compute buffers
|
// simulate full cache, used for allocating worst-case compute buffers
|
||||||
|
// TODO: remove
|
||||||
virtual void set_full() = 0;
|
virtual void set_full() = 0;
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -42,7 +44,7 @@ struct llama_kv_cache : public llama_memory_i {
|
|||||||
//
|
//
|
||||||
|
|
||||||
// =============================================================================================================
|
// =============================================================================================================
|
||||||
// TODO: refactor and simplify this
|
// TODO: refactor and simplify this [TAG: KV_API]
|
||||||
|
|
||||||
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
|
virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
|
||||||
|
|
||||||
@ -121,7 +123,7 @@ public:
|
|||||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
void seq_keep(llama_seq_id seq_id) override;
|
void seq_keep(llama_seq_id seq_id) override;
|
||||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||||
|
|
||||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||||
@ -180,26 +182,6 @@ private:
|
|||||||
const llama_model & model;
|
const llama_model & model;
|
||||||
const llama_hparams & hparams;
|
const llama_hparams & hparams;
|
||||||
|
|
||||||
struct kv_cell {
|
|
||||||
llama_pos pos = -1;
|
|
||||||
llama_pos delta = 0;
|
|
||||||
|
|
||||||
// TODO: replace with bitset uint64_t
|
|
||||||
std::set<llama_seq_id> seq_id;
|
|
||||||
|
|
||||||
bool has_seq_id(const llama_seq_id & id) const {
|
|
||||||
return seq_id.find(id) != seq_id.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_empty() const {
|
|
||||||
return seq_id.empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_same_seq(const kv_cell & other) const {
|
|
||||||
return seq_id == other.seq_id;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct kv_layer {
|
struct kv_layer {
|
||||||
// layer index in the model
|
// layer index in the model
|
||||||
// note: can be different from the layer index in the KV cache
|
// note: can be different from the layer index in the KV cache
|
||||||
@ -209,15 +191,13 @@ private:
|
|||||||
ggml_tensor * v;
|
ggml_tensor * v;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool has_shift = false;
|
|
||||||
bool do_defrag = false;
|
bool do_defrag = false;
|
||||||
bool v_trans = true; // the value tensor is transposed
|
bool v_trans = true; // the value tensor is transposed
|
||||||
|
|
||||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
|
||||||
uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt)
|
|
||||||
|
|
||||||
// computed before each graph build
|
// computed before each graph build
|
||||||
|
// TODO: cells should start to maintain this value dynamically based on the edits
|
||||||
uint32_t n = 0;
|
uint32_t n = 0;
|
||||||
|
|
||||||
const uint32_t n_seq_max = 1;
|
const uint32_t n_seq_max = 1;
|
||||||
@ -233,19 +213,29 @@ private:
|
|||||||
std::vector<ggml_context_ptr> ctxs;
|
std::vector<ggml_context_ptr> ctxs;
|
||||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||||
|
|
||||||
std::vector<kv_cell> cells; // TODO: replace with `struct kv_cells`
|
llama_kv_cells_unified cells;
|
||||||
|
|
||||||
std::vector<kv_layer> layers;
|
std::vector<kv_layer> layers;
|
||||||
|
|
||||||
// model layer id -> KV cache layer id
|
// model layer id -> KV cache layer id
|
||||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||||
|
|
||||||
// recovery information used to restore the KV cells to their original state in case of a failure
|
// recovery information used to restore the KV cells to their original state in case of a failure
|
||||||
|
// TODO: do not store as a state in the llama_kv_cache object, instead return upon batch preparation
|
||||||
|
// to achieve that, first need to refactor the llama_kv_cache interface [TAG: KV_API]
|
||||||
struct {
|
struct {
|
||||||
void clear() {
|
void clear() {
|
||||||
cells.clear();
|
states.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<uint32_t, kv_cell> cells;
|
struct state {
|
||||||
|
uint32_t i;
|
||||||
|
|
||||||
|
llama_kv_cells_unified cells;
|
||||||
|
};
|
||||||
|
|
||||||
|
// stack with the partial states before each ubatch
|
||||||
|
std::vector<state> states;
|
||||||
} recovery;
|
} recovery;
|
||||||
|
|
||||||
// defrag
|
// defrag
|
||||||
@ -257,6 +247,7 @@ private:
|
|||||||
bool defrag_prepare(int32_t n_max_nodes);
|
bool defrag_prepare(int32_t n_max_nodes);
|
||||||
|
|
||||||
// find how many cells are currently in use
|
// find how many cells are currently in use
|
||||||
|
// TODO: optimize
|
||||||
uint32_t cell_max() const;
|
uint32_t cell_max() const;
|
||||||
|
|
||||||
size_t total_size() const;
|
size_t total_size() const;
|
||||||
@ -325,7 +316,7 @@ public:
|
|||||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
void seq_keep(llama_seq_id seq_id) override;
|
void seq_keep(llama_seq_id seq_id) override;
|
||||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||||
|
|
||||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||||
@ -431,7 +422,7 @@ public:
|
|||||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
void seq_keep(llama_seq_id seq_id) override;
|
void seq_keep(llama_seq_id seq_id) override;
|
||||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||||
|
|
||||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||||
|
273
src/llama-kv-cells.h
Normal file
273
src/llama-kv-cells.h
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama.h"
|
||||||
|
#include "llama-cparams.h"
|
||||||
|
|
||||||
|
#include <bitset>
|
||||||
|
#include <cassert>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
// meta information about KV cells that can be part of multiple sequences at the same time
|
||||||
|
// TODO: add unit tests
|
||||||
|
class llama_kv_cells_unified {
|
||||||
|
public:
|
||||||
|
void reset() {
|
||||||
|
for (uint32_t i = 0; i < pos.size(); ++i) {
|
||||||
|
pos[i] = -1;
|
||||||
|
shift[i] = 0;
|
||||||
|
seq[i].reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
used = 0;
|
||||||
|
has_shift = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset_shift() {
|
||||||
|
has_shift = false;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < shift.size(); ++i) {
|
||||||
|
shift[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t size() const {
|
||||||
|
return pos.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void resize(uint32_t n) {
|
||||||
|
pos.resize(n);
|
||||||
|
shift.resize(n);
|
||||||
|
seq.resize(n);
|
||||||
|
|
||||||
|
reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_empty(uint32_t i) const {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0);
|
||||||
|
|
||||||
|
return pos[i] == -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t get_used() const {
|
||||||
|
return used;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool get_has_shift() const {
|
||||||
|
return has_shift;
|
||||||
|
}
|
||||||
|
|
||||||
|
// move cell isrc to idst (used during defrag)
|
||||||
|
void mv(uint32_t isrc, uint32_t idst) {
|
||||||
|
assert(isrc < pos.size());
|
||||||
|
assert(idst < pos.size());
|
||||||
|
|
||||||
|
pos [idst] = pos [isrc];
|
||||||
|
shift[idst] = shift[isrc];
|
||||||
|
seq [idst] = seq [isrc];
|
||||||
|
|
||||||
|
pos [isrc] = -1;
|
||||||
|
shift[isrc] = 0;
|
||||||
|
seq [isrc].reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy the state of cells [i, i + n) (used for save/restore the state of the cells)
|
||||||
|
llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
|
||||||
|
assert(i + n <= pos.size());
|
||||||
|
|
||||||
|
llama_kv_cells_unified res;
|
||||||
|
|
||||||
|
res.resize(n);
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < n; ++j) {
|
||||||
|
res.pos[j] = pos[i + j];
|
||||||
|
res.seq[j] = seq[i + j];
|
||||||
|
|
||||||
|
assert(shift[i + j] == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
|
||||||
|
void set(uint32_t i, const llama_kv_cells_unified & other) {
|
||||||
|
assert(i + other.pos.size() <= pos.size());
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
||||||
|
if (pos[i + j] == -1 && other.pos[j] != -1) {
|
||||||
|
used++;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (pos[i + j] != -1 && other.pos[j] == -1) {
|
||||||
|
used--;
|
||||||
|
}
|
||||||
|
|
||||||
|
pos[i + j] = other.pos[j];
|
||||||
|
seq[i + j] = other.seq[j];
|
||||||
|
|
||||||
|
assert(shift[i + j] == 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: call only if the cell has seq_id
|
||||||
|
// return true if the cell becomes empty
|
||||||
|
bool seq_rm(uint32_t i, llama_seq_id seq_id) {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(seq[i].test(seq_id));
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
assert(seq_id >= 0);
|
||||||
|
|
||||||
|
seq[i].reset(seq_id);
|
||||||
|
|
||||||
|
if (seq[i].none()) {
|
||||||
|
pos[i] = -1;
|
||||||
|
|
||||||
|
used--;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// return true if the cell becomes empty (i.e. it did not contain seq_id before the call)
|
||||||
|
bool seq_keep(uint32_t i, llama_seq_id seq_id) {
|
||||||
|
assert(i < pos.size());
|
||||||
|
|
||||||
|
if (seq[i].test(seq_id)) {
|
||||||
|
seq[i].reset();
|
||||||
|
seq[i].set(seq_id);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (seq[i].any()) {
|
||||||
|
seq[i].reset();
|
||||||
|
pos[i] = -1;
|
||||||
|
|
||||||
|
used--;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(pos[i] == -1);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool seq_has(uint32_t i, llama_seq_id seq_id) const {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(seq_id >= 0);
|
||||||
|
|
||||||
|
return seq[i].test(seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: call only if the cell is not empty and the seq_id is not in the cell
|
||||||
|
void seq_add(uint32_t i, llama_seq_id seq_id) {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
assert(!seq[i].test(seq_id));
|
||||||
|
|
||||||
|
seq[i].set(seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: call only if the cell is not empty
|
||||||
|
llama_pos pos_get(uint32_t i) const {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
|
||||||
|
return pos[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// note: call only if the cell is not empty
|
||||||
|
llama_pos get_shift(uint32_t i) const {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
|
||||||
|
return shift[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if a cell is not empty and its position is within [p0, p1)
|
||||||
|
bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const {
|
||||||
|
assert(i < pos.size());
|
||||||
|
|
||||||
|
return pos[i] >= p0 && pos[i] < p1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the position of an empty cell
|
||||||
|
// does not modify "has_shift"
|
||||||
|
// note: call only if the cell is empty
|
||||||
|
void pos_set(uint32_t i, llama_pos p) {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] == -1);
|
||||||
|
|
||||||
|
pos[i] = p;
|
||||||
|
used++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// pos[i] = pos[i] + d
|
||||||
|
// sets "has_shift" to true
|
||||||
|
// note: call only if the cell is not empty
|
||||||
|
bool pos_add(uint32_t i, llama_pos d) {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
|
||||||
|
pos[i] += d;
|
||||||
|
shift[i] += d;
|
||||||
|
|
||||||
|
has_shift = true;
|
||||||
|
|
||||||
|
if (pos[i] < 0) {
|
||||||
|
pos[i] = -1;
|
||||||
|
seq[i].reset();
|
||||||
|
|
||||||
|
used--;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// pos[i] = pos[i] / d
|
||||||
|
// sets "has_shift" to true
|
||||||
|
// note: call only if the cell is not empty
|
||||||
|
void pos_div(uint32_t i, int d) {
|
||||||
|
assert(i < pos.size());
|
||||||
|
assert(pos[i] != -1);
|
||||||
|
|
||||||
|
const llama_pos p_old = pos[i];
|
||||||
|
|
||||||
|
pos[i] /= d;
|
||||||
|
shift[i] += p_old - pos[i];
|
||||||
|
|
||||||
|
has_shift = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
uint32_t used = 0; // used cells (i.e. pos[i] != -1, allowed to not have any seq_id)
|
||||||
|
|
||||||
|
bool has_shift = false;
|
||||||
|
|
||||||
|
std::vector<llama_pos> pos;
|
||||||
|
|
||||||
|
// this array accumulates any applied shifts to the pos array since the last reset_shift() call
|
||||||
|
// this is used to queue multiple updates to the pos array, which in the end can be applied in one go:
|
||||||
|
//
|
||||||
|
// cells.pos_add(x, shift_x);
|
||||||
|
// cells.pos_div(y, shift_y);
|
||||||
|
// ...
|
||||||
|
//
|
||||||
|
// if (cells.has_shift()) {
|
||||||
|
// for (int i = 0; i < n; ++i) {
|
||||||
|
// auto shift_i = cells.get_shift(i);
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
|
// cells.reset_shift();
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
std::vector<llama_pos> shift;
|
||||||
|
|
||||||
|
std::vector<std::bitset<LLAMA_MAX_PARALLEL_SEQUENCES>> seq;
|
||||||
|
};
|
||||||
|
|
@ -22,7 +22,7 @@ public:
|
|||||||
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0;
|
||||||
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0;
|
||||||
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
virtual void seq_keep(llama_seq_id seq_id) = 0;
|
||||||
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
|
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0;
|
||||||
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
|
||||||
|
|
||||||
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
|
virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
|
||||||
|
Reference in New Issue
Block a user