recurrent : call balloc split_reset() in init_batch() (#14414)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-27 17:55:45 +03:00
committed by GitHub
parent 8d94219a4a
commit 43678060c1

View File

@ -363,30 +363,35 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
}
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
std::vector<llama_ubatch> ubatches;
do {
balloc.split_reset();
while (true) {
llama_ubatch ubatch;
std::vector<llama_ubatch> ubatches;
while (true) {
llama_ubatch ubatch;
if (embd_all) {
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch);
if (embd_all) {
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch);
}
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (ubatch.n_tokens == 0) {
if (!prepare(ubatches)) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
} while (false);
if (!prepare(ubatches)) {
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_context_ptr llama_memory_recurrent::init_full() {