mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-27 12:05:03 +00:00
batch : rework llama_batch_allocr (#14153)
* batch : rework llama_batch_allocr ggml-ci * cont : move validation inside class ggml-ci * cont : move output counting to class ggml-ci * cont : minor ggml-ci * batch : add TODOs ggml-ci
This commit is contained in:
@ -1,5 +1,9 @@
|
||||
#include "llama-batch.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-cparams.h"
|
||||
#include "llama-vocab.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
@ -279,9 +283,42 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
|
||||
);
|
||||
}
|
||||
|
||||
llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) {
|
||||
batch = in_batch;
|
||||
llama_batch_allocr::llama_batch_allocr() = default;
|
||||
|
||||
bool llama_batch_allocr::init(const llama_batch & batch_inp, const llama_vocab & vocab, llama_pos p0) {
|
||||
clear();
|
||||
|
||||
batch = batch_inp;
|
||||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
|
||||
if (!batch.pos) {
|
||||
if (batch.seq_id) {
|
||||
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.token) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.seq_id) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_PARALLEL_SEQUENCES)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_PARALLEL_SEQUENCES);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!batch.pos) {
|
||||
assert(p0 >= 0);
|
||||
pos.resize(batch.n_tokens);
|
||||
@ -290,6 +327,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
|
||||
}
|
||||
batch.pos = pos.data();
|
||||
}
|
||||
|
||||
if (!batch.n_seq_id) {
|
||||
n_seq_id.resize(batch.n_tokens);
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
@ -297,6 +335,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
|
||||
}
|
||||
batch.n_seq_id = n_seq_id.data();
|
||||
}
|
||||
|
||||
if (!batch.seq_id) {
|
||||
seq_id.resize(batch.n_tokens + 1);
|
||||
seq_id[batch.n_tokens] = NULL;
|
||||
@ -305,12 +344,37 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
|
||||
}
|
||||
batch.seq_id = seq_id.data();
|
||||
}
|
||||
|
||||
if (!batch.logits) {
|
||||
// by default return the output only for the last token
|
||||
output.resize(batch.n_tokens);
|
||||
output[output.size() - 1] = true;
|
||||
batch.logits = output.data();
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
n_outputs += batch.logits[i] != 0;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
const llama_batch & llama_batch_allocr::get_batch() const {
|
||||
return batch;
|
||||
}
|
||||
|
||||
uint32_t llama_batch_allocr::get_n_outputs() const {
|
||||
return n_outputs;
|
||||
}
|
||||
|
||||
void llama_batch_allocr::clear() {
|
||||
n_outputs = 0;
|
||||
|
||||
batch = {};
|
||||
pos.clear();
|
||||
n_seq_id.clear();
|
||||
seq_id.clear();
|
||||
output.clear();
|
||||
}
|
||||
|
||||
//
|
||||
|
Reference in New Issue
Block a user