From bbd0f917797e9d524680f1b30d34a46eb06d7651 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 29 Jul 2025 10:40:50 +0200 Subject: [PATCH] server-bench: make seed choice configurable (#14929) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * server-bench: make seed choice configurable * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret * fix error formatting * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Sigbjørn Skjæret --- scripts/server-bench.py | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/scripts/server-bench.py b/scripts/server-bench.py index 3afad66ce..9326be8d5 100755 --- a/scripts/server-bench.py +++ b/scripts/server-bench.py @@ -32,11 +32,12 @@ def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]: return ret -def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int) -> list[int]: +def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int, seed_offset: int) -> list[int]: assert n_prompts >= 0 ret: list[int] = [] for i in range(n_prompts): - random.seed(13 * i + 0) + if seed_offset >= 0: + random.seed(3 * (seed_offset + 1000 * i) + 0) ret.append(random.randint(prompt_length_min, prompt_length_max)) return ret @@ -46,12 +47,20 @@ def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]: def get_server(path_server: str, path_log: Optional[str]) -> dict: - logger.info("Starting the llama.cpp server...") - hostname: str = os.environ.get("LLAMA_ARG_HOST", "127.0.0.1") - port: str = os.environ.get("LLAMA_ARG_PORT", "8080") + if os.environ.get("LLAMA_ARG_HOST") is None: + logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1") + os.environ["LLAMA_ARG_HOST"] = "127.0.0.1" + if os.environ.get("LLAMA_ARG_PORT") is None: + logger.info("LLAMA_ARG_PORT not explicitly set, using 8080") + os.environ["LLAMA_ARG_PORT"] = "8080" + hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST") + port: Optional[str] = os.environ.get("LLAMA_ARG_PORT") + assert hostname is not None + assert port is not None address: str = f"http://{hostname}:{port}" + logger.info(f"Starting the llama.cpp server under {address}...") - fout = open(path_log, "w") if path_log is not None else subprocess.DEVNULL + fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT) n_failures: int = 0 @@ -60,7 +69,7 @@ def get_server(path_server: str, path_log: Optional[str]) -> dict: sleep(1.0) exit_code = process.poll() if exit_code is not None: - raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log}") + raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}") response = requests.get(f"{address}/health") if response.status_code == 200: break @@ -128,7 +137,7 @@ def send_prompt(data: dict) -> tuple[float, list[float]]: return (t_submit, token_arrival_times) -def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int): +def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int, seed_offset: int): if os.environ.get("LLAMA_ARG_N_PARALLEL") is None: logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32") os.environ["LLAMA_ARG_N_PARALLEL"] = "32" @@ -139,7 +148,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'") os.environ["LLAMA_ARG_FLASH_ATTN"] = "true" - parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL", 1)) + parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts) synthetic_prompts: bool = prompts is None prompt_n = [] @@ -151,7 +160,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p prompt_length_min: int = int(prompt_source_split[1]) prompt_length_max: int = int(prompt_source_split[2]) logger.info("Generating random prompts...") - prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max) + prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max, seed_offset) prompts = get_prompts_rng(prompt_n) else: n_predict_min = n_predict @@ -176,10 +185,11 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p data: list[dict] = [] for i, p in enumerate(prompts): - random.seed(13 * i + 1) + if seed_offset >= 0: + random.seed(3 * (seed_offset + 1000 * i) + 1) data.append({ "session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts, - "n_predict": random.randint(n_predict_min, n_predict), "seed": 13 * i + 2}) + "n_predict": random.randint(n_predict_min, n_predict), "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1}) if not synthetic_prompts: logger.info("Getting the prompt lengths...") @@ -251,7 +261,7 @@ if __name__ == "__main__": "Results are printed to console and visualized as plots (saved to current working directory). " "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).") parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary") - parser.add_argument("--path_log", type=str, default="server-bench.log", help="Path to the model to use for the benchmark") + parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark") parser.add_argument( "--prompt_source", type=str, default="rng-1024-2048", help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or " @@ -261,5 +271,7 @@ if __name__ == "__main__": parser.add_argument( "--n_predict_min", type=int, default=1024, help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)") + parser.add_argument("--seed_offset", type=int, default=0, help="Offset for determining the seeds for pseudorandom prompt/generation lengths. " + "Corelations between seeds can occur when set >= 1000. Negative values mean no seed.") args = parser.parse_args() benchmark(**vars(args))