|
import os |
|
import random |
|
import unittest |
|
from distutils.util import strtobool |
|
|
|
import torch |
|
|
|
from packaging import version |
|
|
|
|
|
global_rng = random.Random() |
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12") |
|
|
|
if is_torch_higher_equal_than_1_12: |
|
torch_device = "mps" if torch.backends.mps.is_available() else torch_device |
|
|
|
|
|
def parse_flag_from_env(key, default=False): |
|
try: |
|
value = os.environ[key] |
|
except KeyError: |
|
|
|
_value = default |
|
else: |
|
|
|
try: |
|
_value = strtobool(value) |
|
except ValueError: |
|
|
|
raise ValueError(f"If set, {key} must be yes or no.") |
|
return _value |
|
|
|
|
|
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) |
|
|
|
|
|
def floats_tensor(shape, scale=1.0, rng=None, name=None): |
|
"""Creates a random float32 tensor""" |
|
if rng is None: |
|
rng = global_rng |
|
|
|
total_dims = 1 |
|
for dim in shape: |
|
total_dims *= dim |
|
|
|
values = [] |
|
for _ in range(total_dims): |
|
values.append(rng.random() * scale) |
|
|
|
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() |
|
|
|
|
|
def slow(test_case): |
|
""" |
|
Decorator marking a test as slow. |
|
|
|
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. |
|
|
|
""" |
|
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) |
|
|