|
|
|
from typing import List, Optional |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch import Tensor |
|
|
|
|
|
def _max_by_axis(the_list): |
|
|
|
maxes = the_list[0] |
|
for sublist in the_list[1:]: |
|
for index, item in enumerate(sublist): |
|
maxes[index] = max(maxes[index], item) |
|
return maxes |
|
|
|
|
|
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): |
|
|
|
if tensor_list[0].ndim == 3: |
|
|
|
max_size = _max_by_axis([list(img.shape) for img in tensor_list]) |
|
|
|
batch_shape = [len(tensor_list)] + max_size |
|
b, c, h, w = batch_shape |
|
dtype = tensor_list[0].dtype |
|
device = tensor_list[0].device |
|
tensor = torch.zeros(batch_shape, dtype=dtype, device=device) |
|
mask = torch.ones((b, h, w), dtype=torch.bool, device=device) |
|
for img, pad_img, m in zip(tensor_list, tensor, mask): |
|
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) |
|
m[: img.shape[1], :img.shape[2]] = False |
|
else: |
|
raise ValueError('not supported') |
|
return NestedTensor(tensor, mask) |
|
|
|
|
|
class NestedTensor(object): |
|
def __init__(self, tensors, mask: Optional[Tensor]): |
|
self.tensors = tensors |
|
self.mask = mask |
|
|
|
def to(self, device): |
|
|
|
cast_tensor = self.tensors.to(device) |
|
mask = self.mask |
|
if mask is not None: |
|
assert mask is not None |
|
cast_mask = mask.to(device) |
|
else: |
|
cast_mask = None |
|
return NestedTensor(cast_tensor, cast_mask) |
|
|
|
def decompose(self): |
|
return self.tensors, self.mask |
|
|
|
def __repr__(self): |
|
return str(self.tensors) |
|
|
|
|
|
def is_dist_avail_and_initialized(): |
|
if not dist.is_available(): |
|
return False |
|
if not dist.is_initialized(): |
|
return False |
|
return True |
|
|
|
|
|
def get_rank(): |
|
if not is_dist_avail_and_initialized(): |
|
return 0 |
|
return dist.get_rank() |
|
|
|
|
|
def is_main_process(): |
|
return get_rank() == 0 |