diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0f77f98b2..8d2fceb17 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -286,27 +286,21 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { for (uint32_t i = 0; i < n_kv; ++i) { const uint32_t cell_id = i + kv_self->head; - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; + const llama_kv_cell & kv_cell = kv_self->cells[cell_id]; + + int32_t src = kv_cell.src0; // prevent out-of-bound sources - if (kv_cell.src < 0) { + if (src < 0) { GGML_ASSERT(kv_self->rs_z >= 0); // Need a valid zero-ed cell as a source - kv_cell.src = kv_self->rs_z; + src = kv_self->rs_z; } - if ((uint32_t) kv_cell.src >= kv_self->size) { + if ((uint32_t) src >= kv_self->size) { // ignore out-of-bound sources - kv_cell.src = cell_id; + src = cell_id; } - data[i] = kv_cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (kv_cell.src != (int32_t) cell_id) { - kv_cell.src = cell_id; - } + data[i] = src; } } } diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 108c07731..743b30bad 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -665,10 +665,13 @@ bool llama_kv_cache_unified::find_slot( // Find first to-be-cleared cell rs_z = -1; for (int i = min; i <= max; ++i) { - if (cells[i].src == -1) { + if (rs_z < 0 && cells[i].src == -1) { rs_z = i; - break; } + // Stage the source ids for all used cells to allow correct seq_* behavior + // and still make these values available when setting the inputs + cells[i].src0 = cells[i].src; + cells[i].src = i; } // allow getting the range of used cells, from head to head + n diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 7939bc6b8..6b115e8f7 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -47,6 +47,7 @@ struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; int32_t src = -1; // used by recurrent state models to copy states + int32_t src0 = -1; // like src, but used when setting the inputs (allowing to copy once) int32_t tail = -1; std::set seq_id;