batch : add optional for sequential equal split (#14511)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-07-04 09:08:59 +03:00
committed by GitHub
parent 7b50f7c025
commit 67d1ef23c6
5 changed files with 26 additions and 5 deletions

View File

@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
// note: tracking the other way around is not necessary for now
//seq_cpl[s0][s1] = true;
has_cpl = true;
}
}
}
@ -466,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
return ubatch_add(idxs, idxs.size(), false);
}
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
if (sequential && has_cpl) {
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
return {};
}
std::vector<seq_set_t> cur_seq_set;
llama_seq_id last_seq_id = -1;
// determine the non-overlapping sequence sets participating in this ubatch
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (used[i]) {
@ -485,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
}
}
// accept only increasing sequence ids
if (sequential) {
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
}
if (add) {
cur_seq_set.push_back(seq_set[i]);
last_seq_id = batch.seq_id[i][0];
if (cur_seq_set.size() > n_ubatch) {
break;
}

View File

@ -70,7 +70,8 @@ public:
llama_ubatch split_simple(uint32_t n_ubatch);
// make ubatches of equal-length sequences sets
llama_ubatch split_equal(uint32_t n_ubatch);
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
// sequence-set-wise split - each ubatch contains a single sequence-set
llama_ubatch split_seq(uint32_t n_ubatch);
@ -113,6 +114,9 @@ private:
using pos_set_t = std::set<llama_pos>;
using seq_cpl_t = std::vector<bool>;
// helper flag to quickly determine if there are any coupled sequences in the batch
bool has_cpl;
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1

View File

@ -140,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = balloc.split_equal(n_ubatch);
auto ubatch = balloc.split_equal(n_ubatch, false);
if (ubatch.n_tokens == 0) {
break;

View File

@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch);
ubatch = balloc.split_equal(n_ubatch, false);
}
if (ubatch.n_tokens == 0) {

View File

@ -374,7 +374,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch);
ubatch = balloc.split_equal(n_ubatch, false);
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {