kv-cache : refactor + add llama_memory_state_i (#13746)

* kv-cache : simplify the "struct llama_kv_cache" interface

ggml-ci

* kv-cache : revert the (n_swa + n_ubatch) change (for next PR)

ggml-ci

* kv-cache : some comments

ggml-ci

* context : fix graph reserve for multiple sequences

ggml-ci

* kv-cache : fix typo [no ci]

* kv-cache : fix find_slot() logic for free slots

ggml-ci

* llama : add TODO for deprecating the defrag API in the future

* kv-cache : improve find_slot() using min/max seq pos info

ggml-ci

* llama : handle aborts and compute errors

ggml-ci

* memory : extract state into llama_memory_state

ggml-ci

* kv-cache : add comments

ggml-ci

* server : update batching logic to reset n_batch on successful decode

* server : upon full re-processing, remove the sequence from the cache

* kv-cache : add TODO for doing split_equal when split_simple fails

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-05-31 10:24:04 +03:00
committed by GitHub
parent eb3949938e
commit 12d0188c0d
14 changed files with 1304 additions and 655 deletions

View File

@ -362,7 +362,9 @@ int main(int argc, char ** argv) {
// process in chunks of params.n_batch
int32_t n_batch = params.n_batch;
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
int32_t i_next = 0;
for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
// experiment: process in powers of 2
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
// n_batch /= 2;
@ -370,7 +372,7 @@ int main(int argc, char ** argv) {
// continue;
//}
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
llama_batch batch_view = {
n_tokens,
@ -396,13 +398,18 @@ int main(int argc, char ** argv) {
// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
i -= n_batch;
continue;
}
LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens);
// move the head of the batch forward with the number of tokens we just processed
i_next = i + n_tokens;
// on successful decode, restore the original batch size
n_batch = params.n_batch;
for (auto & client : clients) {
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
continue;