|
""" |
|
Misc functions, including distributed helpers. |
|
|
|
Mostly copy-paste from torchvision references. |
|
""" |
|
from typing import List, Optional |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torchvision |
|
from torch import Tensor |
|
|
|
if float(torchvision.__version__.split(".")[1]) < 7.0: |
|
from torchvision.ops import _new_empty_tensor |
|
from torchvision.ops.misc import _output_size |
|
|
|
|
|
def _max_by_axis(the_list): |
|
|
|
maxes = the_list[0] |
|
for sublist in the_list[1:]: |
|
for index, item in enumerate(sublist): |
|
maxes[index] = max(maxes[index], item) |
|
return maxes |
|
|
|
|
|
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): |
|
|
|
""" |
|
Equivalent to nn.functional.interpolate, but with support for empty batch sizes. |
|
This will eventually be supported natively by PyTorch, and this |
|
class can go away. |
|
""" |
|
if float(torchvision.__version__.split(".")[1]) < 7.0: |
|
if input.numel() > 0: |
|
return torch.nn.functional.interpolate( |
|
input, size, scale_factor, mode, align_corners |
|
) |
|
|
|
output_shape = _output_size(2, input, size, scale_factor) |
|
output_shape = list(input.shape[:-2]) + list(output_shape) |
|
return _new_empty_tensor(input, output_shape) |
|
else: |
|
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) |
|
|
|
|
|
class NestedTensor(object): |
|
def __init__(self, tensors, mask: Optional[Tensor]): |
|
self.tensors = tensors |
|
self.mask = mask |
|
|
|
def to(self, device): |
|
|
|
cast_tensor = self.tensors.to(device) |
|
mask = self.mask |
|
if mask is not None: |
|
assert mask is not None |
|
cast_mask = mask.to(device) |
|
else: |
|
cast_mask = None |
|
return NestedTensor(cast_tensor, cast_mask) |
|
|
|
def decompose(self): |
|
return self.tensors, self.mask |
|
|
|
def __repr__(self): |
|
return str(self.tensors) |
|
|
|
|
|
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): |
|
|
|
if tensor_list[0].ndim == 3: |
|
if torchvision._is_tracing(): |
|
|
|
|
|
return _onnx_nested_tensor_from_tensor_list(tensor_list) |
|
|
|
|
|
max_size = _max_by_axis([list(img.shape) for img in tensor_list]) |
|
|
|
batch_shape = [len(tensor_list)] + max_size |
|
b, c, h, w = batch_shape |
|
dtype = tensor_list[0].dtype |
|
device = tensor_list[0].device |
|
tensor = torch.zeros(batch_shape, dtype=dtype, device=device) |
|
mask = torch.ones((b, h, w), dtype=torch.bool, device=device) |
|
for img, pad_img, m in zip(tensor_list, tensor, mask): |
|
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) |
|
m[: img.shape[1], : img.shape[2]] = False |
|
else: |
|
raise ValueError("not supported") |
|
return NestedTensor(tensor, mask) |
|
|
|
|
|
|
|
|
|
@torch.jit.unused |
|
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: |
|
max_size = [] |
|
for i in range(tensor_list[0].dim()): |
|
max_size_i = torch.max( |
|
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) |
|
).to(torch.int64) |
|
max_size.append(max_size_i) |
|
max_size = tuple(max_size) |
|
|
|
|
|
|
|
|
|
|
|
padded_imgs = [] |
|
padded_masks = [] |
|
for img in tensor_list: |
|
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] |
|
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) |
|
padded_imgs.append(padded_img) |
|
|
|
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) |
|
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) |
|
padded_masks.append(padded_mask.to(torch.bool)) |
|
|
|
tensor = torch.stack(padded_imgs) |
|
mask = torch.stack(padded_masks) |
|
|
|
return NestedTensor(tensor, mask=mask) |
|
|
|
|
|
def is_dist_avail_and_initialized(): |
|
if not dist.is_available(): |
|
return False |
|
if not dist.is_initialized(): |
|
return False |
|
return True |
|
|