mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-19 09:08:04 +00:00
Merge branch 'master' into gg/llama-kv-cache
ggml-ci
This commit is contained in:
@ -7,6 +7,7 @@
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
#include <cinttypes>
|
||||
|
||||
static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||
// TODO move to hparams if a T5 variant appears that uses a different value
|
||||
@ -336,12 +337,55 @@ llama_context::llama_context(const llama_model & model, const llama_context_para
|
||||
}
|
||||
|
||||
struct llama_batch_manager : public llama_batch_manager_i {
|
||||
llama_batch_manager(llama_context & lctx, const llama_batch & batch, bool logits_all) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
|
||||
llama_batch_manager(llama_context & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
const auto & n_embd = hparams.n_embd;
|
||||
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
const int64_t n_tokens_all = batch.n_tokens;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
if (batch.token) {
|
||||
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
||||
throw std::runtime_error("invalid token");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
|
||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||
|
||||
if (lctx.t_compute_start_us == 0) {
|
||||
lctx.t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
lctx.n_queued_tokens += n_tokens_all;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
lctx.embd_seq.clear();
|
||||
|
||||
// count outputs
|
||||
if (batch.logits && !embd_pooled) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
n_outputs_all += batch.logits[i] != 0;
|
||||
}
|
||||
} else if (lctx.logits_all || embd_pooled) {
|
||||
n_outputs_all = n_tokens_all;
|
||||
} else {
|
||||
// keep last output only
|
||||
n_outputs_all = 1;
|
||||
}
|
||||
|
||||
const bool logits_all = n_outputs_all == n_tokens_all;
|
||||
|
||||
lctx.sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !kv_self.recurrent,
|
||||
/* logits_all */ logits_all);
|
||||
@ -379,9 +423,29 @@ struct llama_batch_manager : public llama_batch_manager_i {
|
||||
virtual bool prepare() override {
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
const auto & batch = lctx.sbatch.batch;
|
||||
|
||||
const auto n_tokens_all = batch->n_tokens;
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (n_outputs_all == n_tokens_all) {
|
||||
n_outputs_new = ubatch.n_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||
}
|
||||
}
|
||||
|
||||
// needs to happen before the graph is built
|
||||
lctx.n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
lctx.kv_self_update();
|
||||
@ -459,8 +523,8 @@ struct llama_batch_manager : public llama_batch_manager_i {
|
||||
llama_kv_slot_restorer kv_slot_restorer;
|
||||
};
|
||||
|
||||
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch, bool logits_all) {
|
||||
return std::make_unique<llama_batch_manager>(*this, batch, logits_all);
|
||||
std::unique_ptr<llama_batch_manager_i> llama_context::prepare_batch(const llama_batch & batch) {
|
||||
return std::make_unique<llama_batch_manager>(*this, batch);
|
||||
}
|
||||
|
||||
enum ggml_status llama_context::compute_graph(
|
||||
|
Reference in New Issue
Block a user