mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-10 13:30:27 +00:00
kv-cache : remove const_cast when setting inputs for s_copy
And also fix multi-user inference for recurrent models by using cell_id instead of i as the kv cell index when populating s_copy.
This commit is contained in:
@ -286,27 +286,21 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
const uint32_t cell_id = i + kv_self->head;
|
||||
|
||||
//////////////////////////////////////////////
|
||||
// TODO: this should not mutate the KV cache !
|
||||
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
||||
const llama_kv_cell & kv_cell = kv_self->cells[cell_id];
|
||||
|
||||
int32_t src = kv_cell.src0;
|
||||
|
||||
// prevent out-of-bound sources
|
||||
if (kv_cell.src < 0) {
|
||||
if (src < 0) {
|
||||
GGML_ASSERT(kv_self->rs_z >= 0); // Need a valid zero-ed cell as a source
|
||||
kv_cell.src = kv_self->rs_z;
|
||||
src = kv_self->rs_z;
|
||||
}
|
||||
if ((uint32_t) kv_cell.src >= kv_self->size) {
|
||||
if ((uint32_t) src >= kv_self->size) {
|
||||
// ignore out-of-bound sources
|
||||
kv_cell.src = cell_id;
|
||||
src = cell_id;
|
||||
}
|
||||
|
||||
data[i] = kv_cell.src;
|
||||
|
||||
// TODO: do not mutate the KV cache
|
||||
// ensure copy only happens once
|
||||
if (kv_cell.src != (int32_t) cell_id) {
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
data[i] = src;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -665,10 +665,13 @@ bool llama_kv_cache_unified::find_slot(
|
||||
// Find first to-be-cleared cell
|
||||
rs_z = -1;
|
||||
for (int i = min; i <= max; ++i) {
|
||||
if (cells[i].src == -1) {
|
||||
if (rs_z < 0 && cells[i].src == -1) {
|
||||
rs_z = i;
|
||||
break;
|
||||
}
|
||||
// Stage the source ids for all used cells to allow correct seq_* behavior
|
||||
// and still make these values available when setting the inputs
|
||||
cells[i].src0 = cells[i].src;
|
||||
cells[i].src = i;
|
||||
}
|
||||
|
||||
// allow getting the range of used cells, from head to head + n
|
||||
|
@ -47,6 +47,7 @@ struct llama_kv_cell {
|
||||
llama_pos pos = -1;
|
||||
llama_pos delta = 0;
|
||||
int32_t src = -1; // used by recurrent state models to copy states
|
||||
int32_t src0 = -1; // like src, but used when setting the inputs (allowing to copy once)
|
||||
int32_t tail = -1;
|
||||
|
||||
std::set<llama_seq_id> seq_id;
|
||||
|
Reference in New Issue
Block a user