mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-07-19 00:57:41 +00:00
Docs: script to auto-generate ggml operations docs (#14598)
* Docs: script to auto-generate ggml operations docs * Review: formatting changes + change github action * Use built-in types instead of typing * docs : add BLAS and Metal ops --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
40
.github/workflows/update-ops-docs.yml
vendored
Normal file
40
.github/workflows/update-ops-docs.yml
vendored
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
name: Update Operations Documentation
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
paths:
|
||||||
|
- 'docs/ops/**'
|
||||||
|
- 'scripts/create_ops_docs.py'
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'docs/ops/**'
|
||||||
|
- 'scripts/create_ops_docs.py'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
update-ops-docs:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.x'
|
||||||
|
|
||||||
|
- name: Generate operations documentation to temporary file
|
||||||
|
run: |
|
||||||
|
mkdir -p /tmp/ops_check
|
||||||
|
./scripts/create_ops_docs.py /tmp/ops_check/ops.md
|
||||||
|
|
||||||
|
- name: Check if docs/ops.md matches generated version
|
||||||
|
run: |
|
||||||
|
if ! diff -q docs/ops.md /tmp/ops_check/ops.md; then
|
||||||
|
echo "Operations documentation (docs/ops.md) is not up to date with the backend CSV files."
|
||||||
|
echo "To fix: run ./scripts/create_ops_docs.py and commit the updated docs/ops.md along with your changes"
|
||||||
|
echo "Differences found:"
|
||||||
|
diff docs/ops.md /tmp/ops_check/ops.md || true
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Operations documentation is up to date."
|
95
docs/ops.md
Normal file
95
docs/ops.md
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
# GGML Operations
|
||||||
|
|
||||||
|
List of GGML operations and backend support status.
|
||||||
|
|
||||||
|
Legend:
|
||||||
|
- ✅ Fully supported by this backend
|
||||||
|
- 🟡 Partially supported by this backend
|
||||||
|
- ❌ Not supported by this backend
|
||||||
|
|
||||||
|
| Operation | BLAS | CPU | CUDA | Metal |
|
||||||
|
|-----------|------|------|------|------|
|
||||||
|
| ABS | ❌ | ✅ | 🟡 | ❌ |
|
||||||
|
| ACC | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| ADD | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| ADD1 | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| ARANGE | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| ARGMAX | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| ARGSORT | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| CLAMP | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| CONCAT | ❌ | ✅ | 🟡 | ✅ |
|
||||||
|
| CONT | ❌ | ✅ | 🟡 | ✅ |
|
||||||
|
| CONV_2D_DW | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| CONV_TRANSPOSE_2D | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| COS | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| COUNT_EQUAL | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| CPY | ❌ | 🟡 | 🟡 | 🟡 |
|
||||||
|
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| CROSS_ENTROPY_LOSS_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| DIAG_MASK_INF | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| DIV | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| DUP | ❌ | ✅ | 🟡 | 🟡 |
|
||||||
|
| ELU | ❌ | ✅ | ❌ | 🟡 |
|
||||||
|
| EXP | ❌ | ✅ | 🟡 | ❌ |
|
||||||
|
| FLASH_ATTN_EXT | ❌ | ✅ | 🟡 | 🟡 |
|
||||||
|
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| GEGLU | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| GEGLU_ERF | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| GEGLU_QUICK | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| GELU | ❌ | ✅ | 🟡 | 🟡 |
|
||||||
|
| GELU_ERF | ❌ | ✅ | 🟡 | 🟡 |
|
||||||
|
| GELU_QUICK | ❌ | ✅ | 🟡 | 🟡 |
|
||||||
|
| GET_ROWS | ❌ | ✅ | 🟡 | ✅ |
|
||||||
|
| GET_ROWS_BACK | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
|
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| HARDSIGMOID | ❌ | ✅ | 🟡 | ❌ |
|
||||||
|
| HARDSWISH | ❌ | ✅ | 🟡 | ❌ |
|
||||||
|
| IM2COL | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| L2_NORM | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| LOG | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| MEAN | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| MUL | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||||
|
| MUL_MAT_ID | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| NEG | ❌ | ✅ | 🟡 | 🟡 |
|
||||||
|
| NORM | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| OPT_STEP_ADAMW | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| OUT_PROD | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
|
| PAD | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| PAD_REFLECT_1D | ❌ | ✅ | ❌ | ✅ |
|
||||||
|
| POOL_2D | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| REGLU | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| RELU | ❌ | ✅ | 🟡 | 🟡 |
|
||||||
|
| REPEAT | ❌ | ✅ | 🟡 | ✅ |
|
||||||
|
| REPEAT_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| RMS_NORM | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| RMS_NORM_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| RMS_NORM_MUL | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| ROPE | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| ROPE_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| RWKV_WKV6 | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| RWKV_WKV7 | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| SCALE | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| SET | ❌ | ✅ | ❌ | ✅ |
|
||||||
|
| SET_ROWS | ❌ | 🟡 | ❌ | 🟡 |
|
||||||
|
| SGN | ❌ | ✅ | 🟡 | ❌ |
|
||||||
|
| SIGMOID | ❌ | ✅ | 🟡 | 🟡 |
|
||||||
|
| SILU | ❌ | ✅ | 🟡 | 🟡 |
|
||||||
|
| SILU_BACK | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| SIN | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| SOFT_MAX | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| SOFT_MAX_BACK | ❌ | 🟡 | 🟡 | ❌ |
|
||||||
|
| SQR | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| SQRT | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| SSM_CONV | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| SSM_SCAN | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| STEP | ❌ | ✅ | 🟡 | ❌ |
|
||||||
|
| SUB | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| SUM | ❌ | ✅ | ✅ | ❌ |
|
||||||
|
| SUM_ROWS | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| SWIGLU | ❌ | ✅ | ✅ | 🟡 |
|
||||||
|
| TANH | ❌ | ✅ | 🟡 | 🟡 |
|
||||||
|
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ |
|
||||||
|
| UPSCALE | ❌ | ✅ | ✅ | 🟡 |
|
6534
docs/ops/BLAS.csv
Normal file
6534
docs/ops/BLAS.csv
Normal file
File diff suppressed because it is too large
Load Diff
6534
docs/ops/CPU.csv
Normal file
6534
docs/ops/CPU.csv
Normal file
File diff suppressed because it is too large
Load Diff
6534
docs/ops/CUDA.csv
Normal file
6534
docs/ops/CUDA.csv
Normal file
File diff suppressed because it is too large
Load Diff
6534
docs/ops/Metal.csv
Normal file
6534
docs/ops/Metal.csv
Normal file
File diff suppressed because it is too large
Load Diff
196
scripts/create_ops_docs.py
Executable file
196
scripts/create_ops_docs.py
Executable file
@ -0,0 +1,196 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script parses docs/ops/*.csv and creates the ops.md, which is a table documenting supported operations on various ggml backends.
|
||||||
|
"""
|
||||||
|
import csv
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
class DocsGenerator:
|
||||||
|
def __init__(self, ggml_root: str, output_filename: str = "ops.md"):
|
||||||
|
self.ggml_root = Path(ggml_root)
|
||||||
|
self.ops_dir = self.ggml_root / "docs" / "ops"
|
||||||
|
self.output_filename = output_filename
|
||||||
|
self.backend_support: dict[str, dict[str, list[bool]]] = defaultdict(
|
||||||
|
lambda: defaultdict(list)
|
||||||
|
)
|
||||||
|
self.all_operations: set[str] = set()
|
||||||
|
self.all_backends: set[str] = set()
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def parse_support_files(self) -> None:
|
||||||
|
if not self.ops_dir.exists():
|
||||||
|
self.logger.warning(f"ops directory not found: {self.ops_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.logger.info(f"Parsing support files from {self.ops_dir}...")
|
||||||
|
|
||||||
|
for support_file in self.ops_dir.glob("*.csv"):
|
||||||
|
self.logger.info(f" Reading: {support_file.name}")
|
||||||
|
self._parse_support_file(support_file)
|
||||||
|
|
||||||
|
def _parse_support_file(self, file_path: Path) -> None:
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", newline='') as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
|
||||||
|
for row in reader:
|
||||||
|
# Skip rows that don't have support mode
|
||||||
|
if row.get('test_mode') != 'support':
|
||||||
|
continue
|
||||||
|
|
||||||
|
backend_name = row.get('backend_name', '').strip()
|
||||||
|
operation = row.get('op_name', '').strip()
|
||||||
|
supported_str = row.get('error_message', '').strip() # "yes" or "no"
|
||||||
|
backend_reg_name = row.get('backend_reg_name', '').strip()
|
||||||
|
|
||||||
|
# Skip invalid or error operations
|
||||||
|
if not operation or not backend_name or operation in [
|
||||||
|
"CONTEXT_ERROR",
|
||||||
|
"BUILD_ERROR",
|
||||||
|
]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_supported = supported_str.lower() == "yes"
|
||||||
|
|
||||||
|
# Use backend_reg_name for grouping, fallback to backend_name
|
||||||
|
backend_key = backend_reg_name if backend_reg_name else backend_name
|
||||||
|
|
||||||
|
self.all_backends.add(backend_key)
|
||||||
|
self.backend_support[backend_key][operation].append(is_supported)
|
||||||
|
self.all_operations.add(operation)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f" Error parsing {file_path}: {e}")
|
||||||
|
|
||||||
|
def get_backend_support_status(self, backend: str, operation: str) -> str:
|
||||||
|
support_list = self.backend_support[backend].get(operation, [])
|
||||||
|
|
||||||
|
if not support_list:
|
||||||
|
return "unsupported"
|
||||||
|
|
||||||
|
all_supported = all(support_list)
|
||||||
|
any_supported = any(support_list)
|
||||||
|
|
||||||
|
if all_supported:
|
||||||
|
return "supported"
|
||||||
|
elif any_supported:
|
||||||
|
return "partially supported"
|
||||||
|
else:
|
||||||
|
return "unsupported"
|
||||||
|
|
||||||
|
def get_support_status(self, operation: str) -> str:
|
||||||
|
if operation not in self.all_operations:
|
||||||
|
return "unsupported"
|
||||||
|
|
||||||
|
support_count = 0
|
||||||
|
total_backends = len(self.all_backends)
|
||||||
|
|
||||||
|
for backend in self.all_backends:
|
||||||
|
if self.backend_support[backend].get(operation, False):
|
||||||
|
support_count += 1
|
||||||
|
|
||||||
|
if support_count == 0:
|
||||||
|
return "unsupported"
|
||||||
|
elif support_count == total_backends:
|
||||||
|
return "supported"
|
||||||
|
else:
|
||||||
|
return "partially supported"
|
||||||
|
|
||||||
|
def get_support_symbol(self, status: str) -> str:
|
||||||
|
symbols = {"supported": "✅", "partially supported": "🟡", "unsupported": "❌"}
|
||||||
|
return symbols.get(status, "❓")
|
||||||
|
|
||||||
|
def generate_markdown(self) -> str:
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
lines.append("# GGML Operations")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("List of GGML operations and backend support status.")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("Legend:")
|
||||||
|
lines.append("- ✅ Fully supported by this backend")
|
||||||
|
lines.append("- 🟡 Partially supported by this backend")
|
||||||
|
lines.append("- ❌ Not supported by this backend")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
backends = sorted(self.all_backends)
|
||||||
|
header = "| Operation |"
|
||||||
|
for backend in backends:
|
||||||
|
header += f" {backend} |"
|
||||||
|
|
||||||
|
separator = "|-----------|"
|
||||||
|
for _ in backends:
|
||||||
|
separator += "------|"
|
||||||
|
|
||||||
|
lines.append(header)
|
||||||
|
lines.append(separator)
|
||||||
|
|
||||||
|
sorted_operations = sorted(self.all_operations)
|
||||||
|
|
||||||
|
for operation in sorted_operations:
|
||||||
|
row = f"| {operation:>32} |"
|
||||||
|
|
||||||
|
for backend in backends:
|
||||||
|
status = self.get_backend_support_status(backend, operation)
|
||||||
|
if status == "supported":
|
||||||
|
symbol = "✅"
|
||||||
|
elif status == "partially supported":
|
||||||
|
symbol = "🟡"
|
||||||
|
else:
|
||||||
|
symbol = "❌"
|
||||||
|
row += f" {symbol} |"
|
||||||
|
|
||||||
|
lines.append(row)
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
self.logger.info("Parsing GGML operation support files...")
|
||||||
|
self.parse_support_files()
|
||||||
|
|
||||||
|
if not self.all_operations:
|
||||||
|
self.logger.error(
|
||||||
|
"No operations found. Make sure to run test-backend-ops support --output csv > docs/ops/file.csv first."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
f"Found {len(self.all_operations)} operations across {len(self.all_backends)} backends"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info("Generating markdown...")
|
||||||
|
markdown_content = self.generate_markdown()
|
||||||
|
|
||||||
|
docs_dir = self.ggml_root / "docs"
|
||||||
|
docs_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
ops_file = docs_dir / self.output_filename
|
||||||
|
with open(ops_file, "w") as f:
|
||||||
|
f.write(markdown_content)
|
||||||
|
|
||||||
|
self.logger.info(f"Generated: {ops_file}")
|
||||||
|
self.logger.info(f"Operations: {len(self.all_operations)}")
|
||||||
|
self.logger.info(f"Backends: {len(self.all_backends)}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
output_filename = sys.argv[1]
|
||||||
|
else:
|
||||||
|
output_filename = "ops.md"
|
||||||
|
|
||||||
|
generator = DocsGenerator(".", output_filename)
|
||||||
|
generator.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -317,10 +317,11 @@ enum test_mode {
|
|||||||
MODE_TEST,
|
MODE_TEST,
|
||||||
MODE_PERF,
|
MODE_PERF,
|
||||||
MODE_GRAD,
|
MODE_GRAD,
|
||||||
|
MODE_SUPPORT,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Output format support similar to llama-bench
|
// Output format support similar to llama-bench
|
||||||
enum output_formats { CONSOLE, SQL };
|
enum output_formats { CONSOLE, SQL, CSV };
|
||||||
|
|
||||||
static const char * output_format_str(output_formats format) {
|
static const char * output_format_str(output_formats format) {
|
||||||
switch (format) {
|
switch (format) {
|
||||||
@ -328,6 +329,8 @@ static const char * output_format_str(output_formats format) {
|
|||||||
return "console";
|
return "console";
|
||||||
case SQL:
|
case SQL:
|
||||||
return "sql";
|
return "sql";
|
||||||
|
case CSV:
|
||||||
|
return "csv";
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("invalid output format");
|
GGML_ABORT("invalid output format");
|
||||||
}
|
}
|
||||||
@ -338,6 +341,8 @@ static bool output_format_from_str(const std::string & s, output_formats & forma
|
|||||||
format = CONSOLE;
|
format = CONSOLE;
|
||||||
} else if (s == "sql") {
|
} else if (s == "sql") {
|
||||||
format = SQL;
|
format = SQL;
|
||||||
|
} else if (s == "csv") {
|
||||||
|
format = CSV;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -360,6 +365,8 @@ struct test_result {
|
|||||||
double bandwidth_gb_s;
|
double bandwidth_gb_s;
|
||||||
size_t memory_kb;
|
size_t memory_kb;
|
||||||
int n_runs;
|
int n_runs;
|
||||||
|
std::string device_description;
|
||||||
|
std::string backend_reg_name;
|
||||||
|
|
||||||
test_result() {
|
test_result() {
|
||||||
// Initialize with default values
|
// Initialize with default values
|
||||||
@ -384,7 +391,7 @@ struct test_result {
|
|||||||
test_result(const std::string & backend_name, const std::string & op_name, const std::string & op_params,
|
test_result(const std::string & backend_name, const std::string & op_name, const std::string & op_params,
|
||||||
const std::string & test_mode, bool supported, bool passed, const std::string & error_message = "",
|
const std::string & test_mode, bool supported, bool passed, const std::string & error_message = "",
|
||||||
double time_us = 0.0, double flops = 0.0, double bandwidth_gb_s = 0.0, size_t memory_kb = 0,
|
double time_us = 0.0, double flops = 0.0, double bandwidth_gb_s = 0.0, size_t memory_kb = 0,
|
||||||
int n_runs = 0) :
|
int n_runs = 0, const std::string & device_description = "", const std::string & backend_reg_name = "") :
|
||||||
backend_name(backend_name),
|
backend_name(backend_name),
|
||||||
op_name(op_name),
|
op_name(op_name),
|
||||||
op_params(op_params),
|
op_params(op_params),
|
||||||
@ -396,7 +403,9 @@ struct test_result {
|
|||||||
flops(flops),
|
flops(flops),
|
||||||
bandwidth_gb_s(bandwidth_gb_s),
|
bandwidth_gb_s(bandwidth_gb_s),
|
||||||
memory_kb(memory_kb),
|
memory_kb(memory_kb),
|
||||||
n_runs(n_runs) {
|
n_runs(n_runs),
|
||||||
|
device_description(device_description),
|
||||||
|
backend_reg_name(backend_reg_name) {
|
||||||
// Set test time
|
// Set test time
|
||||||
time_t t = time(NULL);
|
time_t t = time(NULL);
|
||||||
char buf[32];
|
char buf[32];
|
||||||
@ -410,7 +419,8 @@ struct test_result {
|
|||||||
static const std::vector<std::string> & get_fields() {
|
static const std::vector<std::string> & get_fields() {
|
||||||
static const std::vector<std::string> fields = {
|
static const std::vector<std::string> fields = {
|
||||||
"test_time", "build_commit", "backend_name", "op_name", "op_params", "test_mode", "supported",
|
"test_time", "build_commit", "backend_name", "op_name", "op_params", "test_mode", "supported",
|
||||||
"passed", "error_message", "time_us", "flops", "bandwidth_gb_s", "memory_kb", "n_runs"
|
"passed", "error_message", "time_us", "flops", "bandwidth_gb_s", "memory_kb", "n_runs",
|
||||||
|
"device_description", "backend_reg_name"
|
||||||
};
|
};
|
||||||
return fields;
|
return fields;
|
||||||
}
|
}
|
||||||
@ -444,7 +454,9 @@ struct test_result {
|
|||||||
std::to_string(flops),
|
std::to_string(flops),
|
||||||
std::to_string(bandwidth_gb_s),
|
std::to_string(bandwidth_gb_s),
|
||||||
std::to_string(memory_kb),
|
std::to_string(memory_kb),
|
||||||
std::to_string(n_runs) };
|
std::to_string(n_runs),
|
||||||
|
device_description,
|
||||||
|
backend_reg_name };
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -633,6 +645,8 @@ struct console_printer : public printer {
|
|||||||
print_test_console(result);
|
print_test_console(result);
|
||||||
} else if (result.test_mode == "perf") {
|
} else if (result.test_mode == "perf") {
|
||||||
print_perf_console(result);
|
print_perf_console(result);
|
||||||
|
} else if (result.test_mode == "support") {
|
||||||
|
print_support_console(result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -799,6 +813,17 @@ struct console_printer : public printer {
|
|||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void print_support_console(const test_result & result) {
|
||||||
|
printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str());
|
||||||
|
fflush(stdout);
|
||||||
|
|
||||||
|
if (result.supported) {
|
||||||
|
printf("\033[1;32mSUPPORTED\033[0m\n");
|
||||||
|
} else {
|
||||||
|
printf("\033[1;31mNOT SUPPORTED\033[0m\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct sql_printer : public printer {
|
struct sql_printer : public printer {
|
||||||
@ -841,12 +866,39 @@ struct sql_printer : public printer {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct csv_printer : public printer {
|
||||||
|
void print_header() override {
|
||||||
|
std::vector<std::string> fields = test_result::get_fields();
|
||||||
|
for (size_t i = 0; i < fields.size(); i++) {
|
||||||
|
printf("\"%s\"%s", fields[i].c_str(), i < fields.size() - 1 ? "," : "");
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_test_result(const test_result & result) override {
|
||||||
|
std::vector<std::string> values = result.get_values();
|
||||||
|
for (size_t i = 0; i < values.size(); i++) {
|
||||||
|
// Escape quotes and wrap in quotes for CSV
|
||||||
|
std::string escaped_value = values[i];
|
||||||
|
size_t pos = 0;
|
||||||
|
while ((pos = escaped_value.find("\"", pos)) != std::string::npos) {
|
||||||
|
escaped_value.replace(pos, 1, "\"\"");
|
||||||
|
pos += 2;
|
||||||
|
}
|
||||||
|
printf("\"%s\"%s", escaped_value.c_str(), i < values.size() - 1 ? "," : "");
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
static std::unique_ptr<printer> create_printer(output_formats format) {
|
static std::unique_ptr<printer> create_printer(output_formats format) {
|
||||||
switch (format) {
|
switch (format) {
|
||||||
case CONSOLE:
|
case CONSOLE:
|
||||||
return std::make_unique<console_printer>();
|
return std::make_unique<console_printer>();
|
||||||
case SQL:
|
case SQL:
|
||||||
return std::make_unique<sql_printer>();
|
return std::make_unique<sql_printer>();
|
||||||
|
case CSV:
|
||||||
|
return std::make_unique<csv_printer>();
|
||||||
}
|
}
|
||||||
GGML_ABORT("invalid output format");
|
GGML_ABORT("invalid output format");
|
||||||
}
|
}
|
||||||
@ -928,7 +980,7 @@ struct test_case {
|
|||||||
std::vector<ggml_tensor *> sentinels;
|
std::vector<ggml_tensor *> sentinels;
|
||||||
|
|
||||||
void add_sentinel(ggml_context * ctx) {
|
void add_sentinel(ggml_context * ctx) {
|
||||||
if (mode == MODE_PERF || mode == MODE_GRAD) {
|
if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size);
|
ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size);
|
||||||
@ -1153,15 +1205,12 @@ struct test_case {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if backends support op
|
|
||||||
if (!ggml_backend_supports_op(backend, out)) {
|
if (!ggml_backend_supports_op(backend, out)) {
|
||||||
// Create test result for unsupported performance test
|
// Create test result for unsupported performance test
|
||||||
test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf", false, false,
|
test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf", false, false,
|
||||||
"not supported");
|
"not supported");
|
||||||
|
|
||||||
if (output_printer) {
|
output_printer->print_test_result(result);
|
||||||
output_printer->print_test_result(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -1266,6 +1315,38 @@ struct test_case {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool eval_support(ggml_backend_t backend, const char * op_name, printer * output_printer) {
|
||||||
|
mode = MODE_SUPPORT;
|
||||||
|
|
||||||
|
static const size_t graph_nodes = 8192;
|
||||||
|
|
||||||
|
ggml_init_params params = {
|
||||||
|
/* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false),
|
||||||
|
/* .mem_base = */ NULL,
|
||||||
|
/* .no_alloc = */ true,
|
||||||
|
};
|
||||||
|
ggml_context_ptr ctx(ggml_init(params)); // smart ptr
|
||||||
|
GGML_ASSERT(ctx);
|
||||||
|
|
||||||
|
ggml_tensor * out = build_graph(ctx.get());
|
||||||
|
std::string current_op_name = op_desc(out);
|
||||||
|
if (op_name != nullptr && current_op_name != op_name) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool supported = ggml_backend_supports_op(backend, out);
|
||||||
|
|
||||||
|
std::string device_desc = ggml_backend_dev_description(ggml_backend_get_device(backend));
|
||||||
|
std::string backend_reg_name = ggml_backend_reg_name(ggml_backend_dev_backend_reg(ggml_backend_get_device(backend)));
|
||||||
|
|
||||||
|
test_result result(ggml_backend_name(backend), current_op_name, vars(), "support", supported, supported,
|
||||||
|
supported ? "yes" : "no", 0.0, 0.0, 0.0, 0, 0, device_desc, backend_reg_name);
|
||||||
|
|
||||||
|
output_printer->print_test_result(result);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool eval_grad(ggml_backend_t backend, const char * op_name, printer * output_printer) {
|
bool eval_grad(ggml_backend_t backend, const char * op_name, printer * output_printer) {
|
||||||
mode = MODE_GRAD;
|
mode = MODE_GRAD;
|
||||||
const std::vector<float> expect = grad_expect();
|
const std::vector<float> expect = grad_expect();
|
||||||
@ -5599,17 +5680,27 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (mode == MODE_SUPPORT) {
|
||||||
|
auto test_cases = make_test_cases_eval();
|
||||||
|
filter_test_cases(test_cases, params_filter);
|
||||||
|
for (auto & test : test_cases) {
|
||||||
|
test->eval_support(backend, op_name, output_printer);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
static void usage(char ** argv) {
|
static void usage(char ** argv) {
|
||||||
printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>] [--output <console|sql>]\n", argv[0]);
|
printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n", argv[0]);
|
||||||
printf(" valid modes:\n");
|
printf(" valid modes:\n");
|
||||||
printf(" - test (default, compare with CPU backend for correctness)\n");
|
printf(" - test (default, compare with CPU backend for correctness)\n");
|
||||||
printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
|
printf(" - grad (compare gradients from backpropagation with method of finite differences)\n");
|
||||||
printf(" - perf (performance evaluation)\n");
|
printf(" - perf (performance evaluation)\n");
|
||||||
|
printf(" - support (probe backend operation support)\n");
|
||||||
printf(" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n");
|
printf(" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n");
|
||||||
printf(" --output specifies output format (default: console)\n");
|
printf(" --output specifies output format (default: console, options: console, sql, csv)\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
@ -5626,6 +5717,8 @@ int main(int argc, char ** argv) {
|
|||||||
mode = MODE_PERF;
|
mode = MODE_PERF;
|
||||||
} else if (strcmp(argv[i], "grad") == 0) {
|
} else if (strcmp(argv[i], "grad") == 0) {
|
||||||
mode = MODE_GRAD;
|
mode = MODE_GRAD;
|
||||||
|
} else if (strcmp(argv[i], "support") == 0) {
|
||||||
|
mode = MODE_SUPPORT;
|
||||||
} else if (strcmp(argv[i], "-o") == 0) {
|
} else if (strcmp(argv[i], "-o") == 0) {
|
||||||
if (i + 1 < argc) {
|
if (i + 1 < argc) {
|
||||||
op_name_filter = argv[++i];
|
op_name_filter = argv[++i];
|
||||||
|
Reference in New Issue
Block a user