# -*- coding: utf8 -*- import torch.cuda as cuda import torch.nn as nn import torch import collections from torch.nn.parallel._functions import Gather __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] def async_copy_to(obj, dev, main_stream=None): if torch.is_tensor(obj): v = obj.cuda(dev, non_blocking=True) if main_stream is not None: v.data.record_stream(main_stream) return v elif isinstance(obj, collections.Mapping): return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} elif isinstance(obj, collections.Sequence): return [async_copy_to(o, dev, main_stream) for o in obj] else: return obj def dict_gather(outputs, target_device, dim=0): """ Gathers variables from different GPUs on a specified device (-1 means the CPU), with dictionary support. """ def gather_map(outputs): out = outputs[0] if torch.is_tensor(out): # MJY(20180330) HACK:: force nr_dims > 0 if out.dim() == 0: outputs = [o.unsqueeze(0) for o in outputs] return Gather.apply(target_device, dim, *outputs) elif out is None: return None elif isinstance(out, collections.Mapping): return {k: gather_map([o[k] for o in outputs]) for k in out} elif isinstance(out, collections.Sequence): return type(out)(map(gather_map, zip(*outputs))) return gather_map(outputs) class DictGatherDataParallel(nn.DataParallel): def gather(self, outputs, output_device): return dict_gather(outputs, output_device, dim=self.dim) class UserScatteredDataParallel(DictGatherDataParallel): def scatter(self, inputs, kwargs, device_ids): assert len(inputs) == 1 inputs = inputs[0] inputs = _async_copy_stream(inputs, device_ids) inputs = [[i] for i in inputs] assert len(kwargs) == 0 kwargs = [{} for _ in range(len(inputs))] return inputs, kwargs def user_scattered_collate(batch): return batch def _async_copy(inputs, device_ids): nr_devs = len(device_ids) assert type(inputs) in (tuple, list) assert len(inputs) == nr_devs outputs = [] for i, dev in zip(inputs, device_ids): with cuda.device(dev): outputs.append(async_copy_to(i, dev)) return tuple(outputs) def _async_copy_stream(inputs, device_ids): nr_devs = len(device_ids) assert type(inputs) in (tuple, list) assert len(inputs) == nr_devs outputs = [] streams = [_get_stream(d) for d in device_ids] for i, dev, stream in zip(inputs, device_ids, streams): with cuda.device(dev): main_stream = cuda.current_stream() with cuda.stream(stream): outputs.append(async_copy_to(i, dev, main_stream=main_stream)) main_stream.wait_stream(stream) return outputs """Adapted from: torch/nn/parallel/_functions.py""" # background streams used for copying _streams = None def _get_stream(device): """Gets a background stream for copying between CPU and GPU""" global _streams if device == -1: return None if _streams is None: _streams = [None] * cuda.device_count() if _streams[device] is None: _streams[device] = cuda.Stream(device) return _streams[device]