mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-20 17:49:18 +00:00
llama-context: add ability to get logits (#14672)
This commit is contained in:
@ -732,6 +732,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
|
const int32_t n_vocab = model.vocab.n_tokens();
|
||||||
|
|
||||||
// note: during encode, we always pass the full sequence starting from pos = 0
|
// note: during encode, we always pass the full sequence starting from pos = 0
|
||||||
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
|
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
|
||||||
@ -791,10 +792,20 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto * t_logits = res->get_logits();
|
||||||
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
||||||
|
|
||||||
|
// extract logits
|
||||||
|
if (logits && t_logits) {
|
||||||
|
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||||
|
GGML_ASSERT(backend_res != nullptr);
|
||||||
|
GGML_ASSERT(logits != nullptr);
|
||||||
|
|
||||||
|
ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
// extract embeddings
|
// extract embeddings
|
||||||
if (t_embd) {
|
if (embd && t_embd) {
|
||||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||||
GGML_ASSERT(backend_embd != nullptr);
|
GGML_ASSERT(backend_embd != nullptr);
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user