mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-12 22:23:13 +00:00
kv-cache : remove const_cast when setting inputs for s_copy
And also fix multi-user inference for recurrent models by using cell_id instead of i as the kv cell index when populating s_copy.
This commit is contained in:
@ -286,27 +286,21 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
|||||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||||
const uint32_t cell_id = i + kv_self->head;
|
const uint32_t cell_id = i + kv_self->head;
|
||||||
|
|
||||||
//////////////////////////////////////////////
|
const llama_kv_cell & kv_cell = kv_self->cells[cell_id];
|
||||||
// TODO: this should not mutate the KV cache !
|
|
||||||
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
int32_t src = kv_cell.src0;
|
||||||
|
|
||||||
// prevent out-of-bound sources
|
// prevent out-of-bound sources
|
||||||
if (kv_cell.src < 0) {
|
if (src < 0) {
|
||||||
GGML_ASSERT(kv_self->rs_z >= 0); // Need a valid zero-ed cell as a source
|
GGML_ASSERT(kv_self->rs_z >= 0); // Need a valid zero-ed cell as a source
|
||||||
kv_cell.src = kv_self->rs_z;
|
src = kv_self->rs_z;
|
||||||
}
|
}
|
||||||
if ((uint32_t) kv_cell.src >= kv_self->size) {
|
if ((uint32_t) src >= kv_self->size) {
|
||||||
// ignore out-of-bound sources
|
// ignore out-of-bound sources
|
||||||
kv_cell.src = cell_id;
|
src = cell_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
data[i] = kv_cell.src;
|
data[i] = src;
|
||||||
|
|
||||||
// TODO: do not mutate the KV cache
|
|
||||||
// ensure copy only happens once
|
|
||||||
if (kv_cell.src != (int32_t) cell_id) {
|
|
||||||
kv_cell.src = cell_id;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -665,10 +665,13 @@ bool llama_kv_cache_unified::find_slot(
|
|||||||
// Find first to-be-cleared cell
|
// Find first to-be-cleared cell
|
||||||
rs_z = -1;
|
rs_z = -1;
|
||||||
for (int i = min; i <= max; ++i) {
|
for (int i = min; i <= max; ++i) {
|
||||||
if (cells[i].src == -1) {
|
if (rs_z < 0 && cells[i].src == -1) {
|
||||||
rs_z = i;
|
rs_z = i;
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
// Stage the source ids for all used cells to allow correct seq_* behavior
|
||||||
|
// and still make these values available when setting the inputs
|
||||||
|
cells[i].src0 = cells[i].src;
|
||||||
|
cells[i].src = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
// allow getting the range of used cells, from head to head + n
|
// allow getting the range of used cells, from head to head + n
|
||||||
|
@ -47,6 +47,7 @@ struct llama_kv_cell {
|
|||||||
llama_pos pos = -1;
|
llama_pos pos = -1;
|
||||||
llama_pos delta = 0;
|
llama_pos delta = 0;
|
||||||
int32_t src = -1; // used by recurrent state models to copy states
|
int32_t src = -1; // used by recurrent state models to copy states
|
||||||
|
int32_t src0 = -1; // like src, but used when setting the inputs (allowing to copy once)
|
||||||
int32_t tail = -1;
|
int32_t tail = -1;
|
||||||
|
|
||||||
std::set<llama_seq_id> seq_id;
|
std::set<llama_seq_id> seq_id;
|
||||||
|
Reference in New Issue
Block a user