mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-18 00:27:31 +00:00
context : remove batch_manager
ggml-ci
This commit is contained in:
@ -42,9 +42,9 @@ struct llama_sbatch {
|
||||
bool logits_all; // TODO: remove once lctx.logits_all is removed too
|
||||
|
||||
// sorted indices into the batch
|
||||
std::vector<size_t> ids;
|
||||
std::vector<int64_t> ids;
|
||||
// batch indices of the output
|
||||
std::vector<size_t> out_ids;
|
||||
std::vector<int64_t> out_ids;
|
||||
std::vector<llama_sbatch_seq> seq;
|
||||
|
||||
const llama_batch * batch = nullptr;
|
||||
|
@ -161,7 +161,7 @@ llama_context::llama_context(
|
||||
// graph outputs buffer
|
||||
{
|
||||
// resized during inference when a batch uses more outputs
|
||||
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
|
||||
if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
|
||||
LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
|
||||
throw std::runtime_error("failed to reserve initial output buffer");
|
||||
}
|
||||
@ -747,11 +747,11 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
|
||||
);
|
||||
}
|
||||
|
||||
size_t llama_context::output_reserve(size_t n_outputs) {
|
||||
int32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & vocab = model.vocab;
|
||||
|
||||
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
|
||||
const int64_t n_outputs_max = std::max<int64_t>(n_outputs, cparams.n_seq_max);
|
||||
|
||||
const auto n_batch = cparams.n_batch;
|
||||
const auto n_vocab = vocab.n_tokens();
|
||||
@ -817,7 +817,7 @@ size_t llama_context::output_reserve(size_t n_outputs) {
|
||||
}
|
||||
|
||||
void llama_context::output_reorder() {
|
||||
std::vector<size_t> & out_ids = sbatch.out_ids;
|
||||
auto & out_ids = sbatch.out_ids;
|
||||
if (!out_ids.empty()) {
|
||||
const uint32_t n_vocab = model.vocab.n_tokens();
|
||||
const uint32_t n_embd = model.hparams.n_embd;
|
||||
@ -1320,8 +1320,8 @@ size_t llama_context::state_get_data(llama_io_write_i & io) {
|
||||
{
|
||||
output_reorder();
|
||||
|
||||
const uint32_t n_outputs = this->n_outputs;
|
||||
const auto & output_ids = this->output_ids;
|
||||
const auto n_outputs = this->n_outputs;
|
||||
const auto & output_ids = this->output_ids;
|
||||
|
||||
std::vector<int32_t> w_output_pos;
|
||||
|
||||
@ -1334,7 +1334,7 @@ size_t llama_context::state_get_data(llama_io_write_i & io) {
|
||||
// map an output id to a position in the batch
|
||||
int32_t pos = output_ids[i];
|
||||
if (pos >= 0) {
|
||||
GGML_ASSERT((uint32_t) pos < n_outputs);
|
||||
GGML_ASSERT(pos < n_outputs);
|
||||
w_output_pos[pos] = i;
|
||||
}
|
||||
}
|
||||
@ -1386,15 +1386,15 @@ size_t llama_context::state_set_data(llama_io_read_i & io) {
|
||||
|
||||
// read output ids
|
||||
{
|
||||
std::vector<int32_t> output_pos;
|
||||
|
||||
uint32_t n_outputs;
|
||||
auto n_outputs = this->n_outputs;
|
||||
io.read_to(&n_outputs, sizeof(n_outputs));
|
||||
|
||||
if (n_outputs > output_reserve(n_outputs)) {
|
||||
throw std::runtime_error("could not reserve outputs");
|
||||
}
|
||||
|
||||
std::vector<int32_t> output_pos;
|
||||
|
||||
if (n_outputs) {
|
||||
output_pos.resize(n_outputs);
|
||||
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
|
||||
@ -1543,228 +1543,6 @@ ggml_context_ptr llama_context_kv_self::graph_init() {
|
||||
return llama_context::graph_init();
|
||||
}
|
||||
|
||||
struct llama_context_kv_self::batch_manager {
|
||||
batch_manager(llama_context_kv_self & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
|
||||
const auto & model = lctx.model;
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
const int64_t n_tokens_all = batch.n_tokens;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
if (batch.token) {
|
||||
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
||||
throw std::runtime_error("invalid token");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
|
||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||
|
||||
if (lctx.t_compute_start_us == 0) {
|
||||
lctx.t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
lctx.n_queued_tokens += n_tokens_all;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
lctx.embd_seq.clear();
|
||||
|
||||
// count outputs
|
||||
if (batch.logits && !embd_pooled) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
n_outputs_all += batch.logits[i] != 0;
|
||||
}
|
||||
} else if (lctx.logits_all || embd_pooled) {
|
||||
n_outputs_all = n_tokens_all;
|
||||
} else {
|
||||
// keep last output only
|
||||
n_outputs_all = 1;
|
||||
}
|
||||
|
||||
const bool logits_all = n_outputs_all == n_tokens_all;
|
||||
|
||||
lctx.sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !kv_self.recurrent,
|
||||
/* logits_all */ logits_all);
|
||||
}
|
||||
|
||||
~batch_manager() {
|
||||
}
|
||||
|
||||
bool is_done() const {
|
||||
return lctx.sbatch.n_tokens == 0;
|
||||
}
|
||||
|
||||
llama_ubatch next() {
|
||||
llama_ubatch ubatch = llama_ubatch();
|
||||
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & kv_self = lctx.kv_self;
|
||||
|
||||
const auto & n_ubatch = cparams.n_ubatch;
|
||||
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
if (kv_self.recurrent) {
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
||||
} else {
|
||||
// recurrent model architectures are easier to implement
|
||||
// with equal-length sequences
|
||||
ubatch = lctx.sbatch.split_equal(n_ubatch);
|
||||
}
|
||||
} else {
|
||||
ubatch = lctx.sbatch.split_simple(n_ubatch);
|
||||
}
|
||||
|
||||
return ubatch;
|
||||
}
|
||||
|
||||
bool prepare(const llama_ubatch & ubatch) {
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
const auto & batch = lctx.sbatch.batch;
|
||||
|
||||
const auto n_tokens_all = batch->n_tokens;
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (n_outputs_all == n_tokens_all) {
|
||||
n_outputs_new = ubatch.n_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||
}
|
||||
}
|
||||
|
||||
// needs to happen before the graph is built
|
||||
lctx.n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
lctx.kv_self_update();
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
const auto slot_info = kv_self.find_slot(ubatch);
|
||||
if (!slot_info) {
|
||||
return false;
|
||||
}
|
||||
|
||||
kv_slot_restorer.save(slot_info);
|
||||
|
||||
if (!kv_self.recurrent) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
const uint32_t pad = kv_self.get_padding(cparams);
|
||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
}
|
||||
}
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
// reserve a worst case graph if needed
|
||||
if (lctx.need_reserve) {
|
||||
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
||||
|
||||
const auto & cparams = lctx.cparams;
|
||||
const auto & model = lctx.model;
|
||||
|
||||
// build worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
ggml_cgraph * gf = lctx.build_graph(ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(lctx.sched.get());
|
||||
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
}
|
||||
|
||||
lctx.need_reserve = false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void restore() {
|
||||
kv_slot_restorer.restore(lctx.kv_self);
|
||||
}
|
||||
|
||||
void update(const llama_ubatch & ubatch) {
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
// update the kv ring buffer
|
||||
{
|
||||
kv_self.head += ubatch.n_tokens;
|
||||
|
||||
// Ensure kv cache head points to a valid index.
|
||||
if (kv_self.head >= kv_self.size) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void finalize() {
|
||||
const auto & cparams = lctx.cparams;
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
// decide if we need to defrag the kv cache
|
||||
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
|
||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||
// - count the padding towards the number of used tokens
|
||||
const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + lctx.get_ctx_padding(cparams))/float(kv_self.n)) : 0.0f;
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > cparams.defrag_thold) {
|
||||
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
||||
|
||||
kv_self.defrag();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t n_outputs_all = 0;
|
||||
|
||||
llama_context_kv_self & lctx;
|
||||
|
||||
const llama_batch & batch;
|
||||
|
||||
llama_kv_slot_restorer kv_slot_restorer;
|
||||
};
|
||||
|
||||
std::unique_ptr<llama_context_kv_self::batch_manager> llama_context_kv_self::prepare_batch(const llama_batch & batch) {
|
||||
return std::make_unique<batch_manager>(*this, batch);
|
||||
}
|
||||
|
||||
int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
is_encoding = false;
|
||||
|
||||
@ -1783,29 +1561,179 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int32_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
// TODO: try catch
|
||||
auto bman = prepare_batch(batch);
|
||||
const int64_t n_tokens_all = batch.n_tokens;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
const auto n_outputs_all = bman->n_outputs_all;
|
||||
// TODO: remove this stuff
|
||||
class batch_guard {
|
||||
public:
|
||||
batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) {
|
||||
}
|
||||
|
||||
~batch_guard() {
|
||||
if (!is_done) {
|
||||
kv_slot_restorer.restore();
|
||||
}
|
||||
}
|
||||
|
||||
void done() {
|
||||
is_done = true;
|
||||
}
|
||||
|
||||
void save(const llama_kv_cache_slot_info & slot_info) {
|
||||
kv_slot_restorer.save(slot_info);
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_done = false;
|
||||
|
||||
llama_kv_slot_restorer kv_slot_restorer;
|
||||
};
|
||||
|
||||
batch_guard bg(kv_self);
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
if (batch.token) {
|
||||
for (int64_t i = 0; i < n_tokens_all; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
|
||||
throw std::runtime_error("invalid token");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
|
||||
|
||||
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
|
||||
|
||||
if (t_compute_start_us == 0) {
|
||||
t_compute_start_us = ggml_time_us();
|
||||
}
|
||||
n_queued_tokens += n_tokens_all;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
embd_seq.clear();
|
||||
|
||||
int64_t n_outputs_all = 0;
|
||||
|
||||
// count outputs
|
||||
if (batch.logits && !embd_pooled) {
|
||||
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
||||
n_outputs_all += batch.logits[i] != 0;
|
||||
}
|
||||
} else if (logits_all || embd_pooled) {
|
||||
n_outputs_all = n_tokens_all;
|
||||
} else {
|
||||
// keep last output only
|
||||
n_outputs_all = 1;
|
||||
}
|
||||
|
||||
const bool logits_all = n_outputs_all == n_tokens_all;
|
||||
|
||||
sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !kv_self.recurrent,
|
||||
/* logits_all */ logits_all);
|
||||
|
||||
// reserve output buffer
|
||||
// TODO: move to batch manager?
|
||||
if (output_reserve(bman->n_outputs_all) < (size_t) n_outputs_all) {
|
||||
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
||||
return -2;
|
||||
};
|
||||
|
||||
int64_t n_outputs_prev = 0;
|
||||
|
||||
while (!bman->is_done()) {
|
||||
llama_ubatch ubatch = bman->next();
|
||||
while (sbatch.n_tokens > 0) {
|
||||
llama_ubatch ubatch = llama_ubatch();
|
||||
|
||||
if (!bman->prepare(ubatch)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
||||
bman->restore();
|
||||
return -3;
|
||||
const auto & n_ubatch = cparams.n_ubatch;
|
||||
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
if (kv_self.recurrent) {
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = sbatch.split_seq(n_ubatch);
|
||||
} else {
|
||||
// recurrent model architectures are easier to implement
|
||||
// with equal-length sequences
|
||||
ubatch = sbatch.split_equal(n_ubatch);
|
||||
}
|
||||
} else {
|
||||
ubatch = sbatch.split_simple(n_ubatch);
|
||||
}
|
||||
|
||||
// count the outputs in this u_batch
|
||||
{
|
||||
int32_t n_outputs_new = 0;
|
||||
|
||||
if (n_outputs_all == n_tokens_all) {
|
||||
n_outputs_new = ubatch.n_tokens;
|
||||
} else {
|
||||
GGML_ASSERT(ubatch.output);
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
||||
}
|
||||
}
|
||||
|
||||
// needs to happen before the graph is built
|
||||
n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
// non-causal masks do not use the KV cache
|
||||
if (hparams.causal_attn) {
|
||||
kv_self_update();
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
const auto slot_info = kv_self.find_slot(ubatch);
|
||||
if (!slot_info) {
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
||||
return -3;
|
||||
}
|
||||
|
||||
bg.save(slot_info);
|
||||
|
||||
if (!kv_self.recurrent) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
const uint32_t pad = kv_self.get_padding(cparams);
|
||||
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
|
||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||
}
|
||||
}
|
||||
|
||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||
|
||||
// reserve a worst case graph if needed
|
||||
if (need_reserve) {
|
||||
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
|
||||
|
||||
// build worst-case graph
|
||||
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
||||
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
||||
|
||||
ggml_cgraph * gf = build_graph(ubatch, true);
|
||||
|
||||
// initialize scheduler with the worst-case graph
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||
}
|
||||
|
||||
need_reserve = false;
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
@ -1844,7 +1772,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
|
||||
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
|
||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||
bman->restore();
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_ABORTED:
|
||||
return 2;
|
||||
@ -1856,7 +1783,15 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
}
|
||||
}
|
||||
|
||||
bman->update(ubatch);
|
||||
// update the kv ring buffer
|
||||
{
|
||||
kv_self.head += ubatch.n_tokens;
|
||||
|
||||
// Ensure kv cache head points to a valid index.
|
||||
if (kv_self.head >= kv_self.size) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// plot the computation graph in dot format (for debugging purposes)
|
||||
//if (n_past%100 == 0) {
|
||||
@ -1936,14 +1871,17 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
n_outputs_prev += n_outputs;
|
||||
}
|
||||
|
||||
// finalize the batch processing
|
||||
bg.done();
|
||||
|
||||
// set output mappings
|
||||
{
|
||||
bool sorted_output = true;
|
||||
|
||||
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
|
||||
|
||||
for (size_t i = 0; i < (size_t) n_outputs_all; ++i) {
|
||||
size_t out_id = sbatch.out_ids[i];
|
||||
for (int64_t i = 0; i < n_outputs_all; ++i) {
|
||||
int64_t out_id = sbatch.out_ids[i];
|
||||
output_ids[out_id] = i;
|
||||
if (out_id != i) {
|
||||
sorted_output = false;
|
||||
@ -1961,7 +1899,19 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
|
||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||
//synchronize();
|
||||
|
||||
bman->finalize();
|
||||
// decide if we need to defrag the kv cache
|
||||
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
|
||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||
// - count the padding towards the number of used tokens
|
||||
const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + get_ctx_padding(cparams))/float(kv_self.n)) : 0.0f;
|
||||
|
||||
// queue defragmentation for next llama_kv_cache_update
|
||||
if (fragmentation > cparams.defrag_thold) {
|
||||
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
||||
|
||||
kv_self.defrag();
|
||||
}
|
||||
}
|
||||
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
// overlap with device computation.
|
||||
@ -1983,14 +1933,14 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
|
||||
|
||||
const llama_batch & batch = batch_allocr.batch;
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
if (batch.token) {
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
for (int32_t i = 0; i < n_tokens; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
|
||||
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
|
||||
return -1;
|
||||
@ -1999,7 +1949,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
}
|
||||
|
||||
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
||||
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
||||
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
|
||||
|
||||
if (t_compute_start_us == 0) {
|
||||
t_compute_start_us = ggml_time_us();
|
||||
@ -2019,7 +1969,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
return -2;
|
||||
};
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
for (int32_t i = 0; i < n_tokens; ++i) {
|
||||
output_ids[i] = i;
|
||||
}
|
||||
|
||||
@ -2087,7 +2037,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
|
||||
// remember the sequence ids used during the encoding - needed for cross attention later
|
||||
seq_ids_enc.resize(n_tokens);
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
|
||||
llama_seq_id seq_id = ubatch.seq_id[i][s];
|
||||
seq_ids_enc[i].insert(seq_id);
|
||||
@ -2116,7 +2066,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
|
||||
|
||||
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
|
||||
|
||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||
for (int32_t i = 0; i < n_tokens; i++) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
||||
continue;
|
||||
@ -2448,7 +2398,7 @@ void llama_context_kv_self::kv_self_update() {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
auto ctx = graph_init();
|
||||
auto ctx0 = ctx.get();
|
||||
auto * ctx0 = ctx.get();
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
||||
|
||||
@ -2477,7 +2427,7 @@ void llama_context_kv_self::kv_self_update() {
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
|
||||
auto ctx = graph_init();
|
||||
auto ctx0 = ctx.get();
|
||||
auto * ctx0 = ctx.get();
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
|
||||
|
||||
|
@ -92,6 +92,7 @@ struct llama_context : public llama_graph_i {
|
||||
|
||||
virtual void synchronize();
|
||||
|
||||
// zero-out inputs and create ggml_context
|
||||
virtual ggml_context_ptr graph_init();
|
||||
|
||||
// returns the result of ggml_backend_sched_graph_compute_async execution
|
||||
@ -103,13 +104,40 @@ struct llama_context : public llama_graph_i {
|
||||
|
||||
// Make sure enough space is available for outputs.
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
virtual size_t output_reserve(size_t n_outputs);
|
||||
virtual int32_t output_reserve(int32_t n_outputs);
|
||||
|
||||
// make the outputs have the same order they had in the user-provided batch
|
||||
// TODO: maybe remove this
|
||||
virtual void output_reorder();
|
||||
|
||||
// decode a batch of tokens by evaluating the transformer
|
||||
// in case of unsuccessful decoding (error or warning),
|
||||
// the kv_cache state will be returned to its original state
|
||||
// (for non-recurrent models) or cleaned (for recurrent models)
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - inp_batch: batch to evaluate
|
||||
//
|
||||
// return 0 on success
|
||||
// return positive int on warning
|
||||
// return negative int on error
|
||||
//
|
||||
virtual int decode(llama_batch & inp_batch) = 0;
|
||||
|
||||
// encode a batch of tokens by evaluating the encoder part of the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - batch: batch to evaluate
|
||||
//
|
||||
// return 0 on success
|
||||
// return positive int on warning
|
||||
// return negative int on error
|
||||
//
|
||||
virtual int encode(llama_batch & inp_batch) = 0;
|
||||
|
||||
//
|
||||
// graph build API (generic)
|
||||
//
|
||||
|
||||
virtual void build_cb(
|
||||
ggml_tensor * cur,
|
||||
@ -141,31 +169,6 @@ struct llama_context : public llama_graph_i {
|
||||
|
||||
virtual ggml_tensor * build_rope_factors(int il);
|
||||
|
||||
// decode a batch of tokens by evaluating the transformer
|
||||
// in case of unsuccessful decoding (error or warning),
|
||||
// the kv_cache state will be returned to its original state
|
||||
// (for non-recurrent models) or cleaned (for recurrent models)
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - inp_batch: batch to evaluate
|
||||
//
|
||||
// return 0 on success
|
||||
// return positive int on warning
|
||||
// return negative int on error
|
||||
//
|
||||
virtual int decode(llama_batch & inp_batch) = 0;
|
||||
|
||||
// encode a batch of tokens by evaluating the encoder part of the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - batch: batch to evaluate
|
||||
//
|
||||
// return 0 on success
|
||||
// return positive int on warning
|
||||
// return negative int on error
|
||||
//
|
||||
virtual int encode(llama_batch & inp_batch) = 0;
|
||||
|
||||
// state save/load
|
||||
|
||||
virtual size_t state_get_size();
|
||||
@ -268,7 +271,7 @@ protected:
|
||||
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
|
||||
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
||||
|
||||
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||
int32_t output_size = 0; // capacity (of tokens positions) for the output buffers
|
||||
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
||||
|
||||
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
|
||||
@ -291,8 +294,6 @@ protected:
|
||||
// transformer with a self-attention KV cache
|
||||
class llama_context_kv_self : public llama_context {
|
||||
public:
|
||||
struct batch_manager;
|
||||
|
||||
llama_context_kv_self(
|
||||
const llama_model & model,
|
||||
const llama_context_params & params);
|
||||
@ -313,8 +314,6 @@ public:
|
||||
virtual int decode(llama_batch & inp_batch) override;
|
||||
virtual int encode(llama_batch & inp_batch) override;
|
||||
|
||||
virtual std::unique_ptr<batch_manager> prepare_batch(const llama_batch & batch);
|
||||
|
||||
// max token position across all sequences in the current context
|
||||
llama_pos pos_max() const;
|
||||
|
||||
|
@ -150,7 +150,9 @@ struct llama_kv_slot_restorer {
|
||||
|
||||
bool do_restore = false;
|
||||
|
||||
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
|
||||
llama_kv_cache & cache;
|
||||
|
||||
explicit llama_kv_slot_restorer(llama_kv_cache & cache) : cache(cache) {
|
||||
old_state.head = cache.head;
|
||||
old_state.n = cache.n;
|
||||
}
|
||||
@ -167,7 +169,7 @@ struct llama_kv_slot_restorer {
|
||||
|
||||
// must be explicitly called to restore the kv_cache state
|
||||
// and rollback changes from all llama_kv_cache_find_slot calls
|
||||
void restore(struct llama_kv_cache & cache) {
|
||||
void restore() {
|
||||
if (do_restore) {
|
||||
cache.head = old_state.head;
|
||||
cache.n = old_state.n;
|
||||
|
Reference in New Issue
Block a user