diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index dd909df59..2f8108043 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -457,7 +457,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { return ubatch_add(idxs, idxs.size(), false); } -llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { +llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) { std::vector cur_seq_set; llama_seq_id last_seq_id = -1; @@ -479,7 +479,9 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { } // accept only increasing sequence ids - add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1); + if (sequential) { + add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1); + } if (add) { cur_seq_set.push_back(seq_set[i]); diff --git a/src/llama-batch.h b/src/llama-batch.h index d2c537618..edda1505d 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -69,7 +69,8 @@ public: llama_ubatch split_simple(uint32_t n_ubatch); // make ubatches of equal-length sequences sets - llama_ubatch split_equal(uint32_t n_ubatch); + // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids + llama_ubatch split_equal(uint32_t n_ubatch, bool sequential); // sequence-set-wise split - each ubatch contains a single sequence-set llama_ubatch split_seq(uint32_t n_ubatch); diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index f0aac929c..6220640a1 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -102,7 +102,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all // first try simple split do { if (n_seq_virt > 1) { - // requires equal splits + // requires equal splits, so we skip the simple split break; } @@ -141,7 +141,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all std::vector ubatches; while (true) { - auto ubatch = balloc.split_equal(n_ubatch); + auto ubatch = balloc.split_equal(n_ubatch, n_seq_virt > 1); if (ubatch.n_tokens == 0) { break; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index c9d359f65..2080925df 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -418,7 +418,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( std::vector ubatches; while (true) { - auto ubatch = n_seq_virt == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch); + auto ubatch = n_seq_virt == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); if (ubatch.n_tokens == 0) { break; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index e8d3b581a..85470f755 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -71,7 +71,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch); + ubatch = balloc.split_equal(n_ubatch, false); } if (ubatch.n_tokens == 0) { diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 1b1e95d56..55f78eb20 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -372,7 +372,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch); + ubatch = balloc.split_equal(n_ubatch, false); } if (ubatch.n_tokens == 0) {