From 3753b30d658c93c62f1481d4ed0b2d0800f0d284 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 21 Feb 2025 15:50:27 +0200 Subject: [PATCH] context : fix n_outputs init ggml-ci --- src/llama-context.cpp | 8 +++----- src/llama-context.h | 4 ++-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 40d4e47a4..ce68d410a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1274,14 +1274,13 @@ int32_t llama_context::output_reserve(int32_t n_outputs) { logits = has_logits ? output_base : nullptr; embd = has_embd ? output_base + logits_size : nullptr; - output_size = n_outputs_max; - // set all ids as invalid (negative) std::fill(output_ids.begin(), output_ids.end(), -1); ggml_backend_buffer_clear(buf_output.get(), 0); - n_outputs = 0; + this->n_outputs = 0; + this->n_outputs_max = n_outputs_max; return n_outputs_max; } @@ -2131,7 +2130,7 @@ size_t llama_context::state_get_data(llama_io_write_i & io) { std::vector w_output_pos; - GGML_ASSERT(n_outputs <= output_size); + GGML_ASSERT(n_outputs <= n_outputs_max); w_output_pos.resize(n_outputs); @@ -2682,7 +2681,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { /* logits_all */ logits_all); // reserve output buffer - // TODO: move to batch manager? if (output_reserve(n_outputs_all) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all); return -2; diff --git a/src/llama-context.h b/src/llama-context.h index ccb84874f..f8f01e1bd 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -375,8 +375,8 @@ protected: // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; - int32_t output_size = 0; // capacity (of tokens positions) for the output buffers - int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch + int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch + int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers std::vector output_ids; // map batch token positions to ids of the logits and embd buffers