From d5e8e1a2ba315599d09e6d5fbb37a2b98f841c07 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 14 Feb 2025 16:10:55 +0200 Subject: [PATCH] context : remove batch_manager ggml-ci --- src/llama-batch.h | 4 +- src/llama-context.cpp | 462 +++++++++++++++++++----------------------- src/llama-context.h | 61 +++--- src/llama-kv-cache.h | 6 +- 4 files changed, 242 insertions(+), 291 deletions(-) diff --git a/src/llama-batch.h b/src/llama-batch.h index 773c3808b..f1df40d27 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -42,9 +42,9 @@ struct llama_sbatch { bool logits_all; // TODO: remove once lctx.logits_all is removed too // sorted indices into the batch - std::vector ids; + std::vector ids; // batch indices of the output - std::vector out_ids; + std::vector out_ids; std::vector seq; const llama_batch * batch = nullptr; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 31085f644..f3fa4c592 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -161,7 +161,7 @@ llama_context::llama_context( // graph outputs buffer { // resized during inference when a batch uses more outputs - if (output_reserve(params.n_seq_max) < params.n_seq_max) { + if ((uint32_t) output_reserve(params.n_seq_max) < params.n_seq_max) { LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__); throw std::runtime_error("failed to reserve initial output buffer"); } @@ -747,11 +747,11 @@ void llama_context::input_set(const llama_ubatch & ubatch) { ); } -size_t llama_context::output_reserve(size_t n_outputs) { +int32_t llama_context::output_reserve(int32_t n_outputs) { const auto & hparams = model.hparams; const auto & vocab = model.vocab; - const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max); + const int64_t n_outputs_max = std::max(n_outputs, cparams.n_seq_max); const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); @@ -817,7 +817,7 @@ size_t llama_context::output_reserve(size_t n_outputs) { } void llama_context::output_reorder() { - std::vector & out_ids = sbatch.out_ids; + auto & out_ids = sbatch.out_ids; if (!out_ids.empty()) { const uint32_t n_vocab = model.vocab.n_tokens(); const uint32_t n_embd = model.hparams.n_embd; @@ -1320,8 +1320,8 @@ size_t llama_context::state_get_data(llama_io_write_i & io) { { output_reorder(); - const uint32_t n_outputs = this->n_outputs; - const auto & output_ids = this->output_ids; + const auto n_outputs = this->n_outputs; + const auto & output_ids = this->output_ids; std::vector w_output_pos; @@ -1334,7 +1334,7 @@ size_t llama_context::state_get_data(llama_io_write_i & io) { // map an output id to a position in the batch int32_t pos = output_ids[i]; if (pos >= 0) { - GGML_ASSERT((uint32_t) pos < n_outputs); + GGML_ASSERT(pos < n_outputs); w_output_pos[pos] = i; } } @@ -1386,15 +1386,15 @@ size_t llama_context::state_set_data(llama_io_read_i & io) { // read output ids { - std::vector output_pos; - - uint32_t n_outputs; + auto n_outputs = this->n_outputs; io.read_to(&n_outputs, sizeof(n_outputs)); if (n_outputs > output_reserve(n_outputs)) { throw std::runtime_error("could not reserve outputs"); } + std::vector output_pos; + if (n_outputs) { output_pos.resize(n_outputs); io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); @@ -1543,228 +1543,6 @@ ggml_context_ptr llama_context_kv_self::graph_init() { return llama_context::graph_init(); } -struct llama_context_kv_self::batch_manager { - batch_manager(llama_context_kv_self & lctx, const llama_batch & batch) : lctx(lctx), batch(batch), kv_slot_restorer(lctx.kv_self) { - const auto & model = lctx.model; - const auto & cparams = lctx.cparams; - const auto & hparams = lctx.model.hparams; - - const auto & kv_self = lctx.kv_self; - - const int64_t n_tokens_all = batch.n_tokens; - const int64_t n_embd = hparams.n_embd; - - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT - - if (batch.token) { - for (int64_t i = 0; i < n_tokens_all; ++i) { - if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { - LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]); - throw std::runtime_error("invalid token"); - } - } - } - - GGML_ASSERT(n_tokens_all <= cparams.n_batch); - - GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); - - if (lctx.t_compute_start_us == 0) { - lctx.t_compute_start_us = ggml_time_us(); - } - lctx.n_queued_tokens += n_tokens_all; - - // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens - const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; - - lctx.embd_seq.clear(); - - // count outputs - if (batch.logits && !embd_pooled) { - for (uint32_t i = 0; i < n_tokens_all; ++i) { - n_outputs_all += batch.logits[i] != 0; - } - } else if (lctx.logits_all || embd_pooled) { - n_outputs_all = n_tokens_all; - } else { - // keep last output only - n_outputs_all = 1; - } - - const bool logits_all = n_outputs_all == n_tokens_all; - - lctx.sbatch.from_batch(batch, n_embd, - /* simple_split */ !kv_self.recurrent, - /* logits_all */ logits_all); - } - - ~batch_manager() { - } - - bool is_done() const { - return lctx.sbatch.n_tokens == 0; - } - - llama_ubatch next() { - llama_ubatch ubatch = llama_ubatch(); - - const auto & cparams = lctx.cparams; - const auto & kv_self = lctx.kv_self; - - const auto & n_ubatch = cparams.n_ubatch; - - const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; - - if (kv_self.recurrent) { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - ubatch = lctx.sbatch.split_seq(n_ubatch); - } else { - // recurrent model architectures are easier to implement - // with equal-length sequences - ubatch = lctx.sbatch.split_equal(n_ubatch); - } - } else { - ubatch = lctx.sbatch.split_simple(n_ubatch); - } - - return ubatch; - } - - bool prepare(const llama_ubatch & ubatch) { - const auto & cparams = lctx.cparams; - const auto & hparams = lctx.model.hparams; - const auto & batch = lctx.sbatch.batch; - - const auto n_tokens_all = batch->n_tokens; - - auto & kv_self = lctx.kv_self; - - // count the outputs in this u_batch - { - int32_t n_outputs_new = 0; - - if (n_outputs_all == n_tokens_all) { - n_outputs_new = ubatch.n_tokens; - } else { - GGML_ASSERT(ubatch.output); - for (uint32_t i = 0; i < ubatch.n_tokens; i++) { - n_outputs_new += (int32_t) (ubatch.output[i] != 0); - } - } - - // needs to happen before the graph is built - lctx.n_outputs = n_outputs_new; - } - - // non-causal masks do not use the KV cache - if (hparams.causal_attn) { - lctx.kv_self_update(); - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) { - kv_self.head = 0; - } - - const auto slot_info = kv_self.find_slot(ubatch); - if (!slot_info) { - return false; - } - - kv_slot_restorer.save(slot_info); - - if (!kv_self.recurrent) { - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - const uint32_t pad = kv_self.get_padding(cparams); - kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad))); - //kv_self.n = llama_kv_cache_cell_max(kv_self); - } - } - - //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); - - // reserve a worst case graph if needed - if (lctx.need_reserve) { - LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__); - - const auto & cparams = lctx.cparams; - const auto & model = lctx.model; - - // build worst-case graph - uint32_t n_seqs = 1; // TODO: worst-case number of sequences - uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - ggml_cgraph * gf = lctx.build_graph(ubatch, true); - - // initialize scheduler with the worst-case graph - ggml_backend_sched_reset(lctx.sched.get()); - if (!ggml_backend_sched_reserve(lctx.sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); - } - - lctx.need_reserve = false; - } - - return true; - } - - void restore() { - kv_slot_restorer.restore(lctx.kv_self); - } - - void update(const llama_ubatch & ubatch) { - auto & kv_self = lctx.kv_self; - - // update the kv ring buffer - { - kv_self.head += ubatch.n_tokens; - - // Ensure kv cache head points to a valid index. - if (kv_self.head >= kv_self.size) { - kv_self.head = 0; - } - } - } - - void finalize() { - const auto & cparams = lctx.cparams; - - auto & kv_self = lctx.kv_self; - - // decide if we need to defrag the kv cache - if (cparams.causal_attn && cparams.defrag_thold > 0.0f) { - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + lctx.get_ctx_padding(cparams))/float(kv_self.n)) : 0.0f; - - // queue defragmentation for next llama_kv_cache_update - if (fragmentation > cparams.defrag_thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - kv_self.defrag(); - } - } - } - - int64_t n_outputs_all = 0; - - llama_context_kv_self & lctx; - - const llama_batch & batch; - - llama_kv_slot_restorer kv_slot_restorer; -}; - -std::unique_ptr llama_context_kv_self::prepare_batch(const llama_batch & batch) { - return std::make_unique(*this, batch); -} - int llama_context_kv_self::decode(llama_batch & inp_batch) { is_encoding = false; @@ -1783,29 +1561,179 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { const auto & hparams = model.hparams; const int32_t n_vocab = vocab.n_tokens(); - const int64_t n_embd = hparams.n_embd; - // TODO: try catch - auto bman = prepare_batch(batch); + const int64_t n_tokens_all = batch.n_tokens; + const int64_t n_embd = hparams.n_embd; - const auto n_outputs_all = bman->n_outputs_all; + // TODO: remove this stuff + class batch_guard { + public: + batch_guard(llama_kv_cache & kv_self) : kv_slot_restorer(kv_self) { + } + + ~batch_guard() { + if (!is_done) { + kv_slot_restorer.restore(); + } + } + + void done() { + is_done = true; + } + + void save(const llama_kv_cache_slot_info & slot_info) { + kv_slot_restorer.save(slot_info); + } + + private: + bool is_done = false; + + llama_kv_slot_restorer kv_slot_restorer; + }; + + batch_guard bg(kv_self); + + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + + if (batch.token) { + for (int64_t i = 0; i < n_tokens_all; ++i) { + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]); + throw std::runtime_error("invalid token"); + } + } + } + + GGML_ASSERT(n_tokens_all <= cparams.n_batch); + + GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); + + if (t_compute_start_us == 0) { + t_compute_start_us = ggml_time_us(); + } + n_queued_tokens += n_tokens_all; + + // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + embd_seq.clear(); + + int64_t n_outputs_all = 0; + + // count outputs + if (batch.logits && !embd_pooled) { + for (uint32_t i = 0; i < n_tokens_all; ++i) { + n_outputs_all += batch.logits[i] != 0; + } + } else if (logits_all || embd_pooled) { + n_outputs_all = n_tokens_all; + } else { + // keep last output only + n_outputs_all = 1; + } + + const bool logits_all = n_outputs_all == n_tokens_all; + + sbatch.from_batch(batch, n_embd, + /* simple_split */ !kv_self.recurrent, + /* logits_all */ logits_all); // reserve output buffer // TODO: move to batch manager? - if (output_reserve(bman->n_outputs_all) < (size_t) n_outputs_all) { + if (output_reserve(n_outputs_all) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all); return -2; }; int64_t n_outputs_prev = 0; - while (!bman->is_done()) { - llama_ubatch ubatch = bman->next(); + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch = llama_ubatch(); - if (!bman->prepare(ubatch)) { - LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); - bman->restore(); - return -3; + const auto & n_ubatch = cparams.n_ubatch; + + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + + if (kv_self.recurrent) { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = sbatch.split_seq(n_ubatch); + } else { + // recurrent model architectures are easier to implement + // with equal-length sequences + ubatch = sbatch.split_equal(n_ubatch); + } + } else { + ubatch = sbatch.split_simple(n_ubatch); + } + + // count the outputs in this u_batch + { + int32_t n_outputs_new = 0; + + if (n_outputs_all == n_tokens_all) { + n_outputs_new = ubatch.n_tokens; + } else { + GGML_ASSERT(ubatch.output); + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + n_outputs_new += (int32_t) (ubatch.output[i] != 0); + } + } + + // needs to happen before the graph is built + n_outputs = n_outputs_new; + } + + // non-causal masks do not use the KV cache + if (hparams.causal_attn) { + kv_self_update(); + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) { + kv_self.head = 0; + } + + const auto slot_info = kv_self.find_slot(ubatch); + if (!slot_info) { + LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); + return -3; + } + + bg.save(slot_info); + + if (!kv_self.recurrent) { + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + const uint32_t pad = kv_self.get_padding(cparams); + kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(kv_self.cell_max(), pad))); + //kv_self.n = llama_kv_cache_cell_max(kv_self); + } + } + + //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); + + // reserve a worst case graph if needed + if (need_reserve) { + LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__); + + // build worst-case graph + uint32_t n_seqs = 1; // TODO: worst-case number of sequences + uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + ggml_cgraph * gf = build_graph(ubatch, true); + + // initialize scheduler with the worst-case graph + ggml_backend_sched_reset(sched.get()); + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + } + + need_reserve = false; } ggml_backend_sched_reset(sched.get()); @@ -1844,7 +1772,6 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); if (compute_status != GGML_STATUS_SUCCESS) { - bman->restore(); switch (compute_status) { case GGML_STATUS_ABORTED: return 2; @@ -1856,7 +1783,15 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { } } - bman->update(ubatch); + // update the kv ring buffer + { + kv_self.head += ubatch.n_tokens; + + // Ensure kv cache head points to a valid index. + if (kv_self.head >= kv_self.size) { + kv_self.head = 0; + } + } // plot the computation graph in dot format (for debugging purposes) //if (n_past%100 == 0) { @@ -1936,14 +1871,17 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { n_outputs_prev += n_outputs; } + // finalize the batch processing + bg.done(); + // set output mappings { bool sorted_output = true; GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all); - for (size_t i = 0; i < (size_t) n_outputs_all; ++i) { - size_t out_id = sbatch.out_ids[i]; + for (int64_t i = 0; i < n_outputs_all; ++i) { + int64_t out_id = sbatch.out_ids[i]; output_ids[out_id] = i; if (out_id != i) { sorted_output = false; @@ -1961,7 +1899,19 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) { // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); - bman->finalize(); + // decide if we need to defrag the kv cache + if (cparams.causal_attn && cparams.defrag_thold > 0.0f) { + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + get_ctx_padding(cparams))/float(kv_self.n)) : 0.0f; + + // queue defragmentation for next llama_kv_cache_update + if (fragmentation > cparams.defrag_thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + kv_self.defrag(); + } + } // Reset state for the next token before backend sync, to allow the CPU activities in the reset to // overlap with device computation. @@ -1983,14 +1933,14 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : pos_max() + 1); const llama_batch & batch = batch_allocr.batch; - const uint32_t n_tokens = batch.n_tokens; + const int32_t n_tokens = batch.n_tokens; const auto & hparams = model.hparams; GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT if (batch.token) { - for (uint32_t i = 0; i < n_tokens; ++i) { + for (int32_t i = 0; i < n_tokens; ++i) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); return -1; @@ -1999,7 +1949,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { } // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot - GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens"); + GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens"); if (t_compute_start_us == 0) { t_compute_start_us = ggml_time_us(); @@ -2019,7 +1969,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { return -2; }; - for (uint32_t i = 0; i < n_tokens; ++i) { + for (int32_t i = 0; i < n_tokens; ++i) { output_ids[i] = i; } @@ -2087,7 +2037,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { // remember the sequence ids used during the encoding - needed for cross attention later seq_ids_enc.resize(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { + for (int32_t i = 0; i < n_tokens; i++) { for (int s = 0; s < ubatch.n_seq_id[i]; s++) { llama_seq_id seq_id = ubatch.seq_id[i][s]; seq_ids_enc[i].insert(seq_id); @@ -2116,7 +2066,7 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) { GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits - for (uint32_t i = 0; i < n_tokens; i++) { + for (int32_t i = 0; i < n_tokens; i++) { const llama_seq_id seq_id = ubatch.seq_id[i][0]; if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { continue; @@ -2448,7 +2398,7 @@ void llama_context_kv_self::kv_self_update() { ggml_backend_sched_reset(sched.get()); auto ctx = graph_init(); - auto ctx0 = ctx.get(); + auto * ctx0 = ctx.get(); ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); @@ -2477,7 +2427,7 @@ void llama_context_kv_self::kv_self_update() { ggml_backend_sched_reset(sched.get()); auto ctx = graph_init(); - auto ctx0 = ctx.get(); + auto * ctx0 = ctx.get(); ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); diff --git a/src/llama-context.h b/src/llama-context.h index e70c99f33..f2ebf4f13 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -92,6 +92,7 @@ struct llama_context : public llama_graph_i { virtual void synchronize(); + // zero-out inputs and create ggml_context virtual ggml_context_ptr graph_init(); // returns the result of ggml_backend_sched_graph_compute_async execution @@ -103,13 +104,40 @@ struct llama_context : public llama_graph_i { // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. - virtual size_t output_reserve(size_t n_outputs); + virtual int32_t output_reserve(int32_t n_outputs); // make the outputs have the same order they had in the user-provided batch // TODO: maybe remove this virtual void output_reorder(); + // decode a batch of tokens by evaluating the transformer + // in case of unsuccessful decoding (error or warning), + // the kv_cache state will be returned to its original state + // (for non-recurrent models) or cleaned (for recurrent models) + // + // - lctx: llama context + // - inp_batch: batch to evaluate + // + // return 0 on success + // return positive int on warning + // return negative int on error + // + virtual int decode(llama_batch & inp_batch) = 0; + + // encode a batch of tokens by evaluating the encoder part of the transformer + // + // - lctx: llama context + // - batch: batch to evaluate + // + // return 0 on success + // return positive int on warning + // return negative int on error + // + virtual int encode(llama_batch & inp_batch) = 0; + + // // graph build API (generic) + // virtual void build_cb( ggml_tensor * cur, @@ -141,31 +169,6 @@ struct llama_context : public llama_graph_i { virtual ggml_tensor * build_rope_factors(int il); - // decode a batch of tokens by evaluating the transformer - // in case of unsuccessful decoding (error or warning), - // the kv_cache state will be returned to its original state - // (for non-recurrent models) or cleaned (for recurrent models) - // - // - lctx: llama context - // - inp_batch: batch to evaluate - // - // return 0 on success - // return positive int on warning - // return negative int on error - // - virtual int decode(llama_batch & inp_batch) = 0; - - // encode a batch of tokens by evaluating the encoder part of the transformer - // - // - lctx: llama context - // - batch: batch to evaluate - // - // return 0 on success - // return positive int on warning - // return negative int on error - // - virtual int encode(llama_batch & inp_batch) = 0; - // state save/load virtual size_t state_get_size(); @@ -268,7 +271,7 @@ protected: // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; - size_t output_size = 0; // capacity (of tokens positions) for the output buffers + int32_t output_size = 0; // capacity (of tokens positions) for the output buffers int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch std::vector output_ids; // map batch token positions to ids of the logits and embd buffers @@ -291,8 +294,6 @@ protected: // transformer with a self-attention KV cache class llama_context_kv_self : public llama_context { public: - struct batch_manager; - llama_context_kv_self( const llama_model & model, const llama_context_params & params); @@ -313,8 +314,6 @@ public: virtual int decode(llama_batch & inp_batch) override; virtual int encode(llama_batch & inp_batch) override; - virtual std::unique_ptr prepare_batch(const llama_batch & batch); - // max token position across all sequences in the current context llama_pos pos_max() const; diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 6ea497297..3bb07ca9d 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -150,7 +150,9 @@ struct llama_kv_slot_restorer { bool do_restore = false; - explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) { + llama_kv_cache & cache; + + explicit llama_kv_slot_restorer(llama_kv_cache & cache) : cache(cache) { old_state.head = cache.head; old_state.n = cache.n; } @@ -167,7 +169,7 @@ struct llama_kv_slot_restorer { // must be explicitly called to restore the kv_cache state // and rollback changes from all llama_kv_cache_find_slot calls - void restore(struct llama_kv_cache & cache) { + void restore() { if (do_restore) { cache.head = old_state.head; cache.n = old_state.n;