kv-cache : relax SWA masking condition (#14119)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-11 16:48:45 +03:00
committed by GitHub
parent 2baf07727f
commit 89a184fa71

View File

@ -582,21 +582,15 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
continue;
}
// keep track of what the minimum sequence positions would be if we accept the ubatch
llama_seq_id seq_pos_min[LLAMA_MAX_PARALLEL_SEQUENCES];
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
seq_pos_min[s] = cells.seq_pos_min(s);
}
bool found = true;
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_pos pos = ubatch.pos[i];
const llama_seq_id seq_id = ubatch.seq_id[i][0];
//const llama_pos pos = ubatch.pos[i];
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
// can we use this cell? either:
// - the cell is empty
// - the cell is occupied only by one sequence:
// - mask causally, if the sequence is the same as the one we are inserting
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos
bool can_use = cells.is_empty(head_cur + i);
@ -604,21 +598,17 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
if (!can_use && cells.seq_count(head_cur + i) == 1) {
const llama_pos pos_cell = cells.pos_get(head_cur + i);
// causal mask
if (cells.seq_has(head_cur + i, seq_id)) {
can_use = pos_cell >= pos;
}
// (disabled) causal mask
// note: it's better to purge any "future" tokens beforehand
//if (cells.seq_has(head_cur + i, seq_id)) {
// can_use = pos_cell >= pos;
//}
if (!can_use) {
const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
// SWA mask
// note: we insert only in the cell with minimum pos in order to preserve the invariant that
// all positions between [pos_min, pos_max] for each sequence will be present in the cache
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
if (pos_cell == seq_pos_min[seq_id_cell] &&
is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
seq_pos_min[seq_id_cell]++;
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
can_use = true;
}
}
@ -646,8 +636,22 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
}
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
// keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
seq_pos_max_rm[s] = -1;
}
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
if (!cells.is_empty(head_cur + i)) {
assert(cells.seq_count(head_cur + i) == 1);
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
const llama_pos pos = cells.pos_get(head_cur + i);
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
cells.rm(head_cur + i);
}
@ -658,6 +662,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
}
}
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
for (int s = 0; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
if (seq_pos_max_rm[s] == -1) {
continue;
}
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
}
}
// move the head at the end of the slot
head = head_cur + ubatch.n_tokens;
}