diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 58787fdba..69e0d7549 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -306,9 +306,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 batch.seq_id = seq_id.data(); } if (!batch.logits) { - logits.resize(batch.n_tokens); - logits[logits.size() - 1] = true; - batch.logits = logits.data(); + // by default return the output only for the last token + output.resize(batch.n_tokens); + output[output.size() - 1] = true; + batch.logits = output.data(); } } diff --git a/src/llama-batch.h b/src/llama-batch.h index 989fb6cf9..7ad82b528 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -85,7 +85,7 @@ struct llama_batch_allocr { std::vector pos; std::vector n_seq_id; std::vector seq_id; - std::vector logits; + std::vector output; // optionally fulfill the batch returned by llama_batch_get_one llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ebcba6993..2e551bf6e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -758,6 +758,7 @@ int llama_context::encode(llama_batch & inp_batch) { t_compute_start_us = ggml_time_us(); } + // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); n_queued_tokens += n_tokens; @@ -940,6 +941,25 @@ int llama_context::decode(llama_batch & inp_batch) { } } + // this indicates we are doing pooled embedding + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + int64_t n_outputs_all = 0; + + // count outputs + for (uint32_t i = 0; i < n_tokens_all; ++i) { + n_outputs_all += batch.logits[i] != 0; + } + + if (embd_pooled) { + // require that all tokens are output + if (n_outputs_all != n_tokens_all) { + LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n", + __func__, n_outputs_all, n_tokens_all); + return -1; + } + } + GGML_ASSERT(n_tokens_all <= cparams.n_batch); GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); @@ -949,25 +969,9 @@ int llama_context::decode(llama_batch & inp_batch) { } n_queued_tokens += n_tokens_all; - // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens - const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; - + // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); - int64_t n_outputs_all = 0; - - // count outputs - if (batch.logits && !embd_pooled) { - for (uint32_t i = 0; i < n_tokens_all; ++i) { - n_outputs_all += batch.logits[i] != 0; - } - } else if (embd_pooled) { - n_outputs_all = n_tokens_all; - } else { - // keep last output only - n_outputs_all = 1; - } - bool did_optimize = false; // handle any pending defrags/shifts @@ -1029,7 +1033,7 @@ int llama_context::decode(llama_batch & inp_batch) { do { const auto & ubatch = mstate->get_ubatch(); - // count the outputs in this u_batch + // count the outputs in this ubatch { int32_t n_outputs_new = 0; @@ -2073,7 +2077,7 @@ void llama_context::opt_epoch_iter( n_queued_tokens += n_tokens_all; - // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + // this indicates we are doing pooled embedding const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; embd_seq.clear();