Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| from torch.nn.parallel._functions import _get_stream | |
| def scatter(input, devices, streams=None): | |
| """Scatters tensor across multiple GPUs.""" | |
| if streams is None: | |
| streams = [None] * len(devices) | |
| if isinstance(input, list): | |
| chunk_size = (len(input) - 1) // len(devices) + 1 | |
| outputs = [ | |
| scatter(input[i], [devices[i // chunk_size]], | |
| [streams[i // chunk_size]]) for i in range(len(input)) | |
| ] | |
| return outputs | |
| elif isinstance(input, torch.Tensor): | |
| output = input.contiguous() | |
| # TODO: copy to a pinned buffer first (if copying from CPU) | |
| stream = streams[0] if output.numel() > 0 else None | |
| if devices != [-1]: | |
| with torch.cuda.device(devices[0]), torch.cuda.stream(stream): | |
| output = output.cuda(devices[0], non_blocking=True) | |
| else: | |
| # unsqueeze the first dimension thus the tensor's shape is the | |
| # same as those scattered with GPU. | |
| output = output.unsqueeze(0) | |
| return output | |
| else: | |
| raise Exception(f'Unknown type {type(input)}.') | |
| def synchronize_stream(output, devices, streams): | |
| if isinstance(output, list): | |
| chunk_size = len(output) // len(devices) | |
| for i in range(len(devices)): | |
| for j in range(chunk_size): | |
| synchronize_stream(output[i * chunk_size + j], [devices[i]], | |
| [streams[i]]) | |
| elif isinstance(output, torch.Tensor): | |
| if output.numel() != 0: | |
| with torch.cuda.device(devices[0]): | |
| main_stream = torch.cuda.current_stream() | |
| main_stream.wait_stream(streams[0]) | |
| output.record_stream(main_stream) | |
| else: | |
| raise Exception(f'Unknown type {type(output)}.') | |
| def get_input_device(input): | |
| if isinstance(input, list): | |
| for item in input: | |
| input_device = get_input_device(item) | |
| if input_device != -1: | |
| return input_device | |
| return -1 | |
| elif isinstance(input, torch.Tensor): | |
| return input.get_device() if input.is_cuda else -1 | |
| else: | |
| raise Exception(f'Unknown type {type(input)}.') | |
| class Scatter: | |
| def forward(target_gpus, input): | |
| input_device = get_input_device(input) | |
| streams = None | |
| if input_device == -1 and target_gpus != [-1]: | |
| # Perform CPU to GPU copies in a background stream | |
| streams = [_get_stream(device) for device in target_gpus] | |
| outputs = scatter(input, target_gpus, streams) | |
| # Synchronize with the copy stream | |
| if streams is not None: | |
| synchronize_stream(outputs, target_gpus, streams) | |
| return tuple(outputs) | |