| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Encoding Data Parallel""" |
| | import threading |
| | import functools |
| | import torch |
| | from torch.autograd import Variable, Function |
| | import torch.cuda.comm as comm |
| | from torch.nn.parallel.data_parallel import DataParallel |
| | from torch.nn.parallel.parallel_apply import get_a_var |
| | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast |
| |
|
| | torch_ver = torch.__version__[:3] |
| |
|
| | __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', 'patch_replication_callback'] |
| |
|
| | def allreduce(*inputs): |
| | """Cross GPU all reduce autograd operation for calculate mean and |
| | variance in SyncBN. |
| | """ |
| | return AllReduce.apply(*inputs) |
| |
|
| | class AllReduce(Function): |
| | @staticmethod |
| | def forward(ctx, num_inputs, *inputs): |
| | ctx.num_inputs = num_inputs |
| | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] |
| | inputs = [inputs[i:i + num_inputs] |
| | for i in range(0, len(inputs), num_inputs)] |
| | |
| | inputs = sorted(inputs, key=lambda i: i[0].get_device()) |
| | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) |
| | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) |
| | return tuple([t for tensors in outputs for t in tensors]) |
| |
|
| | @staticmethod |
| | def backward(ctx, *inputs): |
| | inputs = [i.data for i in inputs] |
| | inputs = [inputs[i:i + ctx.num_inputs] |
| | for i in range(0, len(inputs), ctx.num_inputs)] |
| | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) |
| | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) |
| | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) |
| |
|
| | class Reduce(Function): |
| | @staticmethod |
| | def forward(ctx, *inputs): |
| | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] |
| | inputs = sorted(inputs, key=lambda i: i.get_device()) |
| | return comm.reduce_add(inputs) |
| |
|
| | @staticmethod |
| | def backward(ctx, gradOutput): |
| | return Broadcast.apply(ctx.target_gpus, gradOutput) |
| |
|
| |
|
| | class DataParallelModel(DataParallel): |
| | """Implements data parallelism at the module level. |
| | |
| | This container parallelizes the application of the given module by |
| | splitting the input across the specified devices by chunking in the |
| | batch dimension. |
| | In the forward pass, the module is replicated on each device, |
| | and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. |
| | Note that the outputs are not gathered, please use compatible |
| | :class:`encoding.parallel.DataParallelCriterion`. |
| | |
| | The batch size should be larger than the number of GPUs used. It should |
| | also be an integer multiple of the number of GPUs so that each chunk is |
| | the same size (so that each GPU processes the same number of samples). |
| | |
| | Args: |
| | module: module to be parallelized |
| | device_ids: CUDA devices (default: all devices) |
| | |
| | Reference: |
| | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, |
| | Amit Agrawal. “Context Encoding for Semantic Segmentation. |
| | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* |
| | |
| | Example:: |
| | |
| | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) |
| | >>> y = net(x) |
| | """ |
| | def gather(self, outputs, output_device): |
| | return outputs |
| |
|
| | def replicate(self, module, device_ids): |
| | modules = super(DataParallelModel, self).replicate(module, device_ids) |
| | return modules |
| |
|
| |
|
| | class DataParallelCriterion(DataParallel): |
| | """ |
| | Calculate loss in multiple-GPUs, which balance the memory usage for |
| | Semantic Segmentation. |
| | |
| | The targets are splitted across the specified devices by chunking in |
| | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. |
| | |
| | Reference: |
| | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, |
| | Amit Agrawal. “Context Encoding for Semantic Segmentation. |
| | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* |
| | |
| | Example:: |
| | |
| | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) |
| | >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) |
| | >>> y = net(x) |
| | >>> loss = criterion(y, target) |
| | """ |
| | def forward(self, inputs, *targets, **kwargs): |
| | |
| | |
| | if not self.device_ids: |
| | return self.module(inputs, *targets, **kwargs) |
| | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) |
| | if len(self.device_ids) == 1: |
| | return self.module(inputs, *targets[0], **kwargs[0]) |
| | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) |
| | outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) |
| | return Reduce.apply(*outputs) / len(outputs) |
| |
|
| |
|
| | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): |
| | assert len(modules) == len(inputs) |
| | assert len(targets) == len(inputs) |
| | if kwargs_tup: |
| | assert len(modules) == len(kwargs_tup) |
| | else: |
| | kwargs_tup = ({},) * len(modules) |
| | if devices is not None: |
| | assert len(modules) == len(devices) |
| | else: |
| | devices = [None] * len(modules) |
| |
|
| | lock = threading.Lock() |
| | results = {} |
| | if torch_ver != "0.3": |
| | grad_enabled = torch.is_grad_enabled() |
| |
|
| | def _worker(i, module, input, target, kwargs, device=None): |
| | if torch_ver != "0.3": |
| | torch.set_grad_enabled(grad_enabled) |
| | if device is None: |
| | device = get_a_var(input).get_device() |
| | try: |
| | if not isinstance(input, tuple): |
| | input = (input,) |
| | with torch.cuda.device(device): |
| | output = module(*(input + target), **kwargs) |
| | with lock: |
| | results[i] = output |
| | except Exception as e: |
| | with lock: |
| | results[i] = e |
| |
|
| | if len(modules) > 1: |
| | threads = [threading.Thread(target=_worker, |
| | args=(i, module, input, target, |
| | kwargs, device),) |
| | for i, (module, input, target, kwargs, device) in |
| | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] |
| |
|
| | for thread in threads: |
| | thread.start() |
| | for thread in threads: |
| | thread.join() |
| | else: |
| | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) |
| |
|
| | outputs = [] |
| | for i in range(len(inputs)): |
| | output = results[i] |
| | if isinstance(output, Exception): |
| | raise output |
| | outputs.append(output) |
| | return outputs |
| |
|