From 9c9e4fc6354fc811efa06a8eb7a86d3315cec9c8 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 14 Jul 2025 21:01:41 +0800 Subject: [PATCH] llama-context: add ability to get logits (#14672) --- src/llama-context.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 06e93b19c..7c07b047b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -731,7 +731,8 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd; + const int64_t n_embd = hparams.n_embd; + const int32_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) { @@ -791,10 +792,20 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + auto * t_logits = res->get_logits(); auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + // extract logits + if (logits && t_logits) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); + } + // extract embeddings - if (t_embd) { + if (embd && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr);