Spaces:
Sleeping
Sleeping
import torch | |
from torch.nn import DataParallel | |
from torch.cuda._utils import _get_device_index | |
from torch.nn.parallel._functions import Scatter | |
from itertools import chain | |
def scatter_imbalance(inputs, target_gpus, dim=0): | |
r""" | |
Slices tensors into approximately equal chunks and | |
distributes them across given GPUs. Duplicates | |
references to objects that are not tensors. | |
""" | |
def scatter_map(obj): | |
if isinstance(obj, torch.Tensor): | |
if (len(target_gpus) == 4) and (obj.size(dim) == 22): | |
return Scatter.apply(target_gpus, (4, 6, 6, 6), dim, obj) | |
if (len(target_gpus) == 4) and (obj.size(dim) == 60): | |
return Scatter.apply(target_gpus, (12, 16, 16, 16), dim, obj) | |
elif (len(target_gpus) == 4) and (obj.size(dim) == 144): | |
return Scatter.apply(target_gpus, (24, 40, 40, 40), dim, obj) | |
elif (len(target_gpus) == 8) and (obj.size(dim) == 46): | |
return Scatter.apply(target_gpus, (4, 6, 6, 6, 6, 6, 6, 6), dim, obj) | |
elif (len(target_gpus) == 8) and (obj.size(dim) == 62): | |
return Scatter.apply(target_gpus, (6, 8, 8, 8, 8, 8, 8, 8), dim, obj) | |
elif (len(target_gpus) == 8) and (obj.size(dim) == 94): | |
return Scatter.apply(target_gpus, (10, 12, 12, 12, 12, 12, 12, 12), dim, obj) | |
elif (len(target_gpus) == 8) and (obj.size(dim) == 110): | |
return Scatter.apply(target_gpus, (12, 14, 14, 14, 14, 14, 14, 14), dim, obj) | |
elif (len(target_gpus) == 8) and (obj.size(dim) == 118): | |
return Scatter.apply(target_gpus, (13, 15, 15, 15, 15, 15, 15, 15), dim, obj) | |
elif (len(target_gpus) == 8) and (obj.size(dim) == 126): | |
return Scatter.apply(target_gpus, (14, 16, 16, 16, 16, 16, 16, 16), dim, obj) | |
elif (len(target_gpus) == 8) and (obj.size(dim) == 134): | |
return Scatter.apply(target_gpus, (15, 17, 17, 17, 17, 17, 17, 17), dim, obj) | |
elif (len(target_gpus) == 8) and (obj.size(dim) == 142): | |
return Scatter.apply(target_gpus, (16, 18, 18, 18, 18, 18, 18, 18), dim, obj) | |
elif (len(target_gpus) == 16) and (obj.size(dim) == 222): | |
return Scatter.apply(target_gpus, (12, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14), dim, obj) | |
return Scatter.apply(target_gpus, None, dim, obj) | |
if isinstance(obj, tuple) and len(obj) > 0: | |
return list(zip(*map(scatter_map, obj))) | |
if isinstance(obj, list) and len(obj) > 0: | |
return list(map(list, zip(*map(scatter_map, obj)))) | |
if isinstance(obj, dict) and len(obj) > 0: | |
return list(map(type(obj), zip(*map(scatter_map, obj.items())))) | |
return [obj for targets in target_gpus] | |
# After scatter_map is called, a scatter_map cell will exist. This cell | |
# has a reference to the actual function scatter_map, which has references | |
# to a closure that has a reference to the scatter_map cell (because the | |
# fn is recursive). To avoid this reference cycle, we set the function to | |
# None, clearing the cell | |
try: | |
return scatter_map(inputs) | |
finally: | |
scatter_map = None | |
def scatter_kwargs_imbalance(inputs, kwargs, target_gpus, dim=0): | |
r"""Scatter with support for kwargs dictionary""" | |
inputs = scatter_imbalance(inputs, target_gpus, dim) if inputs else [] | |
kwargs = scatter_imbalance(kwargs, target_gpus, dim) if kwargs else [] | |
if len(inputs) < len(kwargs): | |
inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) | |
elif len(kwargs) < len(inputs): | |
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) | |
inputs = tuple(inputs) | |
kwargs = tuple(kwargs) | |
return inputs, kwargs | |
class DataParallelImbalance(DataParallel): | |
def __init__(self, module, device_ids=None, output_device=None, dim=0): | |
super(DataParallelImbalance, self).__init__( | |
module, device_ids, output_device, dim) | |
if not torch.cuda.is_available(): | |
self.module = module | |
self.device_ids = [] | |
return | |
if device_ids is None: | |
device_ids = list(range(torch.cuda.device_count())) | |
if output_device is None: | |
output_device = device_ids[0] | |
if not all(t.is_cuda and t.device.index == device_ids[0] | |
for t in chain(module.parameters(), module.buffers())): | |
raise RuntimeError("module must have its parameters and buffers " | |
"on device %d (device_ids[0])" % device_ids[0]) | |
self.dim = dim | |
self.module = module | |
self.device_ids = list( | |
map(lambda x: _get_device_index(x, True), device_ids)) | |
self.output_device = _get_device_index(output_device, True) | |
if len(self.device_ids) == 1: | |
self.module.cuda(device_ids[0]) | |
def forward(self, *inputs, **kwargs): | |
if not self.device_ids: | |
return self.module(*inputs, **kwargs) | |
inputs, kwargs = self.scatter_imbalance( | |
inputs, kwargs, self.device_ids) | |
if len(self.device_ids) == 1: | |
return self.module(*inputs[0], **kwargs[0]) | |
replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) | |
outputs = self.parallel_apply(replicas, inputs, kwargs) | |
return self.gather(outputs, self.output_device) | |
def scatter_imbalance(self, inputs, kwargs, device_ids): | |
return scatter_kwargs_imbalance(inputs, kwargs, device_ids, dim=self.dim) | |