mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 04:15:21 +00:00
context : simplify output counting logic during decode (#14142)
* batch : remove logits_all flag ggml-ci * context : simplify output counting logic during decode ggml-ci * cont : fix comments
This commit is contained in:
@ -306,9 +306,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
|
|||||||
batch.seq_id = seq_id.data();
|
batch.seq_id = seq_id.data();
|
||||||
}
|
}
|
||||||
if (!batch.logits) {
|
if (!batch.logits) {
|
||||||
logits.resize(batch.n_tokens);
|
// by default return the output only for the last token
|
||||||
logits[logits.size() - 1] = true;
|
output.resize(batch.n_tokens);
|
||||||
batch.logits = logits.data();
|
output[output.size() - 1] = true;
|
||||||
|
batch.logits = output.data();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ struct llama_batch_allocr {
|
|||||||
std::vector<llama_pos> pos;
|
std::vector<llama_pos> pos;
|
||||||
std::vector<int32_t> n_seq_id;
|
std::vector<int32_t> n_seq_id;
|
||||||
std::vector<llama_seq_id *> seq_id;
|
std::vector<llama_seq_id *> seq_id;
|
||||||
std::vector<int8_t> logits;
|
std::vector<int8_t> output;
|
||||||
|
|
||||||
// optionally fulfill the batch returned by llama_batch_get_one
|
// optionally fulfill the batch returned by llama_batch_get_one
|
||||||
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
|
llama_batch_allocr(struct llama_batch in_batch, llama_pos p0);
|
||||||
|
@ -758,6 +758,7 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|||||||
t_compute_start_us = ggml_time_us();
|
t_compute_start_us = ggml_time_us();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: this clear of the buffer can easily be forgotten - need something better
|
||||||
embd_seq.clear();
|
embd_seq.clear();
|
||||||
|
|
||||||
n_queued_tokens += n_tokens;
|
n_queued_tokens += n_tokens;
|
||||||
@ -940,6 +941,25 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this indicates we are doing pooled embedding
|
||||||
|
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
|
int64_t n_outputs_all = 0;
|
||||||
|
|
||||||
|
// count outputs
|
||||||
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
|
n_outputs_all += batch.logits[i] != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (embd_pooled) {
|
||||||
|
// require that all tokens are output
|
||||||
|
if (n_outputs_all != n_tokens_all) {
|
||||||
|
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %" PRId64 ", n_tokens_all = %" PRId64 ")\n",
|
||||||
|
__func__, n_outputs_all, n_tokens_all);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||||
|
|
||||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||||
@ -949,25 +969,9 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||||||
}
|
}
|
||||||
n_queued_tokens += n_tokens_all;
|
n_queued_tokens += n_tokens_all;
|
||||||
|
|
||||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
// TODO: this clear of the buffer can easily be forgotten - need something better
|
||||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
||||||
|
|
||||||
embd_seq.clear();
|
embd_seq.clear();
|
||||||
|
|
||||||
int64_t n_outputs_all = 0;
|
|
||||||
|
|
||||||
// count outputs
|
|
||||||
if (batch.logits && !embd_pooled) {
|
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
|
||||||
n_outputs_all += batch.logits[i] != 0;
|
|
||||||
}
|
|
||||||
} else if (embd_pooled) {
|
|
||||||
n_outputs_all = n_tokens_all;
|
|
||||||
} else {
|
|
||||||
// keep last output only
|
|
||||||
n_outputs_all = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool did_optimize = false;
|
bool did_optimize = false;
|
||||||
|
|
||||||
// handle any pending defrags/shifts
|
// handle any pending defrags/shifts
|
||||||
@ -1029,7 +1033,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||||||
do {
|
do {
|
||||||
const auto & ubatch = mstate->get_ubatch();
|
const auto & ubatch = mstate->get_ubatch();
|
||||||
|
|
||||||
// count the outputs in this u_batch
|
// count the outputs in this ubatch
|
||||||
{
|
{
|
||||||
int32_t n_outputs_new = 0;
|
int32_t n_outputs_new = 0;
|
||||||
|
|
||||||
@ -2073,7 +2077,7 @@ void llama_context::opt_epoch_iter(
|
|||||||
|
|
||||||
n_queued_tokens += n_tokens_all;
|
n_queued_tokens += n_tokens_all;
|
||||||
|
|
||||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
// this indicates we are doing pooled embedding
|
||||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
embd_seq.clear();
|
embd_seq.clear();
|
||||||
|
Reference in New Issue
Block a user