Spaces:
Running
Running
import threading | |
import torch | |
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast | |
from ..modules import Module | |
from torch.cuda._utils import _get_device_index | |
from torch.cuda.amp import autocast | |
from torch._utils import ExceptionWrapper | |
__all__ = ['get_a_var', 'parallel_apply'] | |
def get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]]) -> Optional[torch.Tensor]: | |
if isinstance(obj, torch.Tensor): | |
return obj | |
if isinstance(obj, (list, tuple)): | |
for result in map(get_a_var, obj): | |
if isinstance(result, torch.Tensor): | |
return result | |
if isinstance(obj, dict): | |
for result in map(get_a_var, obj.items()): | |
if isinstance(result, torch.Tensor): | |
return result | |
return None | |
def parallel_apply( | |
modules: Sequence[Module], | |
inputs: Sequence[Any], | |
kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None, | |
devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None, | |
) -> List[Any]: | |
r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`. | |
Args: | |
modules (Module): modules to be parallelized | |
inputs (tensor): inputs to the modules | |
devices (list of int or torch.device): CUDA devices | |
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and | |
:attr:`devices` (if given) should all have same length. Moreover, each | |
element of :attr:`inputs` can either be a single object as the only argument | |
to a module, or a collection of positional arguments. | |
""" | |
assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}' | |
if kwargs_tup is not None: | |
assert len(modules) == len(kwargs_tup) | |
else: | |
kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules) | |
if devices is not None: | |
assert len(modules) == len(devices) | |
else: | |
devices = [None] * len(modules) | |
devices = [_get_device_index(x, True) for x in devices] | |
streams = [torch.cuda.current_stream(x) for x in devices] | |
lock = threading.Lock() | |
results = {} | |
grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() | |
def _worker( | |
i: int, | |
module: Module, | |
input: Any, | |
kwargs: Dict[str, Any], | |
device: Optional[Union[int, torch.device]] = None, | |
stream: Optional[torch.cuda.Stream] = None, | |
) -> None: | |
torch.set_grad_enabled(grad_enabled) | |
if device is None: | |
t = get_a_var(input) | |
if t is None: | |
with lock: | |
results[i] = ExceptionWrapper( | |
where=f"in replica {i}, no device was provided and no tensor input was found; " | |
"device cannot be resolved") | |
return | |
device = t.get_device() | |
if stream is None: | |
stream = torch.cuda.current_stream(device) | |
try: | |
with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled): | |
# this also avoids accidental slicing of `input` if it is a Tensor | |
if not isinstance(input, (list, tuple)): | |
input = (input,) | |
output = module(*input, **kwargs) | |
with lock: | |
results[i] = output | |
except Exception: | |
with lock: | |
results[i] = ExceptionWrapper( | |
where=f"in replica {i} on device {device}") | |
if len(modules) > 1: | |
threads = [threading.Thread(target=_worker, | |
args=(i, module, input, kwargs, device, stream)) | |
for i, (module, input, kwargs, device, stream) in | |
enumerate(zip(modules, inputs, kwargs_tup, devices, streams))] | |
for thread in threads: | |
thread.start() | |
for thread in threads: | |
thread.join() | |
else: | |
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0]) | |
outputs = [] | |
for i in range(len(inputs)): | |
output = results[i] | |
if isinstance(output, ExceptionWrapper): | |
output.reraise() | |
outputs.append(output) | |
return outputs | |