yunusserhat's picture
Upload 40 files
94f372a verified
raw
history blame contribute delete
No virus
6.34 kB
import logging
from collections import OrderedDict
from pathlib import Path
from typing import Union, List
import torch
import torchvision
def check_is_valid_torchvision_architecture(architecture: str):
"""Raises an ValueError if architecture is not part of available torchvision models
"""
available = sorted(
name
for name in torchvision.models.__dict__
if name.islower()
and not name.startswith("__")
and callable(torchvision.models.__dict__[name])
)
if architecture not in available:
raise ValueError(f"{architecture} not in {available}")
def build_base_model(arch: str):
model = torchvision.models.__dict__[arch](pretrained=True)
# get input dimension before classification layer
if arch in ["mobilenet_v2"]:
nfeatures = model.classifier[-1].in_features
model = torch.nn.Sequential(*list(model.children())[:-1])
elif arch in ["densenet121", "densenet161", "densenet169"]:
nfeatures = model.classifier.in_features
model = torch.nn.Sequential(*list(model.children())[:-1])
elif "resne" in arch:
# usually all ResNet variants
nfeatures = model.fc.in_features
model = torch.nn.Sequential(*list(model.children())[:-2])
else:
raise NotImplementedError
model.avgpool = torch.nn.AdaptiveAvgPool2d(1)
model.flatten = torch.nn.Flatten(start_dim=1)
return model, nfeatures
def load_weights_if_available(
model: torch.nn.Module, classifier: torch.nn.Module, weights_path: Union[str, Path]
):
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
state_dict_features = OrderedDict()
state_dict_classifier = OrderedDict()
for k, w in checkpoint["state_dict"].items():
if k.startswith("model"):
state_dict_features[k.replace("model.", "")] = w
elif k.startswith("classifier"):
state_dict_classifier[k.replace("classifier.", "")] = w
else:
logging.warning(f"Unexpected prefix in state_dict: {k}")
model.load_state_dict(state_dict_features, strict=True)
return model, classifier
def vectorized_gc_distance(latitudes, longitudes, latitudes_gt, longitudes_gt):
R = 6371
factor_rad = 0.01745329252
longitudes = factor_rad * longitudes
longitudes_gt = factor_rad * longitudes_gt
latitudes = factor_rad * latitudes
latitudes_gt = factor_rad * latitudes_gt
delta_long = longitudes_gt - longitudes
delta_lat = latitudes_gt - latitudes
subterm0 = torch.sin(delta_lat / 2) ** 2
subterm1 = torch.cos(latitudes) * torch.cos(latitudes_gt)
subterm2 = torch.sin(delta_long / 2) ** 2
subterm1 = subterm1 * subterm2
a = subterm0 + subterm1
c = 2 * torch.asin(torch.sqrt(a))
gcd = R * c
return gcd
def gcd_threshold_eval(gc_dists, thresholds=[1, 25, 200, 750, 2500]):
# calculate accuracy for given gcd thresolds
results = {}
for thres in thresholds:
results[thres] = torch.true_divide(
torch.sum(gc_dists <= thres), len(gc_dists)
).item()
return results
def accuracy(output, target, partitioning_shortnames: list, topk=(1, 5, 10)):
def _accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = {}
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res[k] = correct_k / batch_size
return res
with torch.no_grad():
out_dict = {}
for i, pname in enumerate(partitioning_shortnames):
res_dict = _accuracy(output[i], target[i], topk=topk)
for k, v in res_dict.items():
out_dict[f"acc{k}_val/{pname}"] = v
return out_dict
def summarize_gcd_stats(pnames: List[str], outputs, hierarchy=None):
gcd_dict = {}
metric_names = [f"gcd_{p}_val" for p in pnames]
if hierarchy is not None:
metric_names.append("gcd_hierarchy_val")
for metric_name in metric_names:
distances_flat = [output[metric_name] for output in outputs]
distances_flat = torch.cat(distances_flat, dim=0)
gcd_results = gcd_threshold_eval(distances_flat)
for gcd_thres, acc in gcd_results.items():
gcd_dict[f"{metric_name}/{gcd_thres}"] = acc
return gcd_dict
def summarize_test_gcd(pnames, outputs, hierarchy=None):
def _eval(output):
# calculate acc@km for a list of given thresholds
accuracy_outputs = {}
if hierarchy is not None:
pnames.append("hierarchy")
for pname in pnames:
# concat batches of distances
distances_flat = torch.cat([x[pname] for x in output], dim=0)
# acc for all distances
acc_dict = gcd_threshold_eval(distances_flat)
accuracy_outputs[f"acc_test/{pname}"] = acc_dict
return accuracy_outputs
result = {}
if isinstance(outputs[0], dict): # only one testset
result = _eval(outputs)
elif isinstance(outputs[0], list): # multiple testsets
for testset_index, output in enumerate(outputs):
result[testset_index] = _eval(output)
else:
raise TypeError
return result
def summarize_loss_acc_stats(pnames: List[str], outputs, topk=[1, 5, 10]):
loss_acc_dict = {}
metric_names = []
for k in topk:
accuracy_names = [f"acc{k}_val/{p}" for p in pnames]
metric_names.extend(accuracy_names)
metric_names.extend([f"loss_val/{p}" for p in pnames])
for metric_name in ["loss_val/total", *metric_names]:
metric_total = 0
for output in outputs:
metric_value = output[metric_name]
metric_total += metric_value
loss_acc_dict[metric_name] = metric_total / len(outputs)
return loss_acc_dict