|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio |
|
import os |
|
import shutil |
|
import subprocess |
|
import sys |
|
import tempfile |
|
import unittest |
|
from contextlib import contextmanager |
|
from functools import partial |
|
from pathlib import Path |
|
from typing import List, Union |
|
from unittest import mock |
|
|
|
import torch |
|
|
|
from ..state import AcceleratorState, PartialState |
|
from ..utils import ( |
|
gather, |
|
is_bnb_available, |
|
is_comet_ml_available, |
|
is_datasets_available, |
|
is_deepspeed_available, |
|
is_mps_available, |
|
is_safetensors_available, |
|
is_tensorboard_available, |
|
is_timm_available, |
|
is_torch_version, |
|
is_tpu_available, |
|
is_transformers_available, |
|
is_wandb_available, |
|
is_xpu_available, |
|
str_to_bool, |
|
) |
|
|
|
|
|
def parse_flag_from_env(key, default=False): |
|
try: |
|
value = os.environ[key] |
|
except KeyError: |
|
|
|
_value = default |
|
else: |
|
|
|
try: |
|
_value = str_to_bool(value) |
|
except ValueError: |
|
|
|
raise ValueError(f"If set, {key} must be yes or no.") |
|
return _value |
|
|
|
|
|
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) |
|
|
|
|
|
def skip(test_case): |
|
"Decorator that skips a test unconditionally" |
|
return unittest.skip("Test was skipped")(test_case) |
|
|
|
|
|
def slow(test_case): |
|
""" |
|
Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a |
|
truthy value to run them. |
|
""" |
|
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) |
|
|
|
|
|
def require_cpu(test_case): |
|
""" |
|
Decorator marking a test that must be only ran on the CPU. These tests are skipped when a GPU is available. |
|
""" |
|
return unittest.skipUnless(not torch.cuda.is_available(), "test requires only a CPU")(test_case) |
|
|
|
|
|
def require_cuda(test_case): |
|
""" |
|
Decorator marking a test that requires CUDA. These tests are skipped when there are no GPU available. |
|
""" |
|
return unittest.skipUnless(torch.cuda.is_available(), "test requires a GPU")(test_case) |
|
|
|
|
|
def require_xpu(test_case): |
|
""" |
|
Decorator marking a test that requires XPU. These tests are skipped when there are no XPU available. |
|
""" |
|
return unittest.skipUnless(is_xpu_available(), "test requires a XPU")(test_case) |
|
|
|
|
|
def require_mps(test_case): |
|
""" |
|
Decorator marking a test that requires MPS backend. These tests are skipped when torch doesn't support `mps` |
|
backend. |
|
""" |
|
return unittest.skipUnless(is_mps_available(), "test requires a `mps` backend support in `torch`")(test_case) |
|
|
|
|
|
def require_huggingface_suite(test_case): |
|
""" |
|
Decorator marking a test that requires transformers and datasets. These tests are skipped when they are not. |
|
""" |
|
return unittest.skipUnless( |
|
is_transformers_available() and is_datasets_available(), "test requires the Hugging Face suite" |
|
)(test_case) |
|
|
|
|
|
def require_transformers(test_case): |
|
""" |
|
Decorator marking a test that requires transformers. These tests are skipped when they are not. |
|
""" |
|
return unittest.skipUnless(is_transformers_available(), "test requires the transformers library")(test_case) |
|
|
|
|
|
def require_timm(test_case): |
|
""" |
|
Decorator marking a test that requires transformers. These tests are skipped when they are not. |
|
""" |
|
return unittest.skipUnless(is_timm_available(), "test requires the timm library")(test_case) |
|
|
|
|
|
def require_bnb(test_case): |
|
""" |
|
Decorator marking a test that requires bitsandbytes. These tests are skipped when they are not. |
|
""" |
|
return unittest.skipUnless(is_bnb_available(), "test requires the bitsandbytes library")(test_case) |
|
|
|
|
|
def require_tpu(test_case): |
|
""" |
|
Decorator marking a test that requires TPUs. These tests are skipped when there are no TPUs available. |
|
""" |
|
return unittest.skipUnless(is_tpu_available(), "test requires TPU")(test_case) |
|
|
|
|
|
def require_single_gpu(test_case): |
|
""" |
|
Decorator marking a test that requires CUDA on a single GPU. These tests are skipped when there are no GPU |
|
available or number of GPUs is more than one. |
|
""" |
|
return unittest.skipUnless(torch.cuda.device_count() == 1, "test requires a GPU")(test_case) |
|
|
|
|
|
def require_single_xpu(test_case): |
|
""" |
|
Decorator marking a test that requires CUDA on a single XPU. These tests are skipped when there are no XPU |
|
available or number of xPUs is more than one. |
|
""" |
|
return unittest.skipUnless(torch.xpu.device_count() == 1, "test requires a XPU")(test_case) |
|
|
|
|
|
def require_multi_gpu(test_case): |
|
""" |
|
Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple |
|
GPUs. |
|
""" |
|
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case) |
|
|
|
|
|
def require_multi_xpu(test_case): |
|
""" |
|
Decorator marking a test that requires a multi-XPU setup. These tests are skipped on a machine without multiple |
|
XPUs. |
|
""" |
|
return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case) |
|
|
|
|
|
def require_safetensors(test_case): |
|
""" |
|
Decorator marking a test that requires safetensors installed. These tests are skipped when safetensors isn't |
|
installed |
|
""" |
|
return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case) |
|
|
|
|
|
def require_deepspeed(test_case): |
|
""" |
|
Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed |
|
""" |
|
return unittest.skipUnless(is_deepspeed_available(), "test requires DeepSpeed")(test_case) |
|
|
|
|
|
def require_fsdp(test_case): |
|
""" |
|
Decorator marking a test that requires FSDP installed. These tests are skipped when FSDP isn't installed |
|
""" |
|
return unittest.skipUnless(is_torch_version(">=", "1.12.0"), "test requires torch version >= 1.12.0")(test_case) |
|
|
|
|
|
def require_torch_min_version(test_case=None, version=None): |
|
""" |
|
Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an |
|
installed torch version is less than the required one. |
|
""" |
|
if test_case is None: |
|
return partial(require_torch_min_version, version=version) |
|
return unittest.skipUnless(is_torch_version(">=", version), f"test requires torch version >= {version}")(test_case) |
|
|
|
|
|
def require_tensorboard(test_case): |
|
""" |
|
Decorator marking a test that requires tensorboard installed. These tests are skipped when tensorboard isn't |
|
installed |
|
""" |
|
return unittest.skipUnless(is_tensorboard_available(), "test requires Tensorboard")(test_case) |
|
|
|
|
|
def require_wandb(test_case): |
|
""" |
|
Decorator marking a test that requires wandb installed. These tests are skipped when wandb isn't installed |
|
""" |
|
return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case) |
|
|
|
|
|
def require_comet_ml(test_case): |
|
""" |
|
Decorator marking a test that requires comet_ml installed. These tests are skipped when comet_ml isn't installed |
|
""" |
|
return unittest.skipUnless(is_comet_ml_available(), "test requires comet_ml")(test_case) |
|
|
|
|
|
_atleast_one_tracker_available = ( |
|
any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available() |
|
) |
|
|
|
|
|
def require_trackers(test_case): |
|
""" |
|
Decorator marking that a test requires at least one tracking library installed. These tests are skipped when none |
|
are installed |
|
""" |
|
return unittest.skipUnless( |
|
_atleast_one_tracker_available, |
|
"test requires at least one tracker to be available and for `comet_ml` to not be installed", |
|
)(test_case) |
|
|
|
|
|
class TempDirTestCase(unittest.TestCase): |
|
""" |
|
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its |
|
data at the start of a test, and then destroyes it at the end of the TestCase. |
|
|
|
Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases |
|
|
|
The temporary directory location will be stored in `self.tmpdir` |
|
""" |
|
|
|
clear_on_setup = True |
|
|
|
@classmethod |
|
def setUpClass(cls): |
|
"Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`" |
|
cls.tmpdir = tempfile.mkdtemp() |
|
|
|
@classmethod |
|
def tearDownClass(cls): |
|
"Remove `cls.tmpdir` after test suite has finished" |
|
if os.path.exists(cls.tmpdir): |
|
shutil.rmtree(cls.tmpdir) |
|
|
|
def setUp(self): |
|
"Destroy all contents in `self.tmpdir`, but not `self.tmpdir`" |
|
if self.clear_on_setup: |
|
for path in Path(self.tmpdir).glob("**/*"): |
|
if path.is_file(): |
|
path.unlink() |
|
elif path.is_dir(): |
|
shutil.rmtree(path) |
|
|
|
|
|
class AccelerateTestCase(unittest.TestCase): |
|
""" |
|
A TestCase class that will reset the accelerator state at the end of every test. Every test that checks or utilizes |
|
the `AcceleratorState` class should inherit from this to avoid silent failures due to state being shared between |
|
tests. |
|
""" |
|
|
|
def tearDown(self): |
|
super().tearDown() |
|
|
|
AcceleratorState._reset_state() |
|
PartialState._reset_state() |
|
|
|
|
|
class MockingTestCase(unittest.TestCase): |
|
""" |
|
A TestCase class designed to dynamically add various mockers that should be used in every test, mimicking the |
|
behavior of a class-wide mock when defining one normally will not do. |
|
|
|
Useful when a mock requires specific information available only initialized after `TestCase.setUpClass`, such as |
|
setting an environment variable with that information. |
|
|
|
The `add_mocks` function should be ran at the end of a `TestCase`'s `setUp` function, after a call to |
|
`super().setUp()` such as: |
|
```python |
|
def setUp(self): |
|
super().setUp() |
|
mocks = mock.patch.dict(os.environ, {"SOME_ENV_VAR", "SOME_VALUE"}) |
|
self.add_mocks(mocks) |
|
``` |
|
""" |
|
|
|
def add_mocks(self, mocks: Union[mock.Mock, List[mock.Mock]]): |
|
""" |
|
Add custom mocks for tests that should be repeated on each test. Should be called during |
|
`MockingTestCase.setUp`, after `super().setUp()`. |
|
|
|
Args: |
|
mocks (`mock.Mock` or list of `mock.Mock`): |
|
Mocks that should be added to the `TestCase` after `TestCase.setUpClass` has been run |
|
""" |
|
self.mocks = mocks if isinstance(mocks, (tuple, list)) else [mocks] |
|
for m in self.mocks: |
|
m.start() |
|
self.addCleanup(m.stop) |
|
|
|
|
|
def are_the_same_tensors(tensor): |
|
state = AcceleratorState() |
|
tensor = tensor[None].clone().to(state.device) |
|
tensors = gather(tensor).cpu() |
|
tensor = tensor[0].cpu() |
|
for i in range(tensors.shape[0]): |
|
if not torch.equal(tensors[i], tensor): |
|
return False |
|
return True |
|
|
|
|
|
class _RunOutput: |
|
def __init__(self, returncode, stdout, stderr): |
|
self.returncode = returncode |
|
self.stdout = stdout |
|
self.stderr = stderr |
|
|
|
|
|
async def _read_stream(stream, callback): |
|
while True: |
|
line = await stream.readline() |
|
if line: |
|
callback(line) |
|
else: |
|
break |
|
|
|
|
|
async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput: |
|
if echo: |
|
print("\nRunning: ", " ".join(cmd)) |
|
|
|
p = await asyncio.create_subprocess_exec( |
|
cmd[0], |
|
*cmd[1:], |
|
stdin=stdin, |
|
stdout=asyncio.subprocess.PIPE, |
|
stderr=asyncio.subprocess.PIPE, |
|
env=env, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = [] |
|
err = [] |
|
|
|
def tee(line, sink, pipe, label=""): |
|
line = line.decode("utf-8").rstrip() |
|
sink.append(line) |
|
if not quiet: |
|
print(label, line, file=pipe) |
|
|
|
|
|
await asyncio.wait( |
|
[ |
|
asyncio.create_task(_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:"))), |
|
asyncio.create_task(_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:"))), |
|
], |
|
timeout=timeout, |
|
) |
|
return _RunOutput(await p.wait(), out, err) |
|
|
|
|
|
def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput: |
|
loop = asyncio.get_event_loop() |
|
result = loop.run_until_complete( |
|
_stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo) |
|
) |
|
|
|
cmd_str = " ".join(cmd) |
|
if result.returncode > 0: |
|
stderr = "\n".join(result.stderr) |
|
raise RuntimeError( |
|
f"'{cmd_str}' failed with returncode {result.returncode}\n\n" |
|
f"The combined stderr from workers follows:\n{stderr}" |
|
) |
|
|
|
return result |
|
|
|
|
|
class SubprocessCallException(Exception): |
|
pass |
|
|
|
|
|
def run_command(command: List[str], return_stdout=False): |
|
""" |
|
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture |
|
if an error occured while running `command` |
|
""" |
|
try: |
|
output = subprocess.check_output(command, stderr=subprocess.STDOUT) |
|
if return_stdout: |
|
if hasattr(output, "decode"): |
|
output = output.decode("utf-8") |
|
return output |
|
except subprocess.CalledProcessError as e: |
|
raise SubprocessCallException( |
|
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" |
|
) from e |
|
|
|
|
|
@contextmanager |
|
def assert_exception(exception_class: Exception, msg: str = None) -> bool: |
|
""" |
|
Context manager to assert that the right `Exception` class was raised. |
|
|
|
If `msg` is provided, will check that the message is contained in the raised exception. |
|
""" |
|
was_ran = False |
|
try: |
|
yield |
|
was_ran = True |
|
except Exception as e: |
|
assert isinstance(e, exception_class), f"Expected exception of type {exception_class} but got {type(e)}" |
|
if msg is not None: |
|
assert msg in str(e), f"Expected message '{msg}' to be in exception but got '{str(e)}'" |
|
if was_ran: |
|
raise AssertionError(f"Expected exception of type {exception_class} but ran without issue.") |
|
|