import collections import os from os.path import join import io import matplotlib.pyplot as plt import numpy as np import torch.multiprocessing import torch.nn as nn import torch.nn.functional as F import wget from PIL import Image from scipy.optimize import linear_sum_assignment from torch._six import string_classes from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format from torchmetrics import Metric from torchvision import models from torchvision import transforms as T from torch.utils.tensorboard.summary import hparams import matplotlib as mpl torch.multiprocessing.set_sharing_strategy("file_system") colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey") class_names = ( "Buildings", "Cultivation", "Natural green", "Wetland", "Water", "Infrastructure", "Background", ) bounds = list(np.arange(len(class_names) + 1) + 1) cmap = mpl.colors.ListedColormap(colors) norm = mpl.colors.BoundaryNorm(bounds, cmap.N) def compute_biodiv_score(image): """Compute the biodiversity score of an image Args: image (_type_): _description_ Returns: biodiversity_score: the biodiversity score associated to the landscape of the image """ pix = np.array(image.getdata()) return np.mean(pix) import cv2 def create_video(array_images, output_path="output.mp4"): height, width, layers = array_images[0].shape size = (width,height) fourcc = cv2.VideoWriter_fourcc(*'VP90') out = cv2.VideoWriter('output.mp4', fourcc, 2, size) for i in range(len(array_images)): out.write(array_images[i]) out.release() return out def transform_to_pil(outputs, alpha=0.3): """Turn an ouput into a PIL Args: outputs (_type_): _description_ alpha (float, optional): _description_. Defaults to 0.3. Returns: _type_: _description_ """ # Transform img with torch img = torch.moveaxis(prep_for_plot(outputs["img"][0]), -1, 0) img = T.ToPILImage()(img) # Transform label by saving it then open it label = outputs["linear_preds"][0].numpy() # image_label = Image.fromarray(label, mode="P") plt.imsave("output/label.png", label, cmap=cmap) image_label = Image.open("output/label.png") # Overlay labels with img wit alpha background = img.convert("RGBA") overlay = image_label.convert("RGBA") labeled_img = Image.blend(background, overlay, alpha) labeled_img = labeled_img.convert("RGB") return img, image_label, labeled_img def prep_for_plot(img, rescale=True, resize=None): if resize is not None: img = F.interpolate(img.unsqueeze(0), resize, mode="bilinear") else: img = img.unsqueeze(0) plot_img = unnorm(img).squeeze(0).cpu().permute(1, 2, 0) if rescale: plot_img = (plot_img - plot_img.min()) / (plot_img.max() - plot_img.min()) return plot_img def add_plot(writer, name, step): buf = io.BytesIO() plt.savefig(buf, format='jpeg', dpi=100) buf.seek(0) image = Image.open(buf) image = T.ToTensor()(image) writer.add_image(name, image, step) plt.clf() plt.close() @torch.jit.script def shuffle(x): return x[torch.randperm(x.shape[0])] def add_hparams_fixed(writer, hparam_dict, metric_dict, global_step): exp, ssi, sei = hparams(hparam_dict, metric_dict) writer.file_writer.add_summary(exp) writer.file_writer.add_summary(ssi) writer.file_writer.add_summary(sei) for k, v in metric_dict.items(): writer.add_scalar(k, v, global_step) @torch.jit.script def resize(classes: torch.Tensor, size: int): return F.interpolate(classes, (size, size), mode="bilinear", align_corners=False) def one_hot_feats(labels, n_classes): return F.one_hot(labels, n_classes).permute(0, 3, 1, 2).to(torch.float32) def load_model(model_type, data_dir): if model_type == "robust_resnet50": model = models.resnet50(pretrained=False) model_file = join(data_dir, 'imagenet_l2_3_0.pt') if not os.path.exists(model_file): wget.download("http://6.869.csail.mit.edu/fa19/psets19/pset6/imagenet_l2_3_0.pt", model_file) model_weights = torch.load(model_file) model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if 'model' in name} model.load_state_dict(model_weights_modified) model = nn.Sequential(*list(model.children())[:-1]) elif model_type == "densecl": model = models.resnet50(pretrained=False) model_file = join(data_dir, 'densecl_r50_coco_1600ep.pth') if not os.path.exists(model_file): wget.download("https://cloudstor.aarnet.edu.au/plus/s/3GapXiWuVAzdKwJ/download", model_file) model_weights = torch.load(model_file) # model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if # 'model' in name} model.load_state_dict(model_weights['state_dict'], strict=False) model = nn.Sequential(*list(model.children())[:-1]) elif model_type == "resnet50": model = models.resnet50(pretrained=True) model = nn.Sequential(*list(model.children())[:-1]) elif model_type == "mocov2": model = models.resnet50(pretrained=False) model_file = join(data_dir, 'moco_v2_800ep_pretrain.pth.tar') if not os.path.exists(model_file): wget.download("https://dl.fbaipublicfiles.com/moco/moco_checkpoints/" "moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", model_file) checkpoint = torch.load(model_file) # rename moco pre-trained keys state_dict = checkpoint['state_dict'] for k in list(state_dict.keys()): # retain only encoder_q up to before the embedding layer if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): # remove prefix state_dict[k[len("module.encoder_q."):]] = state_dict[k] # delete renamed or unused k del state_dict[k] msg = model.load_state_dict(state_dict, strict=False) assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} model = nn.Sequential(*list(model.children())[:-1]) elif model_type == "densenet121": model = models.densenet121(pretrained=True) model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))]) elif model_type == "vgg11": model = models.vgg11(pretrained=True) model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))]) else: raise ValueError("No model: {} found".format(model_type)) model.eval() model.cuda() return model class UnNormalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, image): image2 = torch.clone(image) for t, m, s in zip(image2, self.mean, self.std): t.mul_(s).add_(m) return image2 normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) class ToTargetTensor(object): def __call__(self, target): return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0) def prep_args(): import sys old_args = sys.argv new_args = [old_args.pop(0)] while len(old_args) > 0: arg = old_args.pop(0) if len(arg.split("=")) == 2: new_args.append(arg) elif arg.startswith("--"): new_args.append(arg[2:] + "=" + old_args.pop(0)) else: raise ValueError("Unexpected arg style {}".format(arg)) sys.argv = new_args def get_transform(res, is_label, crop_type): if crop_type == "center": cropper = T.CenterCrop(res) elif crop_type == "random": cropper = T.RandomCrop(res) elif crop_type is None: cropper = T.Lambda(lambda x: x) res = (res, res) else: raise ValueError("Unknown Cropper {}".format(crop_type)) if is_label: return T.Compose([T.Resize(res, Image.NEAREST), cropper, ToTargetTensor()]) else: return T.Compose([T.Resize(res, Image.NEAREST), cropper, T.ToTensor(), normalize]) def _remove_axes(ax): ax.xaxis.set_major_formatter(plt.NullFormatter()) ax.yaxis.set_major_formatter(plt.NullFormatter()) ax.set_xticks([]) ax.set_yticks([]) def remove_axes(axes): if len(axes.shape) == 2: for ax1 in axes: for ax in ax1: _remove_axes(ax) else: for ax in axes: _remove_axes(ax) class UnsupervisedMetrics(Metric): def __init__(self, prefix: str, n_classes: int, extra_clusters: int, compute_hungarian: bool, dist_sync_on_step=True): # call `self.add_state`for every internal state that is needed for the metrics computations # dist_reduce_fx indicates the function that should be used to reduce # state from multiple processes super().__init__(dist_sync_on_step=dist_sync_on_step) self.n_classes = n_classes self.extra_clusters = extra_clusters self.compute_hungarian = compute_hungarian self.prefix = prefix self.add_state("stats", default=torch.zeros(n_classes + self.extra_clusters, n_classes, dtype=torch.int64), dist_reduce_fx="sum") def update(self, preds: torch.Tensor, target: torch.Tensor): with torch.no_grad(): actual = target.reshape(-1) preds = preds.reshape(-1) mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes) actual = actual[mask] preds = preds[mask] self.stats += torch.bincount( (self.n_classes + self.extra_clusters) * actual + preds, minlength=self.n_classes * (self.n_classes + self.extra_clusters)) \ .reshape(self.n_classes, self.n_classes + self.extra_clusters).t().to(self.stats.device) def map_clusters(self, clusters): if self.extra_clusters == 0: return torch.tensor(self.assignments[1])[clusters] else: missing = sorted(list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0]))) cluster_to_class = self.assignments[1] for missing_entry in missing: if missing_entry == cluster_to_class.shape[0]: cluster_to_class = np.append(cluster_to_class, -1) else: cluster_to_class = np.insert(cluster_to_class, missing_entry + 1, -1) cluster_to_class = torch.tensor(cluster_to_class) return cluster_to_class[clusters] def compute(self): if self.compute_hungarian: self.assignments = linear_sum_assignment(self.stats.detach().cpu(), maximize=True) # print(self.assignments) if self.extra_clusters == 0: self.histogram = self.stats[np.argsort(self.assignments[1]), :] if self.extra_clusters > 0: self.assignments_t = linear_sum_assignment(self.stats.detach().cpu().t(), maximize=True) histogram = self.stats[self.assignments_t[1], :] missing = list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0])) new_row = self.stats[missing, :].sum(0, keepdim=True) histogram = torch.cat([histogram, new_row], axis=0) new_col = torch.zeros(self.n_classes + 1, 1, device=histogram.device) self.histogram = torch.cat([histogram, new_col], axis=1) else: self.assignments = (torch.arange(self.n_classes).unsqueeze(1), torch.arange(self.n_classes).unsqueeze(1)) self.histogram = self.stats tp = torch.diag(self.histogram) fp = torch.sum(self.histogram, dim=0) - tp fn = torch.sum(self.histogram, dim=1) - tp iou = tp / (tp + fp + fn) prc = tp / (tp + fn) opc = torch.sum(tp) / torch.sum(self.histogram) metric_dict = {self.prefix + "mIoU": iou[~torch.isnan(iou)].mean().item(), self.prefix + "Accuracy": opc.item()} return {k: 100 * v for k, v in metric_dict.items()} def flexible_collate(batch): r"""Puts each data field into a tensor with outer dimension batch size""" elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) try: return torch.stack(batch, 0, out=out) except RuntimeError: return batch elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) return flexible_collate([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, string_classes): return batch elif isinstance(elem, collections.abc.Mapping): return {key: flexible_collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple return elem_type(*(flexible_collate(samples) for samples in zip(*batch))) elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError('each element in list of batch should be of equal size') transposed = zip(*batch) return [flexible_collate(samples) for samples in transposed] raise TypeError(default_collate_err_msg_format.format(elem_type))