From 132143938ffa37697fe7f6b08e0efdba50c2cb41 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 24 Jun 2025 15:02:58 +0300 Subject: [PATCH] tools : tmp adjustments (TMP) ggml-ci --- examples/parallel/parallel.cpp | 16 +++++++++++----- tools/batched-bench/batched-bench.cpp | 8 ++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index d53e089a4..83f55747b 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -235,7 +235,7 @@ int main(int argc, char ** argv) { // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time - llama_batch batch = llama_batch_init(n_ctx, 0, 1); + llama_batch batch = llama_batch_init(n_ctx*n_clients, 0, 1); int32_t n_total_prompt = 0; int32_t n_total_gen = 0; @@ -289,8 +289,11 @@ int main(int argc, char ** argv) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { llama_memory_seq_rm(mem, i, -1, -1); - // but keep the system prompt - llama_memory_seq_cp(mem, 0, i, -1, -1); + + if (is_sp_shared) { + // but keep the system prompt + llama_memory_seq_cp(mem, 0, i, -1, -1); + } } LOG_INF("%s: clearing the KV cache\n", __func__); @@ -449,8 +452,11 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_memory_seq_rm(mem, client.id + 1, -1, -1); - llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1); + llama_memory_seq_rm(mem, client.id + 1, -1, -1); + + if (is_sp_shared) { + llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1); + } const auto t_main_end = ggml_time_us(); diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp index a0a2e5ac5..0bac1fc96 100644 --- a/tools/batched-bench/batched-bench.cpp +++ b/tools/batched-bench/batched-bench.cpp @@ -61,7 +61,7 @@ int main(int argc, char ** argv) { const int32_t n_kv_max = llama_n_ctx(ctx); - llama_batch batch = llama_batch_init(n_kv_max, 0, 1); + llama_batch batch = llama_batch_init(n_kv_max*8, 0, 1); // TODO: tmp!!! // decode in batches of ctx_params.n_batch tokens auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) { @@ -119,9 +119,9 @@ int main(int argc, char ** argv) { const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg); - if (n_ctx_req > n_kv_max) { - continue; - } + //if (n_ctx_req > n_kv_max) { + // continue; + //} common_batch_clear(batch);