Spaces:
Runtime error
Runtime error
import collections | |
import os | |
from os.path import join | |
import io | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch.multiprocessing | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import wget | |
from PIL import Image | |
from scipy.optimize import linear_sum_assignment | |
from torch._six import string_classes | |
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format | |
from torchmetrics import Metric | |
from torchvision import models | |
from torchvision import transforms as T | |
from torch.utils.tensorboard.summary import hparams | |
import matplotlib as mpl | |
torch.multiprocessing.set_sharing_strategy("file_system") | |
colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey") | |
class_names = ( | |
"Buildings", | |
"Cultivation", | |
"Natural green", | |
"Wetland", | |
"Water", | |
"Infrastructure", | |
"Background", | |
) | |
bounds = list(np.arange(len(class_names) + 1) + 1) | |
cmap = mpl.colors.ListedColormap(colors) | |
norm = mpl.colors.BoundaryNorm(bounds, cmap.N) | |
def compute_biodiv_score(image): | |
"""Compute the biodiversity score of an image | |
Args: | |
image (_type_): _description_ | |
Returns: | |
biodiversity_score: the biodiversity score associated to the landscape of the image | |
""" | |
pix = np.array(image.getdata()) | |
return np.mean(pix) | |
import cv2 | |
def create_video(array_images, output_path="output.mp4"): | |
height, width, layers = array_images[0].shape | |
size = (width,height) | |
fourcc = cv2.VideoWriter_fourcc(*'VP90') | |
out = cv2.VideoWriter('output.mp4', fourcc, 2, size) | |
for i in range(len(array_images)): | |
out.write(array_images[i]) | |
out.release() | |
return out | |
def transform_to_pil(outputs, alpha=0.3): | |
"""Turn an ouput into a PIL | |
Args: | |
outputs (_type_): _description_ | |
alpha (float, optional): _description_. Defaults to 0.3. | |
Returns: | |
_type_: _description_ | |
""" | |
# Transform img with torch | |
img = torch.moveaxis(prep_for_plot(outputs["img"][0]), -1, 0) | |
img = T.ToPILImage()(img) | |
# Transform label by saving it then open it | |
label = outputs["linear_preds"][0].numpy() | |
# image_label = Image.fromarray(label, mode="P") | |
plt.imsave("output/label.png", label, cmap=cmap) | |
image_label = Image.open("output/label.png") | |
# Overlay labels with img wit alpha | |
background = img.convert("RGBA") | |
overlay = image_label.convert("RGBA") | |
labeled_img = Image.blend(background, overlay, alpha) | |
labeled_img = labeled_img.convert("RGB") | |
return img, image_label, labeled_img | |
def prep_for_plot(img, rescale=True, resize=None): | |
if resize is not None: | |
img = F.interpolate(img.unsqueeze(0), resize, mode="bilinear") | |
else: | |
img = img.unsqueeze(0) | |
plot_img = unnorm(img).squeeze(0).cpu().permute(1, 2, 0) | |
if rescale: | |
plot_img = (plot_img - plot_img.min()) / (plot_img.max() - plot_img.min()) | |
return plot_img | |
def add_plot(writer, name, step): | |
buf = io.BytesIO() | |
plt.savefig(buf, format='jpeg', dpi=100) | |
buf.seek(0) | |
image = Image.open(buf) | |
image = T.ToTensor()(image) | |
writer.add_image(name, image, step) | |
plt.clf() | |
plt.close() | |
def shuffle(x): | |
return x[torch.randperm(x.shape[0])] | |
def add_hparams_fixed(writer, hparam_dict, metric_dict, global_step): | |
exp, ssi, sei = hparams(hparam_dict, metric_dict) | |
writer.file_writer.add_summary(exp) | |
writer.file_writer.add_summary(ssi) | |
writer.file_writer.add_summary(sei) | |
for k, v in metric_dict.items(): | |
writer.add_scalar(k, v, global_step) | |
def resize(classes: torch.Tensor, size: int): | |
return F.interpolate(classes, (size, size), mode="bilinear", align_corners=False) | |
def one_hot_feats(labels, n_classes): | |
return F.one_hot(labels, n_classes).permute(0, 3, 1, 2).to(torch.float32) | |
def load_model(model_type, data_dir): | |
if model_type == "robust_resnet50": | |
model = models.resnet50(pretrained=False) | |
model_file = join(data_dir, 'imagenet_l2_3_0.pt') | |
if not os.path.exists(model_file): | |
wget.download("http://6.869.csail.mit.edu/fa19/psets19/pset6/imagenet_l2_3_0.pt", | |
model_file) | |
model_weights = torch.load(model_file) | |
model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if | |
'model' in name} | |
model.load_state_dict(model_weights_modified) | |
model = nn.Sequential(*list(model.children())[:-1]) | |
elif model_type == "densecl": | |
model = models.resnet50(pretrained=False) | |
model_file = join(data_dir, 'densecl_r50_coco_1600ep.pth') | |
if not os.path.exists(model_file): | |
wget.download("https://cloudstor.aarnet.edu.au/plus/s/3GapXiWuVAzdKwJ/download", | |
model_file) | |
model_weights = torch.load(model_file) | |
# model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if | |
# 'model' in name} | |
model.load_state_dict(model_weights['state_dict'], strict=False) | |
model = nn.Sequential(*list(model.children())[:-1]) | |
elif model_type == "resnet50": | |
model = models.resnet50(pretrained=True) | |
model = nn.Sequential(*list(model.children())[:-1]) | |
elif model_type == "mocov2": | |
model = models.resnet50(pretrained=False) | |
model_file = join(data_dir, 'moco_v2_800ep_pretrain.pth.tar') | |
if not os.path.exists(model_file): | |
wget.download("https://dl.fbaipublicfiles.com/moco/moco_checkpoints/" | |
"moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", model_file) | |
checkpoint = torch.load(model_file) | |
# rename moco pre-trained keys | |
state_dict = checkpoint['state_dict'] | |
for k in list(state_dict.keys()): | |
# retain only encoder_q up to before the embedding layer | |
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): | |
# remove prefix | |
state_dict[k[len("module.encoder_q."):]] = state_dict[k] | |
# delete renamed or unused k | |
del state_dict[k] | |
msg = model.load_state_dict(state_dict, strict=False) | |
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} | |
model = nn.Sequential(*list(model.children())[:-1]) | |
elif model_type == "densenet121": | |
model = models.densenet121(pretrained=True) | |
model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))]) | |
elif model_type == "vgg11": | |
model = models.vgg11(pretrained=True) | |
model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))]) | |
else: | |
raise ValueError("No model: {} found".format(model_type)) | |
model.eval() | |
model.cuda() | |
return model | |
class UnNormalize(object): | |
def __init__(self, mean, std): | |
self.mean = mean | |
self.std = std | |
def __call__(self, image): | |
image2 = torch.clone(image) | |
for t, m, s in zip(image2, self.mean, self.std): | |
t.mul_(s).add_(m) | |
return image2 | |
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
class ToTargetTensor(object): | |
def __call__(self, target): | |
return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0) | |
def prep_args(): | |
import sys | |
old_args = sys.argv | |
new_args = [old_args.pop(0)] | |
while len(old_args) > 0: | |
arg = old_args.pop(0) | |
if len(arg.split("=")) == 2: | |
new_args.append(arg) | |
elif arg.startswith("--"): | |
new_args.append(arg[2:] + "=" + old_args.pop(0)) | |
else: | |
raise ValueError("Unexpected arg style {}".format(arg)) | |
sys.argv = new_args | |
def get_transform(res, is_label, crop_type): | |
if crop_type == "center": | |
cropper = T.CenterCrop(res) | |
elif crop_type == "random": | |
cropper = T.RandomCrop(res) | |
elif crop_type is None: | |
cropper = T.Lambda(lambda x: x) | |
res = (res, res) | |
else: | |
raise ValueError("Unknown Cropper {}".format(crop_type)) | |
if is_label: | |
return T.Compose([T.Resize(res, Image.NEAREST), | |
cropper, | |
ToTargetTensor()]) | |
else: | |
return T.Compose([T.Resize(res, Image.NEAREST), | |
cropper, | |
T.ToTensor(), | |
normalize]) | |
def _remove_axes(ax): | |
ax.xaxis.set_major_formatter(plt.NullFormatter()) | |
ax.yaxis.set_major_formatter(plt.NullFormatter()) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
def remove_axes(axes): | |
if len(axes.shape) == 2: | |
for ax1 in axes: | |
for ax in ax1: | |
_remove_axes(ax) | |
else: | |
for ax in axes: | |
_remove_axes(ax) | |
class UnsupervisedMetrics(Metric): | |
def __init__(self, prefix: str, n_classes: int, extra_clusters: int, compute_hungarian: bool, | |
dist_sync_on_step=True): | |
# call `self.add_state`for every internal state that is needed for the metrics computations | |
# dist_reduce_fx indicates the function that should be used to reduce | |
# state from multiple processes | |
super().__init__(dist_sync_on_step=dist_sync_on_step) | |
self.n_classes = n_classes | |
self.extra_clusters = extra_clusters | |
self.compute_hungarian = compute_hungarian | |
self.prefix = prefix | |
self.add_state("stats", | |
default=torch.zeros(n_classes + self.extra_clusters, n_classes, dtype=torch.int64), | |
dist_reduce_fx="sum") | |
def update(self, preds: torch.Tensor, target: torch.Tensor): | |
with torch.no_grad(): | |
actual = target.reshape(-1) | |
preds = preds.reshape(-1) | |
mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes) | |
actual = actual[mask] | |
preds = preds[mask] | |
self.stats += torch.bincount( | |
(self.n_classes + self.extra_clusters) * actual + preds, | |
minlength=self.n_classes * (self.n_classes + self.extra_clusters)) \ | |
.reshape(self.n_classes, self.n_classes + self.extra_clusters).t().to(self.stats.device) | |
def map_clusters(self, clusters): | |
if self.extra_clusters == 0: | |
return torch.tensor(self.assignments[1])[clusters] | |
else: | |
missing = sorted(list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0]))) | |
cluster_to_class = self.assignments[1] | |
for missing_entry in missing: | |
if missing_entry == cluster_to_class.shape[0]: | |
cluster_to_class = np.append(cluster_to_class, -1) | |
else: | |
cluster_to_class = np.insert(cluster_to_class, missing_entry + 1, -1) | |
cluster_to_class = torch.tensor(cluster_to_class) | |
return cluster_to_class[clusters] | |
def compute(self): | |
if self.compute_hungarian: | |
self.assignments = linear_sum_assignment(self.stats.detach().cpu(), maximize=True) | |
# print(self.assignments) | |
if self.extra_clusters == 0: | |
self.histogram = self.stats[np.argsort(self.assignments[1]), :] | |
if self.extra_clusters > 0: | |
self.assignments_t = linear_sum_assignment(self.stats.detach().cpu().t(), maximize=True) | |
histogram = self.stats[self.assignments_t[1], :] | |
missing = list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0])) | |
new_row = self.stats[missing, :].sum(0, keepdim=True) | |
histogram = torch.cat([histogram, new_row], axis=0) | |
new_col = torch.zeros(self.n_classes + 1, 1, device=histogram.device) | |
self.histogram = torch.cat([histogram, new_col], axis=1) | |
else: | |
self.assignments = (torch.arange(self.n_classes).unsqueeze(1), | |
torch.arange(self.n_classes).unsqueeze(1)) | |
self.histogram = self.stats | |
tp = torch.diag(self.histogram) | |
fp = torch.sum(self.histogram, dim=0) - tp | |
fn = torch.sum(self.histogram, dim=1) - tp | |
iou = tp / (tp + fp + fn) | |
prc = tp / (tp + fn) | |
opc = torch.sum(tp) / torch.sum(self.histogram) | |
metric_dict = {self.prefix + "mIoU": iou[~torch.isnan(iou)].mean().item(), | |
self.prefix + "Accuracy": opc.item()} | |
return {k: 100 * v for k, v in metric_dict.items()} | |
def flexible_collate(batch): | |
r"""Puts each data field into a tensor with outer dimension batch size""" | |
elem = batch[0] | |
elem_type = type(elem) | |
if isinstance(elem, torch.Tensor): | |
out = None | |
if torch.utils.data.get_worker_info() is not None: | |
# If we're in a background process, concatenate directly into a | |
# shared memory tensor to avoid an extra copy | |
numel = sum([x.numel() for x in batch]) | |
storage = elem.storage()._new_shared(numel) | |
out = elem.new(storage) | |
try: | |
return torch.stack(batch, 0, out=out) | |
except RuntimeError: | |
return batch | |
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | |
and elem_type.__name__ != 'string_': | |
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': | |
# array of string classes and object | |
if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | |
raise TypeError(default_collate_err_msg_format.format(elem.dtype)) | |
return flexible_collate([torch.as_tensor(b) for b in batch]) | |
elif elem.shape == (): # scalars | |
return torch.as_tensor(batch) | |
elif isinstance(elem, float): | |
return torch.tensor(batch, dtype=torch.float64) | |
elif isinstance(elem, int): | |
return torch.tensor(batch) | |
elif isinstance(elem, string_classes): | |
return batch | |
elif isinstance(elem, collections.abc.Mapping): | |
return {key: flexible_collate([d[key] for d in batch]) for key in elem} | |
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple | |
return elem_type(*(flexible_collate(samples) for samples in zip(*batch))) | |
elif isinstance(elem, collections.abc.Sequence): | |
# check to make sure that the elements in batch have consistent size | |
it = iter(batch) | |
elem_size = len(next(it)) | |
if not all(len(elem) == elem_size for elem in it): | |
raise RuntimeError('each element in list of batch should be of equal size') | |
transposed = zip(*batch) | |
return [flexible_collate(samples) for samples in transposed] | |
raise TypeError(default_collate_err_msg_format.format(elem_type)) | |