Spaces:
Running
Running
import contextlib | |
import importlib | |
import logging | |
import sys | |
import torch | |
import torch.testing | |
from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] | |
IS_WINDOWS, | |
TEST_WITH_CROSSREF, | |
TEST_WITH_TORCHDYNAMO, | |
TestCase as TorchTestCase, | |
) | |
from . import config, reset, utils | |
log = logging.getLogger(__name__) | |
def run_tests(needs=()): | |
from torch.testing._internal.common_utils import run_tests | |
if ( | |
TEST_WITH_TORCHDYNAMO | |
or IS_WINDOWS | |
or TEST_WITH_CROSSREF | |
or sys.version_info >= (3, 12) | |
): | |
return # skip testing | |
if isinstance(needs, str): | |
needs = (needs,) | |
for need in needs: | |
if need == "cuda" and not torch.cuda.is_available(): | |
return | |
else: | |
try: | |
importlib.import_module(need) | |
except ImportError: | |
return | |
run_tests() | |
class TestCase(TorchTestCase): | |
_exit_stack: contextlib.ExitStack | |
def tearDownClass(cls): | |
cls._exit_stack.close() | |
super().tearDownClass() | |
def setUpClass(cls): | |
super().setUpClass() | |
cls._exit_stack = contextlib.ExitStack() # type: ignore[attr-defined] | |
cls._exit_stack.enter_context( # type: ignore[attr-defined] | |
config.patch( | |
raise_on_ctx_manager_usage=True, | |
suppress_errors=False, | |
log_compilation_metrics=False, | |
), | |
) | |
def setUp(self): | |
self._prior_is_grad_enabled = torch.is_grad_enabled() | |
super().setUp() | |
reset() | |
utils.counters.clear() | |
def tearDown(self): | |
for k, v in utils.counters.items(): | |
print(k, v.most_common()) | |
reset() | |
utils.counters.clear() | |
super().tearDown() | |
if self._prior_is_grad_enabled is not torch.is_grad_enabled(): | |
log.warning("Running test changed grad mode") | |
torch.set_grad_enabled(self._prior_is_grad_enabled) | |