ubatch : new splitting logic (#14217)

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-06-20 10:14:14 +03:00
committed by GitHub
parent 9eaa51e7f0
commit 4c9fdfbe15
19 changed files with 992 additions and 915 deletions

View File

@@ -20,7 +20,7 @@ llama_context::llama_context(
const llama_model & model,
llama_context_params params) :
model(model),
batch_allocr(std::make_unique<llama_batch_allocr>()) {
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
t_start_us = model.t_start_us;
@@ -722,22 +722,26 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
}
int llama_context::encode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
if (batch_inp.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
}
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
// note: during encode, we always pass the full sequence starting from pos = 0
if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) {
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
const llama_batch & batch = batch_allocr->get_batch();
const uint32_t n_tokens = balloc->get_n_tokens();
const uint32_t n_tokens = batch.n_tokens;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
@@ -751,14 +755,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
n_queued_tokens += n_tokens;
const auto & hparams = model.hparams;
const int64_t n_embd = hparams.n_embd;
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
// reserve output buffer
if (output_reserve(n_tokens) < n_tokens) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
@@ -817,34 +813,28 @@ int llama_context::encode(const llama_batch & batch_inp) {
{
// extract sequence embeddings
auto & embd_seq_out = embd_seq;
embd_seq_out.clear();
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
// TODO: fix indexing [UBATCH_IDX]
for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = ubatch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rerank score - n_cls_out floats per sequence
auto & embd_seq_out = embd_seq;
const uint32_t n_cls_out = hparams.n_cls_out;
// TODO: fix indexing [UBATCH_IDX]
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float));
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -869,12 +859,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
cross.v_embd.resize(cross.n_embd*cross.n_enc);
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
const auto & batch = balloc->get_batch();
// remember the sequence ids used during the encoding - needed for cross attention later
cross.seq_ids_enc.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
cross.seq_ids_enc[i].clear();
for (int s = 0; s < batch.n_seq_id[i]; s++) {
llama_seq_id seq_id = batch.seq_id[i][s];
const llama_seq_id seq_id = batch.seq_id[i][s];
cross.seq_ids_enc[i].insert(seq_id);
}
}
@@ -884,6 +878,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
}
int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
if (!memory) {
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
return encode(batch_inp);
@@ -894,29 +890,24 @@ int llama_context::decode(const llama_batch & batch_inp) {
return -1;
}
// when computing embeddings, all tokens are output
const bool embd_all = cparams.embeddings;
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
const llama_batch & batch = batch_allocr->get_batch();
const auto & vocab = model.vocab;
const auto & hparams = model.hparams;
const int32_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd;
const uint32_t n_tokens_all = batch.n_tokens;
// when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1;
}
const uint32_t n_outputs_all = batch_allocr->get_n_outputs();
const uint32_t n_tokens_all = balloc->get_n_tokens();
const uint32_t n_outputs_all = balloc->get_n_outputs();
if (embd_all) {
if (output_all) {
// require that all tokens are output
if (n_outputs_all != n_tokens_all) {
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
@@ -945,7 +936,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
llama_memory_state_ptr mstate;
while (true) {
mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all);
mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
if (!mstate) {
return -2;
}
@@ -966,19 +957,19 @@ int llama_context::decode(const llama_batch & batch_inp) {
did_optimize = true;
if (kv_self_update(true)) {
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
continue;
}
}
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
return 1;
}
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
{
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
return -2;
}
@@ -1005,7 +996,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
if (n_outputs_all == n_tokens_all) {
n_outputs_new = ubatch.n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
@@ -1105,27 +1095,27 @@ int llama_context::decode(const llama_batch & batch_inp) {
// extract sequence embeddings (cleared before processing each batch)
auto & embd_seq_out = embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// extract the rerank score - a single float per sequence
// extract the rerank score - n_cls_out floats per sequence
auto & embd_seq_out = embd_seq;
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(1);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
const uint32_t n_cls_out = hparams.n_cls_out;
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
const int32_t seq_idx = ubatch.seq_idx[seq_id];
embd_seq_out[seq_id].resize(n_cls_out);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -1145,7 +1135,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
if (n_outputs > 0) {
bool sorted_output = true;
auto & out_ids = mstate->out_ids();
auto & out_ids = balloc->get_out_ids();
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
@@ -1318,8 +1308,8 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
this->n_outputs = n_outputs;
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
auto * gf = graph_init();
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
@@ -2039,7 +2029,12 @@ void llama_context::opt_epoch_iter(
batch.logits [pos_batch] = true;
}
const auto n_tokens_all = batch.n_tokens;
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return;
}
const uint32_t n_tokens_all = balloc->get_n_tokens();
n_queued_tokens += n_tokens_all;
@@ -2047,7 +2042,7 @@ void llama_context::opt_epoch_iter(
uint32_t n_outputs_all = n_tokens_all;
auto mstate = memory->init_batch(batch, cparams.n_ubatch, true);
auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
break;