136 lines
4.1 KiB
Python
136 lines
4.1 KiB
Python
"""Shared script helpers for deterministic project tooling."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import os
|
|
import pathlib
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
|
|
|
|
class ScriptError(RuntimeError):
|
|
"""Raised for user-facing script failures."""
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class CommandPlan:
|
|
"""Command description that can be printed for dry-run evidence or executed."""
|
|
|
|
label: str
|
|
command: tuple[str, ...]
|
|
cwd: pathlib.Path
|
|
required_tool: str | None = None
|
|
deferred: str | None = None
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class CommandOutcome:
|
|
"""Result of executing or planning a command."""
|
|
|
|
plan: CommandPlan
|
|
returncode: int
|
|
executed: bool
|
|
|
|
|
|
def find_project_root(start: pathlib.Path | None = None) -> pathlib.Path:
|
|
"""Find the project root by walking upward to the root CMakeLists.txt."""
|
|
|
|
current = (pathlib.Path.cwd() if start is None else start).resolve()
|
|
if current.is_file():
|
|
current = current.parent
|
|
|
|
for candidate in (current, *current.parents):
|
|
if (candidate / "CMakeLists.txt").is_file() and (candidate / "scripts").is_dir():
|
|
return candidate
|
|
|
|
raise ScriptError(f"Unable to find project root from {current}")
|
|
|
|
|
|
def format_command(command: Sequence[str]) -> str:
|
|
"""Format a command for logs without invoking a shell."""
|
|
|
|
return " ".join(_quote_arg(part) for part in command)
|
|
|
|
|
|
def print_plan(plans: Iterable[CommandPlan], *, prefix: str = "plan") -> None:
|
|
"""Print planned command lines in stable order."""
|
|
|
|
for index, plan in enumerate(plans, start=1):
|
|
print(f"{prefix}[{index}].label: {plan.label}")
|
|
print(f"{prefix}[{index}].cwd: {plan.cwd}")
|
|
print(f"{prefix}[{index}].command: {format_command(plan.command)}")
|
|
if plan.deferred is not None:
|
|
print(f"{prefix}[{index}].deferred: {plan.deferred}")
|
|
|
|
|
|
def find_tool(name: str) -> str | None:
|
|
"""Return a project-local tool path before falling back to PATH."""
|
|
|
|
project_bin = find_project_root(pathlib.Path(__file__)) / "scripts" / "bin" / name
|
|
if project_bin.is_file() and os.access(project_bin, os.X_OK):
|
|
return str(project_bin)
|
|
|
|
return shutil.which(name)
|
|
|
|
|
|
def require_tool(name: str) -> str:
|
|
"""Return an executable path or raise a user-facing error."""
|
|
|
|
path = find_tool(name)
|
|
if path is None:
|
|
raise ScriptError(f"Required tool not found on PATH: {name}")
|
|
return path
|
|
|
|
|
|
def python_command(script: pathlib.Path, *args: str) -> tuple[str, ...]:
|
|
"""Build a command using the current Python interpreter."""
|
|
|
|
return (sys.executable, str(script), *args)
|
|
|
|
|
|
def run_plan(plan: CommandPlan, *, dry_run: bool, env: Mapping[str, str] | None = None) -> CommandOutcome:
|
|
"""Execute one planned command or print it in dry-run mode."""
|
|
|
|
if dry_run:
|
|
print(f"DRY-RUN {plan.label}: {format_command(plan.command)}")
|
|
if plan.deferred is not None:
|
|
print(f"DEFERRED {plan.label}: {plan.deferred}")
|
|
return CommandOutcome(plan=plan, returncode=0, executed=False)
|
|
|
|
if plan.deferred is not None:
|
|
print(f"DEFERRED {plan.label}: {plan.deferred}")
|
|
return CommandOutcome(plan=plan, returncode=2, executed=False)
|
|
|
|
if plan.required_tool is not None:
|
|
_ = require_tool(plan.required_tool)
|
|
|
|
completed = subprocess.run(
|
|
list(plan.command),
|
|
cwd=str(plan.cwd),
|
|
env=None if env is None else dict(env),
|
|
check=False,
|
|
)
|
|
return CommandOutcome(plan=plan, returncode=completed.returncode, executed=True)
|
|
|
|
|
|
def run_plans(plans: Sequence[CommandPlan], *, dry_run: bool) -> int:
|
|
"""Run plans in order and stop on the first non-zero exit code."""
|
|
|
|
for plan in plans:
|
|
outcome = run_plan(plan, dry_run=dry_run)
|
|
if outcome.returncode != 0:
|
|
return outcome.returncode
|
|
return 0
|
|
|
|
|
|
def _quote_arg(value: str) -> str:
|
|
if not value:
|
|
return "''"
|
|
safe_chars = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_+-=./:")
|
|
if all(char in safe_chars for char in value):
|
|
return value
|
|
return "'" + value.replace("'", "'\\''") + "'"
|