From 2e42be42bd6bf1dcc643d6ac4e77419bfe5dd24f Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 14 Jun 2025 16:34:20 +0800 Subject: [PATCH] compare-llama-bench: add option to plot (#14169) * compare llama-bench: add option to plot * Address review comments: convert case + add type hints * Add matplotlib to requirements * fix tests * Improve comment and fix assert condition for test * Add back default test_name, add --plot_log_scale * use log_scale regardless of x_values --- .../requirements-compare-llama-bench.txt | 1 + scripts/compare-llama-bench.py | 169 +++++++++++++++++- 2 files changed, 169 insertions(+), 1 deletion(-) diff --git a/requirements/requirements-compare-llama-bench.txt b/requirements/requirements-compare-llama-bench.txt index e0aaa3204..d87e897e1 100644 --- a/requirements/requirements-compare-llama-bench.txt +++ b/requirements/requirements-compare-llama-bench.txt @@ -1,2 +1,3 @@ tabulate~=0.9.0 GitPython~=3.1.43 +matplotlib~=3.10.0 diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index a1013c3b7..30e3cf864 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -19,6 +19,7 @@ except ImportError as e: print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100 raise e + logger = logging.getLogger("compare-llama-bench") # All llama-bench SQL fields @@ -122,11 +123,15 @@ help_s = ( parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed") parser.add_argument("-s", "--show", help=help_s) parser.add_argument("--verbose", action="store_true", help="increase output verbosity") +parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)") +parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth") +parser.add_argument("--plot_log_scale", action="store_true", help="use log scale for x axis in plots (off by default)") known_args, unknown_args = parser.parse_known_args() logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO) + if known_args.check: # Check if all required Python libraries are installed. Would have failed earlier if not. sys.exit(0) @@ -499,7 +504,6 @@ else: name_compare = bench_data.get_commit_name(hexsha8_compare) - # If the user provided columns to group the results by, use them: if known_args.show is not None: show = known_args.show.split(",") @@ -544,6 +548,14 @@ else: show.remove(prop) except ValueError: pass + + # Add plot_x parameter to parameters to show if it's not already present: + if known_args.plot: + for k, v in PRETTY_NAMES.items(): + if v == known_args.plot_x and k not in show: + show.append(k) + break + rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare) if not rows_show: @@ -600,6 +612,161 @@ if "gpu_info" in show: headers = [PRETTY_NAMES[p] for p in show] headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"] +if known_args.plot: + def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str, log_scale: bool = False): + try: + import matplotlib.pyplot as plt + import matplotlib + matplotlib.use('Agg') + except ImportError as e: + logger.error("matplotlib is required for --plot.") + raise e + + data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup) + plot_x_index = None + plot_x_label = plot_x_param + + if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]: + pretty_name = PRETTY_NAMES.get(plot_x_param, plot_x_param) + if pretty_name in data_headers: + plot_x_index = data_headers.index(pretty_name) + plot_x_label = pretty_name + elif plot_x_param in data_headers: + plot_x_index = data_headers.index(plot_x_param) + plot_x_label = plot_x_param + else: + logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}") + return + + grouped_data = {} + + for i, row in enumerate(table_data): + group_key_parts = [] + test_name = row[-4] + + base_test = "" + x_value = None + + if plot_x_param in ["n_prompt", "n_gen", "n_depth"]: + for j, val in enumerate(row[:-4]): + header_name = data_headers[j] + if val is not None and str(val).strip(): + group_key_parts.append(f"{header_name}={val}") + + if plot_x_param == "n_prompt" and "pp" in test_name: + base_test = test_name.split("@")[0] + x_value = base_test + elif plot_x_param == "n_gen" and "tg" in test_name: + x_value = test_name.split("@")[0] + elif plot_x_param == "n_depth" and "@d" in test_name: + base_test = test_name.split("@d")[0] + x_value = int(test_name.split("@d")[1]) + else: + base_test = test_name + + if base_test.strip(): + group_key_parts.append(f"Test={base_test}") + else: + for j, val in enumerate(row[:-4]): + if j != plot_x_index: + header_name = data_headers[j] + if val is not None and str(val).strip(): + group_key_parts.append(f"{header_name}={val}") + else: + x_value = val + + group_key_parts.append(f"Test={test_name}") + + group_key = tuple(group_key_parts) + + if group_key not in grouped_data: + grouped_data[group_key] = [] + + grouped_data[group_key].append({ + 'x_value': x_value, + 'baseline': float(row[-3]), + 'compare': float(row[-2]), + 'speedup': float(row[-1]) + }) + + if not grouped_data: + logger.error("No data available for plotting") + return + + def make_axes(num_groups, max_cols=2, base_size=(8, 4)): + from math import ceil + cols = 1 if num_groups == 1 else min(max_cols, num_groups) + rows = ceil(num_groups / cols) + + # Scale figure size by grid dimensions + w, h = base_size + fig, ax_arr = plt.subplots(rows, cols, + figsize=(w * cols, h * rows), + squeeze=False) + + axes = ax_arr.flatten()[:num_groups] + return fig, axes + + num_groups = len(grouped_data) + fig, axes = make_axes(num_groups) + + plot_idx = 0 + + for group_key, points in grouped_data.items(): + if plot_idx >= len(axes): + break + ax = axes[plot_idx] + + try: + points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0) + x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted] + except ValueError: + points_sorted = sorted(points, key=lambda p: group_key) + x_values = [p['x_value'] for p in points_sorted] + + baseline_vals = [p['baseline'] for p in points_sorted] + compare_vals = [p['compare'] for p in points_sorted] + + ax.plot(x_values, baseline_vals, 'o-', color='skyblue', + label=f'{baseline_name}', linewidth=2, markersize=6) + ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8, + label=f'{compare_name}', linewidth=2, markersize=6) + + if log_scale: + ax.set_xscale('log', base=2) + unique_x = sorted(set(x_values)) + ax.set_xticks(unique_x) + ax.set_xticklabels([str(int(x)) for x in unique_x]) + + title_parts = [] + for part in group_key: + if '=' in part: + key, value = part.split('=', 1) + title_parts.append(f"{key}: {value}") + + title = ', '.join(title_parts) if title_parts else "Performance comparison" + + ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold') + ax.set_ylabel('Tokens per second (t/s)', fontsize=12, fontweight='bold') + ax.set_title(title, fontsize=12, fontweight='bold') + ax.legend(loc='best', fontsize=10) + ax.grid(True, alpha=0.3) + + plot_idx += 1 + + for i in range(plot_idx, len(axes)): + axes[i].set_visible(False) + + fig.suptitle(f'Performance comparison: {compare_name} vs. {baseline_name}', + fontsize=14, fontweight='bold') + fig.subplots_adjust(top=1) + + plt.tight_layout() + plt.savefig(output_file, dpi=300, bbox_inches='tight') + plt.close() + + create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x, known_args.plot_log_scale) + print(tabulate( # noqa: NP100 table, headers=headers,