From e4868d16d24dec55e61bcaadaca28feed8f98b13 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 24 Jul 2025 16:31:48 +0300 Subject: [PATCH] context : perform output reorder lazily upon access after sync (#14853) * context : perform output reorder after lazily upon access after sync ggml-ci * cont : add TODO --- include/llama.h | 2 ++ src/llama-context.cpp | 49 +++++++++++++++++++++++++++++++------------ src/llama-context.h | 9 ++++++++ 3 files changed, 47 insertions(+), 13 deletions(-) diff --git a/include/llama.h b/include/llama.h index 1c3a1cd1b..6f454a508 100644 --- a/include/llama.h +++ b/include/llama.h @@ -956,6 +956,7 @@ extern "C" { // in the order they have appeared in the batch. // Rows: number of tokens for which llama_batch.logits[i] != 0 // Cols: n_vocab + // TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522) LLAMA_API float * llama_get_logits(struct llama_context * ctx); // Logits for the ith token. For positive indices, Equivalent to: @@ -970,6 +971,7 @@ extern "C" { // in the order they have appeared in the batch. // shape: [n_outputs*n_embd] // Otherwise, returns NULL. + // TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); // Get the embeddings for the ith token. For positive indices, Equivalent to: diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 6eb344736..a91d157e2 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -508,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const { } float * llama_context::get_logits() { + output_reorder(); + return logits; } float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; + output_reorder(); + try { if (logits == nullptr) { throw std::runtime_error("no logits"); @@ -550,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) { } float * llama_context::get_embeddings() { + output_reorder(); + return embd; } float * llama_context::get_embeddings_ith(int32_t i) { int64_t j = -1; + output_reorder(); + try { if (embd == nullptr) { throw std::runtime_error("no embeddings"); @@ -970,6 +978,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); + output_swaps.clear(); bool did_optimize = false; @@ -1189,9 +1198,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // make the outputs have the same order they had in the user-provided batch // note: this is mostly relevant for recurrent models atm if (!sorted_output) { - const uint32_t n_vocab = model.vocab.n_tokens(); - const uint64_t n_embd = model.hparams.n_embd; - GGML_ASSERT((size_t) n_outputs == out_ids.size()); // TODO: is there something more efficient which also minimizes swaps? @@ -1207,16 +1213,9 @@ int llama_context::decode(const llama_batch & batch_inp) { continue; } std::swap(out_ids[i], out_ids[j_min]); - if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { - std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); - } - } - if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { - std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); - } - } + + // remember the swaps and apply them lazily upon logits/embeddings access + output_swaps.push_back({ i, j_min }); } std::fill(output_ids.begin(), output_ids.end(), -1); @@ -1307,6 +1306,30 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { return n_outputs_max; } +void llama_context::output_reorder() { + const uint32_t n_vocab = model.vocab.n_tokens(); + const uint64_t n_embd = model.hparams.n_embd; + + for (uint32_t s = 0; s < output_swaps.size(); ++s) { + const uint32_t i0 = output_swaps[s].i0; + const uint32_t i1 = output_swaps[s].i1; + + if (logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); + } + } + + if (embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); + } + } + } + + output_swaps.clear(); +} + // // graph // diff --git a/src/llama-context.h b/src/llama-context.h index 1601ac682..fdbe61207 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -181,6 +181,8 @@ private: // Returns max number of outputs for which space was reserved. uint32_t output_reserve(int32_t n_outputs); + void output_reorder(); + // // graph // @@ -250,6 +252,13 @@ private: std::vector output_ids; // map batch token positions to ids of the logits and embd buffers + struct swap_info { + uint32_t i0; + uint32_t i1; + }; + + std::vector output_swaps; + ggml_backend_sched_ptr sched; ggml_backend_t backend_cpu = nullptr;