Merge branch 'master' into gg/llama-kv-cache

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-01-27 14:00:56 +02:00
6 changed files with 106 additions and 96 deletions

View File

@ -7,6 +7,7 @@
#include <cmath>
#include <cstring>
#include <stdexcept>
#include <cinttypes>
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
// TODO move to hparams if a T5 variant appears that uses a different value
@ -336,12 +337,55 @@ llama_context::llama_context(const llama_model & model, const llama_context_para
}
struct llama_batch_manager : public llama_batch_manager_i {
llama_batch_manager(llama_context & lctx, const llama_batch & batch, bool logits_all) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
llama_batch_manager(llama_context & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
const auto & model = lctx.model;
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
const auto & n_embd = hparams.n_embd;
const auto & kv_self = lctx.kv_self;
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (int64_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
throw std::runtime_error("invalid token");
}
}
}
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
}
lctx.n_queued_tokens += n_tokens_all;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
lctx.embd_seq.clear();
// count outputs
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs_all += batch.logits[i] != 0;
}
} else if (lctx.logits_all || embd_pooled) {
n_outputs_all = n_tokens_all;
} else {
// keep last output only
n_outputs_all = 1;
}
const bool logits_all = n_outputs_all == n_tokens_all;
lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ logits_all);
@ -379,9 +423,29 @@ struct llama_batch_manager : public llama_batch_manager_i {
virtual bool prepare() override {
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
const auto & batch = lctx.sbatch.batch;
const auto n_tokens_all = batch->n_tokens;
auto & kv_self = lctx.kv_self;
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
if (n_outputs_all == n_tokens_all) {
n_outputs_new = ubatch.n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
}
// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
}
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
lctx.kv_self_update();
@ -459,8 +523,8 @@ struct llama_batch_manager : public llama_batch_manager_i {
llama_kv_slot_restorer kv_slot_restorer;
};
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch, bool logits_all) {
return std::make_unique<llama_batch_manager>(*this, batch, logits_all);
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch) {
return std::make_unique<llama_batch_manager>(*this, batch);
}
enum ggml_status llama_context::compute_graph(