mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-14 04:17:53 -04:00
server-bench: external OAI servers, sqlite (#15179)
* server-bench: external OAI servers, sqlite * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update scripts/server-bench.py Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * raise_for_status --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sqlite3
|
||||
import subprocess
|
||||
from time import sleep, time
|
||||
from typing import Optional, Union
|
||||
@@ -47,6 +48,8 @@ def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
|
||||
|
||||
|
||||
def get_server(path_server: str, path_log: Optional[str]) -> dict:
|
||||
if path_server.startswith("http://") or path_server.startswith("https://"):
|
||||
return {"process": None, "address": path_server, "fout": None}
|
||||
if os.environ.get("LLAMA_ARG_HOST") is None:
|
||||
logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
|
||||
os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
|
||||
@@ -89,15 +92,13 @@ def get_prompt_length(data: dict) -> int:
|
||||
f"{server_address}/apply-template",
|
||||
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
|
||||
response.raise_for_status()
|
||||
prompt: str = json.loads(response.text)["prompt"]
|
||||
response = session.post(
|
||||
f"{server_address}/tokenize",
|
||||
json={"content": prompt, "add_special": True}
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
|
||||
response.raise_for_status()
|
||||
tokens: list[str] = json.loads(response.text)["tokens"]
|
||||
return len(tokens)
|
||||
|
||||
@@ -107,7 +108,12 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
|
||||
server_address: str = data["server_address"]
|
||||
|
||||
t_submit = time()
|
||||
if data["synthetic_prompt"]:
|
||||
if data["external_server"]:
|
||||
json_data: dict = {
|
||||
"prompt": data["prompt"], "ignore_eos": True,
|
||||
"seed": data["seed"], "max_tokens": data["n_predict"], "stream": True}
|
||||
response = session.post(f"{server_address}/v1/completions", json=json_data, stream=True)
|
||||
elif data["synthetic_prompt"]:
|
||||
json_data: dict = {
|
||||
"prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
|
||||
"seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
|
||||
@@ -117,34 +123,38 @@ def send_prompt(data: dict) -> tuple[float, list[float]]:
|
||||
f"{server_address}/apply-template",
|
||||
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
|
||||
response.raise_for_status()
|
||||
prompt: str = json.loads(response.text)["prompt"]
|
||||
|
||||
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
|
||||
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
lines = []
|
||||
token_arrival_times: list[float] = []
|
||||
for line in response.iter_lines(decode_unicode=False):
|
||||
if not line.startswith(b"data: "):
|
||||
continue
|
||||
lines.append(line)
|
||||
token_arrival_times.append(time())
|
||||
token_arrival_times = token_arrival_times[:-1]
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
|
||||
if len(lines) > 1 and "timings" in json.loads(lines[-2][6:]):
|
||||
token_arrival_times = token_arrival_times[:-1]
|
||||
|
||||
return (t_submit, token_arrival_times)
|
||||
|
||||
|
||||
def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, n_predict_min: int, seed_offset: int):
|
||||
def benchmark(
|
||||
path_server: str, path_log: Optional[str], path_db: Optional[str], name: Optional[str], prompt_source: str, n_prompts: int,
|
||||
n_predict: int, n_predict_min: int, seed_offset: int):
|
||||
external_server: bool = path_server.startswith("http://") or path_server.startswith("https://")
|
||||
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
|
||||
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
|
||||
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
|
||||
if os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
|
||||
if not external_server and os.environ.get("LLAMA_ARG_N_GPU_LAYERS") is None:
|
||||
logger.info("LLAMA_ARG_N_GPU_LAYERS not explicitly set, using 999")
|
||||
os.environ["LLAMA_ARG_N_GPU_LAYERS"] = "999"
|
||||
if os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
|
||||
if not external_server and os.environ.get("LLAMA_ARG_FLASH_ATTN") is None:
|
||||
logger.info("LLAMA_ARG_FLASH_ATTN not explicitly set, using 'true'")
|
||||
os.environ["LLAMA_ARG_FLASH_ATTN"] = "true"
|
||||
|
||||
@@ -165,7 +175,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
|
||||
else:
|
||||
n_predict_min = n_predict
|
||||
|
||||
if os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
|
||||
if not external_server and os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
|
||||
context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
|
||||
context_total: int = context_per_slot * parallel
|
||||
os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
|
||||
@@ -176,6 +186,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
|
||||
try:
|
||||
server = get_server(path_server, path_log)
|
||||
server_address: str = server["address"]
|
||||
assert external_server == (server["process"] is None)
|
||||
|
||||
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
|
||||
session = requests.Session()
|
||||
@@ -188,8 +199,9 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
|
||||
if seed_offset >= 0:
|
||||
random.seed(3 * (seed_offset + 1000 * i) + 1)
|
||||
data.append({
|
||||
"session": session, "server_address": server_address, "prompt": p, "synthetic_prompt": synthetic_prompts,
|
||||
"n_predict": random.randint(n_predict_min, n_predict), "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
|
||||
"session": session, "server_address": server_address, "external_server": external_server, "prompt": p,
|
||||
"synthetic_prompt": synthetic_prompts, "n_predict": random.randint(n_predict_min, n_predict),
|
||||
"seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
|
||||
|
||||
if not synthetic_prompts:
|
||||
logger.info("Getting the prompt lengths...")
|
||||
@@ -199,7 +211,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
|
||||
t0 = time()
|
||||
results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
|
||||
finally:
|
||||
if server is not None:
|
||||
if server is not None and server["process"] is not None:
|
||||
server["process"].terminate()
|
||||
server["process"].wait()
|
||||
if session is not None:
|
||||
@@ -233,15 +245,24 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
|
||||
logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
|
||||
logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
|
||||
logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
|
||||
logger.info("")
|
||||
logger.info(
|
||||
"The above numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
|
||||
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
|
||||
|
||||
if path_db is not None:
|
||||
con = sqlite3.connect(path_db)
|
||||
cursor = con.cursor()
|
||||
cursor.execute(
|
||||
"CREATE TABLE IF NOT EXISTS server_bench"
|
||||
"(name TEXT, n_parallel INTEGER, prompt_source TEXT, n_prompts INTEGER, "
|
||||
"n_predict INTEGER, n_predict_min INTEGER, seed_offset INTEGER, runtime REAL);")
|
||||
cursor.execute(
|
||||
"INSERT INTO server_bench VALUES (?, ?, ?, ?, ?, ?, ?, ?);",
|
||||
[name, parallel, prompt_source, n_prompts, n_predict, n_predict_min, seed_offset, token_t_last])
|
||||
con.commit()
|
||||
|
||||
plt.figure()
|
||||
plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
|
||||
plt.xlim(0, 1.05e0 * np.max(prompt_n))
|
||||
plt.ylim(0, 1.05e3 * np.max(prompt_t))
|
||||
plt.title(name or "")
|
||||
plt.xlabel("Prompt length [tokens]")
|
||||
plt.ylabel("Time to first token [ms]")
|
||||
plt.savefig("prompt_time.png", dpi=240)
|
||||
@@ -250,6 +271,7 @@ def benchmark(path_server: str, path_log: Optional[str], prompt_source: str, n_p
|
||||
plt.figure()
|
||||
plt.hist(token_t, np.arange(0, bin_max))
|
||||
plt.xlim(0, bin_max + 1)
|
||||
plt.title(name or "")
|
||||
plt.xlabel("Time [s]")
|
||||
plt.ylabel("Num. tokens generated per second")
|
||||
plt.savefig("gen_rate.png", dpi=240)
|
||||
@@ -259,9 +281,13 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
|
||||
"Results are printed to console and visualized as plots (saved to current working directory). "
|
||||
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help).")
|
||||
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
|
||||
"The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
|
||||
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
|
||||
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
|
||||
parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark")
|
||||
parser.add_argument("--path_db", type=str, default=None, help="Path to an sqlite database to store the benchmark results in")
|
||||
parser.add_argument("--name", type=str, default=None, help="Name to label plots and database entries with")
|
||||
parser.add_argument(
|
||||
"--prompt_source", type=str, default="rng-1024-2048",
|
||||
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
|
||||
|
Reference in New Issue
Block a user