mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 12:35:16 +00:00
kv-cache : refactor the update/defrag mechanism (#13988)
* kv-cache : refactor update mechanism ggml-ci * memory : improve status handling * defrag : reset head + add comments ggml-ci * cont : minor fixes ggml-ci
This commit is contained in:
@ -123,26 +123,16 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
|
||||
|
||||
assert(heads_base.size() == heads_swa.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS,
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
||||
this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(this);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::update(llama_context & lctx) {
|
||||
bool res = false;
|
||||
|
||||
res = res | kv_base->update(lctx);
|
||||
res = res | kv_swa ->update(lctx);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_iswa::defrag_sched(float thold) {
|
||||
kv_base->defrag_sched(thold);
|
||||
kv_swa ->defrag_sched(thold);
|
||||
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_kv_cache_unified_iswa_state>(this, lctx, optimize);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
||||
@ -174,26 +164,38 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv) : status(status) {
|
||||
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base()));
|
||||
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa ()));
|
||||
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
state_base = kv->get_base()->init_full();
|
||||
state_swa = kv->get_swa ()->init_full();
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
state_base = kv->get_base()->init_update(lctx, optimize);
|
||||
state_swa = kv->get_swa ()->init_update(lctx, optimize);
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
std::vector<llama_ubatch> ubatches)
|
||||
: status(status),
|
||||
sbatch(std::move(sbatch)),
|
||||
ubatches(std::move(ubatches)) {
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
state_base.reset(new llama_kv_cache_unified_state(status, kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
||||
state_swa .reset(new llama_kv_cache_unified_state(status, kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
||||
}
|
||||
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
||||
sbatch(std::move(sbatch)),
|
||||
ubatches(std::move(ubatches)) {
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
||||
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
||||
|
||||
@ -233,17 +235,18 @@ llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
||||
|
||||
const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return state_base.get();
|
||||
return static_cast<const llama_kv_cache_unified_state *>(state_base.get());
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return state_swa.get();
|
||||
return static_cast<const llama_kv_cache_unified_state *>(state_swa.get());
|
||||
}
|
||||
|
Reference in New Issue
Block a user