# Copyright (c) OpenMMLab. All rights reserved. import torch from torch.nn.parallel._functions import Scatter as OrigScatter from ._functions import Scatter from .data_container import DataContainer def scatter(inputs, target_gpus, dim=0): """Scatter inputs to target gpus. The only difference from original :func:`scatter` is to add support for :type:`~mmcv.parallel.DataContainer`. """ def scatter_map(obj): if isinstance(obj, torch.Tensor): if target_gpus != [-1]: return OrigScatter.apply(target_gpus, None, dim, obj) else: # for CPU inference we use self-implemented scatter return Scatter.forward(target_gpus, obj) if isinstance(obj, DataContainer): if obj.cpu_only: return obj.data else: return Scatter.forward(target_gpus, obj.data) if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: out = list(map(list, zip(*map(scatter_map, obj)))) return out if isinstance(obj, dict) and len(obj) > 0: out = list(map(type(obj), zip(*map(scatter_map, obj.items())))) return out return [obj for targets in target_gpus] # After scatter_map is called, a scatter_map cell will exist. This cell # has a reference to the actual function scatter_map, which has references # to a closure that has a reference to the scatter_map cell (because the # fn is recursive). To avoid this reference cycle, we set the function to # None, clearing the cell try: return scatter_map(inputs) finally: scatter_map = None def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): """Scatter with support for kwargs dictionary.""" inputs = scatter(inputs, target_gpus, dim) if inputs else [] kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] if len(inputs) < len(kwargs): inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) elif len(kwargs) < len(inputs): kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) inputs = tuple(inputs) kwargs = tuple(kwargs) return inputs, kwargs