cont : kv-cells cp/set for non-cont slots

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-21 15:26:01 +03:00
parent f875d6cb72
commit 39d0b1e8df
3 changed files with 64 additions and 32 deletions

View File

@ -388,7 +388,8 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
struct state { struct state {
uint32_t head_old; // old position of the head, before placing the ubatch uint32_t head_old; // old position of the head, before placing the ubatch
uint32_t head_new; // new position of the head, after placing the ubatch
slot_info sinfo; // slot info for the ubatch
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
}; };
@ -409,13 +410,8 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
// remeber the position that we found // remeber the position that we found
res.push_back(sinfo_new); res.push_back(sinfo_new);
// TODO: temporary
if (supports_set_rows) {
GGML_ASSERT(sinfo_new.is_cont());
}
// store the old state of the cells in the recovery stack // store the old state of the cells in the recovery stack
states.push_back({head, sinfo_new.head(), cells.cp(sinfo_new.head(), ubatch.n_tokens)}); states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
// now emplace the ubatch // now emplace the ubatch
apply_ubatch(sinfo_new, ubatch); apply_ubatch(sinfo_new, ubatch);
@ -423,7 +419,7 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
// iterate backwards and restore the cells to their original state // iterate backwards and restore the cells to their original state
for (auto it = states.rbegin(); it != states.rend(); ++it) { for (auto it = states.rbegin(); it != states.rend(); ++it) {
cells.set(it->head_new, it->cells); cells.set(it->sinfo.idxs, it->cells);
head = it->head_old; head = it->head_old;
} }

View File

@ -49,20 +49,6 @@ public:
return idxs.empty(); return idxs.empty();
} }
// TODO: tmp until kv cells support non-cont slots
bool is_cont() const {
bool res = true;
for (uint32_t i = 1; i < idxs.size(); ++i) {
if (idxs[i] != idxs[i - 1] + 1) {
res = false;
break;
}
}
return res;
}
// TODO: implement // TODO: implement
//std::vector<idx_vec_t> seq_idxs; //std::vector<idx_vec_t> seq_idxs;
}; };

View File

@ -105,10 +105,29 @@ public:
res.resize(n); res.resize(n);
for (uint32_t j = 0; j < n; ++j) { for (uint32_t j = 0; j < n; ++j) {
res.pos[j] = pos[i + j]; const auto idx = i + j;
res.seq[j] = seq[i + j];
assert(shift[i + j] == 0); res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
}
return res;
}
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
llama_kv_cells_unified res;
res.resize(idxs.size());
for (uint32_t j = 0; j < idxs.size(); ++j) {
const auto idx = idxs[j];
res.pos[j] = pos[idx];
res.seq[j] = seq[idx];
assert(shift[idx] == 0);
} }
return res; return res;
@ -119,26 +138,57 @@ public:
assert(i + other.pos.size() <= pos.size()); assert(i + other.pos.size() <= pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) { for (uint32_t j = 0; j < other.pos.size(); ++j) {
if (pos[i + j] == -1 && other.pos[j] != -1) { const auto idx = i + j;
if (pos[idx] == -1 && other.pos[j] != -1) {
used.insert(i + j); used.insert(i + j);
} }
if (pos[i + j] != -1 && other.pos[j] == -1) { if (pos[idx] != -1 && other.pos[j] == -1) {
used.erase(i + j); used.erase(i + j);
} }
if (pos[i + j] != -1) { if (pos[idx] != -1) {
seq_pos_rm(i + j); seq_pos_rm(i + j);
} }
pos[i + j] = other.pos[j]; pos[idx] = other.pos[j];
seq[i + j] = other.seq[j]; seq[idx] = other.seq[j];
if (pos[i + j] != -1) { if (pos[idx] != -1) {
seq_pos_add(i + j); seq_pos_add(i + j);
} }
assert(shift[i + j] == 0); assert(shift[idx] == 0);
}
}
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
assert(idxs.size() == other.pos.size());
for (uint32_t j = 0; j < other.pos.size(); ++j) {
const auto idx = idxs[j];
if (pos[idx] == -1 && other.pos[j] != -1) {
used.insert(idx);
}
if (pos[idx] != -1 && other.pos[j] == -1) {
used.erase(idx);
}
if (pos[idx] != -1) {
seq_pos_rm(idx);
}
pos[idx] = other.pos[j];
seq[idx] = other.seq[j];
if (pos[idx] != -1) {
seq_pos_add(idx);
}
assert(shift[idx] == 0);
} }
} }