mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-19 09:08:04 +00:00
context : remove batch_manager
ggml-ci
This commit is contained in:
@ -42,9 +42,9 @@ struct llama_sbatch {
|
|||||||
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
||||||
|
|
||||||
// sorted indices into the batch
|
// sorted indices into the batch
|
||||||
std::vector<size_t> ids;
|
std::vector<int64_t> ids;
|
||||||
// batch indices of the output
|
// batch indices of the output
|
||||||
std::vector<size_t> out_ids;
|
std::vector<int64_t> out_ids;
|
||||||
std::vector<llama_sbatch_seq> seq;
|
std::vector<llama_sbatch_seq> seq;
|
||||||
|
|
||||||
const llama_batch * batch = nullptr;
|
const llama_batch * batch = nullptr;
|
||||||
|
@ -161,7 +161,7 @@ llama_context::llama_context(
|
|||||||
// graph outputs buffer
|
// graph outputs buffer
|
||||||
{
|
{
|
||||||
// resized during inference when a batch uses more outputs
|
// resized during inference when a batch uses more outputs
|
||||||
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
|
if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
|
||||||
throw std::runtime_error("failed to reserve initial output buffer");
|
throw std::runtime_error("failed to reserve initial output buffer");
|
||||||
}
|
}
|
||||||
@ -747,11 +747,11 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_context::output_reserve(size_t n_outputs) {
|
int32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
const auto & vocab = model.vocab;
|
const auto & vocab = model.vocab;
|
||||||
|
|
||||||
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
|
const int64_t n_outputs_max = std::max<int64_t>(n_outputs, cparams.n_seq_max);
|
||||||
|
|
||||||
const auto n_batch = cparams.n_batch;
|
const auto n_batch = cparams.n_batch;
|
||||||
const auto n_vocab = vocab.n_tokens();
|
const auto n_vocab = vocab.n_tokens();
|
||||||
@ -817,7 +817,7 @@ size_t llama_context::output_reserve(size_t n_outputs) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void llama_context::output_reorder() {
|
void llama_context::output_reorder() {
|
||||||
std::vector<size_t> & out_ids = sbatch.out_ids;
|
auto & out_ids = sbatch.out_ids;
|
||||||
if (!out_ids.empty()) {
|
if (!out_ids.empty()) {
|
||||||
const uint32_t n_vocab = model.vocab.n_tokens();
|
const uint32_t n_vocab = model.vocab.n_tokens();
|
||||||
const uint32_t n_embd = model.hparams.n_embd;
|
const uint32_t n_embd = model.hparams.n_embd;
|
||||||
@ -1320,7 +1320,7 @@ size_t llama_context::state_get_data(llama_io_write_i & io) {
|
|||||||
{
|
{
|
||||||
output_reorder();
|
output_reorder();
|
||||||
|
|
||||||
const uint32_t n_outputs = this->n_outputs;
|
const auto n_outputs = this->n_outputs;
|
||||||
const auto & output_ids = this->output_ids;
|
const auto & output_ids = this->output_ids;
|
||||||
|
|
||||||
std::vector<int32_t> w_output_pos;
|
std::vector<int32_t> w_output_pos;
|
||||||
@ -1334,7 +1334,7 @@ size_t llama_context::state_get_data(llama_io_write_i & io) {
|
|||||||
// map an output id to a position in the batch
|
// map an output id to a position in the batch
|
||||||
int32_t pos = output_ids[i];
|
int32_t pos = output_ids[i];
|
||||||
if (pos >= 0) {
|
if (pos >= 0) {
|
||||||
GGML_ASSERT((uint32_t) pos < n_outputs);
|
GGML_ASSERT(pos < n_outputs);
|
||||||
w_output_pos[pos] = i;
|
w_output_pos[pos] = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1386,15 +1386,15 @@ size_t llama_context::state_set_data(llama_io_read_i & io) {
|
|||||||
|
|
||||||
// read output ids
|
// read output ids
|
||||||
{
|
{
|
||||||
std::vector<int32_t> output_pos;
|
auto n_outputs = this->n_outputs;
|
||||||
|
|
||||||
uint32_t n_outputs;
|
|
||||||
io.read_to(&n_outputs, sizeof(n_outputs));
|
io.read_to(&n_outputs, sizeof(n_outputs));
|
||||||
|
|
||||||
if (n_outputs > output_reserve(n_outputs)) {
|
if (n_outputs > output_reserve(n_outputs)) {
|
||||||
throw std::runtime_error("could not reserve outputs");
|
throw std::runtime_error("could not reserve outputs");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int32_t> output_pos;
|
||||||
|
|
||||||
if (n_outputs) {
|
if (n_outputs) {
|
||||||
output_pos.resize(n_outputs);
|
output_pos.resize(n_outputs);
|
||||||
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
||||||
@ -1543,228 +1543,6 @@ ggml_context_ptr llama_context_kv_self::graph_init() {
|
|||||||
return llama_context::graph_init();
|
return llama_context::graph_init();
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_context_kv_self::batch_manager {
|
|
||||||
batch_manager(llama_context_kv_self & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
|
|
||||||
const auto & model = lctx.model;
|
|
||||||
const auto & cparams = lctx.cparams;
|
|
||||||
const auto & hparams = lctx.model.hparams;
|
|
||||||
|
|
||||||
const auto & kv_self = lctx.kv_self;
|
|
||||||
|
|
||||||
const int64_t n_tokens_all = batch.n_tokens;
|
|
||||||
const int64_t n_embd = hparams.n_embd;
|
|
||||||
|
|
||||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
||||||
|
|
||||||
if (batch.token) {
|
|
||||||
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
|
||||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
|
||||||
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
|
||||||
throw std::runtime_error("invalid token");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
|
||||||
|
|
||||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
|
||||||
|
|
||||||
if (lctx.t_compute_start_us == 0) {
|
|
||||||
lctx.t_compute_start_us = ggml_time_us();
|
|
||||||
}
|
|
||||||
lctx.n_queued_tokens += n_tokens_all;
|
|
||||||
|
|
||||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
|
||||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
||||||
|
|
||||||
lctx.embd_seq.clear();
|
|
||||||
|
|
||||||
// count outputs
|
|
||||||
if (batch.logits && !embd_pooled) {
|
|
||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
|
||||||
n_outputs_all += batch.logits[i] != 0;
|
|
||||||
}
|
|
||||||
} else if (lctx.logits_all || embd_pooled) {
|
|
||||||
n_outputs_all = n_tokens_all;
|
|
||||||
} else {
|
|
||||||
// keep last output only
|
|
||||||
n_outputs_all = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
const bool logits_all = n_outputs_all == n_tokens_all;
|
|
||||||
|
|
||||||
lctx.sbatch.from_batch(batch, n_embd,
|
|
||||||
/* simple_split */ !kv_self.recurrent,
|
|
||||||
/* logits_all */ logits_all);
|
|
||||||
}
|
|
||||||
|
|
||||||
~batch_manager() {
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_done() const {
|
|
||||||
return lctx.sbatch.n_tokens == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_ubatch next() {
|
|
||||||
llama_ubatch ubatch = llama_ubatch();
|
|
||||||
|
|
||||||
const auto & cparams = lctx.cparams;
|
|
||||||
const auto & kv_self = lctx.kv_self;
|
|
||||||
|
|
||||||
const auto & n_ubatch = cparams.n_ubatch;
|
|
||||||
|
|
||||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
||||||
|
|
||||||
if (kv_self.recurrent) {
|
|
||||||
if (embd_pooled) {
|
|
||||||
// Pooled embeddings cannot be split across ubatches (yet)
|
|
||||||
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
|
||||||
} else {
|
|
||||||
// recurrent model architectures are easier to implement
|
|
||||||
// with equal-length sequences
|
|
||||||
ubatch = lctx.sbatch.split_equal(n_ubatch);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ubatch = lctx.sbatch.split_simple(n_ubatch);
|
|
||||||
}
|
|
||||||
|
|
||||||
return ubatch;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool prepare(const llama_ubatch & ubatch) {
|
|
||||||
const auto & cparams = lctx.cparams;
|
|
||||||
const auto & hparams = lctx.model.hparams;
|
|
||||||
const auto & batch = lctx.sbatch.batch;
|
|
||||||
|
|
||||||
const auto n_tokens_all = batch->n_tokens;
|
|
||||||
|
|
||||||
auto & kv_self = lctx.kv_self;
|
|
||||||
|
|
||||||
// count the outputs in this u_batch
|
|
||||||
{
|
|
||||||
int32_t n_outputs_new = 0;
|
|
||||||
|
|
||||||
if (n_outputs_all == n_tokens_all) {
|
|
||||||
n_outputs_new = ubatch.n_tokens;
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(ubatch.output);
|
|
||||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
|
||||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// needs to happen before the graph is built
|
|
||||||
lctx.n_outputs = n_outputs_new;
|
|
||||||
}
|
|
||||||
|
|
||||||
// non-causal masks do not use the KV cache
|
|
||||||
if (hparams.causal_attn) {
|
|
||||||
lctx.kv_self_update();
|
|
||||||
|
|
||||||
// if we have enough unused cells before the current head ->
|
|
||||||
// better to start searching from the beginning of the cache, hoping to fill it
|
|
||||||
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
|
|
||||||
kv_self.head = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto slot_info = kv_self.find_slot(ubatch);
|
|
||||||
if (!slot_info) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
kv_slot_restorer.save(slot_info);
|
|
||||||
|
|
||||||
if (!kv_self.recurrent) {
|
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
|
||||||
// after enough generations, the benefit from this heuristic disappears
|
|
||||||
// if we start defragmenting the cache, the benefit from this will be more important
|
|
||||||
const uint32_t pad = kv_self.get_padding(cparams);
|
|
||||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
|
|
||||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
|
||||||
|
|
||||||
// reserve a worst case graph if needed
|
|
||||||
if (lctx.need_reserve) {
|
|
||||||
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
|
||||||
|
|
||||||
const auto & cparams = lctx.cparams;
|
|
||||||
const auto & model = lctx.model;
|
|
||||||
|
|
||||||
// build worst-case graph
|
|
||||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
|
||||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
||||||
|
|
||||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
||||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
||||||
|
|
||||||
ggml_cgraph * gf = lctx.build_graph(ubatch, true);
|
|
||||||
|
|
||||||
// initialize scheduler with the worst-case graph
|
|
||||||
ggml_backend_sched_reset(lctx.sched.get());
|
|
||||||
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
|
|
||||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
|
||||||
}
|
|
||||||
|
|
||||||
lctx.need_reserve = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
void restore() {
|
|
||||||
kv_slot_restorer.restore(lctx.kv_self);
|
|
||||||
}
|
|
||||||
|
|
||||||
void update(const llama_ubatch & ubatch) {
|
|
||||||
auto & kv_self = lctx.kv_self;
|
|
||||||
|
|
||||||
// update the kv ring buffer
|
|
||||||
{
|
|
||||||
kv_self.head += ubatch.n_tokens;
|
|
||||||
|
|
||||||
// Ensure kv cache head points to a valid index.
|
|
||||||
if (kv_self.head >= kv_self.size) {
|
|
||||||
kv_self.head = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void finalize() {
|
|
||||||
const auto & cparams = lctx.cparams;
|
|
||||||
|
|
||||||
auto & kv_self = lctx.kv_self;
|
|
||||||
|
|
||||||
// decide if we need to defrag the kv cache
|
|
||||||
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
|
|
||||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
|
||||||
// - count the padding towards the number of used tokens
|
|
||||||
const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + lctx.get_ctx_padding(cparams))/float(kv_self.n)) : 0.0f;
|
|
||||||
|
|
||||||
// queue defragmentation for next llama_kv_cache_update
|
|
||||||
if (fragmentation > cparams.defrag_thold) {
|
|
||||||
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
|
||||||
|
|
||||||
kv_self.defrag();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t n_outputs_all = 0;
|
|
||||||
|
|
||||||
llama_context_kv_self & lctx;
|
|
||||||
|
|
||||||
const llama_batch & batch;
|
|
||||||
|
|
||||||
llama_kv_slot_restorer kv_slot_restorer;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::unique_ptr<llama_context_kv_self::batch_manager> llama_context_kv_self::prepare_batch(const llama_batch & batch) {
|
|
||||||
return std::make_unique<batch_manager>(*this, batch);
|
|
||||||
}
|
|
||||||
|
|
||||||
int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||||
is_encoding = false;
|
is_encoding = false;
|
||||||
|
|
||||||
@ -1783,31 +1561,181 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
const int32_t n_vocab = vocab.n_tokens();
|
const int32_t n_vocab = vocab.n_tokens();
|
||||||
|
|
||||||
|
const int64_t n_tokens_all = batch.n_tokens;
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
|
|
||||||
// TODO: try catch
|
// TODO: remove this stuff
|
||||||
auto bman = prepare_batch(batch);
|
class batch_guard {
|
||||||
|
public:
|
||||||
|
batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) {
|
||||||
|
}
|
||||||
|
|
||||||
const auto n_outputs_all = bman->n_outputs_all;
|
~batch_guard() {
|
||||||
|
if (!is_done) {
|
||||||
|
kv_slot_restorer.restore();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void done() {
|
||||||
|
is_done = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void save(const llama_kv_cache_slot_info & slot_info) {
|
||||||
|
kv_slot_restorer.save(slot_info);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool is_done = false;
|
||||||
|
|
||||||
|
llama_kv_slot_restorer kv_slot_restorer;
|
||||||
|
};
|
||||||
|
|
||||||
|
batch_guard bg(kv_self);
|
||||||
|
|
||||||
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||||
|
|
||||||
|
if (batch.token) {
|
||||||
|
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
||||||
|
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||||
|
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
||||||
|
throw std::runtime_error("invalid token");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||||
|
|
||||||
|
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||||
|
|
||||||
|
if (t_compute_start_us == 0) {
|
||||||
|
t_compute_start_us = ggml_time_us();
|
||||||
|
}
|
||||||
|
n_queued_tokens += n_tokens_all;
|
||||||
|
|
||||||
|
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||||
|
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
|
embd_seq.clear();
|
||||||
|
|
||||||
|
int64_t n_outputs_all = 0;
|
||||||
|
|
||||||
|
// count outputs
|
||||||
|
if (batch.logits && !embd_pooled) {
|
||||||
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||||
|
n_outputs_all += batch.logits[i] != 0;
|
||||||
|
}
|
||||||
|
} else if (logits_all || embd_pooled) {
|
||||||
|
n_outputs_all = n_tokens_all;
|
||||||
|
} else {
|
||||||
|
// keep last output only
|
||||||
|
n_outputs_all = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bool logits_all = n_outputs_all == n_tokens_all;
|
||||||
|
|
||||||
|
sbatch.from_batch(batch, n_embd,
|
||||||
|
/* simple_split */ !kv_self.recurrent,
|
||||||
|
/* logits_all */ logits_all);
|
||||||
|
|
||||||
// reserve output buffer
|
// reserve output buffer
|
||||||
// TODO: move to batch manager?
|
// TODO: move to batch manager?
|
||||||
if (output_reserve(bman->n_outputs_all) < (size_t) n_outputs_all) {
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
||||||
return -2;
|
return -2;
|
||||||
};
|
};
|
||||||
|
|
||||||
int64_t n_outputs_prev = 0;
|
int64_t n_outputs_prev = 0;
|
||||||
|
|
||||||
while (!bman->is_done()) {
|
while (sbatch.n_tokens > 0) {
|
||||||
llama_ubatch ubatch = bman->next();
|
llama_ubatch ubatch = llama_ubatch();
|
||||||
|
|
||||||
if (!bman->prepare(ubatch)) {
|
const auto & n_ubatch = cparams.n_ubatch;
|
||||||
|
|
||||||
|
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||||
|
|
||||||
|
if (kv_self.recurrent) {
|
||||||
|
if (embd_pooled) {
|
||||||
|
// Pooled embeddings cannot be split across ubatches (yet)
|
||||||
|
ubatch = sbatch.split_seq(n_ubatch);
|
||||||
|
} else {
|
||||||
|
// recurrent model architectures are easier to implement
|
||||||
|
// with equal-length sequences
|
||||||
|
ubatch = sbatch.split_equal(n_ubatch);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ubatch = sbatch.split_simple(n_ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
|
// count the outputs in this u_batch
|
||||||
|
{
|
||||||
|
int32_t n_outputs_new = 0;
|
||||||
|
|
||||||
|
if (n_outputs_all == n_tokens_all) {
|
||||||
|
n_outputs_new = ubatch.n_tokens;
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(ubatch.output);
|
||||||
|
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||||
|
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// needs to happen before the graph is built
|
||||||
|
n_outputs = n_outputs_new;
|
||||||
|
}
|
||||||
|
|
||||||
|
// non-causal masks do not use the KV cache
|
||||||
|
if (hparams.causal_attn) {
|
||||||
|
kv_self_update();
|
||||||
|
|
||||||
|
// if we have enough unused cells before the current head ->
|
||||||
|
// better to start searching from the beginning of the cache, hoping to fill it
|
||||||
|
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
|
||||||
|
kv_self.head = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto slot_info = kv_self.find_slot(ubatch);
|
||||||
|
if (!slot_info) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
||||||
bman->restore();
|
|
||||||
return -3;
|
return -3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bg.save(slot_info);
|
||||||
|
|
||||||
|
if (!kv_self.recurrent) {
|
||||||
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
|
// after enough generations, the benefit from this heuristic disappears
|
||||||
|
// if we start defragmenting the cache, the benefit from this will be more important
|
||||||
|
const uint32_t pad = kv_self.get_padding(cparams);
|
||||||
|
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
|
||||||
|
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||||
|
|
||||||
|
// reserve a worst case graph if needed
|
||||||
|
if (need_reserve) {
|
||||||
|
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
||||||
|
|
||||||
|
// build worst-case graph
|
||||||
|
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||||
|
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
|
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||||
|
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||||
|
|
||||||
|
ggml_cgraph * gf = build_graph(ubatch, true);
|
||||||
|
|
||||||
|
// initialize scheduler with the worst-case graph
|
||||||
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
need_reserve = false;
|
||||||
|
}
|
||||||
|
|
||||||
ggml_backend_sched_reset(sched.get());
|
ggml_backend_sched_reset(sched.get());
|
||||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||||
|
|
||||||
@ -1844,7 +1772,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||||
bman->restore();
|
|
||||||
switch (compute_status) {
|
switch (compute_status) {
|
||||||
case GGML_STATUS_ABORTED:
|
case GGML_STATUS_ABORTED:
|
||||||
return 2;
|
return 2;
|
||||||
@ -1856,7 +1783,15 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bman->update(ubatch);
|
// update the kv ring buffer
|
||||||
|
{
|
||||||
|
kv_self.head += ubatch.n_tokens;
|
||||||
|
|
||||||
|
// Ensure kv cache head points to a valid index.
|
||||||
|
if (kv_self.head >= kv_self.size) {
|
||||||
|
kv_self.head = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// plot the computation graph in dot format (for debugging purposes)
|
// plot the computation graph in dot format (for debugging purposes)
|
||||||
//if (n_past%100 == 0) {
|
//if (n_past%100 == 0) {
|
||||||
@ -1936,14 +1871,17 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
n_outputs_prev += n_outputs;
|
n_outputs_prev += n_outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// finalize the batch processing
|
||||||
|
bg.done();
|
||||||
|
|
||||||
// set output mappings
|
// set output mappings
|
||||||
{
|
{
|
||||||
bool sorted_output = true;
|
bool sorted_output = true;
|
||||||
|
|
||||||
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
|
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
|
||||||
|
|
||||||
for (size_t i = 0; i < (size_t) n_outputs_all; ++i) {
|
for (int64_t i = 0; i < n_outputs_all; ++i) {
|
||||||
size_t out_id = sbatch.out_ids[i];
|
int64_t out_id = sbatch.out_ids[i];
|
||||||
output_ids[out_id] = i;
|
output_ids[out_id] = i;
|
||||||
if (out_id != i) {
|
if (out_id != i) {
|
||||||
sorted_output = false;
|
sorted_output = false;
|
||||||
@ -1961,7 +1899,19 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
|||||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||||
//synchronize();
|
//synchronize();
|
||||||
|
|
||||||
bman->finalize();
|
// decide if we need to defrag the kv cache
|
||||||
|
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
|
||||||
|
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||||
|
// - count the padding towards the number of used tokens
|
||||||
|
const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + get_ctx_padding(cparams))/float(kv_self.n)) : 0.0f;
|
||||||
|
|
||||||
|
// queue defragmentation for next llama_kv_cache_update
|
||||||
|
if (fragmentation > cparams.defrag_thold) {
|
||||||
|
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
||||||
|
|
||||||
|
kv_self.defrag();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||||
// overlap with device computation.
|
// overlap with device computation.
|
||||||
@ -1983,14 +1933,14 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
|||||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
|
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
|
||||||
|
|
||||||
const llama_batch & batch = batch_allocr.batch;
|
const llama_batch & batch = batch_allocr.batch;
|
||||||
const uint32_t n_tokens = batch.n_tokens;
|
const int32_t n_tokens = batch.n_tokens;
|
||||||
|
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||||
|
|
||||||
if (batch.token) {
|
if (batch.token) {
|
||||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
for (int32_t i = 0; i < n_tokens; ++i) {
|
||||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||||
return -1;
|
return -1;
|
||||||
@ -1999,7 +1949,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
||||||
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
|
||||||
|
|
||||||
if (t_compute_start_us == 0) {
|
if (t_compute_start_us == 0) {
|
||||||
t_compute_start_us = ggml_time_us();
|
t_compute_start_us = ggml_time_us();
|
||||||
@ -2019,7 +1969,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
|||||||
return -2;
|
return -2;
|
||||||
};
|
};
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
for (int32_t i = 0; i < n_tokens; ++i) {
|
||||||
output_ids[i] = i;
|
output_ids[i] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2087,7 +2037,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||||
seq_ids_enc.resize(n_tokens);
|
seq_ids_enc.resize(n_tokens);
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (int32_t i = 0; i < n_tokens; i++) {
|
||||||
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||||
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
||||||
seq_ids_enc[i].insert(seq_id);
|
seq_ids_enc[i].insert(seq_id);
|
||||||
@ -2116,7 +2066,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
|||||||
|
|
||||||
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
||||||
|
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (int32_t i = 0; i < n_tokens; i++) {
|
||||||
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||||
continue;
|
continue;
|
||||||
@ -2448,7 +2398,7 @@ void llama_context_kv_self::kv_self_update() {
|
|||||||
ggml_backend_sched_reset(sched.get());
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
|
||||||
auto ctx = graph_init();
|
auto ctx = graph_init();
|
||||||
auto ctx0 = ctx.get();
|
auto * ctx0 = ctx.get();
|
||||||
|
|
||||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
||||||
|
|
||||||
@ -2477,7 +2427,7 @@ void llama_context_kv_self::kv_self_update() {
|
|||||||
ggml_backend_sched_reset(sched.get());
|
ggml_backend_sched_reset(sched.get());
|
||||||
|
|
||||||
auto ctx = graph_init();
|
auto ctx = graph_init();
|
||||||
auto ctx0 = ctx.get();
|
auto * ctx0 = ctx.get();
|
||||||
|
|
||||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
||||||
|
|
||||||
|
@ -92,6 +92,7 @@ struct llama_context : public llama_graph_i {
|
|||||||
|
|
||||||
virtual void synchronize();
|
virtual void synchronize();
|
||||||
|
|
||||||
|
// zero-out inputs and create ggml_context
|
||||||
virtual ggml_context_ptr graph_init();
|
virtual ggml_context_ptr graph_init();
|
||||||
|
|
||||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||||
@ -103,13 +104,40 @@ struct llama_context : public llama_graph_i {
|
|||||||
|
|
||||||
// Make sure enough space is available for outputs.
|
// Make sure enough space is available for outputs.
|
||||||
// Returns max number of outputs for which space was reserved.
|
// Returns max number of outputs for which space was reserved.
|
||||||
virtual size_t output_reserve(size_t n_outputs);
|
virtual int32_t output_reserve(int32_t n_outputs);
|
||||||
|
|
||||||
// make the outputs have the same order they had in the user-provided batch
|
// make the outputs have the same order they had in the user-provided batch
|
||||||
// TODO: maybe remove this
|
// TODO: maybe remove this
|
||||||
virtual void output_reorder();
|
virtual void output_reorder();
|
||||||
|
|
||||||
|
// decode a batch of tokens by evaluating the transformer
|
||||||
|
// in case of unsuccessful decoding (error or warning),
|
||||||
|
// the kv_cache state will be returned to its original state
|
||||||
|
// (for non-recurrent models) or cleaned (for recurrent models)
|
||||||
|
//
|
||||||
|
// - lctx: llama context
|
||||||
|
// - inp_batch: batch to evaluate
|
||||||
|
//
|
||||||
|
// return 0 on success
|
||||||
|
// return positive int on warning
|
||||||
|
// return negative int on error
|
||||||
|
//
|
||||||
|
virtual int decode(llama_batch & inp_batch) = 0;
|
||||||
|
|
||||||
|
// encode a batch of tokens by evaluating the encoder part of the transformer
|
||||||
|
//
|
||||||
|
// - lctx: llama context
|
||||||
|
// - batch: batch to evaluate
|
||||||
|
//
|
||||||
|
// return 0 on success
|
||||||
|
// return positive int on warning
|
||||||
|
// return negative int on error
|
||||||
|
//
|
||||||
|
virtual int encode(llama_batch & inp_batch) = 0;
|
||||||
|
|
||||||
|
//
|
||||||
// graph build API (generic)
|
// graph build API (generic)
|
||||||
|
//
|
||||||
|
|
||||||
virtual void build_cb(
|
virtual void build_cb(
|
||||||
ggml_tensor * cur,
|
ggml_tensor * cur,
|
||||||
@ -141,31 +169,6 @@ struct llama_context : public llama_graph_i {
|
|||||||
|
|
||||||
virtual ggml_tensor * build_rope_factors(int il);
|
virtual ggml_tensor * build_rope_factors(int il);
|
||||||
|
|
||||||
// decode a batch of tokens by evaluating the transformer
|
|
||||||
// in case of unsuccessful decoding (error or warning),
|
|
||||||
// the kv_cache state will be returned to its original state
|
|
||||||
// (for non-recurrent models) or cleaned (for recurrent models)
|
|
||||||
//
|
|
||||||
// - lctx: llama context
|
|
||||||
// - inp_batch: batch to evaluate
|
|
||||||
//
|
|
||||||
// return 0 on success
|
|
||||||
// return positive int on warning
|
|
||||||
// return negative int on error
|
|
||||||
//
|
|
||||||
virtual int decode(llama_batch & inp_batch) = 0;
|
|
||||||
|
|
||||||
// encode a batch of tokens by evaluating the encoder part of the transformer
|
|
||||||
//
|
|
||||||
// - lctx: llama context
|
|
||||||
// - batch: batch to evaluate
|
|
||||||
//
|
|
||||||
// return 0 on success
|
|
||||||
// return positive int on warning
|
|
||||||
// return negative int on error
|
|
||||||
//
|
|
||||||
virtual int encode(llama_batch & inp_batch) = 0;
|
|
||||||
|
|
||||||
// state save/load
|
// state save/load
|
||||||
|
|
||||||
virtual size_t state_get_size();
|
virtual size_t state_get_size();
|
||||||
@ -268,7 +271,7 @@ protected:
|
|||||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||||
|
|
||||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
int32_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||||
|
|
||||||
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
||||||
@ -291,8 +294,6 @@ protected:
|
|||||||
// transformer with a self-attention KV cache
|
// transformer with a self-attention KV cache
|
||||||
class llama_context_kv_self : public llama_context {
|
class llama_context_kv_self : public llama_context {
|
||||||
public:
|
public:
|
||||||
struct batch_manager;
|
|
||||||
|
|
||||||
llama_context_kv_self(
|
llama_context_kv_self(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
const llama_context_params & params);
|
const llama_context_params & params);
|
||||||
@ -313,8 +314,6 @@ public:
|
|||||||
virtual int decode(llama_batch & inp_batch) override;
|
virtual int decode(llama_batch & inp_batch) override;
|
||||||
virtual int encode(llama_batch & inp_batch) override;
|
virtual int encode(llama_batch & inp_batch) override;
|
||||||
|
|
||||||
virtual std::unique_ptr<batch_manager> prepare_batch(const llama_batch & batch);
|
|
||||||
|
|
||||||
// max token position across all sequences in the current context
|
// max token position across all sequences in the current context
|
||||||
llama_pos pos_max() const;
|
llama_pos pos_max() const;
|
||||||
|
|
||||||
|
@ -150,7 +150,9 @@ struct llama_kv_slot_restorer {
|
|||||||
|
|
||||||
bool do_restore = false;
|
bool do_restore = false;
|
||||||
|
|
||||||
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
|
llama_kv_cache & cache;
|
||||||
|
|
||||||
|
explicit llama_kv_slot_restorer(llama_kv_cache & cache) : cache(cache) {
|
||||||
old_state.head = cache.head;
|
old_state.head = cache.head;
|
||||||
old_state.n = cache.n;
|
old_state.n = cache.n;
|
||||||
}
|
}
|
||||||
@ -167,7 +169,7 @@ struct llama_kv_slot_restorer {
|
|||||||
|
|
||||||
// must be explicitly called to restore the kv_cache state
|
// must be explicitly called to restore the kv_cache state
|
||||||
// and rollback changes from all llama_kv_cache_find_slot calls
|
// and rollback changes from all llama_kv_cache_find_slot calls
|
||||||
void restore(struct llama_kv_cache & cache) {
|
void restore() {
|
||||||
if (do_restore) {
|
if (do_restore) {
|
||||||
cache.head = old_state.head;
|
cache.head = old_state.head;
|
||||||
cache.n = old_state.n;
|
cache.n = old_state.n;
|
||||||
|
Reference in New Issue
Block a user