mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 12:05:03 +00:00
@ -582,21 +582,15 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// keep track of what the minimum sequence positions would be if we accept the ubatch
|
|
||||||
llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
|
|
||||||
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
|
||||||
seq_pos_min[s] = cells.seq_pos_min(s);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool found = true;
|
bool found = true;
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
const llama_pos pos = ubatch.pos[i];
|
//const llama_pos pos = ubatch.pos[i];
|
||||||
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||||
|
|
||||||
// can we use this cell? either:
|
// can we use this cell? either:
|
||||||
// - the cell is empty
|
// - the cell is empty
|
||||||
// - the cell is occupied only by one sequence:
|
// - the cell is occupied only by one sequence:
|
||||||
// - mask causally, if the sequence is the same as the one we are inserting
|
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
||||||
// - mask SWA, using current max pos for that sequence in the cache
|
// - mask SWA, using current max pos for that sequence in the cache
|
||||||
// always insert in the cell with minimum pos
|
// always insert in the cell with minimum pos
|
||||||
bool can_use = cells.is_empty(head_cur + i);
|
bool can_use = cells.is_empty(head_cur + i);
|
||||||
@ -604,21 +598,17 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|||||||
if (!can_use && cells.seq_count(head_cur + i) == 1) {
|
if (!can_use && cells.seq_count(head_cur + i) == 1) {
|
||||||
const llama_pos pos_cell = cells.pos_get(head_cur + i);
|
const llama_pos pos_cell = cells.pos_get(head_cur + i);
|
||||||
|
|
||||||
// causal mask
|
// (disabled) causal mask
|
||||||
if (cells.seq_has(head_cur + i, seq_id)) {
|
// note: it's better to purge any "future" tokens beforehand
|
||||||
can_use = pos_cell >= pos;
|
//if (cells.seq_has(head_cur + i, seq_id)) {
|
||||||
}
|
// can_use = pos_cell >= pos;
|
||||||
|
//}
|
||||||
|
|
||||||
if (!can_use) {
|
if (!can_use) {
|
||||||
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
|
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
|
||||||
|
|
||||||
// SWA mask
|
// SWA mask
|
||||||
// note: we insert only in the cell with minimum pos in order to preserve the invariant that
|
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||||
// all positions between [pos_min, pos_max] for each sequence will be present in the cache
|
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
|
||||||
if (pos_cell == seq_pos_min[seq_id_cell] &&
|
|
||||||
is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
|
||||||
seq_pos_min[seq_id_cell]++;
|
|
||||||
can_use = true;
|
can_use = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -646,8 +636,22 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
||||||
|
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||||
|
// for non-SWA cache, this would be always empty
|
||||||
|
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
|
||||||
|
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||||
|
seq_pos_max_rm[s] = -1;
|
||||||
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||||
if (!cells.is_empty(head_cur + i)) {
|
if (!cells.is_empty(head_cur + i)) {
|
||||||
|
assert(cells.seq_count(head_cur + i) == 1);
|
||||||
|
|
||||||
|
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
|
||||||
|
const llama_pos pos = cells.pos_get(head_cur + i);
|
||||||
|
|
||||||
|
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
||||||
|
|
||||||
cells.rm(head_cur + i);
|
cells.rm(head_cur + i);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -658,6 +662,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
|
||||||
|
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
||||||
|
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
|
||||||
|
if (seq_pos_max_rm[s] == -1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
|
||||||
|
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
|
||||||
|
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
|
||||||
|
|
||||||
|
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// move the head at the end of the slot
|
// move the head at the end of the slot
|
||||||
head = head_cur + ubatch.n_tokens;
|
head = head_cur + ubatch.n_tokens;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user