context : remove batch_manager

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-02-14 16:10:55 +02:00
parent 131743ff4f
commit d5e8e1a2ba
4 changed files with 242 additions and 291 deletions

View File

@ -42,9 +42,9 @@ struct llama_sbatch {
bool logits_all; // TODO: remove once lctx.logits_all is removed too
// sorted indices into the batch
std::vector<size_t> ids;
std::vector<int64_t> ids;
// batch indices of the output
std::vector<size_t> out_ids;
std::vector<int64_t> out_ids;
std::vector<llama_sbatch_seq> seq;
const llama_batch * batch = nullptr;

View File

@ -161,7 +161,7 @@ llama_context::llama_context(
// graph outputs buffer
{
// resized during inference when a batch uses more outputs
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) {
LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
throw std::runtime_error("failed to reserve initial output buffer");
}
@ -747,11 +747,11 @@ void llama_context::input_set(const llama_ubatch & ubatch) {
);
}
size_t llama_context::output_reserve(size_t n_outputs) {
int32_t llama_context::output_reserve(int32_t n_outputs) {
const auto & hparams = model.hparams;
const auto & vocab = model.vocab;
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
const int64_t n_outputs_max = std::max<int64_t>(n_outputs, cparams.n_seq_max);
const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
@ -817,7 +817,7 @@ size_t llama_context::output_reserve(size_t n_outputs) {
}
void llama_context::output_reorder() {
std::vector<size_t> & out_ids = sbatch.out_ids;
auto & out_ids = sbatch.out_ids;
if (!out_ids.empty()) {
const uint32_t n_vocab = model.vocab.n_tokens();
const uint32_t n_embd = model.hparams.n_embd;
@ -1320,8 +1320,8 @@ size_t llama_context::state_get_data(llama_io_write_i & io) {
{
output_reorder();
const uint32_t n_outputs = this->n_outputs;
const auto & output_ids = this->output_ids;
const auto n_outputs = this->n_outputs;
const auto & output_ids = this->output_ids;
std::vector<int32_t> w_output_pos;
@ -1334,7 +1334,7 @@ size_t llama_context::state_get_data(llama_io_write_i & io) {
// map an output id to a position in the batch
int32_t pos = output_ids[i];
if (pos >= 0) {
GGML_ASSERT((uint32_t) pos < n_outputs);
GGML_ASSERT(pos < n_outputs);
w_output_pos[pos] = i;
}
}
@ -1386,15 +1386,15 @@ size_t llama_context::state_set_data(llama_io_read_i & io) {
// read output ids
{
std::vector<int32_t> output_pos;
uint32_t n_outputs;
auto n_outputs = this->n_outputs;
io.read_to(&n_outputs, sizeof(n_outputs));
if (n_outputs > output_reserve(n_outputs)) {
throw std::runtime_error("could not reserve outputs");
}
std::vector<int32_t> output_pos;
if (n_outputs) {
output_pos.resize(n_outputs);
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
@ -1543,228 +1543,6 @@ ggml_context_ptr llama_context_kv_self::graph_init() {
return llama_context::graph_init();
}
struct llama_context_kv_self::batch_manager {
batch_manager(llama_context_kv_self & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) {
const auto & model = lctx.model;
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
const auto & kv_self = lctx.kv_self;
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (int64_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
throw std::runtime_error("invalid token");
}
}
}
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
if (lctx.t_compute_start_us == 0) {
lctx.t_compute_start_us = ggml_time_us();
}
lctx.n_queued_tokens += n_tokens_all;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
lctx.embd_seq.clear();
// count outputs
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs_all += batch.logits[i] != 0;
}
} else if (lctx.logits_all || embd_pooled) {
n_outputs_all = n_tokens_all;
} else {
// keep last output only
n_outputs_all = 1;
}
const bool logits_all = n_outputs_all == n_tokens_all;
lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ logits_all);
}
~batch_manager() {
}
bool is_done() const {
return lctx.sbatch.n_tokens == 0;
}
llama_ubatch next() {
llama_ubatch ubatch = llama_ubatch();
const auto & cparams = lctx.cparams;
const auto & kv_self = lctx.kv_self;
const auto & n_ubatch = cparams.n_ubatch;
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
if (kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = lctx.sbatch.split_equal(n_ubatch);
}
} else {
ubatch = lctx.sbatch.split_simple(n_ubatch);
}
return ubatch;
}
bool prepare(const llama_ubatch & ubatch) {
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;
const auto & batch = lctx.sbatch.batch;
const auto n_tokens_all = batch->n_tokens;
auto & kv_self = lctx.kv_self;
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
if (n_outputs_all == n_tokens_all) {
n_outputs_new = ubatch.n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
}
// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
}
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
lctx.kv_self_update();
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
kv_self.head = 0;
}
const auto slot_info = kv_self.find_slot(ubatch);
if (!slot_info) {
return false;
}
kv_slot_restorer.save(slot_info);
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = kv_self.get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
// reserve a worst case graph if needed
if (lctx.need_reserve) {
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
const auto & cparams = lctx.cparams;
const auto & model = lctx.model;
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf = lctx.build_graph(ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(lctx.sched.get());
if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
lctx.need_reserve = false;
}
return true;
}
void restore() {
kv_slot_restorer.restore(lctx.kv_self);
}
void update(const llama_ubatch & ubatch) {
auto & kv_self = lctx.kv_self;
// update the kv ring buffer
{
kv_self.head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {
kv_self.head = 0;
}
}
}
void finalize() {
const auto & cparams = lctx.cparams;
auto & kv_self = lctx.kv_self;
// decide if we need to defrag the kv cache
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
// - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens
const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + lctx.get_ctx_padding(cparams))/float(kv_self.n)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
kv_self.defrag();
}
}
}
int64_t n_outputs_all = 0;
llama_context_kv_self & lctx;
const llama_batch & batch;
llama_kv_slot_restorer kv_slot_restorer;
};
std::unique_ptr<llama_context_kv_self::batch_manager> llama_context_kv_self::prepare_batch(const llama_batch & batch) {
return std::make_unique<batch_manager>(*this, batch);
}
int llama_context_kv_self::decode(llama_batch & inp_batch) {
is_encoding = false;
@ -1783,29 +1561,179 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
const auto & hparams = model.hparams;
const int32_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd;
// TODO: try catch
auto bman = prepare_batch(batch);
const int64_t n_tokens_all = batch.n_tokens;
const int64_t n_embd = hparams.n_embd;
const auto n_outputs_all = bman->n_outputs_all;
// TODO: remove this stuff
class batch_guard {
public:
batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) {
}
~batch_guard() {
if (!is_done) {
kv_slot_restorer.restore();
}
}
void done() {
is_done = true;
}
void save(const llama_kv_cache_slot_info & slot_info) {
kv_slot_restorer.save(slot_info);
}
private:
bool is_done = false;
llama_kv_slot_restorer kv_slot_restorer;
};
batch_guard bg(kv_self);
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (int64_t i = 0; i < n_tokens_all; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]);
throw std::runtime_error("invalid token");
}
}
}
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
if (t_compute_start_us == 0) {
t_compute_start_us = ggml_time_us();
}
n_queued_tokens += n_tokens_all;
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
embd_seq.clear();
int64_t n_outputs_all = 0;
// count outputs
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs_all += batch.logits[i] != 0;
}
} else if (logits_all || embd_pooled) {
n_outputs_all = n_tokens_all;
} else {
// keep last output only
n_outputs_all = 1;
}
const bool logits_all = n_outputs_all == n_tokens_all;
sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* logits_all */ logits_all);
// reserve output buffer
// TODO: move to batch manager?
if (output_reserve(bman->n_outputs_all) < (size_t) n_outputs_all) {
if (output_reserve(n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
return -2;
};
int64_t n_outputs_prev = 0;
while (!bman->is_done()) {
llama_ubatch ubatch = bman->next();
while (sbatch.n_tokens > 0) {
llama_ubatch ubatch = llama_ubatch();
if (!bman->prepare(ubatch)) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
bman->restore();
return -3;
const auto & n_ubatch = cparams.n_ubatch;
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
if (kv_self.recurrent) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = sbatch.split_seq(n_ubatch);
} else {
// recurrent model architectures are easier to implement
// with equal-length sequences
ubatch = sbatch.split_equal(n_ubatch);
}
} else {
ubatch = sbatch.split_simple(n_ubatch);
}
// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;
if (n_outputs_all == n_tokens_all) {
n_outputs_new = ubatch.n_tokens;
} else {
GGML_ASSERT(ubatch.output);
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
}
}
// needs to happen before the graph is built
n_outputs = n_outputs_new;
}
// non-causal masks do not use the KV cache
if (hparams.causal_attn) {
kv_self_update();
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
kv_self.head = 0;
}
const auto slot_info = kv_self.find_slot(ubatch);
if (!slot_info) {
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
return -3;
}
bg.save(slot_info);
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
const uint32_t pad = kv_self.get_padding(cparams);
kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad)));
//kv_self.n = llama_kv_cache_cell_max(kv_self);
}
}
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
// reserve a worst case graph if needed
if (need_reserve) {
LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__);
// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf = build_graph(ubatch, true);
// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(sched.get());
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
need_reserve = false;
}
ggml_backend_sched_reset(sched.get());
@ -1844,7 +1772,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1);
if (compute_status != GGML_STATUS_SUCCESS) {
bman->restore();
switch (compute_status) {
case GGML_STATUS_ABORTED:
return 2;
@ -1856,7 +1783,15 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
}
}
bman->update(ubatch);
// update the kv ring buffer
{
kv_self.head += ubatch.n_tokens;
// Ensure kv cache head points to a valid index.
if (kv_self.head >= kv_self.size) {
kv_self.head = 0;
}
}
// plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {
@ -1936,14 +1871,17 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
n_outputs_prev += n_outputs;
}
// finalize the batch processing
bg.done();
// set output mappings
{
bool sorted_output = true;
GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all);
for (size_t i = 0; i < (size_t) n_outputs_all; ++i) {
size_t out_id = sbatch.out_ids[i];
for (int64_t i = 0; i < n_outputs_all; ++i) {
int64_t out_id = sbatch.out_ids[i];
output_ids[out_id] = i;
if (out_id != i) {
sorted_output = false;
@ -1961,7 +1899,19 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
// wait for the computation to finish (automatically done when obtaining the model output)
//synchronize();
bman->finalize();
// decide if we need to defrag the kv cache
if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
// - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens
const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + get_ctx_padding(cparams))/float(kv_self.n)) : 0.0f;
// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
kv_self.defrag();
}
}
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
// overlap with device computation.
@ -1983,14 +1933,14 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1);
const llama_batch & batch = batch_allocr.batch;
const uint32_t n_tokens = batch.n_tokens;
const int32_t n_tokens = batch.n_tokens;
const auto & hparams = model.hparams;
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
if (batch.token) {
for (uint32_t i = 0; i < n_tokens; ++i) {
for (int32_t i = 0; i < n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return -1;
@ -1999,7 +1949,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
}
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens");
if (t_compute_start_us == 0) {
t_compute_start_us = ggml_time_us();
@ -2019,7 +1969,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
return -2;
};
for (uint32_t i = 0; i < n_tokens; ++i) {
for (int32_t i = 0; i < n_tokens; ++i) {
output_ids[i] = i;
}
@ -2087,7 +2037,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
// remember the sequence ids used during the encoding - needed for cross attention later
seq_ids_enc.resize(n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
for (int32_t i = 0; i < n_tokens; i++) {
for (int s = 0; s < ubatch.n_seq_id[i]; s++) {
llama_seq_id seq_id = ubatch.seq_id[i][s];
seq_ids_enc[i].insert(seq_id);
@ -2116,7 +2066,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits
for (uint32_t i = 0; i < n_tokens; i++) {
for (int32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = ubatch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
@ -2448,7 +2398,7 @@ void llama_context_kv_self::kv_self_update() {
ggml_backend_sched_reset(sched.get());
auto ctx = graph_init();
auto ctx0 = ctx.get();
auto * ctx0 = ctx.get();
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
@ -2477,7 +2427,7 @@ void llama_context_kv_self::kv_self_update() {
ggml_backend_sched_reset(sched.get());
auto ctx = graph_init();
auto ctx0 = ctx.get();
auto * ctx0 = ctx.get();
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);

View File

@ -92,6 +92,7 @@ struct llama_context : public llama_graph_i {
virtual void synchronize();
// zero-out inputs and create ggml_context
virtual ggml_context_ptr graph_init();
// returns the result of ggml_backend_sched_graph_compute_async execution
@ -103,13 +104,40 @@ struct llama_context : public llama_graph_i {
// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
virtual size_t output_reserve(size_t n_outputs);
virtual int32_t output_reserve(int32_t n_outputs);
// make the outputs have the same order they had in the user-provided batch
// TODO: maybe remove this
virtual void output_reorder();
// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - inp_batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int decode(llama_batch & inp_batch) = 0;
// encode a batch of tokens by evaluating the encoder part of the transformer
//
// - lctx: llama context
// - batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int encode(llama_batch & inp_batch) = 0;
//
// graph build API (generic)
//
virtual void build_cb(
ggml_tensor * cur,
@ -141,31 +169,6 @@ struct llama_context : public llama_graph_i {
virtual ggml_tensor * build_rope_factors(int il);
// decode a batch of tokens by evaluating the transformer
// in case of unsuccessful decoding (error or warning),
// the kv_cache state will be returned to its original state
// (for non-recurrent models) or cleaned (for recurrent models)
//
// - lctx: llama context
// - inp_batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int decode(llama_batch & inp_batch) = 0;
// encode a batch of tokens by evaluating the encoder part of the transformer
//
// - lctx: llama context
// - batch: batch to evaluate
//
// return 0 on success
// return positive int on warning
// return negative int on error
//
virtual int encode(llama_batch & inp_batch) = 0;
// state save/load
virtual size_t state_get_size();
@ -268,7 +271,7 @@ protected:
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
@ -291,8 +294,6 @@ protected:
// transformer with a self-attention KV cache
class llama_context_kv_self : public llama_context {
public:
struct batch_manager;
llama_context_kv_self(
const llama_model & model,
const llama_context_params & params);
@ -313,8 +314,6 @@ public:
virtual int decode(llama_batch & inp_batch) override;
virtual int encode(llama_batch & inp_batch) override;
virtual std::unique_ptr<batch_manager> prepare_batch(const llama_batch & batch);
// max token position across all sequences in the current context
llama_pos pos_max() const;

View File

@ -150,7 +150,9 @@ struct llama_kv_slot_restorer {
bool do_restore = false;
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
llama_kv_cache & cache;
explicit llama_kv_slot_restorer(llama_kv_cache & cache) : cache(cache) {
old_state.head = cache.head;
old_state.n = cache.n;
}
@ -167,7 +169,7 @@ struct llama_kv_slot_restorer {
// must be explicitly called to restore the kv_cache state
// and rollback changes from all llama_kv_cache_find_slot calls
void restore(struct llama_kv_cache & cache) {
void restore() {
if (do_restore) {
cache.head = old_state.head;
cache.n = old_state.n;