From c79184d2d192489e3c918bab8ed717d22f8c02bd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 4 Jul 2025 09:04:59 +0300 Subject: [PATCH] batch : add n_used count (#14512) ggml-ci --- src/llama-batch.cpp | 9 +++++++++ src/llama-batch.h | 3 +++ src/llama-kv-cache-unified-iswa.cpp | 10 ++++++++++ src/llama-kv-cache-unified.cpp | 5 +++++ src/llama-memory-hybrid.cpp | 5 +++++ src/llama-memory-recurrent.cpp | 3 ++- 6 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 91b1d6078..8d84c6805 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -405,6 +405,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const { return n_outputs; } +uint32_t llama_batch_allocr::get_n_used() const { + return n_used; +} + std::vector & llama_batch_allocr::get_out_ids() { return out_ids; } @@ -420,6 +424,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const { void llama_batch_allocr::split_reset() { out_ids.clear(); + n_used = 0; + used.clear(); used.resize(get_n_tokens(), false); @@ -444,6 +450,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { idxs.push_back(cur_idx); used[cur_idx] = true; + ++n_used; ++cur_idx; @@ -529,6 +536,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) { idxs_per_seq[s].push_back(idx); used[idx] = true; + ++n_used; ++cur_idx[s]; } @@ -570,6 +578,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) { idxs.push_back(cur_idx); used[cur_idx] = true; + ++n_used; if (idxs.size() >= n_ubatch) { break; diff --git a/src/llama-batch.h b/src/llama-batch.h index d2c537618..edff8cdd6 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -54,6 +54,7 @@ public: uint32_t get_n_tokens() const; uint32_t get_n_outputs() const; + uint32_t get_n_used() const; // the array of output indices in the order they were encountered during the ubatch splitting std::vector & get_out_ids(); @@ -125,6 +126,8 @@ private: // batch indices of the output std::vector out_ids; + uint32_t n_used; + // used[i] indicates if token i has already been used in a previous ubatch std::vector used; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index ee202cc71..ab4c41c78 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -113,6 +113,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + auto sinfos_base = kv_base->prepare(ubatches); if (sinfos_base.empty()) { break; @@ -144,6 +149,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + auto sinfos_base = kv_base->prepare(ubatches); if (sinfos_base.empty()) { break; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index ff2207985..d3129cc53 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -360,6 +360,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + auto sinfos = prepare(ubatches); if (sinfos.empty()) { break; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 03d974d85..908e927fa 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba ubatches.push_back(std::move(ubatch)); // NOLINT } + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + // prepare the recurrent batches first if (!mem_recr->prepare(ubatches)) { // TODO: will the recurrent cache be in an undefined context at this point? diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 6ed84057c..ca0c8dd56 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -377,7 +377,8 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & ubatch = balloc.split_equal(n_ubatch); } - if (ubatch.n_tokens == 0) { + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split break; }