diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 0de2f69d8..19bc16cd5 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -388,7 +388,8 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st struct state { uint32_t head_old; // old position of the head, before placing the ubatch - uint32_t head_new; // new position of the head, after placing the ubatch + + slot_info sinfo; // slot info for the ubatch llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch }; @@ -409,13 +410,8 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st // remeber the position that we found res.push_back(sinfo_new); - // TODO: temporary - if (supports_set_rows) { - GGML_ASSERT(sinfo_new.is_cont()); - } - // store the old state of the cells in the recovery stack - states.push_back({head, sinfo_new.head(), cells.cp(sinfo_new.head(), ubatch.n_tokens)}); + states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)}); // now emplace the ubatch apply_ubatch(sinfo_new, ubatch); @@ -423,7 +419,7 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st // iterate backwards and restore the cells to their original state for (auto it = states.rbegin(); it != states.rend(); ++it) { - cells.set(it->head_new, it->cells); + cells.set(it->sinfo.idxs, it->cells); head = it->head_old; } diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 698c2458a..6edf7b588 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -49,20 +49,6 @@ public: return idxs.empty(); } - // TODO: tmp until kv cells support non-cont slots - bool is_cont() const { - bool res = true; - - for (uint32_t i = 1; i < idxs.size(); ++i) { - if (idxs[i] != idxs[i - 1] + 1) { - res = false; - break; - } - } - - return res; - } - // TODO: implement //std::vector seq_idxs; }; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index c95d63594..eb52832c5 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -105,10 +105,29 @@ public: res.resize(n); for (uint32_t j = 0; j < n; ++j) { - res.pos[j] = pos[i + j]; - res.seq[j] = seq[i + j]; + const auto idx = i + j; - assert(shift[i + j] == 0); + res.pos[j] = pos[idx]; + res.seq[j] = seq[idx]; + + assert(shift[idx] == 0); + } + + return res; + } + + llama_kv_cells_unified cp(const std::vector & idxs) const { + llama_kv_cells_unified res; + + res.resize(idxs.size()); + + for (uint32_t j = 0; j < idxs.size(); ++j) { + const auto idx = idxs[j]; + + res.pos[j] = pos[idx]; + res.seq[j] = seq[idx]; + + assert(shift[idx] == 0); } return res; @@ -119,26 +138,57 @@ public: assert(i + other.pos.size() <= pos.size()); for (uint32_t j = 0; j < other.pos.size(); ++j) { - if (pos[i + j] == -1 && other.pos[j] != -1) { + const auto idx = i + j; + + if (pos[idx] == -1 && other.pos[j] != -1) { used.insert(i + j); } - if (pos[i + j] != -1 && other.pos[j] == -1) { + if (pos[idx] != -1 && other.pos[j] == -1) { used.erase(i + j); } - if (pos[i + j] != -1) { + if (pos[idx] != -1) { seq_pos_rm(i + j); } - pos[i + j] = other.pos[j]; - seq[i + j] = other.seq[j]; + pos[idx] = other.pos[j]; + seq[idx] = other.seq[j]; - if (pos[i + j] != -1) { + if (pos[idx] != -1) { seq_pos_add(i + j); } - assert(shift[i + j] == 0); + assert(shift[idx] == 0); + } + } + + void set(const std::vector & idxs, const llama_kv_cells_unified & other) { + assert(idxs.size() == other.pos.size()); + + for (uint32_t j = 0; j < other.pos.size(); ++j) { + const auto idx = idxs[j]; + + if (pos[idx] == -1 && other.pos[j] != -1) { + used.insert(idx); + } + + if (pos[idx] != -1 && other.pos[j] == -1) { + used.erase(idx); + } + + if (pos[idx] != -1) { + seq_pos_rm(idx); + } + + pos[idx] = other.pos[j]; + seq[idx] = other.seq[j]; + + if (pos[idx] != -1) { + seq_pos_add(idx); + } + + assert(shift[idx] == 0); } }