kv-cache : remove const_cast when setting inputs for s_copy

And also fix multi-user inference for recurrent models
by using cell_id instead of i as the kv cell index
when populating s_copy.
This commit is contained in:
Francis Couture-Harpin
2025-05-01 22:18:57 -04:00
parent 791998b42d
commit 94c3d53043
3 changed files with 14 additions and 16 deletions

View File

@ -286,27 +286,21 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self->head;
//////////////////////////////////////////////
// TODO: this should not mutate the KV cache !
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
const llama_kv_cell & kv_cell = kv_self->cells[cell_id];
int32_t src = kv_cell.src0;
// prevent out-of-bound sources
if (kv_cell.src < 0) {
if (src < 0) {
GGML_ASSERT(kv_self->rs_z >= 0); // Need a valid zero-ed cell as a source
kv_cell.src = kv_self->rs_z;
src = kv_self->rs_z;
}
if ((uint32_t) kv_cell.src >= kv_self->size) {
if ((uint32_t) src >= kv_self->size) {
// ignore out-of-bound sources
kv_cell.src = cell_id;
src = cell_id;
}
data[i] = kv_cell.src;
// TODO: do not mutate the KV cache
// ensure copy only happens once
if (kv_cell.src != (int32_t) cell_id) {
kv_cell.src = cell_id;
}
data[i] = src;
}
}
}

View File

@ -665,10 +665,13 @@ bool llama_kv_cache_unified::find_slot(
// Find first to-be-cleared cell
rs_z = -1;
for (int i = min; i <= max; ++i) {
if (cells[i].src == -1) {
if (rs_z < 0 && cells[i].src == -1) {
rs_z = i;
break;
}
// Stage the source ids for all used cells to allow correct seq_* behavior
// and still make these values available when setting the inputs
cells[i].src0 = cells[i].src;
cells[i].src = i;
}
// allow getting the range of used cells, from head to head + n

View File

@ -47,6 +47,7 @@ struct llama_kv_cell {
llama_pos pos = -1;
llama_pos delta = 0;
int32_t src = -1; // used by recurrent state models to copy states
int32_t src0 = -1; // like src, but used when setting the inputs (allowing to copy once)
int32_t tail = -1;
std::set<llama_seq_id> seq_id;