recurrent : call balloc split_reset() in init_batch() (#14414)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-27 17:55:45 +03:00
committed by GitHub
parent 8d94219a4a
commit 43678060c1

View File

@ -363,30 +363,35 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
} }
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
std::vector<llama_ubatch> ubatches; do {
balloc.split_reset();
while (true) { std::vector<llama_ubatch> ubatches;
llama_ubatch ubatch; while (true) {
llama_ubatch ubatch;
if (embd_all) { if (embd_all) {
// if all tokens are output, split by sequence // if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch); ubatch = balloc.split_seq(n_ubatch);
} else { } else {
ubatch = balloc.split_equal(n_ubatch); ubatch = balloc.split_equal(n_ubatch);
}
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
} }
if (ubatch.n_tokens == 0) { if (!prepare(ubatches)) {
break; break;
} }
ubatches.push_back(std::move(ubatch)); // NOLINT return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
} } while (false);
if (!prepare(ubatches)) { return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
} }
llama_memory_context_ptr llama_memory_recurrent::init_full() { llama_memory_context_ptr llama_memory_recurrent::init_full() {