From c3ee46fab49a765d2e32e171e9ed7a5fa121dd9c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 12 Jun 2025 11:49:26 +0300 Subject: [PATCH] batch : remove logits_all flag (#14141) ggml-ci --- src/llama-batch.cpp | 10 ++-------- src/llama-batch.h | 4 +--- src/llama-context.cpp | 6 +++--- src/llama-kv-cache-recurrent.cpp | 4 ++-- src/llama-kv-cache-recurrent.h | 3 +-- src/llama-kv-cache-unified-iswa.cpp | 6 +++--- src/llama-kv-cache-unified-iswa.h | 3 +-- src/llama-kv-cache-unified.cpp | 5 ++--- src/llama-kv-cache-unified.h | 3 +-- src/llama-memory.h | 3 +-- 10 files changed, 17 insertions(+), 30 deletions(-) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 6a19a2431..58787fdba 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -105,12 +105,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s ubatch.seq_id = batch->seq_id + seq.offset; } } - if (logits_all) { - for (size_t i = 0; i < length; ++i) { - ubatch.output[ubatch.n_tokens + i] = 1; - out_ids.push_back(ids[seq.offset + i]); - } - } else if (batch->logits) { + if (batch->logits) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { size_t id = ids[seq.offset + i]; @@ -197,11 +192,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { return ubatch; } -llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { +llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) { GGML_ASSERT(batch.n_tokens >= 0); this->batch = &batch; this->n_embd = n_embd; - this->logits_all = logits_all; n_tokens = batch.n_tokens; ids.resize(n_tokens); diff --git a/src/llama-batch.h b/src/llama-batch.h index b8260b94f..989fb6cf9 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -39,8 +39,6 @@ struct llama_sbatch { size_t n_embd; - bool logits_all; // TODO: remove once lctx.logits_all is removed too - // sorted indices into the batch std::vector ids; // batch indices of the output @@ -76,7 +74,7 @@ struct llama_sbatch { llama_ubatch split_seq(size_t n_ubatch); llama_sbatch() = default; - llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); + llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false); }; // temporary allocate memory for the input batch if needed diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 8cea21d69..ebcba6993 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -764,7 +764,7 @@ int llama_context::encode(llama_batch & inp_batch) { const int64_t n_embd = hparams.n_embd; - llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true); const llama_ubatch ubatch = sbatch.split_simple(n_tokens); @@ -976,7 +976,7 @@ int llama_context::decode(llama_batch & inp_batch) { llama_memory_state_ptr mstate; while (true) { - mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all); + mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled); if (!mstate) { return -2; } @@ -2080,7 +2080,7 @@ void llama_context::opt_epoch_iter( int64_t n_outputs_all = n_tokens_all; - auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ true); + auto mstate = memory->init_batch(batch, cparams.n_ubatch, embd_pooled); if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); break; diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index f8cdd5280..de23b4ad2 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -359,10 +359,10 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } -llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { +llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) { GGML_UNUSED(embd_pooled); - auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); + auto sbatch = llama_sbatch(batch, hparams.n_embd, false); std::vector ubatches; diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h index 4b33bafd7..d7c02ea87 100644 --- a/src/llama-kv-cache-recurrent.h +++ b/src/llama-kv-cache-recurrent.h @@ -32,8 +32,7 @@ public: llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) override; + bool embd_pooled) override; llama_memory_state_ptr init_full() override; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index caa58ea9a..9814f7663 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -95,12 +95,12 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { return kv_swa->seq_pos_max(seq_id); } -llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) { GGML_UNUSED(embd_pooled); // first try simple split do { - auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + auto sbatch = llama_sbatch(batch, hparams.n_embd, true); std::vector ubatches; @@ -128,7 +128,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch // if it fails, try equal split do { - auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); + auto sbatch = llama_sbatch(batch, hparams.n_embd, false); std::vector ubatches; diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 3dbf33ed7..d114c7378 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -34,8 +34,7 @@ public: llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) override; + bool embd_pooled) override; llama_memory_state_ptr init_full() override; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index ddeb138f3..89606c598 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -310,12 +310,11 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { llama_memory_state_ptr llama_kv_cache_unified::init_batch( const llama_batch & batch, uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) { + bool embd_pooled) { GGML_UNUSED(embd_pooled); do { - auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + auto sbatch = llama_sbatch(batch, hparams.n_embd, true); std::vector ubatches; while (sbatch.n_tokens > 0) { diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index cf4c691ba..d6dcd19f2 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -59,8 +59,7 @@ public: llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) override; + bool embd_pooled) override; llama_memory_state_ptr init_full() override; diff --git a/src/llama-memory.h b/src/llama-memory.h index 991aae781..42e226dc0 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -73,8 +73,7 @@ struct llama_memory_i { virtual llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) = 0; + bool embd_pooled) = 0; // simulate full cache, used for allocating worst-case compute buffers virtual llama_memory_state_ptr init_full() = 0;