mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-07 09:22:39 -04:00
kv-cells : fix tracking of seq_pos (#14339)
* kv-cells : fix tracking of seq_pos during cache reuse ggml-ci * cont : improve error message ggml-ci * cont : add more comments
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <cassert>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <map>
|
||||
|
||||
// meta information about KV cells that can be part of multiple sequences at the same time
|
||||
// TODO: add unit tests
|
||||
@@ -164,7 +165,7 @@ public:
|
||||
assert(seq_id >= 0);
|
||||
|
||||
seq[i].reset(seq_id);
|
||||
seq_pos[seq_id].erase(pos[i]);
|
||||
seq_pos_dec(seq_id, pos[i]);
|
||||
|
||||
if (seq[i].none()) {
|
||||
pos[i] = -1;
|
||||
@@ -187,7 +188,7 @@ public:
|
||||
seq[i].reset();
|
||||
|
||||
seq[i].set(seq_id);
|
||||
seq_pos[seq_id].insert(pos[i]);
|
||||
seq_pos_inc(seq_id, pos[i]);
|
||||
|
||||
return false;
|
||||
}
|
||||
@@ -232,7 +233,7 @@ public:
|
||||
assert(!seq[i].test(seq_id));
|
||||
|
||||
seq[i].set(seq_id);
|
||||
seq_pos[seq_id].insert(pos[i]);
|
||||
seq_pos_inc(seq_id, pos[i]);
|
||||
}
|
||||
|
||||
// return the sequence id of this cell
|
||||
@@ -259,7 +260,9 @@ public:
|
||||
return -1;
|
||||
}
|
||||
|
||||
return *seq_pos[seq_id].begin();
|
||||
assert(seq_pos[seq_id].begin()->second > 0);
|
||||
|
||||
return seq_pos[seq_id].begin()->first;
|
||||
}
|
||||
|
||||
// the maximum position of sequence seq_id currently present in any of the cells
|
||||
@@ -272,7 +275,9 @@ public:
|
||||
return -1;
|
||||
}
|
||||
|
||||
return *seq_pos[seq_id].rbegin();
|
||||
assert(seq_pos[seq_id].rbegin()->second > 0);
|
||||
|
||||
return seq_pos[seq_id].rbegin()->first;
|
||||
}
|
||||
|
||||
// note: call only if the cell is not empty
|
||||
@@ -389,17 +394,36 @@ private:
|
||||
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
||||
std::vector<seq_set_t> seq;
|
||||
|
||||
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
||||
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
|
||||
// if the position p is not present, seq_pos[s][p] is not set
|
||||
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
||||
std::set<llama_pos> seq_pos[LLAMA_MAX_SEQ];
|
||||
//
|
||||
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
|
||||
// - during performing a cache reuse via (rm + add)
|
||||
// - some vision models have input embeddings with repeating positions
|
||||
//
|
||||
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
|
||||
|
||||
// helper functions for updating `seq_pos`, once cell at a time:
|
||||
|
||||
void seq_pos_dec(llama_seq_id s, llama_pos p) {
|
||||
auto it = seq_pos[s].find(p);
|
||||
assert(it != seq_pos[s].end());
|
||||
|
||||
if (--it->second == 0) {
|
||||
seq_pos[s].erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
void seq_pos_inc(llama_seq_id s, llama_pos p) {
|
||||
seq_pos[s][p]++;
|
||||
}
|
||||
|
||||
// remove cell i
|
||||
void seq_pos_rm(uint32_t i) {
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
if (seq[i].test(s)) {
|
||||
seq_pos[s].erase(pos[i]);
|
||||
seq_pos_dec(s, pos[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -408,7 +432,7 @@ private:
|
||||
void seq_pos_add(uint32_t i) {
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
if (seq[i].test(s)) {
|
||||
seq_pos[s].insert(pos[i]);
|
||||
seq_pos_inc(s, pos[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user