Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/torch
/utils
/benchmark
/examples
/fuzzer.py
# mypy: allow-untyped-defs | |
"""Example of the Timer and Fuzzer APIs: | |
$ python -m examples.fuzzer | |
""" | |
import sys | |
import torch.utils.benchmark as benchmark_utils | |
def main(): | |
add_fuzzer = benchmark_utils.Fuzzer( | |
parameters=[ | |
[ | |
benchmark_utils.FuzzedParameter( | |
name=f"k{i}", | |
minval=16, | |
maxval=16 * 1024, | |
distribution="loguniform", | |
) for i in range(3) | |
], | |
benchmark_utils.FuzzedParameter( | |
name="d", | |
distribution={2: 0.6, 3: 0.4}, | |
), | |
], | |
tensors=[ | |
[ | |
benchmark_utils.FuzzedTensor( | |
name=name, | |
size=("k0", "k1", "k2"), | |
dim_parameter="d", | |
probability_contiguous=0.75, | |
min_elements=64 * 1024, | |
max_elements=128 * 1024, | |
) for name in ("x", "y") | |
], | |
], | |
seed=0, | |
) | |
n = 250 | |
measurements = [] | |
for i, (tensors, tensor_properties, _) in enumerate(add_fuzzer.take(n=n)): | |
x, x_order = tensors["x"], str(tensor_properties["x"]["order"]) | |
y, y_order = tensors["y"], str(tensor_properties["y"]["order"]) | |
shape = ", ".join(tuple(f'{i:>4}' for i in x.shape)) | |
description = "".join([ | |
f"{x.numel():>7} | {shape:<16} | ", | |
f"{'contiguous' if x.is_contiguous() else x_order:<12} | ", | |
f"{'contiguous' if y.is_contiguous() else y_order:<12} | ", | |
]) | |
timer = benchmark_utils.Timer( | |
stmt="x + y", | |
globals=tensors, | |
description=description, | |
) | |
measurements.append(timer.blocked_autorange(min_run_time=0.1)) | |
measurements[-1].metadata = {"numel": x.numel()} | |
print(f"\r{i + 1} / {n}", end="") | |
sys.stdout.flush() | |
print() | |
# More string munging to make pretty output. | |
print(f"Average attempts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}") | |
def time_fn(m): | |
return m.median / m.metadata["numel"] | |
measurements.sort(key=time_fn) | |
template = f"{{:>6}}{' ' * 19}Size Shape{' ' * 13}X order Y order\n{'-' * 80}" | |
print(template.format("Best:")) | |
for m in measurements[:15]: | |
print(f"{time_fn(m) * 1e9:>4.1f} ns / element {m.description}") | |
print("\n" + template.format("Worst:")) | |
for m in measurements[-15:]: | |
print(f"{time_fn(m) * 1e9:>4.1f} ns / element {m.description}") | |
if __name__ == "__main__": | |
main() | |