import collections import os from os.path import join import io import matplotlib.pyplot as plt import numpy as np import torch.multiprocessing import torch.nn as nn import torch.nn.functional as F import wget import datetime from dateutil.relativedelta import relativedelta from PIL import Image from scipy.optimize import linear_sum_assignment from torch._six import string_classes from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format from torchmetrics import Metric from torchvision import models from torchvision import transforms as T from torch.utils.tensorboard.summary import hparams import matplotlib as mpl from PIL import Image import matplotlib as mpl import torch.multiprocessing import torchvision.transforms as T import plotly.graph_objects as go import plotly.express as px import numpy as np from plotly.subplots import make_subplots import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey") class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background') mapping_class = { "Buildings": 1, "Cultivation": 2, "Natural green": 3, "Wetland": 4, "Water": 5, "Infrastructure": 6, "Background": 0, } score_attribution = { "Buildings" : 0., "Cultivation": 0.3, "Natural green": 1., "Wetland": 0.9, "Water": 0.9, "Infrastructure": 0., "Background": 0. } bounds = list(np.arange(len(mapping_class.keys()) + 1) + 1) cmap = mpl.colors.ListedColormap(colors) norm = mpl.colors.BoundaryNorm(bounds, cmap.N) def compute_biodiv_score(class_image): """Compute the biodiversity score of an image Args: image (_type_): _description_ Returns: biodiversity_score: the biodiversity score associated to the landscape of the image """ score_matrice = class_image.copy().astype(int) for key in mapping_class.keys(): score_matrice = np.where(score_matrice==mapping_class[key], score_attribution[key], score_matrice) number_of_pixel = np.prod(list(score_matrice.shape)) score = np.sum(score_matrice)/number_of_pixel score_details = { key: np.sum(np.where(class_image == mapping_class[key], 1, 0)) for key in mapping_class.keys() if key not in ["background"] } return score, score_details def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) : scores = [0.89, 0.70, 0.3, 0.2] # fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True) # fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True) # # Scores # scatters = [go.Scatter( # x=months[:i+1], # y=scores[:i+1], # mode="lines+markers+text", # marker_color="black", # text = [f"{score:.4f}" for score in scores[:i+1]], # textposition="top center", # ) for i in range(len(scores))] # fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1) # fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2) # fig.add_trace(go.Pie(labels = class_names, # values = [nb_values[0][key] for key in mapping_class.keys()], # marker_colors = colors, # name="Segment repartition", # textposition='inside', # texttemplate = "%{percent:.0%}", # textfont_size=14 # ), # row=1, col=3) # fig.add_trace(scatters[0], row=1, col=4) # # fig.update_traces(selector=dict(type='scatter')) # number_frames = len(imgs) # frames = [dict( # name = k, # data = [ fig2["frames"][k]["data"][0], # fig3["frames"][k]["data"][0], # go.Pie(labels = class_names, # values = [nb_values[k][key] for key in mapping_class.keys()], # marker_colors = colors, # name="Segment repartition", # textposition='inside', # texttemplate = "%{percent:.0%}", # textfont_size=14 # ), # scatters[k] # ], # traces=[0, 1, 2, 3] # ) for k in range(number_frames)] # updatemenus = [dict(type='buttons', # buttons=[dict( # label='Play', # method='animate', # args=[ # [f'{k}' for k in range(number_frames)], # dict( # frame=dict(duration=500, redraw=False), # transition=dict(duration=0), # # easing='linear', # # fromcurrent=True, # # mode='immediate' # ) # ]) # ], # direction= 'left', # pad=dict(r= 10, t=85), # showactive=True, x= 0.1, y= 0.1, xanchor= 'right', yanchor= 'bottom') # ] # sliders = [{'yanchor': 'top', # 'xanchor': 'left', # 'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'}, # 'transition': {'duration': 500.0, 'easing': 'linear'}, # 'pad': {'b': 10, 't': 50}, # 'len': 0.9, 'x': 0.1, 'y': 0, # 'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False}, # 'transition': {'duration': 0, 'easing': 'linear'}}], # 'label': months[k], 'method': 'animate'} for k in range(number_frames) # ]}] # fig.update(frames=frames, # layout={ # "xaxis1": { # "autorange":True, # 'showgrid': False, # 'zeroline': False, # thick line at x=0 # 'visible': False, # numbers below # }, # "yaxis1": { # "autorange":True, # 'showgrid': False, # 'zeroline': False, # 'visible': False,}, # "xaxis2": { # "autorange":True, # 'showgrid': False, # 'zeroline': False, # 'visible': False, # }, # "yaxis2": { # "autorange":True, # 'showgrid': False, # 'zeroline': False, # 'visible': False,}, # "xaxis4": { # "ticktext": months, # "tickvals": months, # "tickangle": 90, # }, # "yaxis4": { # 'range': [min(scores)*0.9, max(scores)* 1.1], # 'showgrid': False, # 'zeroline': False, # 'visible': True # }, # }) # fig.update_layout( # updatemenus=updatemenus, # sliders=sliders, # # legend=dict( # # yanchor= 'bottom', # # xanchor= 'center', # # orientation="h"), # ) # Scores fig = make_subplots( rows=1, cols=4, specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "scatter"}]], subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores") ) fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True) fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True) pie_charts = [go.Pie(labels = class_names, values = [nb_values[k][key] for key in mapping_class.keys()], marker_colors = colors, name="Segment repartition", textposition='inside', texttemplate = "%{percent:.0%}", textfont_size=14, ) for k in range(len(scores))] scatters = [go.Scatter( x=months[:i+1], y=scores[:i+1], mode="lines+markers+text", marker_color="black", text = [f"{score:.4f}" for score in scores[:i+1]], textposition="top center", ) for i in range(len(scores))] fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1) fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2) fig.add_trace(pie_charts[0], row=1, col=3) fig.add_trace(scatters[0], row=1, col=4) start_date = datetime.datetime.strptime(months[0], "%Y-%m-%d") - relativedelta(months=1) end_date = datetime.datetime.strptime(months[-1], "%Y-%m-%d") + relativedelta(months=1) interval = [start_date.strftime("%Y-%m-%d"),end_date.strftime("%Y-%m-%d")] fig.update_layout({ "xaxis": { "autorange":True, 'showgrid': False, 'zeroline': False, # thick line at x=0 'visible': False, # numbers below }, "yaxis": { "autorange":True, 'showgrid': False, 'zeroline': False, 'visible': False,}, "xaxis1": { "range":[0,imgs[0].shape[1]], 'showgrid': False, 'zeroline': False, 'visible': False, }, "yaxis1": { "range":[imgs[0].shape[0],0], 'showgrid': False, 'zeroline': False, 'visible': False,}, "xaxis3": { "dtick":"M3", "range":interval }, "yaxis3": { 'range': [min(scores)*0.9, max(scores)* 1.1], 'showgrid': False, 'zeroline': False, 'visible': True }} ) frames = [dict( name = k, data = [ fig2["frames"][k]["data"][0], fig3["frames"][k]["data"][0], pie_charts[k], scatters[k] ], traces=[0,1,2,3] ) for k in range(len(scores))] updatemenus = [dict(type='buttons', buttons=[dict(label='Play', method='animate', args=[ [f'{k}' for k in range(len(scores))], dict( frame=dict(duration=500, redraw=False), transition=dict(duration=0), # easing='linear', # fromcurrent=True, # mode='immediate' ) ] )], direction= 'left', pad=dict(r= 10, t=85), showactive =True, x= 0.1, y= 0, xanchor= 'right', yanchor= 'top') ] sliders = [{'yanchor': 'top', 'xanchor': 'left', 'currentvalue': { 'font': {'size': 16}, 'visible': True, 'xanchor': 'right'}, 'transition': { 'duration': 500.0, 'easing': 'linear'}, 'pad': {'b': 10, 't': 50}, 'len': 0.9, 'x': 0.1, 'y': 0, 'steps': [{'args': [None, {'frame': {'duration': 500.0,'redraw': False}, 'transition': {'duration': 0}}], 'label': k, 'method': 'animate'} for k in range(len(scores)) ] }] fig.update_layout(updatemenus=updatemenus, sliders=sliders, ) fig.update(frames=frames) return fig def transform_to_pil(output, alpha=0.3): # Transform img with torch img = torch.moveaxis(prep_for_plot(output['img']),-1,0) img=T.ToPILImage()(img) cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)]) labels = np.array(output['linear_preds'])-1 label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8)) # Overlay labels with img wit alpha background = img.convert("RGBA") overlay = label.convert("RGBA") labeled_img = Image.blend(background, overlay, alpha) return img, label, labeled_img def prep_for_plot(img, rescale=True, resize=None): if resize is not None: img = F.interpolate(img.unsqueeze(0), resize, mode="bilinear") else: img = img.unsqueeze(0) plot_img = unnorm(img).squeeze(0).cpu().permute(1, 2, 0) if rescale: plot_img = (plot_img - plot_img.min()) / (plot_img.max() - plot_img.min()) return plot_img def add_plot(writer, name, step): buf = io.BytesIO() plt.savefig(buf, format='jpeg', dpi=100) buf.seek(0) image = Image.open(buf) image = T.ToTensor()(image) writer.add_image(name, image, step) plt.clf() plt.close() @torch.jit.script def shuffle(x): return x[torch.randperm(x.shape[0])] def add_hparams_fixed(writer, hparam_dict, metric_dict, global_step): exp, ssi, sei = hparams(hparam_dict, metric_dict) writer.file_writer.add_summary(exp) writer.file_writer.add_summary(ssi) writer.file_writer.add_summary(sei) for k, v in metric_dict.items(): writer.add_scalar(k, v, global_step) @torch.jit.script def resize(classes: torch.Tensor, size: int): return F.interpolate(classes, (size, size), mode="bilinear", align_corners=False) def one_hot_feats(labels, n_classes): return F.one_hot(labels, n_classes).permute(0, 3, 1, 2).to(torch.float32) def load_model(model_type, data_dir): if model_type == "robust_resnet50": model = models.resnet50(pretrained=False) model_file = join(data_dir, 'imagenet_l2_3_0.pt') if not os.path.exists(model_file): wget.download("http://6.869.csail.mit.edu/fa19/psets19/pset6/imagenet_l2_3_0.pt", model_file) model_weights = torch.load(model_file) model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if 'model' in name} model.load_state_dict(model_weights_modified) model = nn.Sequential(*list(model.children())[:-1]) elif model_type == "densecl": model = models.resnet50(pretrained=False) model_file = join(data_dir, 'densecl_r50_coco_1600ep.pth') if not os.path.exists(model_file): wget.download("https://cloudstor.aarnet.edu.au/plus/s/3GapXiWuVAzdKwJ/download", model_file) model_weights = torch.load(model_file) # model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if # 'model' in name} model.load_state_dict(model_weights['state_dict'], strict=False) model = nn.Sequential(*list(model.children())[:-1]) elif model_type == "resnet50": model = models.resnet50(pretrained=True) model = nn.Sequential(*list(model.children())[:-1]) elif model_type == "mocov2": model = models.resnet50(pretrained=False) model_file = join(data_dir, 'moco_v2_800ep_pretrain.pth.tar') if not os.path.exists(model_file): wget.download("https://dl.fbaipublicfiles.com/moco/moco_checkpoints/" "moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", model_file) checkpoint = torch.load(model_file) # rename moco pre-trained keys state_dict = checkpoint['state_dict'] for k in list(state_dict.keys()): # retain only encoder_q up to before the embedding layer if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): # remove prefix state_dict[k[len("module.encoder_q."):]] = state_dict[k] # delete renamed or unused k del state_dict[k] msg = model.load_state_dict(state_dict, strict=False) assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} model = nn.Sequential(*list(model.children())[:-1]) elif model_type == "densenet121": model = models.densenet121(pretrained=True) model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))]) elif model_type == "vgg11": model = models.vgg11(pretrained=True) model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))]) else: raise ValueError("No model: {} found".format(model_type)) model.eval() model.cuda() return model class UnNormalize(object): def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, image): image2 = torch.clone(image) for t, m, s in zip(image2, self.mean, self.std): t.mul_(s).add_(m) return image2 normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) class ToTargetTensor(object): def __call__(self, target): return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0) def prep_args(): import sys old_args = sys.argv new_args = [old_args.pop(0)] while len(old_args) > 0: arg = old_args.pop(0) if len(arg.split("=")) == 2: new_args.append(arg) elif arg.startswith("--"): new_args.append(arg[2:] + "=" + old_args.pop(0)) else: raise ValueError("Unexpected arg style {}".format(arg)) sys.argv = new_args def get_transform(res, is_label, crop_type): if crop_type == "center": cropper = T.CenterCrop(res) elif crop_type == "random": cropper = T.RandomCrop(res) elif crop_type is None: cropper = T.Lambda(lambda x: x) res = (res, res) else: raise ValueError("Unknown Cropper {}".format(crop_type)) if is_label: return T.Compose([T.Resize(res, Image.NEAREST), cropper, ToTargetTensor()]) else: return T.Compose([T.Resize(res, Image.NEAREST), cropper, T.ToTensor(), normalize]) def _remove_axes(ax): ax.xaxis.set_major_formatter(plt.NullFormatter()) ax.yaxis.set_major_formatter(plt.NullFormatter()) ax.set_xticks([]) ax.set_yticks([]) def remove_axes(axes): if len(axes.shape) == 2: for ax1 in axes: for ax in ax1: _remove_axes(ax) else: for ax in axes: _remove_axes(ax) class UnsupervisedMetrics(Metric): def __init__(self, prefix: str, n_classes: int, extra_clusters: int, compute_hungarian: bool, dist_sync_on_step=True): # call `self.add_state`for every internal state that is needed for the metrics computations # dist_reduce_fx indicates the function that should be used to reduce # state from multiple processes super().__init__(dist_sync_on_step=dist_sync_on_step) self.n_classes = n_classes self.extra_clusters = extra_clusters self.compute_hungarian = compute_hungarian self.prefix = prefix self.add_state("stats", default=torch.zeros(n_classes + self.extra_clusters, n_classes, dtype=torch.int64), dist_reduce_fx="sum") def update(self, preds: torch.Tensor, target: torch.Tensor): with torch.no_grad(): actual = target.reshape(-1) preds = preds.reshape(-1) mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes) actual = actual[mask] preds = preds[mask] self.stats += torch.bincount( (self.n_classes + self.extra_clusters) * actual + preds, minlength=self.n_classes * (self.n_classes + self.extra_clusters)) \ .reshape(self.n_classes, self.n_classes + self.extra_clusters).t().to(self.stats.device) def map_clusters(self, clusters): if self.extra_clusters == 0: return torch.tensor(self.assignments[1])[clusters] else: missing = sorted(list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0]))) cluster_to_class = self.assignments[1] for missing_entry in missing: if missing_entry == cluster_to_class.shape[0]: cluster_to_class = np.append(cluster_to_class, -1) else: cluster_to_class = np.insert(cluster_to_class, missing_entry + 1, -1) cluster_to_class = torch.tensor(cluster_to_class) return cluster_to_class[clusters] def compute(self): if self.compute_hungarian: self.assignments = linear_sum_assignment(self.stats.detach().cpu(), maximize=True) # print(self.assignments) if self.extra_clusters == 0: self.histogram = self.stats[np.argsort(self.assignments[1]), :] if self.extra_clusters > 0: self.assignments_t = linear_sum_assignment(self.stats.detach().cpu().t(), maximize=True) histogram = self.stats[self.assignments_t[1], :] missing = list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0])) new_row = self.stats[missing, :].sum(0, keepdim=True) histogram = torch.cat([histogram, new_row], axis=0) new_col = torch.zeros(self.n_classes + 1, 1, device=histogram.device) self.histogram = torch.cat([histogram, new_col], axis=1) else: self.assignments = (torch.arange(self.n_classes).unsqueeze(1), torch.arange(self.n_classes).unsqueeze(1)) self.histogram = self.stats tp = torch.diag(self.histogram) fp = torch.sum(self.histogram, dim=0) - tp fn = torch.sum(self.histogram, dim=1) - tp iou = tp / (tp + fp + fn) prc = tp / (tp + fn) opc = torch.sum(tp) / torch.sum(self.histogram) metric_dict = {self.prefix + "mIoU": iou[~torch.isnan(iou)].mean().item(), self.prefix + "Accuracy": opc.item()} return {k: 100 * v for k, v in metric_dict.items()} def flexible_collate(batch): r"""Puts each data field into a tensor with outer dimension batch size""" elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) try: return torch.stack(batch, 0, out=out) except RuntimeError: return batch elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) return flexible_collate([torch.as_tensor(b) for b in batch]) elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, string_classes): return batch elif isinstance(elem, collections.abc.Mapping): return {key: flexible_collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple return elem_type(*(flexible_collate(samples) for samples in zip(*batch))) elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError('each element in list of batch should be of equal size') transposed = zip(*batch) return [flexible_collate(samples) for samples in transposed] raise TypeError(default_collate_err_msg_format.format(elem_type)) if __name__ == "__main__": fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)