|
import logging
|
|
from collections import OrderedDict
|
|
from pathlib import Path
|
|
from typing import Union, List
|
|
|
|
import torch
|
|
import torchvision
|
|
|
|
|
|
def check_is_valid_torchvision_architecture(architecture: str):
|
|
"""Raises an ValueError if architecture is not part of available torchvision models
|
|
"""
|
|
available = sorted(
|
|
name
|
|
for name in torchvision.models.__dict__
|
|
if name.islower()
|
|
and not name.startswith("__")
|
|
and callable(torchvision.models.__dict__[name])
|
|
)
|
|
if architecture not in available:
|
|
raise ValueError(f"{architecture} not in {available}")
|
|
|
|
|
|
def build_base_model(arch: str):
|
|
|
|
model = torchvision.models.__dict__[arch](pretrained=True)
|
|
|
|
|
|
if arch in ["mobilenet_v2"]:
|
|
nfeatures = model.classifier[-1].in_features
|
|
model = torch.nn.Sequential(*list(model.children())[:-1])
|
|
elif arch in ["densenet121", "densenet161", "densenet169"]:
|
|
nfeatures = model.classifier.in_features
|
|
model = torch.nn.Sequential(*list(model.children())[:-1])
|
|
elif "resne" in arch:
|
|
|
|
nfeatures = model.fc.in_features
|
|
model = torch.nn.Sequential(*list(model.children())[:-2])
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
model.avgpool = torch.nn.AdaptiveAvgPool2d(1)
|
|
model.flatten = torch.nn.Flatten(start_dim=1)
|
|
return model, nfeatures
|
|
|
|
|
|
def load_weights_if_available(
|
|
model: torch.nn.Module, classifier: torch.nn.Module, weights_path: Union[str, Path]
|
|
):
|
|
|
|
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
|
|
|
|
state_dict_features = OrderedDict()
|
|
state_dict_classifier = OrderedDict()
|
|
for k, w in checkpoint["state_dict"].items():
|
|
if k.startswith("model"):
|
|
state_dict_features[k.replace("model.", "")] = w
|
|
elif k.startswith("classifier"):
|
|
state_dict_classifier[k.replace("classifier.", "")] = w
|
|
else:
|
|
logging.warning(f"Unexpected prefix in state_dict: {k}")
|
|
model.load_state_dict(state_dict_features, strict=True)
|
|
return model, classifier
|
|
|
|
|
|
def vectorized_gc_distance(latitudes, longitudes, latitudes_gt, longitudes_gt):
|
|
R = 6371
|
|
factor_rad = 0.01745329252
|
|
longitudes = factor_rad * longitudes
|
|
longitudes_gt = factor_rad * longitudes_gt
|
|
latitudes = factor_rad * latitudes
|
|
latitudes_gt = factor_rad * latitudes_gt
|
|
delta_long = longitudes_gt - longitudes
|
|
delta_lat = latitudes_gt - latitudes
|
|
subterm0 = torch.sin(delta_lat / 2) ** 2
|
|
subterm1 = torch.cos(latitudes) * torch.cos(latitudes_gt)
|
|
subterm2 = torch.sin(delta_long / 2) ** 2
|
|
subterm1 = subterm1 * subterm2
|
|
a = subterm0 + subterm1
|
|
c = 2 * torch.asin(torch.sqrt(a))
|
|
gcd = R * c
|
|
return gcd
|
|
|
|
|
|
def gcd_threshold_eval(gc_dists, thresholds=[1, 25, 200, 750, 2500]):
|
|
|
|
results = {}
|
|
for thres in thresholds:
|
|
results[thres] = torch.true_divide(
|
|
torch.sum(gc_dists <= thres), len(gc_dists)
|
|
).item()
|
|
return results
|
|
|
|
|
|
def accuracy(output, target, partitioning_shortnames: list, topk=(1, 5, 10)):
|
|
def _accuracy(output, target, topk=(1,)):
|
|
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
|
with torch.no_grad():
|
|
maxk = max(topk)
|
|
batch_size = target.size(0)
|
|
|
|
_, pred = output.topk(maxk, 1, True, True)
|
|
pred = pred.t()
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
|
|
res = {}
|
|
for k in topk:
|
|
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
|
res[k] = correct_k / batch_size
|
|
return res
|
|
|
|
with torch.no_grad():
|
|
out_dict = {}
|
|
for i, pname in enumerate(partitioning_shortnames):
|
|
res_dict = _accuracy(output[i], target[i], topk=topk)
|
|
for k, v in res_dict.items():
|
|
out_dict[f"acc{k}_val/{pname}"] = v
|
|
|
|
return out_dict
|
|
|
|
|
|
def summarize_gcd_stats(pnames: List[str], outputs, hierarchy=None):
|
|
gcd_dict = {}
|
|
metric_names = [f"gcd_{p}_val" for p in pnames]
|
|
if hierarchy is not None:
|
|
metric_names.append("gcd_hierarchy_val")
|
|
for metric_name in metric_names:
|
|
distances_flat = [output[metric_name] for output in outputs]
|
|
distances_flat = torch.cat(distances_flat, dim=0)
|
|
gcd_results = gcd_threshold_eval(distances_flat)
|
|
for gcd_thres, acc in gcd_results.items():
|
|
gcd_dict[f"{metric_name}/{gcd_thres}"] = acc
|
|
return gcd_dict
|
|
|
|
|
|
def summarize_test_gcd(pnames, outputs, hierarchy=None):
|
|
def _eval(output):
|
|
|
|
accuracy_outputs = {}
|
|
if hierarchy is not None:
|
|
pnames.append("hierarchy")
|
|
for pname in pnames:
|
|
|
|
distances_flat = torch.cat([x[pname] for x in output], dim=0)
|
|
|
|
acc_dict = gcd_threshold_eval(distances_flat)
|
|
accuracy_outputs[f"acc_test/{pname}"] = acc_dict
|
|
return accuracy_outputs
|
|
|
|
result = {}
|
|
|
|
if isinstance(outputs[0], dict):
|
|
result = _eval(outputs)
|
|
elif isinstance(outputs[0], list):
|
|
for testset_index, output in enumerate(outputs):
|
|
result[testset_index] = _eval(output)
|
|
else:
|
|
raise TypeError
|
|
|
|
return result
|
|
|
|
|
|
def summarize_loss_acc_stats(pnames: List[str], outputs, topk=[1, 5, 10]):
|
|
|
|
loss_acc_dict = {}
|
|
metric_names = []
|
|
for k in topk:
|
|
accuracy_names = [f"acc{k}_val/{p}" for p in pnames]
|
|
metric_names.extend(accuracy_names)
|
|
metric_names.extend([f"loss_val/{p}" for p in pnames])
|
|
for metric_name in ["loss_val/total", *metric_names]:
|
|
metric_total = 0
|
|
for output in outputs:
|
|
metric_value = output[metric_name]
|
|
metric_total += metric_value
|
|
loss_acc_dict[metric_name] = metric_total / len(outputs)
|
|
return loss_acc_dict |