| |
| import torch |
| import triton |
| from typing import Callable, List, Union, Tuple, Dict |
| import json |
| import os |
| import time |
|
|
| |
| class do_bench_config(): |
| def __init__( |
| self, |
| warm_up: int = 25, |
| repetition: int = 100, |
| quantiles: List[float] = None, |
| return_mode: str = "median" |
| ): |
| self.warm_up = warm_up |
| self.repetition = repetition |
| self.quantiles = quantiles if quantiles is not None else [0.5, 0.8, 0.2] |
| self.return_mode = return_mode |
|
|
| |
| |
| |
| |
| |
| |
|
|
| PYTEST_BENCHMARK_RESULTS = {} |
|
|
| def add_benchmark_result(op_name: str, result_dict: Dict): |
| if op_name not in PYTEST_BENCHMARK_RESULTS: |
| PYTEST_BENCHMARK_RESULTS[op_name] = [] |
| PYTEST_BENCHMARK_RESULTS[op_name].append(result_dict) |
|
|
| def save_all_benchmark_results(output_dir: str): |
| os.makedirs(output_dir, exist_ok=True) |
| for op_name, results_list in PYTEST_BENCHMARK_RESULTS.items(): |
| if results_list: |
| file_name = op_name + ".json" |
| file_path = os.path.join(output_dir, file_name) |
| try: |
| with open(file_path, 'w', encoding='utf8') as f: |
| json.dump(results_list, f, indent=4, ensure_ascii=False) |
| print(f"Benchmark results for {op_name} saved to: {file_path}") |
| except IOError as e: |
| print(f"Error saving results for {op_name} to {file_path}: {e}") |
| PYTEST_BENCHMARK_RESULTS.clear() |
|
|
| |
| class PytestBenchmarker: |
| def __init__(self, op_callable: Callable, op_name: str, config: do_bench_config = None): |
| self.op_callable = op_callable |
| self.op_name = op_name |
| self.config = config if config else do_bench_config() |
|
|
| def run_benchmark(self, current_params_dict: Dict, |
| gbps_calculator: Callable = None, |
| tflops_calculator: Callable = None |
| ): |
| """ |
| Runs the benchmark for the op_callable (which should be a lambda with inputs already bound). |
| current_params_dict: Dictionary describing the current pytest parameters. |
| gbps_calculator: A function that takes (original_inputs_tuple, ms) and returns GB/s. |
| tflops_calculator: A function that takes (original_inputs_tuple, ms) and returns TFLOPS. |
| """ |
| try: |
| ms, min_ms, max_ms = triton.testing.do_bench( |
| self.op_callable, |
| warmup=self.config.warm_up, |
| rep=self.config.repetition, |
| quantiles=self.config.quantiles, |
| return_mode=self.config.return_mode |
| ) |
|
|
| gbps = "N/A" |
| if gbps_calculator: |
| try: |
| |
| |
| |
| gbps = gbps_calculator(current_params_dict, ms) |
| except Exception as e_gbps: |
| print(f"Warning: GB/s calculation failed for {self.op_name} with params {current_params_dict}: {e_gbps}") |
|
|
|
|
| tflops = "N/A" |
| if tflops_calculator: |
| try: |
| tflops = tflops_calculator(current_params_dict, ms) |
| except Exception as e_tflops: |
| print(f"Warning: TFLOPS calculation failed for {self.op_name} with params {current_params_dict}: {e_tflops}") |
|
|
|
|
| result = { |
| "params": current_params_dict, |
| "ms": round(ms, 4), |
| "min_ms": round(min_ms, 4), |
| "max_ms": round(max_ms, 4), |
| "GB/s": round(gbps, 2) if isinstance(gbps, (float, int)) else gbps, |
| "TFLOPS": round(tflops, 2) if isinstance(tflops, (float, int)) else tflops, |
| } |
| |
| add_benchmark_result(self.op_name, result) |
| |
| return result |
|
|
| except Exception as e: |
| print(f"Error during benchmark for {self.op_name} with params {current_params_dict}: {e}") |
| error_result = { |
| "params": current_params_dict, |
| "error": str(e) |
| } |
| add_benchmark_result(self.op_name, error_result) |
| return error_result |