Spaces:
Runtime error
Runtime error
import os | |
import random | |
import unittest | |
from distutils.util import strtobool | |
import torch | |
from packaging import version | |
global_rng = random.Random() | |
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | |
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12") | |
if is_torch_higher_equal_than_1_12: | |
torch_device = "mps" if torch.backends.mps.is_available() else torch_device | |
def parse_flag_from_env(key, default=False): | |
try: | |
value = os.environ[key] | |
except KeyError: | |
# KEY isn't set, default to `default`. | |
_value = default | |
else: | |
# KEY is set, convert it to True or False. | |
try: | |
_value = strtobool(value) | |
except ValueError: | |
# More values are supported, but let's keep the message simple. | |
raise ValueError(f"If set, {key} must be yes or no.") | |
return _value | |
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) | |
def floats_tensor(shape, scale=1.0, rng=None, name=None): | |
"""Creates a random float32 tensor""" | |
if rng is None: | |
rng = global_rng | |
total_dims = 1 | |
for dim in shape: | |
total_dims *= dim | |
values = [] | |
for _ in range(total_dims): | |
values.append(rng.random() * scale) | |
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() | |
def slow(test_case): | |
""" | |
Decorator marking a test as slow. | |
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. | |
""" | |
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) | |