mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-29 20:45:04 +00:00
llama : refactor kv cache guard (#12695)
* llama : refactor kv cache guard ggml-ci * cont : fix comment [no ci] * llama : fix kv_cache restore logic ggml-ci * context : simplify kv cache updates ggml-ci * cont : better name [no ci] * llama : fix llama_decode return code when could not find KV slot ggml-ci * context : change log err -> warn [no ci] * kv-cache : add comment + warning
This commit is contained in:
@ -1201,33 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
const int64_t n_tokens_all = batch.n_tokens;
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
// TODO: remove this stuff
|
||||
class batch_guard {
|
||||
public:
|
||||
batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
|
||||
}
|
||||
|
||||
~batch_guard() {
|
||||
if (!is_done) {
|
||||
kv_slot_restorer.restore();
|
||||
}
|
||||
}
|
||||
|
||||
void done() {
|
||||
is_done = true;
|
||||
}
|
||||
|
||||
void save(const llama_kv_cache_slot_info & slot_info) {
|
||||
kv_slot_restorer.save(slot_info);
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_done = false;
|
||||
|
||||
llama_kv_slot_restorer kv_slot_restorer;
|
||||
};
|
||||
|
||||
batch_guard bg(*kv_self);
|
||||
llama_kv_cache_guard kv_guard(kv_self.get());
|
||||
|
||||
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
||||
|
||||
@ -1280,6 +1254,9 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
return -2;
|
||||
};
|
||||
|
||||
// handle any pending defrags/shifts
|
||||
kv_self_update();
|
||||
|
||||
int64_t n_outputs_prev = 0;
|
||||
|
||||
while (sbatch.n_tokens > 0) {
|
||||
@ -1319,22 +1296,12 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
|
||||
// find KV slot
|
||||
{
|
||||
kv_self_update();
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
|
||||
kv_self->head = 0;
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto slot_info = kv_self->find_slot(ubatch);
|
||||
if (!slot_info) {
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
||||
return -3;
|
||||
}
|
||||
|
||||
bg.save(slot_info);
|
||||
|
||||
if (!kv_self->recurrent) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
@ -1371,16 +1338,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
}
|
||||
}
|
||||
|
||||
// update the kv ring buffer
|
||||
{
|
||||
kv_self->head += ubatch.n_tokens;
|
||||
|
||||
// Ensure kv cache head points to a valid index.
|
||||
if (kv_self->head >= kv_self->size) {
|
||||
kv_self->head = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// plot the computation graph in dot format (for debugging purposes)
|
||||
//if (n_past%100 == 0) {
|
||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
@ -1467,7 +1424,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
}
|
||||
|
||||
// finalize the batch processing
|
||||
bg.done();
|
||||
kv_guard.commit();
|
||||
|
||||
// set output mappings
|
||||
{
|
||||
|
Reference in New Issue
Block a user