|
|
|
|
|
""" |
|
Misc functions, including distributed helpers. |
|
|
|
Mostly copy-paste from torchvision references. |
|
""" |
|
from typing import List, Optional |
|
from collections import OrderedDict |
|
from scipy.io import loadmat |
|
import numpy as np |
|
import csv |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
import torch |
|
import torch.distributed as dist |
|
import torchvision |
|
from torch import Tensor |
|
|
|
|
|
def _max_by_axis(the_list): |
|
|
|
maxes = the_list[0] |
|
for sublist in the_list[1:]: |
|
for index, item in enumerate(sublist): |
|
maxes[index] = max(maxes[index], item) |
|
return maxes |
|
|
|
def get_world_size() -> int: |
|
if not dist.is_available(): |
|
return 1 |
|
if not dist.is_initialized(): |
|
return 1 |
|
return dist.get_world_size() |
|
|
|
def reduce_dict(input_dict, average=True): |
|
""" |
|
Args: |
|
input_dict (dict): all the values will be reduced |
|
average (bool): whether to do average or sum |
|
Reduce the values in the dictionary from all processes so that all processes |
|
have the averaged results. Returns a dict with the same fields as |
|
input_dict, after reduction. |
|
""" |
|
world_size = get_world_size() |
|
if world_size < 2: |
|
return input_dict |
|
with torch.no_grad(): |
|
names = [] |
|
values = [] |
|
|
|
for k in sorted(input_dict.keys()): |
|
names.append(k) |
|
values.append(input_dict[k]) |
|
values = torch.stack(values, dim=0) |
|
dist.all_reduce(values) |
|
if average: |
|
values /= world_size |
|
reduced_dict = {k: v for k, v in zip(names, values)} |
|
return reduced_dict |
|
|
|
class NestedTensor(object): |
|
def __init__(self, tensors, mask: Optional[Tensor]): |
|
self.tensors = tensors |
|
self.mask = mask |
|
|
|
def to(self, device): |
|
|
|
cast_tensor = self.tensors.to(device) |
|
mask = self.mask |
|
if mask is not None: |
|
assert mask is not None |
|
cast_mask = mask.to(device) |
|
else: |
|
cast_mask = None |
|
return NestedTensor(cast_tensor, cast_mask) |
|
|
|
def decompose(self): |
|
return self.tensors, self.mask |
|
|
|
def __repr__(self): |
|
return str(self.tensors) |
|
|
|
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): |
|
|
|
if tensor_list[0].ndim == 3: |
|
if torchvision._is_tracing(): |
|
|
|
|
|
return _onnx_nested_tensor_from_tensor_list(tensor_list) |
|
|
|
|
|
max_size = _max_by_axis([list(img.shape) for img in tensor_list]) |
|
|
|
batch_shape = [len(tensor_list)] + max_size |
|
b, c, h, w = batch_shape |
|
dtype = tensor_list[0].dtype |
|
device = tensor_list[0].device |
|
tensor = torch.zeros(batch_shape, dtype=dtype, device=device) |
|
mask = torch.ones((b, h, w), dtype=torch.bool, device=device) |
|
for img, pad_img, m in zip(tensor_list, tensor, mask): |
|
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) |
|
m[: img.shape[1], : img.shape[2]] = False |
|
else: |
|
raise ValueError("not supported") |
|
return NestedTensor(tensor, mask) |
|
|
|
|
|
|
|
@torch.jit.unused |
|
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: |
|
max_size = [] |
|
for i in range(tensor_list[0].dim()): |
|
max_size_i = torch.max( |
|
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) |
|
).to(torch.int64) |
|
max_size.append(max_size_i) |
|
max_size = tuple(max_size) |
|
|
|
|
|
|
|
|
|
|
|
padded_imgs = [] |
|
padded_masks = [] |
|
for img in tensor_list: |
|
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] |
|
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) |
|
padded_imgs.append(padded_img) |
|
|
|
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) |
|
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) |
|
padded_masks.append(padded_mask.to(torch.bool)) |
|
|
|
tensor = torch.stack(padded_imgs) |
|
mask = torch.stack(padded_masks) |
|
|
|
return NestedTensor(tensor, mask=mask) |
|
|
|
def is_dist_avail_and_initialized(): |
|
if not dist.is_available(): |
|
return False |
|
if not dist.is_initialized(): |
|
return False |
|
return True |
|
|
|
def load_parallal_model(model, state_dict_): |
|
state_dict = OrderedDict() |
|
for key in state_dict_: |
|
if key.startswith('module') and not key.startswith('module_list'): |
|
state_dict[key[7:]] = state_dict_[key] |
|
else: |
|
state_dict[key] = state_dict_[key] |
|
|
|
|
|
model_state_dict = model.state_dict() |
|
for key in state_dict: |
|
if key in model_state_dict: |
|
if state_dict[key].shape != model_state_dict[key].shape: |
|
print('Skip loading parameter {}, required shape{}, loaded shape{}.'.format( |
|
key, model_state_dict[key].shape, state_dict[key].shape)) |
|
state_dict[key] = model_state_dict[key] |
|
else: |
|
print('Drop parameter {}.'.format(key)) |
|
for key in model_state_dict: |
|
if key not in state_dict: |
|
print('No param {}.'.format(key)) |
|
state_dict[key] = model_state_dict[key] |
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
return model |
|
|
|
class ADEVisualize(object): |
|
def __init__(self): |
|
self.colors = loadmat('dataset/color150.mat')['colors'] |
|
self.names = {} |
|
with open('dataset/object150_info.csv') as f: |
|
reader = csv.reader(f) |
|
next(reader) |
|
for row in reader: |
|
self.names[int(row[0])] = row[5].split(";")[0] |
|
|
|
def unique(self, ar, return_index=False, return_inverse=False, return_counts=False): |
|
ar = np.asanyarray(ar).flatten() |
|
|
|
optional_indices = return_index or return_inverse |
|
optional_returns = optional_indices or return_counts |
|
|
|
if ar.size == 0: |
|
if not optional_returns: |
|
ret = ar |
|
else: |
|
ret = (ar,) |
|
if return_index: |
|
ret += (np.empty(0, np.bool),) |
|
if return_inverse: |
|
ret += (np.empty(0, np.bool),) |
|
if return_counts: |
|
ret += (np.empty(0, np.intp),) |
|
return ret |
|
if optional_indices: |
|
perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') |
|
aux = ar[perm] |
|
else: |
|
ar.sort() |
|
aux = ar |
|
flag = np.concatenate(([True], aux[1:] != aux[:-1])) |
|
|
|
if not optional_returns: |
|
ret = aux[flag] |
|
else: |
|
ret = (aux[flag],) |
|
if return_index: |
|
ret += (perm[flag],) |
|
if return_inverse: |
|
iflag = np.cumsum(flag) - 1 |
|
inv_idx = np.empty(ar.shape, dtype=np.intp) |
|
inv_idx[perm] = iflag |
|
ret += (inv_idx,) |
|
if return_counts: |
|
idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) |
|
ret += (np.diff(idx),) |
|
return ret |
|
|
|
def colorEncode(self, labelmap, colors, mode='RGB'): |
|
labelmap = labelmap.astype('int') |
|
labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), |
|
dtype=np.uint8) |
|
for label in self.unique(labelmap): |
|
if label < 0: |
|
continue |
|
labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ |
|
np.tile(colors[label], |
|
(labelmap.shape[0], labelmap.shape[1], 1)) |
|
|
|
if mode == 'BGR': |
|
return labelmap_rgb[:, :, ::-1] |
|
else: |
|
return labelmap_rgb |
|
|
|
def show_result(self, img, pred, save_path=None): |
|
pred = np.int32(pred) |
|
|
|
pred_color = self.colorEncode(pred, self.colors) |
|
pil_img = img.convert('RGBA') |
|
pred_color = Image.fromarray(pred_color).convert('RGBA') |
|
im_vis = Image.blend(pil_img, pred_color, 0.6) |
|
if save_path is not None: |
|
im_vis.save(save_path) |
|
|
|
else: |
|
plt.imshow(im_vis) |