|
import torch |
|
import torch.distributed as dist |
|
from torch.nn.modules import Module |
|
from torch.autograd import Variable |
|
|
|
def _flatten_dense_tensors(tensors): |
|
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of |
|
same dense type. |
|
Since inputs are dense, the resulting tensor will be a concatenated 1D |
|
buffer. Element-wise operation on this buffer will be equivalent to |
|
operating individually. |
|
Arguments: |
|
tensors (Iterable[Tensor]): dense tensors to flatten. |
|
Returns: |
|
A contiguous 1D buffer containing input tensors. |
|
""" |
|
if len(tensors) == 1: |
|
return tensors[0].contiguous().view(-1) |
|
flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) |
|
return flat |
|
|
|
def _unflatten_dense_tensors(flat, tensors): |
|
"""View a flat buffer using the sizes of tensors. Assume that tensors are of |
|
same dense type, and that flat is given by _flatten_dense_tensors. |
|
Arguments: |
|
flat (Tensor): flattened dense tensors to unflatten. |
|
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to |
|
unflatten flat. |
|
Returns: |
|
Unflattened dense tensors with sizes same as tensors and values from |
|
flat. |
|
""" |
|
outputs = [] |
|
offset = 0 |
|
for tensor in tensors: |
|
numel = tensor.numel() |
|
outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) |
|
offset += numel |
|
return tuple(outputs) |
|
|
|
|
|
''' |
|
This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py |
|
launcher included with this example. It assumes that your run is using multiprocess with 1 |
|
GPU/process, that the model is on the correct device, and that torch.set_device has been |
|
used to set the device. |
|
|
|
Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, |
|
and will be allreduced at the finish of the backward pass. |
|
''' |
|
class DistributedDataParallel(Module): |
|
|
|
def __init__(self, module): |
|
super(DistributedDataParallel, self).__init__() |
|
|
|
if not hasattr(dist, '_backend'): |
|
self.warn_on_half = True |
|
else: |
|
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False |
|
|
|
self.module = module |
|
|
|
for p in self.module.state_dict().values(): |
|
if not torch.is_tensor(p): |
|
continue |
|
dist.broadcast(p, 0) |
|
|
|
def allreduce_params(): |
|
if(self.needs_reduction): |
|
self.needs_reduction = False |
|
buckets = {} |
|
for param in self.module.parameters(): |
|
if param.requires_grad and param.grad is not None: |
|
tp = type(param.data) |
|
if tp not in buckets: |
|
buckets[tp] = [] |
|
buckets[tp].append(param) |
|
if self.warn_on_half: |
|
if torch.cuda.HalfTensor in buckets: |
|
print("WARNING: gloo dist backend for half parameters may be extremely slow." + |
|
" It is recommended to use the NCCL backend in this case. This currently requires" + |
|
"PyTorch built from top of tree master.") |
|
self.warn_on_half = False |
|
|
|
for tp in buckets: |
|
bucket = buckets[tp] |
|
grads = [param.grad.data for param in bucket] |
|
coalesced = _flatten_dense_tensors(grads) |
|
dist.all_reduce(coalesced) |
|
coalesced /= dist.get_world_size() |
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): |
|
buf.copy_(synced) |
|
|
|
for param in list(self.module.parameters()): |
|
def allreduce_hook(*unused): |
|
param._execution_engine.queue_callback(allreduce_params) |
|
if param.requires_grad: |
|
param.register_hook(allreduce_hook) |
|
|
|
def forward(self, *inputs, **kwargs): |
|
self.needs_reduction = True |
|
return self.module(*inputs, **kwargs) |
|
|
|
''' |
|
def _sync_buffers(self): |
|
buffers = list(self.module._all_buffers()) |
|
if len(buffers) > 0: |
|
# cross-node buffer sync |
|
flat_buffers = _flatten_dense_tensors(buffers) |
|
dist.broadcast(flat_buffers, 0) |
|
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): |
|
buf.copy_(synced) |
|
def train(self, mode=True): |
|
# Clear NCCL communicator and CUDA event cache of the default group ID, |
|
# These cache will be recreated at the later call. This is currently a |
|
# work-around for a potential NCCL deadlock. |
|
if dist._backend == dist.dist_backend.NCCL: |
|
dist._clear_group_cache() |
|
super(DistributedDataParallel, self).train(mode) |
|
self.module.train(mode) |
|
''' |
|
''' |
|
Modifies existing model to do gradient allreduce, but doesn't change class |
|
so you don't need "module" |
|
''' |
|
def apply_gradient_allreduce(module): |
|
if not hasattr(dist, '_backend'): |
|
module.warn_on_half = True |
|
else: |
|
module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False |
|
|
|
for p in module.state_dict().values(): |
|
if not torch.is_tensor(p): |
|
continue |
|
dist.broadcast(p, 0) |
|
|
|
def allreduce_params(): |
|
if(module.needs_reduction): |
|
module.needs_reduction = False |
|
buckets = {} |
|
for param in module.parameters(): |
|
if param.requires_grad and param.grad is not None: |
|
tp = param.data.dtype |
|
if tp not in buckets: |
|
buckets[tp] = [] |
|
buckets[tp].append(param) |
|
if module.warn_on_half: |
|
if torch.cuda.HalfTensor in buckets: |
|
print("WARNING: gloo dist backend for half parameters may be extremely slow." + |
|
" It is recommended to use the NCCL backend in this case. This currently requires" + |
|
"PyTorch built from top of tree master.") |
|
module.warn_on_half = False |
|
|
|
for tp in buckets: |
|
bucket = buckets[tp] |
|
grads = [param.grad.data for param in bucket] |
|
coalesced = _flatten_dense_tensors(grads) |
|
dist.all_reduce(coalesced) |
|
coalesced /= dist.get_world_size() |
|
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): |
|
buf.copy_(synced) |
|
|
|
for param in list(module.parameters()): |
|
def allreduce_hook(*unused): |
|
Variable._execution_engine.queue_callback(allreduce_params) |
|
if param.requires_grad: |
|
param.register_hook(allreduce_hook) |
|
|
|
def set_needs_reduction(self, input, output): |
|
self.needs_reduction = True |
|
|
|
module.register_forward_hook(set_needs_reduction) |
|
return module |
|
|