mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
kv-cache : refactor the update/defrag mechanism (#13988)
* kv-cache : refactor update mechanism ggml-ci * memory : improve status handling * defrag : reset head + add comments ggml-ci * cont : minor fixes ggml-ci
This commit is contained in:
@ -429,22 +429,54 @@ const llama_kv_cache * llama_context::get_kv_self() const {
|
|||||||
return kv_self;
|
return kv_self;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_context::kv_self_update() {
|
void llama_context::kv_self_defrag_sched() {
|
||||||
|
if (!memory) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_force_optimize = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_context::kv_self_update(bool optimize) {
|
||||||
if (!memory) {
|
if (!memory) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
||||||
|
|
||||||
if (!kv_self->update(*this)) {
|
{
|
||||||
// no updates have been performed
|
// TODO: remove in the future
|
||||||
return false;
|
optimize |= memory_force_optimize;
|
||||||
|
memory_force_optimize = false;
|
||||||
|
|
||||||
|
const auto kv_state = kv_self->init_update(this, optimize);
|
||||||
|
switch (kv_state->get_status()) {
|
||||||
|
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||||
|
{
|
||||||
|
// noop
|
||||||
|
} break;
|
||||||
|
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||||
|
{
|
||||||
|
// no updates need to be performed
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||||
|
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||||
|
{
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!kv_state->apply()) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the KV cache did any computation, we have to reserve a new worst-case graph
|
// if the KV cache did any computation, we have to reserve a new worst-case graph
|
||||||
const auto kv_state = kv_self->init_full();
|
const auto kv_state = kv_self->init_full();
|
||||||
if (!kv_state) {
|
if (!kv_state) {
|
||||||
throw std::runtime_error("failed to initialize KV cache");
|
throw std::runtime_error("failed to initialize memory state");
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint32_t n_seqs = cparams.n_seq_max;
|
const uint32_t n_seqs = cparams.n_seq_max;
|
||||||
@ -452,7 +484,7 @@ bool llama_context::kv_self_update() {
|
|||||||
|
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, kv_state.get());
|
||||||
if (!gf) {
|
if (!gf) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to reserve graph after the KV cache update\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@ -940,13 +972,13 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||||||
n_outputs_all = 1;
|
n_outputs_all = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool did_optimize = false;
|
||||||
|
|
||||||
// handle any pending defrags/shifts
|
// handle any pending defrags/shifts
|
||||||
kv_self_update();
|
kv_self_update(false);
|
||||||
|
|
||||||
llama_memory_state_ptr kv_state;
|
llama_memory_state_ptr kv_state;
|
||||||
|
|
||||||
bool did_defrag = false;
|
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
kv_state = kv_self->init_batch(batch, cparams.n_ubatch, embd_pooled, /* logits_all */ n_outputs_all == n_tokens_all);
|
||||||
if (!kv_state) {
|
if (!kv_state) {
|
||||||
@ -957,25 +989,32 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||||||
case LLAMA_MEMORY_STATUS_SUCCESS:
|
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||||
{
|
{
|
||||||
} break;
|
} break;
|
||||||
|
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||||
|
{
|
||||||
|
LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, kv_state->get_status());
|
||||||
|
|
||||||
|
return -2;
|
||||||
|
}
|
||||||
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||||
{
|
{
|
||||||
if (!did_defrag) {
|
if (!did_optimize) {
|
||||||
did_defrag = true;
|
did_optimize = true;
|
||||||
|
|
||||||
kv_self->defrag_sched(-1.0f);
|
if (kv_self_update(true)) {
|
||||||
if (kv_self_update()) {
|
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens);
|
||||||
LLAMA_LOG_DEBUG("%s: failed to init batch of size %d, retrying after defrag\n", __func__, batch.n_tokens);
|
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
|
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens);
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||||
{
|
{
|
||||||
|
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens);
|
||||||
|
|
||||||
return -2;
|
return -2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1189,11 +1228,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|||||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||||
//synchronize();
|
//synchronize();
|
||||||
|
|
||||||
// decide if we need to defrag the kv cache
|
|
||||||
if (cparams.defrag_thold > 0.0f) {
|
|
||||||
kv_self->defrag_sched(cparams.defrag_thold);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||||
// overlap with device computation.
|
// overlap with device computation.
|
||||||
ggml_backend_sched_reset(sched.get());
|
ggml_backend_sched_reset(sched.get());
|
||||||
@ -2283,7 +2317,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
|
|||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
void llama_kv_self_update(llama_context * ctx) {
|
void llama_kv_self_update(llama_context * ctx) {
|
||||||
ctx->kv_self_update();
|
ctx->kv_self_update(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
enum llama_pooling_type llama_pooling_type(const llama_context * ctx) {
|
||||||
@ -2538,13 +2572,8 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|||||||
|
|
||||||
// deprecated
|
// deprecated
|
||||||
void llama_kv_self_defrag(llama_context * ctx) {
|
void llama_kv_self_defrag(llama_context * ctx) {
|
||||||
auto * kv = ctx->get_kv_self();
|
|
||||||
if (!kv) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// force defrag
|
// force defrag
|
||||||
kv->defrag_sched(-1.0f);
|
ctx->kv_self_defrag_sched();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
||||||
|
@ -52,7 +52,8 @@ struct llama_context {
|
|||||||
|
|
||||||
// return true of the KV cache was updated
|
// return true of the KV cache was updated
|
||||||
// TODO: remove
|
// TODO: remove
|
||||||
bool kv_self_update();
|
bool kv_self_update(bool optimize);
|
||||||
|
void kv_self_defrag_sched();
|
||||||
|
|
||||||
enum llama_pooling_type pooling_type() const;
|
enum llama_pooling_type pooling_type() const;
|
||||||
|
|
||||||
@ -231,6 +232,9 @@ private:
|
|||||||
|
|
||||||
std::unique_ptr<llama_memory_i> memory;
|
std::unique_ptr<llama_memory_i> memory;
|
||||||
|
|
||||||
|
// TODO: temporary, until the llama_kv_self_defrag() API is removed
|
||||||
|
bool memory_force_optimize = false;
|
||||||
|
|
||||||
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
// decode output (2-dimensional array: [n_outputs][n_vocab])
|
||||||
size_t logits_size = 0; // capacity (of floats) for logits
|
size_t logits_size = 0; // capacity (of floats) for logits
|
||||||
float * logits = nullptr;
|
float * logits = nullptr;
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#include "llama-kv-cache-recurrent.h"
|
#include "llama-kv-cache-recurrent.h"
|
||||||
|
|
||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
|
#include "llama-io.h"
|
||||||
#include "llama-batch.h"
|
#include "llama-batch.h"
|
||||||
#include "llama-model.h"
|
#include "llama-model.h"
|
||||||
|
|
||||||
@ -386,6 +387,13 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
|
|||||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
|
||||||
|
GGML_UNUSED(lctx);
|
||||||
|
GGML_UNUSED(optimize);
|
||||||
|
|
||||||
|
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
||||||
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||||
// simply remember the full state because it is very small for this type of cache
|
// simply remember the full state because it is very small for this type of cache
|
||||||
// TODO: optimize
|
// TODO: optimize
|
||||||
@ -419,17 +427,6 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
|||||||
return success;
|
return success;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_recurrent::update(llama_context & lctx) {
|
|
||||||
GGML_UNUSED(lctx);
|
|
||||||
// noop
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_cache_recurrent::defrag_sched(float thold) {
|
|
||||||
GGML_UNUSED(thold);
|
|
||||||
// noop
|
|
||||||
}
|
|
||||||
|
|
||||||
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||||
const uint32_t n_tokens = ubatch.n_tokens;
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
const uint32_t n_seqs = ubatch.n_seqs;
|
const uint32_t n_seqs = ubatch.n_seqs;
|
||||||
|
@ -52,9 +52,7 @@ public:
|
|||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_state_ptr init_full() override;
|
||||||
|
|
||||||
bool update(llama_context & lctx) override;
|
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
void defrag_sched(float thold) override;
|
|
||||||
|
|
||||||
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
||||||
|
|
||||||
|
@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
|
|||||||
|
|
||||||
assert(heads_base.size() == heads_swa.size());
|
assert(heads_base.size() == heads_swa.size());
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
||||||
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
|
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||||
bool res = false;
|
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
|
||||||
|
|
||||||
res = res | kv_base->update(lctx);
|
|
||||||
res = res | kv_swa ->update(lctx);
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
|
|
||||||
kv_base->defrag_sched(thold);
|
|
||||||
kv_swa ->defrag_sched(thold);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||||
@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
|||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||||
llama_memory_status status,
|
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||||
llama_kv_cache_unified_iswa * kv) : status(status) {
|
state_base = kv->get_base()->init_full();
|
||||||
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
|
state_swa = kv->get_swa ()->init_full();
|
||||||
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
|
|
||||||
|
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||||
|
llama_kv_cache_unified_iswa * kv,
|
||||||
|
llama_context * lctx,
|
||||||
|
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||||
|
state_base = kv->get_base()->init_update(lctx, optimize);
|
||||||
|
state_swa = kv->get_swa ()->init_update(lctx, optimize);
|
||||||
|
|
||||||
|
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||||
llama_memory_status status,
|
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
llama_sbatch sbatch,
|
llama_sbatch sbatch,
|
||||||
std::vector<uint32_t> heads_base,
|
std::vector<uint32_t> heads_base,
|
||||||
std::vector<uint32_t> heads_swa,
|
std::vector<uint32_t> heads_swa,
|
||||||
std::vector<llama_ubatch> ubatches)
|
std::vector<llama_ubatch> ubatches)
|
||||||
: status(status),
|
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
||||||
sbatch(std::move(sbatch)),
|
sbatch(std::move(sbatch)),
|
||||||
ubatches(std::move(ubatches)) {
|
ubatches(std::move(ubatches)) {
|
||||||
// note: here we copy the ubatches. not sure if this is ideal
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
||||||
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
||||||
}
|
|
||||||
|
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||||
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
||||||
|
|
||||||
@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
|||||||
|
|
||||||
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return ubatches[i_next];
|
return ubatches[i_next];
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return state_base.get();
|
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
return state_swa.get();
|
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
|
||||||
}
|
}
|
||||||
|
@ -54,9 +54,7 @@ public:
|
|||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_state_ptr init_full() override;
|
||||||
|
|
||||||
bool update(llama_context & lctx) override;
|
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
void defrag_sched(float thold) override;
|
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
@ -86,12 +84,16 @@ public:
|
|||||||
|
|
||||||
// used to create a full-cache state
|
// used to create a full-cache state
|
||||||
llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_state(
|
||||||
llama_memory_status status,
|
|
||||||
llama_kv_cache_unified_iswa * kv);
|
llama_kv_cache_unified_iswa * kv);
|
||||||
|
|
||||||
|
// used to create an update state
|
||||||
|
llama_kv_cache_unified_iswa_state(
|
||||||
|
llama_kv_cache_unified_iswa * kv,
|
||||||
|
llama_context * lctx,
|
||||||
|
bool optimize);
|
||||||
|
|
||||||
// used to create a state from a batch
|
// used to create a state from a batch
|
||||||
llama_kv_cache_unified_iswa_state(
|
llama_kv_cache_unified_iswa_state(
|
||||||
llama_memory_status status,
|
|
||||||
llama_kv_cache_unified_iswa * kv,
|
llama_kv_cache_unified_iswa * kv,
|
||||||
llama_sbatch sbatch,
|
llama_sbatch sbatch,
|
||||||
std::vector<uint32_t> heads_base,
|
std::vector<uint32_t> heads_base,
|
||||||
@ -120,7 +122,7 @@ public:
|
|||||||
const llama_kv_cache_unified_state * get_swa() const;
|
const llama_kv_cache_unified_state * get_swa() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const llama_memory_status status;
|
llama_memory_status status;
|
||||||
|
|
||||||
//llama_kv_cache_unified_iswa * kv;
|
//llama_kv_cache_unified_iswa * kv;
|
||||||
|
|
||||||
@ -131,6 +133,6 @@ private:
|
|||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
std::unique_ptr<llama_kv_cache_unified_state> state_base;
|
llama_memory_state_ptr state_base;
|
||||||
std::unique_ptr<llama_kv_cache_unified_state> state_swa;
|
llama_memory_state_ptr state_swa;
|
||||||
};
|
};
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#include "llama-kv-cache-unified.h"
|
#include "llama-kv-cache-unified.h"
|
||||||
|
|
||||||
#include "llama-impl.h"
|
#include "llama-impl.h"
|
||||||
|
#include "llama-io.h"
|
||||||
#include "llama-model.h"
|
#include "llama-model.h"
|
||||||
#include "llama-context.h"
|
#include "llama-context.h"
|
||||||
|
|
||||||
@ -320,16 +321,49 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
|||||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS,
|
return std::make_unique<llama_kv_cache_unified_state>(
|
||||||
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
this, std::move(sbatch), std::move(heads), std::move(ubatches));
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
llama_memory_state_ptr llama_kv_cache_unified::init_full() {
|
||||||
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
return std::make_unique<llama_kv_cache_unified_state>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
||||||
std::vector<uint32_t> res;
|
bool do_shift = get_has_shift();
|
||||||
|
|
||||||
|
defrag_info dinfo;
|
||||||
|
|
||||||
|
// see if we need to defrag
|
||||||
|
{
|
||||||
|
bool do_defrag = optimize;
|
||||||
|
|
||||||
|
const auto thold = lctx->get_cparams().defrag_thold;
|
||||||
|
|
||||||
|
if (!do_defrag && thold > 0.0f) {
|
||||||
|
const auto n_kv = cells.used_max_p1();
|
||||||
|
|
||||||
|
// - do not defrag small contexts (i.e. < 2048 tokens)
|
||||||
|
// - count the padding towards the number of used tokens
|
||||||
|
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
||||||
|
|
||||||
|
if (fragmentation > thold) {
|
||||||
|
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
||||||
|
|
||||||
|
do_defrag = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (do_defrag) {
|
||||||
|
dinfo = defrag_prepare(lctx->graph_max_nodes());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_unique<llama_kv_cache_unified_state>(this, lctx, do_shift, std::move(dinfo));
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||||
|
llama_kv_cache_unified::ubatch_heads res;
|
||||||
|
|
||||||
struct state {
|
struct state {
|
||||||
uint32_t head_old; // old position of the head, before placing the ubatch
|
uint32_t head_old; // old position of the head, before placing the ubatch
|
||||||
@ -374,12 +408,12 @@ std::vector<uint32_t> llama_kv_cache_unified::prepare(const std::vector<llama_ub
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::update(llama_context & lctx) {
|
bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) {
|
||||||
bool updated = false;
|
bool updated = false;
|
||||||
|
|
||||||
auto * sched = lctx.get_sched();
|
auto * sched = lctx->get_sched();
|
||||||
|
|
||||||
if (cells.get_has_shift()) {
|
if (do_shift) {
|
||||||
if (!get_can_shift()) {
|
if (!get_can_shift()) {
|
||||||
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
|
||||||
}
|
}
|
||||||
@ -390,9 +424,9 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|||||||
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
||||||
ggml_backend_sched_reset(sched);
|
ggml_backend_sched_reset(sched);
|
||||||
|
|
||||||
auto * gf = lctx.graph_init();
|
auto * gf = lctx->graph_init();
|
||||||
|
|
||||||
auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf);
|
||||||
if (!res) {
|
if (!res) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__);
|
||||||
return updated;
|
return updated;
|
||||||
@ -405,7 +439,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|||||||
|
|
||||||
res->set_inputs(nullptr);
|
res->set_inputs(nullptr);
|
||||||
|
|
||||||
if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__);
|
||||||
return updated;
|
return updated;
|
||||||
}
|
}
|
||||||
@ -416,56 +450,55 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
|
|||||||
cells.reset_shift();
|
cells.reset_shift();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (do_defrag) {
|
if (!dinfo.empty()) {
|
||||||
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
||||||
|
|
||||||
if (defrag_prepare(lctx.graph_max_nodes())) {
|
// apply moves:
|
||||||
ggml_backend_sched_reset(sched);
|
{
|
||||||
|
const auto n_kv = dinfo.ids.size();
|
||||||
|
|
||||||
auto * gf = lctx.graph_init();
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||||
|
assert(dinfo.ids[i] <= n_kv);
|
||||||
|
|
||||||
auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf);
|
if (dinfo.ids[i] == n_kv) {
|
||||||
if (!res) {
|
continue;
|
||||||
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
}
|
||||||
return updated;
|
|
||||||
|
cells.mv(i, dinfo.ids[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
// reset the head so we can find the first free slot during the next ubatch
|
||||||
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
head = 0;
|
||||||
return updated;
|
|
||||||
}
|
|
||||||
|
|
||||||
res->set_inputs(nullptr);
|
|
||||||
|
|
||||||
if (lctx.graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
|
||||||
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
|
||||||
return updated;
|
|
||||||
}
|
|
||||||
|
|
||||||
updated = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
do_defrag = false;
|
ggml_backend_sched_reset(sched);
|
||||||
|
|
||||||
|
auto * gf = lctx->graph_init();
|
||||||
|
|
||||||
|
auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo);
|
||||||
|
if (!res) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__);
|
||||||
|
return updated;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ggml_backend_sched_alloc_graph(sched, gf)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__);
|
||||||
|
return updated;
|
||||||
|
}
|
||||||
|
|
||||||
|
res->set_inputs(nullptr);
|
||||||
|
|
||||||
|
if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__);
|
||||||
|
return updated;
|
||||||
|
}
|
||||||
|
|
||||||
|
updated = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return updated;
|
return updated;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_unified::defrag_sched(float thold) {
|
|
||||||
const auto n_kv = cells.used_max_p1();
|
|
||||||
|
|
||||||
// - do not defrag small contexts (i.e. < 2048 tokens)
|
|
||||||
// - count the padding towards the number of used tokens
|
|
||||||
const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f;
|
|
||||||
|
|
||||||
// queue defragmentation for next llama_kv_cache_update
|
|
||||||
if (fragmentation > thold) {
|
|
||||||
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
|
||||||
|
|
||||||
do_defrag = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
||||||
const uint32_t n_tokens = ubatch.n_tokens;
|
const uint32_t n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
@ -612,6 +645,10 @@ uint32_t llama_kv_cache_unified::get_size() const {
|
|||||||
return cells.size();
|
return cells.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool llama_kv_cache_unified::get_has_shift() const {
|
||||||
|
return cells.get_has_shift();
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
||||||
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
||||||
}
|
}
|
||||||
@ -941,12 +978,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
||||||
const llama_cparams & cparams,
|
const llama_cparams & cparams,
|
||||||
ggml_context * ctx,
|
ggml_context * ctx,
|
||||||
ggml_cgraph * gf) const {
|
ggml_cgraph * gf,
|
||||||
|
const defrag_info & dinfo) const {
|
||||||
auto res = std::make_unique<llm_graph_result>();
|
auto res = std::make_unique<llm_graph_result>();
|
||||||
|
|
||||||
const auto & ids = defrag_info.ids;
|
const auto & ids = dinfo.ids;
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
// CPU defrag
|
// CPU defrag
|
||||||
@ -1087,7 +1125,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
||||||
const uint32_t n_layer = layers.size();
|
const uint32_t n_layer = layers.size();
|
||||||
|
|
||||||
const uint32_t n_kv = cells.used_max_p1();
|
const uint32_t n_kv = cells.used_max_p1();
|
||||||
@ -1108,14 +1146,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||||||
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer);
|
||||||
|
|
||||||
// determine which KV cells to move where
|
// determine which KV cells to move where
|
||||||
//
|
defrag_info res;
|
||||||
// cell i moves to ids[i]
|
auto & ids = res.ids;
|
||||||
//
|
|
||||||
// if ids[i] == i || ids[i] == n_kv, then cell i is not moved
|
|
||||||
//
|
|
||||||
auto & ids = defrag_info.ids;
|
|
||||||
|
|
||||||
ids.clear();
|
|
||||||
ids.resize(n_kv, n_kv);
|
ids.resize(n_kv, n_kv);
|
||||||
|
|
||||||
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
||||||
@ -1179,11 +1212,6 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||||||
// this cell goes to (i0 + nf)
|
// this cell goes to (i0 + nf)
|
||||||
ids[i1] = i0 + nf;
|
ids[i1] = i0 + nf;
|
||||||
|
|
||||||
// move the cell meta data
|
|
||||||
cells.mv(i1, i0 + nf);
|
|
||||||
|
|
||||||
head = n_used;
|
|
||||||
|
|
||||||
if (!cont) {
|
if (!cont) {
|
||||||
n_moves++;
|
n_moves++;
|
||||||
cont = true;
|
cont = true;
|
||||||
@ -1206,14 +1234,14 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (n_moves == 0) {
|
if (n_moves == 0) {
|
||||||
return false;
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
|
LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves);
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
|
LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer);
|
||||||
|
|
||||||
return true;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
|
||||||
@ -1636,24 +1664,27 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
|
llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||||
llama_memory_status status,
|
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
||||||
llama_kv_cache_unified * kv) : status(status), kv(kv) {
|
n_kv = kv->get_size();
|
||||||
n_kv = kv->get_size();
|
head = 0;
|
||||||
head = 0;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||||
llama_memory_status status,
|
llama_kv_cache_unified * kv,
|
||||||
llama_kv_cache_unified * kv,
|
llama_context * lctx,
|
||||||
llama_sbatch sbatch,
|
bool do_shift,
|
||||||
std::vector<uint32_t> heads,
|
defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) {
|
||||||
std::vector<llama_ubatch> ubatches)
|
if (!do_shift && dinfo.empty()) {
|
||||||
: status(status),
|
status = LLAMA_MEMORY_STATUS_NO_UPDATE;
|
||||||
kv(kv),
|
|
||||||
sbatch(std::move(sbatch)),
|
|
||||||
heads(std::move(heads)),
|
|
||||||
ubatches(std::move(ubatches)) {
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
||||||
|
llama_kv_cache_unified * kv,
|
||||||
|
llama_sbatch sbatch,
|
||||||
|
llama_kv_cache_unified::ubatch_heads heads,
|
||||||
|
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
||||||
|
}
|
||||||
|
|
||||||
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
||||||
|
|
||||||
@ -1670,6 +1701,13 @@ bool llama_kv_cache_unified_state::next() {
|
|||||||
bool llama_kv_cache_unified_state::apply() {
|
bool llama_kv_cache_unified_state::apply() {
|
||||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
|
// no ubatches -> this is a KV cache update
|
||||||
|
if (ubatches.empty()) {
|
||||||
|
kv->update(lctx, do_shift, dinfo);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
|
kv->apply_ubatch(heads[i_next], ubatches[i_next]);
|
||||||
|
|
||||||
n_kv = kv->get_n_kv();
|
n_kv = kv->get_n_kv();
|
||||||
|
@ -24,6 +24,19 @@ public:
|
|||||||
// this callback is used to filter out layers that should not be included in the cache
|
// this callback is used to filter out layers that should not be included in the cache
|
||||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||||
|
|
||||||
|
using ubatch_heads = std::vector<uint32_t>;
|
||||||
|
|
||||||
|
struct defrag_info {
|
||||||
|
bool empty() const {
|
||||||
|
return ids.empty();
|
||||||
|
}
|
||||||
|
|
||||||
|
// contains information about which cell moves where:
|
||||||
|
// - cell i moves to ids[i]
|
||||||
|
// - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved
|
||||||
|
std::vector<uint32_t> ids;
|
||||||
|
};
|
||||||
|
|
||||||
llama_kv_cache_unified(
|
llama_kv_cache_unified(
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
layer_filter_cb && filter,
|
layer_filter_cb && filter,
|
||||||
@ -66,9 +79,7 @@ public:
|
|||||||
|
|
||||||
llama_memory_state_ptr init_full() override;
|
llama_memory_state_ptr init_full() override;
|
||||||
|
|
||||||
bool update(llama_context & lctx) override;
|
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
void defrag_sched(float thold) override;
|
|
||||||
|
|
||||||
bool get_can_shift() const override;
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
@ -83,6 +94,8 @@ public:
|
|||||||
|
|
||||||
uint32_t get_size() const;
|
uint32_t get_size() const;
|
||||||
|
|
||||||
|
bool get_has_shift() const;
|
||||||
|
|
||||||
//
|
//
|
||||||
// graph_build API
|
// graph_build API
|
||||||
//
|
//
|
||||||
@ -103,7 +116,9 @@ public:
|
|||||||
|
|
||||||
// find places for the provided ubatches in the cache, returns the head locations
|
// find places for the provided ubatches in the cache, returns the head locations
|
||||||
// return empty vector on failure
|
// return empty vector on failure
|
||||||
std::vector<uint32_t> prepare(const std::vector<llama_ubatch> & ubatches);
|
ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
|
||||||
|
|
||||||
|
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||||
|
|
||||||
// return the cell position where we can insert the ubatch
|
// return the cell position where we can insert the ubatch
|
||||||
// return -1 on failure to find a contiguous slot of kv cells
|
// return -1 on failure to find a contiguous slot of kv cells
|
||||||
@ -133,8 +148,7 @@ private:
|
|||||||
ggml_tensor * v;
|
ggml_tensor * v;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool do_defrag = false;
|
bool v_trans = true; // the value tensor is transposed
|
||||||
bool v_trans = true; // the value tensor is transposed
|
|
||||||
|
|
||||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||||
@ -160,13 +174,8 @@ private:
|
|||||||
// model layer id -> KV cache layer id
|
// model layer id -> KV cache layer id
|
||||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||||
|
|
||||||
// defrag
|
// return non-empty vector if cells have been moved
|
||||||
struct {
|
defrag_info defrag_prepare(int32_t n_max_nodes) const;
|
||||||
std::vector<uint32_t> ids;
|
|
||||||
} defrag_info;
|
|
||||||
|
|
||||||
// return true if cells have been moved
|
|
||||||
bool defrag_prepare(int32_t n_max_nodes);
|
|
||||||
|
|
||||||
size_t total_size() const;
|
size_t total_size() const;
|
||||||
|
|
||||||
@ -192,7 +201,8 @@ private:
|
|||||||
llm_graph_result_ptr build_graph_defrag(
|
llm_graph_result_ptr build_graph_defrag(
|
||||||
const llama_cparams & cparams,
|
const llama_cparams & cparams,
|
||||||
ggml_context * ctx,
|
ggml_context * ctx,
|
||||||
ggml_cgraph * gf) const;
|
ggml_cgraph * gf,
|
||||||
|
const defrag_info & dinfo) const;
|
||||||
|
|
||||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||||
@ -203,20 +213,29 @@ private:
|
|||||||
|
|
||||||
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
class llama_kv_cache_unified_state : public llama_memory_state_i {
|
||||||
public:
|
public:
|
||||||
|
// some shorthands
|
||||||
|
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
||||||
|
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||||
|
|
||||||
// used for errors
|
// used for errors
|
||||||
llama_kv_cache_unified_state(llama_memory_status status);
|
llama_kv_cache_unified_state(llama_memory_status status);
|
||||||
|
|
||||||
// used to create a full-cache state
|
// used to create a full-cache state
|
||||||
llama_kv_cache_unified_state(
|
llama_kv_cache_unified_state(
|
||||||
llama_memory_status status,
|
|
||||||
llama_kv_cache_unified * kv);
|
llama_kv_cache_unified * kv);
|
||||||
|
|
||||||
// used to create a state from a batch
|
// used to create an update state
|
||||||
|
llama_kv_cache_unified_state(
|
||||||
|
llama_kv_cache_unified * kv,
|
||||||
|
llama_context * lctx,
|
||||||
|
bool do_shift,
|
||||||
|
defrag_info dinfo);
|
||||||
|
|
||||||
|
// used to create a decode state from a batch
|
||||||
llama_kv_cache_unified_state(
|
llama_kv_cache_unified_state(
|
||||||
llama_memory_status status,
|
|
||||||
llama_kv_cache_unified * kv,
|
llama_kv_cache_unified * kv,
|
||||||
llama_sbatch sbatch,
|
llama_sbatch sbatch,
|
||||||
std::vector<uint32_t> heads,
|
ubatch_heads heads,
|
||||||
std::vector<llama_ubatch> ubatches);
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
virtual ~llama_kv_cache_unified_state();
|
virtual ~llama_kv_cache_unified_state();
|
||||||
@ -253,16 +272,30 @@ public:
|
|||||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const llama_memory_status status;
|
llama_memory_status status;
|
||||||
|
|
||||||
llama_kv_cache_unified * kv;
|
llama_kv_cache_unified * kv;
|
||||||
|
llama_context * lctx;
|
||||||
|
|
||||||
|
//
|
||||||
|
// update state
|
||||||
|
//
|
||||||
|
|
||||||
|
bool do_shift = false;
|
||||||
|
|
||||||
|
defrag_info dinfo;
|
||||||
|
|
||||||
|
//
|
||||||
|
// batch processing state
|
||||||
|
//
|
||||||
|
|
||||||
llama_sbatch sbatch;
|
llama_sbatch sbatch;
|
||||||
|
|
||||||
// the index of the next ubatch to process
|
// the index of the next ubatch to process
|
||||||
size_t i_next = 0;
|
size_t i_next = 0;
|
||||||
|
|
||||||
std::vector<uint32_t> heads;
|
ubatch_heads heads;
|
||||||
|
|
||||||
std::vector<llama_ubatch> ubatches;
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -1,12 +1,16 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "llama-io.h"
|
|
||||||
#include "llama-memory.h"
|
#include "llama-memory.h"
|
||||||
|
|
||||||
|
class llama_io_write_i;
|
||||||
|
class llama_io_read_i;
|
||||||
|
|
||||||
struct llama_kv_cache : public llama_memory_i {
|
struct llama_kv_cache : public llama_memory_i {
|
||||||
virtual ~llama_kv_cache() = default;
|
virtual ~llama_kv_cache() = default;
|
||||||
|
|
||||||
|
// TODO: move the init_ interfaces to llama_memory_i
|
||||||
|
|
||||||
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
||||||
// return a state object containing the ubatches and KV cache state required to process them
|
// return a state object containing the ubatches and KV cache state required to process them
|
||||||
// check the llama_memory_state_i::get_status() for the result
|
// check the llama_memory_state_i::get_status() for the result
|
||||||
@ -19,16 +23,9 @@ struct llama_kv_cache : public llama_memory_i {
|
|||||||
// simulate full cache, used for allocating worst-case compute buffers
|
// simulate full cache, used for allocating worst-case compute buffers
|
||||||
virtual llama_memory_state_ptr init_full() = 0;
|
virtual llama_memory_state_ptr init_full() = 0;
|
||||||
|
|
||||||
// process any pending defrag/shift/etc. operations
|
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
||||||
// optionally call once before processing a new batch
|
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
||||||
// return true if any operations were performed
|
virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
||||||
virtual bool update(llama_context & lctx) = 0;
|
|
||||||
|
|
||||||
// schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing
|
|
||||||
// TODO: change to
|
|
||||||
// llama_memory_state_ptr init_defrag(float thold) = 0;
|
|
||||||
//
|
|
||||||
virtual void defrag_sched(float thold) = 0;
|
|
||||||
|
|
||||||
// getters
|
// getters
|
||||||
virtual bool get_can_shift() const = 0;
|
virtual bool get_can_shift() const = 0;
|
||||||
|
@ -1 +1,42 @@
|
|||||||
#include "llama-memory.h"
|
#include "llama-memory.h"
|
||||||
|
|
||||||
|
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) {
|
||||||
|
bool has_update = false;
|
||||||
|
|
||||||
|
switch (s0) {
|
||||||
|
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||||
|
{
|
||||||
|
has_update = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||||
|
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||||
|
{
|
||||||
|
return s0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (s1) {
|
||||||
|
case LLAMA_MEMORY_STATUS_SUCCESS:
|
||||||
|
{
|
||||||
|
has_update = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
||||||
|
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
||||||
|
{
|
||||||
|
return s1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if either status has an update, then the combined status has an update
|
||||||
|
return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
|
||||||
|
}
|
||||||
|
@ -36,12 +36,19 @@ public:
|
|||||||
virtual bool get_can_edit() const = 0;
|
virtual bool get_can_edit() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
|
||||||
|
|
||||||
enum llama_memory_status {
|
enum llama_memory_status {
|
||||||
LLAMA_MEMORY_STATUS_SUCCESS = 0,
|
LLAMA_MEMORY_STATUS_SUCCESS = 0,
|
||||||
|
LLAMA_MEMORY_STATUS_NO_UPDATE,
|
||||||
LLAMA_MEMORY_STATUS_FAILED_PREPARE,
|
LLAMA_MEMORY_STATUS_FAILED_PREPARE,
|
||||||
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// helper function for combining the status of two memory states
|
||||||
|
// useful for implementing hybrid memory types (e.g. iSWA)
|
||||||
|
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
|
||||||
|
|
||||||
// the interface for managing the memory state during batch processing
|
// the interface for managing the memory state during batch processing
|
||||||
// this interface is implemented per memory type. see:
|
// this interface is implemented per memory type. see:
|
||||||
// - llama_kv_cache_unified_state
|
// - llama_kv_cache_unified_state
|
||||||
@ -69,7 +76,7 @@ public:
|
|||||||
// get the current ubatch
|
// get the current ubatch
|
||||||
virtual const llama_ubatch & get_ubatch() const = 0;
|
virtual const llama_ubatch & get_ubatch() const = 0;
|
||||||
|
|
||||||
// get the status of the memory state
|
// get the status of the memory state - used for error handling and checking if any updates would be applied
|
||||||
virtual llama_memory_status get_status() const = 0;
|
virtual llama_memory_status get_status() const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user