kv-cache : fix split_equal handling in unified implementation (#14130)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-12 10:02:15 +03:00
committed by GitHub
parent a20b2b05bc
commit 9596506965
3 changed files with 128 additions and 71 deletions

View File

@ -877,6 +877,8 @@ int llama_context::encode(llama_batch & inp_batch) {
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
// remember the sequence ids used during the encoding - needed for cross attention later // remember the sequence ids used during the encoding - needed for cross attention later
// TODO: the seuqence indexing here is likely not correct in the general case
// probably works only for split_simple
cross.seq_ids_enc.resize(n_tokens); cross.seq_ids_enc.resize(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) { for (int32_t i = 0; i < n_tokens; i++) {
cross.seq_ids_enc[i].clear(); cross.seq_ids_enc[i].clear();

View File

@ -98,33 +98,66 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
GGML_UNUSED(embd_pooled); GGML_UNUSED(embd_pooled);
// TODO: if we fail with split_simple, we should attempt different splitting strategies // first try simple split
do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) {
auto ubatch = sbatch.split_simple(n_ubatch);
ubatches.push_back(ubatch);
}
auto heads_base = kv_base->prepare(ubatches);
if (heads_base.empty()) {
break;
}
auto heads_swa = kv_swa->prepare(ubatches);
if (heads_swa.empty()) {
break;
}
assert(heads_base.size() == heads_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_state>(
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
} while (false);
// if it fails, try equal split
do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all);
std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) {
auto ubatch = sbatch.split_equal(n_ubatch);
ubatches.push_back(ubatch);
}
auto heads_base = kv_base->prepare(ubatches);
if (heads_base.empty()) {
break;
}
auto heads_swa = kv_swa->prepare(ubatches);
if (heads_swa.empty()) {
break;
}
assert(heads_base.size() == heads_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_state>(
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
} while (false);
// TODO: if we fail again, we should attempt different splitting strategies
// but to do that properly, we first have to refactor the batches to be more flexible // but to do that properly, we first have to refactor the batches to be more flexible
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) {
auto ubatch = sbatch.split_simple(n_ubatch);
ubatches.push_back(ubatch);
}
auto heads_base = kv_base->prepare(ubatches);
if (heads_base.empty()) {
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
auto heads_swa = kv_swa->prepare(ubatches);
if (heads_swa.empty()) {
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
assert(heads_base.size() == heads_swa.size());
return std::make_unique<llama_kv_cache_unified_iswa_state>(
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
} }
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {

View File

@ -314,20 +314,24 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
bool logits_all) { bool logits_all) {
GGML_UNUSED(embd_pooled); GGML_UNUSED(embd_pooled);
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); do {
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
while (sbatch.n_tokens > 0) { while (sbatch.n_tokens > 0) {
ubatches.push_back(sbatch.split_simple(n_ubatch)); ubatches.push_back(sbatch.split_simple(n_ubatch));
} }
auto heads = prepare(ubatches); auto heads = prepare(ubatches);
if (heads.empty()) { if (heads.empty()) {
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); break;
} }
return std::make_unique<llama_kv_cache_unified_state>( return std::make_unique<llama_kv_cache_unified_state>(
this, std::move(sbatch), std::move(heads), std::move(ubatches)); this, std::move(sbatch), std::move(heads), std::move(ubatches));
} while (false);
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
} }
llama_memory_state_ptr llama_kv_cache_unified::init_full() { llama_memory_state_ptr llama_kv_cache_unified::init_full() {
@ -521,7 +525,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
} }
if (debug > 0) { if (debug > 0) {
LLAMA_LOG_CONT("\n");
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa); LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
if ((debug == 2 && n_swa > 0) || debug > 2) { if ((debug == 2 && n_swa > 0) || debug > 2) {
@ -530,7 +533,13 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
if (cells.is_empty(i)) { if (cells.is_empty(i)) {
ss += '.'; ss += '.';
} else { } else {
ss += std::to_string(cells.seq_get(i)); assert(cells.seq_count(i) >= 1);
if (cells.seq_count(i) == 1) {
ss += std::to_string(cells.seq_get(i));
} else {
ss += 'M';
}
} }
if (i%256 == 255) { if (i%256 == 255) {
ss += " *"; ss += " *";
@ -636,6 +645,12 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
} }
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
if (debug > 0) {
LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
}
// keep track of the max sequence position that we would overwrite with this ubatch // keep track of the max sequence position that we would overwrite with this ubatch
// for non-SWA cache, this would be always empty // for non-SWA cache, this would be always empty
llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES]; llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
@ -643,22 +658,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
seq_pos_max_rm[s] = -1; seq_pos_max_rm[s] = -1;
} }
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
if (!cells.is_empty(head_cur + i)) { for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) {
assert(cells.seq_count(head_cur + i) == 1); const uint32_t idx = s*ubatch.n_seq_tokens + j;
const llama_seq_id seq_id = cells.seq_get(head_cur + i); if (!cells.is_empty(head_cur + idx)) {
const llama_pos pos = cells.pos_get(head_cur + i); assert(cells.seq_count(head_cur + idx) == 1);
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); const llama_seq_id seq_id = cells.seq_get(head_cur + idx);
const llama_pos pos = cells.pos_get(head_cur + idx);
cells.rm(head_cur + i); seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
}
cells.pos_set(head_cur + i, ubatch.pos[i]); cells.rm(head_cur + idx);
}
for (int32_t j = 0; j < ubatch.n_seq_id[i]; j++) { cells.pos_set(head_cur + idx, ubatch.pos[idx]);
cells.seq_add(head_cur + i, ubatch.seq_id[i][j]);
for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) {
cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
}
} }
} }
@ -677,7 +696,6 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
} }
} }
// move the head at the end of the slot // move the head at the end of the slot
head = head_cur + ubatch.n_tokens; head = head_cur + ubatch.n_tokens;
} }
@ -774,14 +792,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
} }
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
const int64_t n_tokens = ubatch->n_tokens; const uint32_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens; const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
const int64_t n_seqs = ubatch->n_seqs; const uint32_t n_seqs = ubatch->n_seqs;
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
float * data = (float *) dst->data; float * data = (float *) dst->data;
const auto n_kv = dst->ne[0]; const int64_t n_kv = dst->ne[0];
// Use only the previous KV cells of the correct sequence for each token of the ubatch. // Use only the previous KV cells of the correct sequence for each token of the ubatch.
// It's assumed that if a token in the batch has multiple sequences, they are equivalent. // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@ -795,12 +813,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
// xxxxx----- // xxxxx-----
// xxxxx----- // xxxxx-----
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
for (int h = 0; h < 1; ++h) { for (uint32_t h = 0; h < 1; ++h) {
for (int s = 0; s < n_seqs; ++s) { for (uint32_t s = 0; s < n_seqs; ++s) {
const llama_seq_id seq_id = ubatch->seq_id[s][0]; const llama_seq_id seq_id = ubatch->seq_id[s][0];
for (int j = 0; j < n_seq_tokens; ++j) { for (uint32_t j = 0; j < n_seq_tokens; ++j) {
const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; const uint32_t idx = s*n_seq_tokens + j;
const llama_pos p1 = ubatch->pos[idx];
for (uint32_t i = 0; i < n_kv; ++i) { for (uint32_t i = 0; i < n_kv; ++i) {
float f = 0.0f; float f = 0.0f;
@ -830,16 +850,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
f = -INFINITY; f = -INFINITY;
} }
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; data[h*(n_kv*n_tokens) + idx*n_kv + i] = f;
} }
} }
} }
// mask padded tokens // mask padded tokens
if (data) { if (data) {
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) {
for (uint32_t j = 0; j < n_kv; ++j) { for (uint32_t i = 0; i < n_kv; ++i) {
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY;
} }
} }
} }
@ -1490,9 +1510,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
seq_rm(dest_seq_id, -1, -1); seq_rm(dest_seq_id, -1, -1);
llama_sbatch sbatch; llama_sbatch sbatch;
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
batch.n_tokens = cell_count; ubatch.n_tokens = cell_count;
ubatch.n_seq_tokens = cell_count;
ubatch.n_seqs = 1;
for (uint32_t i = 0; i < cell_count; ++i) { for (uint32_t i = 0; i < cell_count; ++i) {
llama_pos pos; llama_pos pos;
@ -1512,18 +1534,18 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
io.read_to(&seq_id, sizeof(seq_id)); io.read_to(&seq_id, sizeof(seq_id));
} }
batch.pos[i] = pos; ubatch.pos[i] = pos;
batch.n_seq_id[i] = n_seq_id; ubatch.n_seq_id[i] = n_seq_id;
batch.seq_id[i] = &dest_seq_id; ubatch.seq_id[i] = &dest_seq_id;
} }
const auto head_cur = find_slot(batch); const auto head_cur = find_slot(ubatch);
if (head_cur < 0) { if (head_cur < 0) {
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
return false; return false;
} }
apply_ubatch(head_cur, batch); apply_ubatch(head_cur, ubatch);
// keep the head at the old position because we will read the KV data into it in state_read_data() // keep the head at the old position because we will read the KV data into it in state_read_data()
head = head_cur; head = head_cur;
@ -1531,8 +1553,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
// DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells // Assume that this is one contiguous block of cells
GGML_ASSERT(head_cur + cell_count <= cells.size()); GGML_ASSERT(head_cur + cell_count <= cells.size());
GGML_ASSERT(cells.pos_get(head_cur) == batch.pos[0]); GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]);
GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == batch.pos[cell_count - 1]); GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id)); GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id));
GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
} else { } else {