""" Modified from DETR https://github.com/facebookresearch/detr Misc functions. Mostly copy-paste from torchvision references. """ import pickle from typing import Optional, List from collections import OrderedDict, defaultdict, deque import time import datetime import torch import torch.distributed as dist from torch import Tensor from torchvision.ops.boxes import box_area import numpy as np from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn import matplotlib.pyplot as plt # needed due to empty tensor bug in pytorch and torchvision 0.5 import torchvision if float(torchvision.__version__.split(".")[1]) < 7.0: from torchvision.ops import _new_empty_tensor from torchvision.ops.misc import _output_size def all_gather(data): """ Run all_gather on arbitrary picklable data (not necessarily tensors) Args: data: any picklable object Returns: list[data]: list of data gathered from each rank """ world_size = get_world_size() if world_size == 1: return [data] # serialized to a Tensor buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to("cuda") # obtain Tensor size of each rank local_size = torch.tensor([tensor.numel()], device="cuda") size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # receiving Tensor from all ranks # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes tensor_list = [] for _ in size_list: tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) if local_size != max_size: padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") tensor = torch.cat((tensor, padding), dim=0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def reduce_dict(input_dict, average=True): """ Args: input_dict (dict): all the values will be reduced average (bool): whether to do average or sum Reduce the values in the dictionary from all processes so that all processes have the averaged results. Returns a dict with the same fields as input_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return input_dict with torch.no_grad(): names = [] values = [] # sort the keys so that they are consistent across processes for k in sorted(input_dict.keys()): names.append(k) values.append(input_dict[k]) values = torch.stack(values, dim=0) dist.all_reduce(values) if average: values /= world_size reduced_dict = {k: v for k, v in zip(names, values)} return reduced_dict def _max_by_axis(the_list): # type: (List[List[int]]) -> List[int] maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes class NestedTensor(object): def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors self.mask = mask def to(self, device): cast_tensor = self.tensors.to(device) mask = self.mask if mask is not None: assert mask is not None cast_mask = mask.to(device) else: cast_mask = None return NestedTensor(cast_tensor, cast_mask) @property def device(self): return self.tensors.device def decompose(self): return self.tensors, self.mask def __repr__(self): return str(self.tensors) def nested_tensor_from_tensor_list(tensor_list: List[Tensor], size_divisibility=1, split=True): """ This function receives a list of image tensors and returns a NestedTensor of the padded images, along with their padding masks (true for padding areas, false otherwise). """ # make this more general # if image tensor is stacked as [T*3, H, W], then use split if split: tensor_list = [tensor.split(3,dim=0) for tensor in tensor_list] tensor_list = [item for sublist in tensor_list for item in sublist] # list[tensor], length = batch_size x time if tensor_list[0].ndim == 3: # make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) if size_divisibility > 1: # so that the mask dowmsample can be matched stride = size_divisibility # the last two dims are [H, W], both subject to divisibility requirement max_size[-2] = (max_size[-2] + (stride - 1)) // stride * stride max_size[-1] = (max_size[-1] + (stride - 1)) // stride * stride # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], :img.shape[2]] = False # valid locations else: raise ValueError('not supported') return NestedTensor(tensor, mask) def nested_tensor_from_videos_list(videos_list: List[Tensor], size_divisibility=1): """ This function receives a list of videos (each of shape [T, C, H, W]) and returns a NestedTensor of the padded videos (shape [B, T, C, PH, PW], along with their padding masks (true for padding areas, false otherwise, of shape [B, T, PH, PW]. """ max_size = _max_by_axis([list(img.shape) for img in videos_list]) if size_divisibility > 1: # so that the mask dowmsample can be matched stride = size_divisibility # the last two dims are [H, W], both subject to divisibility requirement max_size[-2] = (max_size[-2] + (stride - 1)) // stride * stride max_size[-1] = (max_size[-1] + (stride - 1)) // stride * stride padded_batch_shape = [len(videos_list)] + max_size b, t, c, h, w = padded_batch_shape dtype = videos_list[0].dtype device = videos_list[0].device padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device) videos_pad_masks = torch.ones((b, t, h, w), dtype=torch.bool, device=device) # generate a bigger template, put fg features into template, other places/pad mask convert to True. # TODO: currently padding blank on bottom right. roialign only need to consider use initial h w to rescale. for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, padded_videos, videos_pad_masks): pad_vid_frames[:vid_frames.shape[0], :, :vid_frames.shape[2], :vid_frames.shape[3]].copy_(vid_frames) vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames.shape[3]] = False return NestedTensor(padded_videos, videos_pad_masks) def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop('force', False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def save_on_master(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) def box_xyxy_to_cxcywh(x): x0, y0, x1, y1 = x.unbind(-1) b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] return torch.stack(b, dim=-1) def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor """ Equivalent to nn.functional.interpolate, but with support for empty batch sizes. This will eventually be supported natively by PyTorch, and this class can go away. """ if float(torchvision.__version__.split(".")[1]) < 7.0: if input.numel() > 0: return torch.nn.functional.interpolate( input, size, scale_factor, mode, align_corners ) output_shape = _output_size(2, input, size, scale_factor) output_shape = list(input.shape[:-2]) + list(output_shape) return _new_empty_tensor(input, output_shape) else: return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ if not is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append( "{}: {}".format(name, str(meter)) ) return self.delimiter.join(loss_str) def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, print_freq, header=None): i = 0 if not header: header = '' start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt='{avg:.4f}') space_fmt = ':' + str(len(str(len(iterable)))) + 'd' if torch.cuda.is_available(): log_msg = self.delimiter.join([ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}', 'max mem: {memory:.0f}' ]) else: log_msg = self.delimiter.join([ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}' ]) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) yield obj iter_time.update(time.time() - end) if i % print_freq == 0 or i == len(iterable) - 1: eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB)) else: print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time))) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('{} Total time: {} ({:.4f} s / it)'.format( header, total_time_str, total_time / len(iterable))) def clip_iou(boxes1,boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, :2], boxes2[:, :2]) rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) wh = (rb - lt).clamp(min=0) inter = wh[:,0] * wh[:,1] union = area1 + area2 - inter iou = (inter + 1e-6) / (union+1e-6) return iou def multi_iou(boxes1, boxes2): lt = torch.max(boxes1[...,:2], boxes2[...,:2]) rb = torch.min(boxes1[...,2:], boxes2[...,2:]) wh = (rb - lt).clamp(min=0) wh_1 = boxes1[...,2:] - boxes1[...,:2] wh_2 = boxes2[...,2:] - boxes2[...,:2] inter = wh[...,0] * wh[...,1] union = wh_1[...,0] * wh_1[...,1] + wh_2[...,0] * wh_2[...,1] - inter iou = (inter + 1e-6) / (union + 1e-6) return iou def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=-1) # modified from torchvision to also return the union def box_iou(boxes1, boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] wh = (rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] union = area1[:, None] + area2 - inter iou = (inter+1e-6) / (union+1e-6) return iou, union def generalized_box_iou(boxes1, boxes2): """ Generalized IoU from https://giou.stanford.edu/ The boxes should be in [x0, y0, x1, y1] format Returns a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) """ # degenerate boxes gives inf / nan results # so do an early check assert (boxes1[:, 2:] >= boxes1[:, :2]).all() assert (boxes2[:, 2:] >= boxes2[:, :2]).all() iou, union = box_iou(boxes1, boxes2) lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) wh = (rb - lt).clamp(min=0) # [N,M,2] area = wh[:, :, 0] * wh[:, :, 1] return iou - ((area - union) + 1e-6) / (area + 1e-6) def inverse_sigmoid(x, eps=1e-5): x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1/x2) def masks_to_boxes(masks): """Compute the bounding boxes around the provided masks The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. Returns a [N, 4] tensors, with the boxes in xyxy format """ if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device) h, w = masks.shape[-2:] y = torch.arange(0, h, dtype=torch.float) x = torch.arange(0, w, dtype=torch.float) y, x = torch.meshgrid(y, x) x_mask = (masks * x.unsqueeze(0)) x_max = x_mask.flatten(1).max(-1)[0] x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] y_mask = (masks * y.unsqueeze(0)) y_max = y_mask.flatten(1).max(-1)[0] y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] return torch.stack([x_min, y_min, x_max, y_max], 1) def clean_state_dict(state_dict): new_state_dict = OrderedDict() for k, v in state_dict.items(): if k[:7] == 'module.': k = k[7:] # remove `module.` new_state_dict[k] = v return new_state_dict def get_batch_observer(cur_epoch, total_epochs, batch_num, disable=False): batch_ob = Progress( TextColumn("[bold cyan]Epoch:{task.fields[epoch]}/{task.fields[total_epoch]}"), BarColumn(bar_width=40), "{task.completed}/{task.total}", "•", TimeRemainingColumn(), "•", TextColumn("[bold red]loss={task.fields[loss]:.4f}"), "•", TextColumn("[bold deep_sky_blue1]cls={task.fields[cls]:.4f}"), "•", TextColumn("[bold magenta]bbox={task.fields[bbox]:.4f}"), "•", TextColumn("[bold magenta]giou={task.fields[giou]:.4f}"), "•", TextColumn("[bold gold1]mask={task.fields[mask]:.4f}"), "•", TextColumn("[bold gold1]dice={task.fields[dice]:.4f}"), "•", TextColumn("[bold gold1]proj={task.fields[proj]:.4f}"), disable=disable, ) pg = batch_ob.add_task(description="Training Observer", total=batch_num, epoch=cur_epoch, total_epoch=total_epochs, loss=0, cls=0, bbox=0, giou=0, mask=0, dice=0, proj=0) return batch_ob, pg def colormap(rgb=False): # Choose a matplotlib colormap # cmap = plt.cm.tab10 cmap = plt.cm.Set1 # Number of colors you want to extract from the colormap num_colors = 10 # Extract the colors from the colormap color_list = cmap(np.linspace(0, 1, num_colors))[:, :3] * 255 if not rgb: color_list = color_list[:, ::-1] return color_list