scripts: synthetic prompt mode for server-bench.py (#14695)

This commit is contained in:
Johannes Gäßler
2025-07-16 09:33:28 +02:00
committed by GitHub
parent 4b91d6f71f
commit 5cae766541
2 changed files with 122 additions and 67 deletions

187
scripts/server-bench.py Normal file → Executable file
View File

@ -2,9 +2,11 @@
import argparse import argparse
import json import json
import os
import random
import subprocess import subprocess
from time import sleep, time from time import sleep, time
from typing import Optional from typing import Optional, Union
import datasets import datasets
import logging import logging
@ -18,31 +20,39 @@ logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger("server-bench") logger = logging.getLogger("server-bench")
def get_prompts(n_prompts: int) -> list[str]: def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
logger.info("Loading MMLU dataset...") ret = []
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore if dataset_name.lower() == "mmlu":
logger.info("Loading MMLU dataset...")
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
else:
return None
if n_prompts >= 0: if n_prompts >= 0:
ret = ret[:n_prompts] ret = ret[:n_prompts]
return ret return ret
def get_server(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int) -> dict: def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int) -> list[int]:
logger.info("Starting the llama.cpp server...") assert n_prompts >= 0
address = f"http://localhost:{port}" ret: list[int] = []
for i in range(n_prompts):
random.seed(13 * i + 0)
ret.append(random.randint(prompt_length_min, prompt_length_max))
return ret
popen_args: list[str] = [
path_server, def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
"--flash-attn", return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths]
"--n-gpu-layers", str(n_gpu_layers),
"--parallel", str(parallel),
"--ctx-size", str(parallel * ctx_size), def get_server(path_server: str, path_log: Optional[str]) -> dict:
"--model", path_model, logger.info("Starting the llama.cpp server...")
"--port", str(port), hostname: str = os.environ.get("LLAMA_ARG_HOST", "127.0.0.1")
"--swa-full", # FIXME performance bad otherwise port: str = os.environ.get("LLAMA_ARG_PORT", "8080")
# "--attn-streams", address: str = f"http://{hostname}:{port}"
]
fout = open("bench.log", "w") if path_log is not None else subprocess.DEVNULL fout = open(path_log, "w") if path_log is not None else subprocess.DEVNULL
process = subprocess.Popen(popen_args, stdout=fout, stderr=subprocess.STDOUT) process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
n_failures: int = 0 n_failures: int = 0
while True: while True:
@ -50,14 +60,14 @@ def get_server(path_server: str, path_model: str, path_log: Optional[str], port:
sleep(1.0) sleep(1.0)
exit_code = process.poll() exit_code = process.poll()
if exit_code is not None: if exit_code is not None:
raise RuntimeError(f"llama.cpp server for {path_model} exited unexpectedly with exit code {exit_code}") raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}, see {path_log}")
response = requests.get(f"{address}/health") response = requests.get(f"{address}/health")
if response.status_code == 200: if response.status_code == 200:
break break
except requests.ConnectionError: except requests.ConnectionError:
n_failures += 1 n_failures += 1
if n_failures >= 10: if n_failures >= 10:
raise RuntimeError(f"llama.cpp server for {path_model} is not healthy after 10 seconds") raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
return {"process": process, "address": address, "fout": fout} return {"process": process, "address": address, "fout": fout}
@ -87,58 +97,97 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
session = data["session"] session = data["session"]
server_address: str = data["server_address"] server_address: str = data["server_address"]
response = session.post( t_submit = time()
f"{server_address}/apply-template", if data["synthetic_prompt"]:
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]} json_data: dict = {
) "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
if response.status_code != 200: "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") response = session.post(f"{server_address}/completion", json=json_data, stream=True)
prompt: str = json.loads(response.text)["prompt"] else:
response = session.post(
f"{server_address}/apply-template",
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
)
if response.status_code != 200:
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
prompt: str = json.loads(response.text)["prompt"]
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True} json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
response = session.post(f"{server_address}/completion", json=json_data, stream=True) response = session.post(f"{server_address}/completion", json=json_data, stream=True)
last_valid_line: str = ""
token_arrival_times: list[float] = [] token_arrival_times: list[float] = []
for line in response.iter_lines(decode_unicode=True): for line in response.iter_lines(decode_unicode=False):
if not line.startswith("data: "): if not line.startswith(b"data: "):
continue continue
last_valid_line = line
token_arrival_times.append(time()) token_arrival_times.append(time())
token_arrival_times = token_arrival_times[:-1] token_arrival_times = token_arrival_times[:-1]
if response.status_code != 200: if response.status_code != 200:
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
timings: dict = json.loads(last_valid_line[6:])["timings"]
return (timings["prompt_ms"], token_arrival_times) return (t_submit, token_arrival_times)
def benchmark(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int, n_prompts: int, n_predict: int): def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int):
num_workers: int = parallel + 1 if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
prompts: list[str] = get_prompts(n_prompts) logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL", 1))
prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
synthetic_prompts: bool = prompts is None
prompt_n = []
if synthetic_prompts:
prompt_source_split: list[str] = prompt_source.split("-")
assert len(prompt_source_split) == 3
assert prompt_source_split[0].lower() == "rng"
prompt_length_min: int = int(prompt_source_split[1])
prompt_length_max: int = int(prompt_source_split[2])
logger.info("Generating random prompts...")
prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max)
prompts = get_prompts_rng(prompt_n)
else:
n_predict_min = n_predict
if os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
context_total: int = context_per_slot * parallel
os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).")
server: Optional[dict] = None server: Optional[dict] = None
session = None session = None
try: try:
server = get_server(path_server, path_model, path_log, port, n_gpu_layers, parallel, ctx_size) server = get_server(path_server, path_log)
server_address: str = server["address"] server_address: str = server["address"]
adapter = requests.adapters.HTTPAdapter(pool_connections=num_workers, pool_maxsize=num_workers) # type: ignore adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
session = requests.Session() session = requests.Session()
session.mount("http://", adapter) session.mount("http://", adapter)
session.mount("https://", adapter) session.mount("https://", adapter)
data: list[dict] = [] data: list[dict] = []
for i, p in enumerate(prompts):
data.append({"session": session, "server_address": server_address, "prompt": p, "n_predict": n_predict, "seed": i})
logger.info("Getting the prompt lengths...") for i, p in enumerate(prompts):
prompt_n = [get_prompt_length(d) for d in data] random.seed(13 * i + 1)
data.append({
"session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
"n_predict": random.randint(n_predict_min, n_predict), "seed": 13 * i + 2})
if not synthetic_prompts:
logger.info("Getting the prompt lengths...")
prompt_n = [get_prompt_length(d) for d in data]
logger.info("Starting the benchmark...\n") logger.info("Starting the benchmark...\n")
t0 = time() t0 = time()
results: list[tuple[int, list[float]]] = thread_map(send_prompt, data, max_workers=num_workers, chunksize=1) results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
finally: finally:
if server is not None: if server is not None:
server["process"].terminate() server["process"].terminate()
@ -146,17 +195,18 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
if session is not None: if session is not None:
session.close() session.close()
prompt_ms = [] prompt_t = []
token_t = [] token_t = []
depth_sum: int = 0 depth_sum: int = 0
for pn, (pms, tat) in zip(prompt_n, results): for pn, (t_submit, tat) in zip(prompt_n, results):
prompt_ms.append(pms) prompt_t.append(tat[0] - t_submit)
token_t += tat token_t += tat
n_tokens: int = len(tat) n_tokens: int = len(tat)
depth_sum += n_tokens * pn depth_sum += n_tokens * pn
depth_sum += n_tokens * (n_tokens + 1) // 2 depth_sum += n_tokens * (n_tokens + 1) // 2
assert len(token_t) > 0
prompt_n = np.array(prompt_n, dtype=np.int64) prompt_n = np.array(prompt_n, dtype=np.int64)
prompt_ms = np.array(prompt_ms, dtype=np.float64) prompt_t = np.array(prompt_t, dtype=np.float64)
token_t = np.array(token_t, dtype=np.float64) token_t = np.array(token_t, dtype=np.float64)
token_t -= t0 token_t -= t0
@ -167,18 +217,21 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min") logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens") logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens")
logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens") logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
logger.info(f"Average prompt latency: {np.mean(prompt_ms):.2f} ms") logger.info(f"Average prompt latency: {1e3 * np.mean(prompt_t):.2f} ms")
logger.info(f"Average prompt speed: {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s") logger.info(f"Average prompt speed: {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s")
logger.info(f"Total generated tokens: {token_t.shape[0]}") logger.info(f"Total generated tokens: {token_t.shape[0]}")
logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens") logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s") logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot") logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
logger.info("")
logger.info(
"The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
plt.figure() plt.figure()
plt.scatter(prompt_n, prompt_ms, s=10.0, marker=".", alpha=0.25) plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
plt.xlim(0, 1.05 * np.max(prompt_n)) plt.xlim(0, 1.05e0 * np.max(prompt_n))
plt.ylim(0, 1.05 * np.max(prompt_ms)) plt.ylim(0, 1.05e3 * np.max(prompt_t))
plt.title(path_model)
plt.xlabel("Prompt length [tokens]") plt.xlabel("Prompt length [tokens]")
plt.ylabel("Time to first token [ms]") plt.ylabel("Time to first token [ms]")
plt.savefig("prompt_time.png", dpi=240) plt.savefig("prompt_time.png", dpi=240)
@ -187,7 +240,6 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
plt.figure() plt.figure()
plt.hist(token_t, np.arange(0, bin_max)) plt.hist(token_t, np.arange(0, bin_max))
plt.xlim(0, bin_max + 1) plt.xlim(0, bin_max + 1)
plt.title(path_model)
plt.xlabel("Time [s]") plt.xlabel("Time [s]")
plt.ylabel("Num. tokens generated per second") plt.ylabel("Num. tokens generated per second")
plt.savefig("gen_rate.png", dpi=240) plt.savefig("gen_rate.png", dpi=240)
@ -196,15 +248,18 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. " description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
"Results are printed to console and visualized as plots (saved to current working directory).") "Results are printed to console and visualized as plots (saved to current working directory). "
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary") parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
parser.add_argument("--path_model", type=str, required=True, help="Path to the model to use for the benchmark") parser.add_argument("--path_log", type=str, default="server-bench.log", help="Path to the model to use for the benchmark")
parser.add_argument("--path_log", type=str, default=None, help="Path to the model to use for the benchmark") parser.add_argument(
parser.add_argument("--port", type=int, default=18725, help="Port to use for the server during the benchmark") "--prompt_source", type=str, default="rng-1024-2048",
parser.add_argument("--n_gpu_layers", type=int, default=999, help="Number of GPU layers for the server") help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
parser.add_argument("--parallel", type=int, default=16, help="Number of slots for the server") "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]")
parser.add_argument("--ctx_size", type=int, default=4096, help="Server context size per slot") parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate")
parser.add_argument("--n_prompts", type=int, default=1000, help="Number of prompts to evaluate")
parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt") parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
parser.add_argument(
"--n_predict_min", type=int, default=1024,
help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
args = parser.parse_args() args = parser.parse_args()
benchmark(**vars(args)) benchmark(**vars(args))

View File

@ -7,7 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
**Features:** **Features:**
* LLM inference of F16 and quantized models on GPU and CPU * LLM inference of F16 and quantized models on GPU and CPU
* [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
* Reranking endoint (https://github.com/ggml-org/llama.cpp/pull/9510) * Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510)
* Parallel decoding with multi-user support * Parallel decoding with multi-user support
* Continuous batching * Continuous batching
* Multimodal ([documentation](../../docs/multimodal.md)) / with OpenAI-compatible API support * Multimodal ([documentation](../../docs/multimodal.md)) / with OpenAI-compatible API support