batch : rework llama_batch_allocr (#14153)

* batch : rework llama_batch_allocr

ggml-ci

* cont : move validation inside class

ggml-ci

* cont : move output counting to class

ggml-ci

* cont : minor

ggml-ci

* batch : add TODOs

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-13 13:47:55 +03:00
committed by GitHub
parent b7cc7745e3
commit 60c666347b
7 changed files with 162 additions and 106 deletions

View File

@ -1,5 +1,9 @@
#include "llama-batch.h"
#include "llama-impl.h"
#include "llama-cparams.h"
#include "llama-vocab.h"
#include <cassert>
#include <cstring>
#include <algorithm>
@ -279,9 +283,42 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
);
}
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
batch = in_batch;
llama_batch_allocr::llama_batch_allocr() = default;
bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
clear();
batch = batch_inp;
GGML_ASSERT(batch.n_tokens > 0);
if (!batch.pos) {
if (batch.seq_id) {
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
return false;
}
}
if (batch.token) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return false;
}
}
}
if (batch.seq_id) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
return false;
}
}
}
}
if (!batch.pos) {
assert(p0 >= 0);
pos.resize(batch.n_tokens);
@ -290,6 +327,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
}
batch.pos = pos.data();
}
if (!batch.n_seq_id) {
n_seq_id.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
@ -297,6 +335,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
}
batch.n_seq_id = n_seq_id.data();
}
if (!batch.seq_id) {
seq_id.resize(batch.n_tokens + 1);
seq_id[batch.n_tokens] = NULL;
@ -305,12 +344,37 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
}
batch.seq_id = seq_id.data();
}
if (!batch.logits) {
// by default return the output only for the last token
output.resize(batch.n_tokens);
output[output.size() - 1] = true;
batch.logits = output.data();
}
for (int32_t i = 0; i < batch.n_tokens; ++i) {
n_outputs += batch.logits[i] != 0;
}
return true;
}
const llama_batch & llama_batch_allocr::get_batch() const {
return batch;
}
uint32_t llama_batch_allocr::get_n_outputs() const {
return n_outputs;
}
void llama_batch_allocr::clear() {
n_outputs = 0;
batch = {};
pos.clear();
n_seq_id.clear();
seq_id.clear();
output.clear();
}
//