Spaces:
Running
Running
import warnings | |
import torch | |
from . import comm | |
from torch.autograd import Function | |
from torch._utils import _get_device_index | |
from typing import List, Optional | |
class Broadcast(Function): | |
def forward(ctx, target_gpus, *inputs): | |
assert all(i.device.type != 'cpu' for i in inputs), ( | |
'Broadcast function not implemented for CPU tensors' | |
) | |
target_gpus = [_get_device_index(x, True) for x in target_gpus] | |
ctx.target_gpus = target_gpus | |
if len(inputs) == 0: | |
return tuple() | |
ctx.num_inputs = len(inputs) | |
ctx.input_device = inputs[0].get_device() | |
outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus) | |
non_differentiables = [] | |
for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]): | |
if not input_requires_grad: | |
for output in outputs: | |
non_differentiables.append(output[idx]) | |
ctx.mark_non_differentiable(*non_differentiables) | |
return tuple([t for tensors in outputs for t in tensors]) | |
def backward(ctx, *grad_outputs): | |
return (None,) + ReduceAddCoalesced.apply(ctx.input_device, ctx.num_inputs, *grad_outputs) | |
class ReduceAddCoalesced(Function): | |
def forward(ctx, destination, num_inputs, *grads): | |
ctx.target_gpus = [grads[i].get_device() for i in range(0, len(grads), num_inputs)] | |
grads_ = [grads[i:i + num_inputs] | |
for i in range(0, len(grads), num_inputs)] | |
return comm.reduce_add_coalesced(grads_, destination) | |
def backward(ctx, *grad_outputs): | |
return (None, None,) + Broadcast.apply(ctx.target_gpus, *grad_outputs) | |
class Gather(Function): | |
def forward(ctx, target_device, dim, *inputs): | |
assert all(i.device.type != 'cpu' for i in inputs), ( | |
'Gather function not implemented for CPU tensors' | |
) | |
if (target_device == 'cpu'): | |
ctx.target_device = 'cpu' | |
else: | |
target_device = _get_device_index(target_device, True) | |
ctx.target_device = target_device | |
ctx.dim = dim | |
ctx.input_gpus = tuple(i.get_device() for i in inputs) | |
if all(t.dim() == 0 for t in inputs) and dim == 0: | |
inputs = tuple(t.view(1) for t in inputs) | |
warnings.warn('Was asked to gather along dimension 0, but all ' | |
'input tensors were scalars; will instead unsqueeze ' | |
'and return a vector.') | |
ctx.unsqueezed_scalar = True | |
else: | |
ctx.unsqueezed_scalar = False | |
ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs) | |
return comm.gather(inputs, ctx.dim, ctx.target_device) | |
def backward(ctx, grad_output): | |
scattered_grads = Scatter.apply(ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output) | |
if ctx.unsqueezed_scalar: | |
scattered_grads = tuple(g[0] for g in scattered_grads) | |
return (None, None) + scattered_grads | |
class Scatter(Function): | |
def forward(ctx, target_gpus, chunk_sizes, dim, input): | |
target_gpus = [_get_device_index(x, True) for x in target_gpus] | |
ctx.dim = dim | |
ctx.input_device = input.get_device() if input.device.type != "cpu" else -1 | |
streams = None | |
if torch.cuda.is_available() and ctx.input_device == -1: | |
# Perform CPU to GPU copies in a background stream | |
streams = [_get_stream(torch.device("cuda", device)) for device in target_gpus] | |
outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams) | |
# Synchronize with the copy stream | |
if streams is not None: | |
for i, output in enumerate(outputs): | |
with torch.cuda.device(target_gpus[i]): | |
main_stream = torch.cuda.current_stream() | |
main_stream.wait_stream(streams[i]) | |
output.record_stream(main_stream) | |
return outputs | |
def backward(ctx, *grad_output): | |
return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output) | |
# background streams used for copying | |
_streams: Optional[List[Optional[torch.Stream]]] = None | |
def _get_stream(device: torch.device): | |
"""Get a background stream for copying between CPU and target device.""" | |
global _streams | |
if device.type == "cpu": | |
return None | |
device_mod = getattr(torch, device.type, None) | |
if device_mod is None: | |
return None | |
if _streams is None: | |
_streams = [None] * device_mod.device_count() | |
if _streams[device.index] is None: | |
_streams[device.index] = device_mod.Stream(device.index) | |
return _streams[device.index] | |