diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 2f8108043..8c69e2f24 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -166,6 +166,8 @@ bool llama_batch_allocr::init( // note: tracking the other way around is not necessary for now //seq_cpl[s0][s1] = true; + + has_cpl = true; } } } @@ -403,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const { return n_outputs; } +uint32_t llama_batch_allocr::get_n_used() const { + return n_used; +} + std::vector & llama_batch_allocr::get_out_ids() { return out_ids; } @@ -418,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const { void llama_batch_allocr::split_reset() { out_ids.clear(); + n_used = 0; + used.clear(); used.resize(get_n_tokens(), false); @@ -442,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { idxs.push_back(cur_idx); used[cur_idx] = true; + ++n_used; ++cur_idx; @@ -458,6 +467,12 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { } llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) { + if (sequential && has_cpl) { + LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__); + + return {}; + } + std::vector cur_seq_set; llama_seq_id last_seq_id = -1; @@ -536,6 +551,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) idxs_per_seq[s].push_back(idx); used[idx] = true; + ++n_used; ++cur_idx[s]; } @@ -577,6 +593,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) { idxs.push_back(cur_idx); used[cur_idx] = true; + ++n_used; if (idxs.size() >= n_ubatch) { break; diff --git a/src/llama-batch.h b/src/llama-batch.h index edda1505d..3420803ff 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -54,6 +54,7 @@ public: uint32_t get_n_tokens() const; uint32_t get_n_outputs() const; + uint32_t get_n_used() const; // the array of output indices in the order they were encountered during the ubatch splitting std::vector & get_out_ids(); @@ -113,6 +114,9 @@ private: using pos_set_t = std::set; using seq_cpl_t = std::vector; + // helper flag to quickly determine if there are any coupled sequences in the batch + bool has_cpl; + std::vector seq_pos; // seq_pos[s]: the set of positions in sequence s std::vector seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 @@ -126,6 +130,8 @@ private: // batch indices of the output std::vector out_ids; + uint32_t n_used; + // used[i] indicates if token i has already been used in a previous ubatch std::vector used; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index 6220640a1..9518412d5 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -119,6 +119,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + auto sinfos_base = kv_base->prepare(ubatches); if (sinfos_base.empty()) { break; @@ -150,6 +155,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + auto sinfos_base = kv_base->prepare(ubatches); if (sinfos_base.empty()) { break; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 2080925df..a11f23000 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -427,6 +427,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + auto sinfos = prepare(ubatches); if (sinfos.empty()) { break; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 85470f755..6c1304856 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -81,6 +81,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + // prepare the recurrent batches first if (!mem_recr->prepare(ubatches)) { // TODO: will the recurrent cache be in an undefined context at this point? diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 55f78eb20..815e57868 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -365,26 +365,35 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { std::vector ubatches; - while (true) { - llama_ubatch ubatch; + do { + balloc.split_reset(); - if (embd_all) { - // if all tokens are output, split by sequence - ubatch = balloc.split_seq(n_ubatch); - } else { - ubatch = balloc.split_equal(n_ubatch, false); + while (true) { + llama_ubatch ubatch; + + if (embd_all) { + // if all tokens are output, split by sequence + ubatch = balloc.split_seq(n_ubatch); + } else { + ubatch = balloc.split_equal(n_ubatch, false); + } + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT } - if (ubatch.n_tokens == 0) { + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split break; } - ubatches.push_back(std::move(ubatch)); // NOLINT - } - - if (!prepare(ubatches)) { - return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); - } + if (!prepare(ubatches)) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + } while (false); return std::make_unique(this, std::move(ubatches)); }