|
|
|
import torch
|
|
from torch.nn.parallel._functions import _get_stream
|
|
|
|
|
|
def scatter(input, devices, streams=None):
|
|
"""Scatters tensor across multiple GPUs."""
|
|
if streams is None:
|
|
streams = [None] * len(devices)
|
|
|
|
if isinstance(input, list):
|
|
chunk_size = (len(input) - 1) // len(devices) + 1
|
|
outputs = [
|
|
scatter(input[i], [devices[i // chunk_size]],
|
|
[streams[i // chunk_size]]) for i in range(len(input))
|
|
]
|
|
return outputs
|
|
elif isinstance(input, torch.Tensor):
|
|
output = input.contiguous()
|
|
|
|
stream = streams[0] if output.numel() > 0 else None
|
|
if devices != [-1]:
|
|
with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
|
|
output = output.cuda(devices[0], non_blocking=True)
|
|
else:
|
|
|
|
|
|
output = output.unsqueeze(0)
|
|
return output
|
|
else:
|
|
raise Exception(f'Unknown type {type(input)}.')
|
|
|
|
|
|
def synchronize_stream(output, devices, streams):
|
|
if isinstance(output, list):
|
|
chunk_size = len(output) // len(devices)
|
|
for i in range(len(devices)):
|
|
for j in range(chunk_size):
|
|
synchronize_stream(output[i * chunk_size + j], [devices[i]],
|
|
[streams[i]])
|
|
elif isinstance(output, torch.Tensor):
|
|
if output.numel() != 0:
|
|
with torch.cuda.device(devices[0]):
|
|
main_stream = torch.cuda.current_stream()
|
|
main_stream.wait_stream(streams[0])
|
|
output.record_stream(main_stream)
|
|
else:
|
|
raise Exception(f'Unknown type {type(output)}.')
|
|
|
|
|
|
def get_input_device(input):
|
|
if isinstance(input, list):
|
|
for item in input:
|
|
input_device = get_input_device(item)
|
|
if input_device != -1:
|
|
return input_device
|
|
return -1
|
|
elif isinstance(input, torch.Tensor):
|
|
return input.get_device() if input.is_cuda else -1
|
|
else:
|
|
raise Exception(f'Unknown type {type(input)}.')
|
|
|
|
|
|
class Scatter:
|
|
|
|
@staticmethod
|
|
def forward(target_gpus, input):
|
|
input_device = get_input_device(input)
|
|
streams = None
|
|
if input_device == -1 and target_gpus != [-1]:
|
|
|
|
streams = [_get_stream(device) for device in target_gpus]
|
|
|
|
outputs = scatter(input, target_gpus, streams)
|
|
|
|
if streams is not None:
|
|
synchronize_stream(outputs, target_gpus, streams)
|
|
|
|
return tuple(outputs)
|
|
|