|
|
|
|
|
import logging |
|
from contextlib import contextmanager |
|
from functools import wraps |
|
import torch |
|
|
|
__all__ = ["retry_if_cuda_oom"] |
|
|
|
|
|
@contextmanager |
|
def _ignore_torch_cuda_oom(): |
|
""" |
|
A context which ignores CUDA OOM exception from pytorch. |
|
""" |
|
try: |
|
yield |
|
except RuntimeError as e: |
|
|
|
if "CUDA out of memory. " in str(e): |
|
pass |
|
else: |
|
raise |
|
|
|
|
|
def retry_if_cuda_oom(func): |
|
""" |
|
Makes a function retry itself after encountering |
|
pytorch's CUDA OOM error. |
|
It will first retry after calling `torch.cuda.empty_cache()`. |
|
|
|
If that still fails, it will then retry by trying to convert inputs to CPUs. |
|
In this case, it expects the function to dispatch to CPU implementation. |
|
The return values may become CPU tensors as well and it's user's |
|
responsibility to convert it back to CUDA tensor if needed. |
|
|
|
Args: |
|
func: a stateless callable that takes tensor-like objects as arguments |
|
|
|
Returns: |
|
a callable which retries `func` if OOM is encountered. |
|
|
|
Examples: |
|
:: |
|
output = retry_if_cuda_oom(some_torch_function)(input1, input2) |
|
# output may be on CPU even if inputs are on GPU |
|
|
|
Note: |
|
1. When converting inputs to CPU, it will only look at each argument and check |
|
if it has `.device` and `.to` for conversion. Nested structures of tensors |
|
are not supported. |
|
|
|
2. Since the function might be called more than once, it has to be |
|
stateless. |
|
""" |
|
|
|
def maybe_to_cpu(x): |
|
try: |
|
like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to") |
|
except AttributeError: |
|
like_gpu_tensor = False |
|
if like_gpu_tensor: |
|
return x.to(device="cpu") |
|
else: |
|
return x |
|
|
|
@wraps(func) |
|
def wrapped(*args, **kwargs): |
|
with _ignore_torch_cuda_oom(): |
|
return func(*args, **kwargs) |
|
|
|
|
|
torch.cuda.empty_cache() |
|
with _ignore_torch_cuda_oom(): |
|
return func(*args, **kwargs) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logger.info("Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func))) |
|
new_args = (maybe_to_cpu(x) for x in args) |
|
new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()} |
|
return func(*new_args, **new_kwargs) |
|
|
|
return wrapped |
|
|