batch : add n_used count (#14512)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-07-04 09:04:59 +03:00
committed by GitHub
parent 499a8f5a78
commit c79184d2d1
6 changed files with 34 additions and 1 deletions

View File

@ -405,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
return n_outputs;
}
uint32_t llama_batch_allocr::get_n_used() const {
return n_used;
}
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
return out_ids;
}
@ -420,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
void llama_batch_allocr::split_reset() {
out_ids.clear();
n_used = 0;
used.clear();
used.resize(get_n_tokens(), false);
@ -444,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
idxs.push_back(cur_idx);
used[cur_idx] = true;
++n_used;
++cur_idx;
@ -529,6 +536,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
idxs_per_seq[s].push_back(idx);
used[idx] = true;
++n_used;
++cur_idx[s];
}
@ -570,6 +578,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
idxs.push_back(cur_idx);
used[cur_idx] = true;
++n_used;
if (idxs.size() >= n_ubatch) {
break;

View File

@ -54,6 +54,7 @@ public:
uint32_t get_n_tokens() const;
uint32_t get_n_outputs() const;
uint32_t get_n_used() const;
// the array of output indices in the order they were encountered during the ubatch splitting
std::vector<int32_t> & get_out_ids();
@ -125,6 +126,8 @@ private:
// batch indices of the output
std::vector<int32_t> out_ids;
uint32_t n_used;
// used[i] indicates if token i has already been used in a previous ubatch
std::vector<bool> used;

View File

@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;
@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;

View File

@ -360,6 +360,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
auto sinfos = prepare(ubatches);
if (sinfos.empty()) {
break;

View File

@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
// prepare the recurrent batches first
if (!mem_recr->prepare(ubatches)) {
// TODO: will the recurrent cache be in an undefined context at this point?

View File

@ -377,7 +377,8 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
ubatch = balloc.split_equal(n_ubatch);
}
if (ubatch.n_tokens == 0) {
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}