Spaces:
Paused
Paused
# Copyright (c) Facebook, Inc. and its affiliates. | |
# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py | |
# Modified by Xueyan Zou | |
""" | |
Misc functions, including distributed helpers. | |
Mostly copy-paste from torchvision references. | |
""" | |
from typing import List, Optional | |
import torch | |
import torch.distributed as dist | |
import torchvision | |
from torch import Tensor | |
def _max_by_axis(the_list): | |
# type: (List[List[int]]) -> List[int] | |
maxes = the_list[0] | |
for sublist in the_list[1:]: | |
for index, item in enumerate(sublist): | |
maxes[index] = max(maxes[index], item) | |
return maxes | |
class NestedTensor(object): | |
def __init__(self, tensors, mask: Optional[Tensor]): | |
self.tensors = tensors | |
self.mask = mask | |
def to(self, device): | |
# type: (Device) -> NestedTensor # noqa | |
cast_tensor = self.tensors.to(device) | |
mask = self.mask | |
if mask is not None: | |
assert mask is not None | |
cast_mask = mask.to(device) | |
else: | |
cast_mask = None | |
return NestedTensor(cast_tensor, cast_mask) | |
def decompose(self): | |
return self.tensors, self.mask | |
def __repr__(self): | |
return str(self.tensors) | |
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): | |
# TODO make this more general | |
if tensor_list[0].ndim == 3: | |
if torchvision._is_tracing(): | |
# nested_tensor_from_tensor_list() does not export well to ONNX | |
# call _onnx_nested_tensor_from_tensor_list() instead | |
return _onnx_nested_tensor_from_tensor_list(tensor_list) | |
# TODO make it support different-sized images | |
max_size = _max_by_axis([list(img.shape) for img in tensor_list]) | |
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) | |
batch_shape = [len(tensor_list)] + max_size | |
b, c, h, w = batch_shape | |
dtype = tensor_list[0].dtype | |
device = tensor_list[0].device | |
tensor = torch.zeros(batch_shape, dtype=dtype, device=device) | |
mask = torch.ones((b, h, w), dtype=torch.bool, device=device) | |
for img, pad_img, m in zip(tensor_list, tensor, mask): | |
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) | |
m[: img.shape[1], : img.shape[2]] = False | |
elif tensor_list[0].ndim == 2: | |
if torchvision._is_tracing(): | |
# nested_tensor_from_tensor_list() does not export well to ONNX | |
# call _onnx_nested_tensor_from_tensor_list() instead | |
return _onnx_nested_tensor_from_tensor_list(tensor_list) | |
# TODO make it support different-sized images | |
max_size = _max_by_axis([list(txt.shape) for txt in tensor_list]) | |
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) | |
batch_shape = [len(tensor_list)] + max_size | |
b, c, l = batch_shape | |
dtype = tensor_list[0].dtype | |
device = tensor_list[0].device | |
tensor = torch.zeros(batch_shape, dtype=dtype, device=device) | |
mask = torch.ones((b, l), dtype=torch.bool, device=device) | |
for txt, pad_txt, m in zip(tensor_list, tensor, mask): | |
pad_txt[: txt.shape[0], : txt.shape[1]] = txt | |
m[: txt.shape[1]] = False | |
else: | |
raise ValueError("not supported") | |
return NestedTensor(tensor, mask) | |
def _collate_and_pad_divisibility(tensor_list: list, div=32): | |
max_size = [] | |
for i in range(tensor_list[0].dim()): | |
max_size_i = torch.max( | |
torch.tensor([img.shape[i] for img in tensor_list]).to(torch.float32) | |
).to(torch.int64) | |
max_size.append(max_size_i) | |
max_size = tuple(max_size) | |
c,h,w = max_size | |
pad_h = (div - h % div) if h % div != 0 else 0 | |
pad_w = (div - w % div) if w % div != 0 else 0 | |
max_size = (c,h+pad_h,w+pad_w) | |
# work around for | |
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) | |
# m[: img.shape[1], :img.shape[2]] = False | |
# which is not yet supported in onnx | |
padded_imgs = [] | |
padded_masks = [] | |
for img in tensor_list: | |
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] | |
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) | |
padded_imgs.append(padded_img) | |
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) | |
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) | |
padded_masks.append(padded_mask.to(torch.bool)) | |
return padded_imgs | |
# _onnx_nested_tensor_from_tensor_list() is an implementation of | |
# nested_tensor_from_tensor_list() that is supported by ONNX tracing. | |
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: | |
max_size = [] | |
for i in range(tensor_list[0].dim()): | |
max_size_i = torch.max( | |
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) | |
).to(torch.int64) | |
max_size.append(max_size_i) | |
max_size = tuple(max_size) | |
# work around for | |
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) | |
# m[: img.shape[1], :img.shape[2]] = False | |
# which is not yet supported in onnx | |
padded_imgs = [] | |
padded_masks = [] | |
for img in tensor_list: | |
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] | |
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) | |
padded_imgs.append(padded_img) | |
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) | |
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) | |
padded_masks.append(padded_mask.to(torch.bool)) | |
tensor = torch.stack(padded_imgs) | |
mask = torch.stack(padded_masks) | |
return NestedTensor(tensor, mask=mask) | |
def is_dist_avail_and_initialized(): | |
if not dist.is_available(): | |
return False | |
if not dist.is_initialized(): | |
return False | |
return True | |
def get_iou(gt_masks, pred_masks, ignore_label=-1): | |
rev_ignore_mask = ~(gt_masks == ignore_label) | |
gt_masks = gt_masks.bool() | |
n,h,w = gt_masks.shape | |
intersection = ((gt_masks & pred_masks) & rev_ignore_mask).reshape(n,h*w).sum(dim=-1) | |
union = ((gt_masks | pred_masks) & rev_ignore_mask).reshape(n,h*w).sum(dim=-1) | |
ious = (intersection / union) | |
return ious |