Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/torch
/utils
/benchmark
/examples
/compare.py
# mypy: allow-untyped-defs | |
"""Example of Timer and Compare APIs: | |
$ python -m examples.compare | |
""" | |
import pickle | |
import sys | |
import time | |
import torch | |
import torch.utils.benchmark as benchmark_utils | |
class FauxTorch: | |
"""Emulate different versions of pytorch. | |
In normal circumstances this would be done with multiple processes | |
writing serialized measurements, but this simplifies that model to | |
make the example clearer. | |
""" | |
def __init__(self, real_torch, extra_ns_per_element): | |
self._real_torch = real_torch | |
self._extra_ns_per_element = extra_ns_per_element | |
def extra_overhead(self, result): | |
# time.sleep has a ~65 us overhead, so only fake a | |
# per-element overhead if numel is large enough. | |
numel = int(result.numel()) | |
if numel > 5000: | |
time.sleep(numel * self._extra_ns_per_element * 1e-9) | |
return result | |
def add(self, *args, **kwargs): | |
return self.extra_overhead(self._real_torch.add(*args, **kwargs)) | |
def mul(self, *args, **kwargs): | |
return self.extra_overhead(self._real_torch.mul(*args, **kwargs)) | |
def cat(self, *args, **kwargs): | |
return self.extra_overhead(self._real_torch.cat(*args, **kwargs)) | |
def matmul(self, *args, **kwargs): | |
return self.extra_overhead(self._real_torch.matmul(*args, **kwargs)) | |
def main(): | |
tasks = [ | |
("add", "add", "torch.add(x, y)"), | |
("add", "add (extra +0)", "torch.add(x, y + zero)"), | |
] | |
serialized_results = [] | |
repeats = 2 | |
timers = [ | |
benchmark_utils.Timer( | |
stmt=stmt, | |
globals={ | |
"torch": torch if branch == "master" else FauxTorch(torch, overhead_ns), | |
"x": torch.ones((size, 4)), | |
"y": torch.ones((1, 4)), | |
"zero": torch.zeros(()), | |
}, | |
label=label, | |
sub_label=sub_label, | |
description=f"size: {size}", | |
env=branch, | |
num_threads=num_threads, | |
) | |
for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)] | |
for label, sub_label, stmt in tasks | |
for size in [1, 10, 100, 1000, 10000, 50000] | |
for num_threads in [1, 4] | |
] | |
for i, timer in enumerate(timers * repeats): | |
serialized_results.append(pickle.dumps( | |
timer.blocked_autorange(min_run_time=0.05) | |
)) | |
print(f"\r{i + 1} / {len(timers) * repeats}", end="") | |
sys.stdout.flush() | |
print() | |
comparison = benchmark_utils.Compare([ | |
pickle.loads(i) for i in serialized_results | |
]) | |
print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n") | |
comparison.print() | |
print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n") | |
comparison.trim_significant_figures() | |
comparison.colorize() | |
comparison.print() | |
if __name__ == "__main__": | |
main() | |