Spaces:
Running
on
Zero
Running
on
Zero
# https://raw.githubusercontent.com/facebookresearch/dino/main/utils.py | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Misc functions. | |
Mostly copy-paste from torchvision references or other public repos like DETR: | |
https://github.com/facebookresearch/detr/blob/master/util/misc.py | |
""" | |
import os | |
import sys | |
import time | |
import math | |
import random | |
import datetime | |
import subprocess | |
from collections import defaultdict, deque | |
import numpy as np | |
import torch | |
from torch import nn | |
import torch.distributed as dist | |
from PIL import ImageFilter, ImageOps | |
class GaussianBlur(object): | |
""" | |
Apply Gaussian Blur to the PIL image. | |
""" | |
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): | |
self.prob = p | |
self.radius_min = radius_min | |
self.radius_max = radius_max | |
def __call__(self, img): | |
do_it = random.random() <= self.prob | |
if not do_it: | |
return img | |
return img.filter( | |
ImageFilter.GaussianBlur( | |
radius=random.uniform(self.radius_min, self.radius_max))) | |
class Solarization(object): | |
""" | |
Apply Solarization to the PIL image. | |
""" | |
def __init__(self, p): | |
self.p = p | |
def __call__(self, img): | |
if random.random() < self.p: | |
return ImageOps.solarize(img) | |
else: | |
return img | |
def load_pretrained_weights(model, pretrained_weights, checkpoint_key, | |
model_name, patch_size): | |
if os.path.isfile(pretrained_weights): | |
state_dict = torch.load(pretrained_weights, map_location="cpu") | |
if checkpoint_key is not None and checkpoint_key in state_dict: | |
print(f"Take key {checkpoint_key} in provided checkpoint dict") | |
state_dict = state_dict[checkpoint_key] | |
# remove `module.` prefix | |
state_dict = { | |
k.replace("module.", ""): v | |
for k, v in state_dict.items() | |
} | |
# remove `backbone.` prefix induced by multicrop wrapper | |
state_dict = { | |
k.replace("backbone.", ""): v | |
for k, v in state_dict.items() | |
} | |
msg = model.load_state_dict(state_dict, strict=False) | |
print('Pretrained weights found at {} and loaded with msg: {}'.format( | |
pretrained_weights, msg)) | |
else: | |
print( | |
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate." | |
) | |
url = None | |
if model_name == "vit_small" and patch_size == 16: | |
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" | |
elif model_name == "vit_small" and patch_size == 8: | |
url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth" | |
elif model_name == "vit_base" and patch_size == 16: | |
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" | |
elif model_name == "vit_base" and patch_size == 8: | |
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" | |
elif model_name == "xcit_small_12_p16": | |
url = "dino_xcit_small_12_p16_pretrain/dino_xcit_small_12_p16_pretrain.pth" | |
elif model_name == "xcit_small_12_p8": | |
url = "dino_xcit_small_12_p8_pretrain/dino_xcit_small_12_p8_pretrain.pth" | |
elif model_name == "xcit_medium_24_p16": | |
url = "dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth" | |
elif model_name == "xcit_medium_24_p8": | |
url = "dino_xcit_medium_24_p8_pretrain/dino_xcit_medium_24_p8_pretrain.pth" | |
elif model_name == "resnet50": | |
url = "dino_resnet50_pretrain/dino_resnet50_pretrain.pth" | |
if url is not None: | |
print( | |
"Since no pretrained weights have been provided, we load the reference pretrained DINO weights." | |
) | |
state_dict = torch.hub.load_state_dict_from_url( | |
url="https://dl.fbaipublicfiles.com/dino/" + url) | |
model.load_state_dict(state_dict, strict=True) | |
else: | |
print( | |
"There is no reference weights available for this model => We use random weights." | |
) | |
def load_pretrained_linear_weights(linear_classifier, model_name, patch_size): | |
url = None | |
if model_name == "vit_small" and patch_size == 16: | |
url = "dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth" | |
elif model_name == "vit_small" and patch_size == 8: | |
url = "dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth" | |
elif model_name == "vit_base" and patch_size == 16: | |
url = "dino_vitbase16_pretrain/dino_vitbase16_linearweights.pth" | |
elif model_name == "vit_base" and patch_size == 8: | |
url = "dino_vitbase8_pretrain/dino_vitbase8_linearweights.pth" | |
elif model_name == "resnet50": | |
url = "dino_resnet50_pretrain/dino_resnet50_linearweights.pth" | |
if url is not None: | |
print("We load the reference pretrained linear weights.") | |
state_dict = torch.hub.load_state_dict_from_url( | |
url="https://dl.fbaipublicfiles.com/dino/" + url)["state_dict"] | |
linear_classifier.load_state_dict(state_dict, strict=True) | |
else: | |
print("We use random linear weights.") | |
def clip_gradients(model, clip): | |
norms = [] | |
for name, p in model.named_parameters(): | |
if p.grad is not None: | |
param_norm = p.grad.data.norm(2) | |
norms.append(param_norm.item()) | |
clip_coef = clip / (param_norm + 1e-6) | |
if clip_coef < 1: | |
p.grad.data.mul_(clip_coef) | |
return norms | |
def cancel_gradients_last_layer(epoch, model, freeze_last_layer): | |
if epoch >= freeze_last_layer: | |
return | |
for n, p in model.named_parameters(): | |
if "last_layer" in n: | |
p.grad = None | |
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): | |
""" | |
Re-start from checkpoint | |
""" | |
if not os.path.isfile(ckp_path): | |
return | |
print("Found checkpoint at {}".format(ckp_path)) | |
# open checkpoint file | |
checkpoint = torch.load(ckp_path, map_location="cpu") | |
# key is what to look for in the checkpoint file | |
# value is the object to load | |
# example: {'state_dict': model} | |
for key, value in kwargs.items(): | |
if key in checkpoint and value is not None: | |
try: | |
msg = value.load_state_dict(checkpoint[key], strict=False) | |
print("=> loaded '{}' from checkpoint '{}' with msg {}".format( | |
key, ckp_path, msg)) | |
except TypeError: | |
try: | |
msg = value.load_state_dict(checkpoint[key]) | |
print("=> loaded '{}' from checkpoint: '{}'".format( | |
key, ckp_path)) | |
except ValueError: | |
print( | |
"=> failed to load '{}' from checkpoint: '{}'".format( | |
key, ckp_path)) | |
else: | |
print("=> key '{}' not found in checkpoint: '{}'".format( | |
key, ckp_path)) | |
# re load variable important for the run | |
if run_variables is not None: | |
for var_name in run_variables: | |
if var_name in checkpoint: | |
run_variables[var_name] = checkpoint[var_name] | |
def cosine_scheduler(base_value, | |
final_value, | |
epochs, | |
niter_per_ep, | |
warmup_epochs=0, | |
start_warmup_value=0): | |
warmup_schedule = np.array([]) | |
warmup_iters = warmup_epochs * niter_per_ep | |
if warmup_epochs > 0: | |
warmup_schedule = np.linspace(start_warmup_value, base_value, | |
warmup_iters) | |
iters = np.arange(epochs * niter_per_ep - warmup_iters) | |
schedule = final_value + 0.5 * (base_value - final_value) * ( | |
1 + np.cos(np.pi * iters / len(iters))) | |
schedule = np.concatenate((warmup_schedule, schedule)) | |
assert len(schedule) == epochs * niter_per_ep | |
return schedule | |
def bool_flag(s): | |
""" | |
Parse boolean arguments from the command line. | |
""" | |
FALSY_STRINGS = {"off", "false", "0"} | |
TRUTHY_STRINGS = {"on", "true", "1"} | |
if s.lower() in FALSY_STRINGS: | |
return False | |
elif s.lower() in TRUTHY_STRINGS: | |
return True | |
else: | |
raise argparse.ArgumentTypeError("invalid value for a boolean flag") | |
def fix_random_seeds(seed=31): | |
""" | |
Fix random seeds. | |
""" | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
class SmoothedValue(object): | |
"""Track a series of values and provide access to smoothed values over a | |
window or the global series average. | |
""" | |
def __init__(self, window_size=20, fmt=None): | |
if fmt is None: | |
fmt = "{median:.6f} ({global_avg:.6f})" | |
self.deque = deque(maxlen=window_size) | |
self.total = 0.0 | |
self.count = 0 | |
self.fmt = fmt | |
def update(self, value, n=1): | |
self.deque.append(value) | |
self.count += n | |
self.total += value * n | |
def synchronize_between_processes(self): | |
""" | |
Warning: does not synchronize the deque! | |
""" | |
if not is_dist_avail_and_initialized(): | |
return | |
t = torch.tensor([self.count, self.total], | |
dtype=torch.float64, | |
device='cuda') | |
dist.barrier() | |
dist.all_reduce(t) | |
t = t.tolist() | |
self.count = int(t[0]) | |
self.total = t[1] | |
def median(self): | |
d = torch.tensor(list(self.deque)) | |
return d.median().item() | |
def avg(self): | |
d = torch.tensor(list(self.deque), dtype=torch.float32) | |
return d.mean().item() | |
def global_avg(self): | |
return self.total / self.count | |
def max(self): | |
return max(self.deque) | |
def value(self): | |
return self.deque[-1] | |
def __str__(self): | |
return self.fmt.format(median=self.median, | |
avg=self.avg, | |
global_avg=self.global_avg, | |
max=self.max, | |
value=self.value) | |
def reduce_dict(input_dict, average=True): | |
""" | |
Args: | |
input_dict (dict): all the values will be reduced | |
average (bool): whether to do average or sum | |
Reduce the values in the dictionary from all processes so that all processes | |
have the averaged results. Returns a dict with the same fields as | |
input_dict, after reduction. | |
""" | |
world_size = get_world_size() | |
if world_size < 2: | |
return input_dict | |
with torch.no_grad(): | |
names = [] | |
values = [] | |
# sort the keys so that they are consistent across processes | |
for k in sorted(input_dict.keys()): | |
names.append(k) | |
values.append(input_dict[k]) | |
values = torch.stack(values, dim=0) | |
dist.all_reduce(values) | |
if average: | |
values /= world_size | |
reduced_dict = {k: v for k, v in zip(names, values)} | |
return reduced_dict | |
class MetricLogger(object): | |
def __init__(self, delimiter="\t"): | |
self.meters = defaultdict(SmoothedValue) | |
self.delimiter = delimiter | |
def update(self, **kwargs): | |
for k, v in kwargs.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.item() | |
assert isinstance(v, (float, int)) | |
self.meters[k].update(v) | |
def __getattr__(self, attr): | |
if attr in self.meters: | |
return self.meters[attr] | |
if attr in self.__dict__: | |
return self.__dict__[attr] | |
raise AttributeError("'{}' object has no attribute '{}'".format( | |
type(self).__name__, attr)) | |
def __str__(self): | |
loss_str = [] | |
for name, meter in self.meters.items(): | |
loss_str.append("{}: {}".format(name, str(meter))) | |
return self.delimiter.join(loss_str) | |
def synchronize_between_processes(self): | |
for meter in self.meters.values(): | |
meter.synchronize_between_processes() | |
def add_meter(self, name, meter): | |
self.meters[name] = meter | |
def log_every(self, iterable, print_freq, header=None): | |
i = 0 | |
if not header: | |
header = '' | |
start_time = time.time() | |
end = time.time() | |
iter_time = SmoothedValue(fmt='{avg:.6f}') | |
data_time = SmoothedValue(fmt='{avg:.6f}') | |
space_fmt = ':' + str(len(str(len(iterable)))) + 'd' | |
if torch.cuda.is_available(): | |
log_msg = self.delimiter.join([ | |
header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', | |
'time: {time}', 'data: {data}', 'max mem: {memory:.0f}' | |
]) | |
else: | |
log_msg = self.delimiter.join([ | |
header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', | |
'time: {time}', 'data: {data}' | |
]) | |
MB = 1024.0 * 1024.0 | |
for obj in iterable: | |
data_time.update(time.time() - end) | |
yield obj | |
iter_time.update(time.time() - end) | |
if i % print_freq == 0 or i == len(iterable) - 1: | |
eta_seconds = iter_time.global_avg * (len(iterable) - i) | |
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | |
if torch.cuda.is_available(): | |
print( | |
log_msg.format( | |
i, | |
len(iterable), | |
eta=eta_string, | |
meters=str(self), | |
time=str(iter_time), | |
data=str(data_time), | |
memory=torch.cuda.max_memory_allocated() / MB)) | |
else: | |
print( | |
log_msg.format(i, | |
len(iterable), | |
eta=eta_string, | |
meters=str(self), | |
time=str(iter_time), | |
data=str(data_time))) | |
i += 1 | |
end = time.time() | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
print('{} Total time: {} ({:.6f} s / it)'.format( | |
header, total_time_str, total_time / len(iterable))) | |
def get_sha(): | |
cwd = os.path.dirname(os.path.abspath(__file__)) | |
def _run(command): | |
return subprocess.check_output(command, | |
cwd=cwd).decode('ascii').strip() | |
sha = 'N/A' | |
diff = "clean" | |
branch = 'N/A' | |
try: | |
sha = _run(['git', 'rev-parse', 'HEAD']) | |
subprocess.check_output(['git', 'diff'], cwd=cwd) | |
diff = _run(['git', 'diff-index', 'HEAD']) | |
diff = "has uncommited changes" if diff else "clean" | |
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) | |
except Exception: | |
pass | |
message = f"sha: {sha}, status: {diff}, branch: {branch}" | |
return message | |
def is_dist_avail_and_initialized(): | |
if not dist.is_available(): | |
return False | |
if not dist.is_initialized(): | |
return False | |
return True | |
def get_world_size(): | |
if not is_dist_avail_and_initialized(): | |
return 1 | |
return dist.get_world_size() | |
def get_rank(): | |
if not is_dist_avail_and_initialized(): | |
return 0 | |
return dist.get_rank() | |
def is_main_process(): | |
return get_rank() == 0 | |
def save_on_master(*args, **kwargs): | |
if is_main_process(): | |
torch.save(*args, **kwargs) | |
def setup_for_distributed(is_master): | |
""" | |
This function disables printing when not in master process | |
""" | |
import builtins as __builtin__ | |
builtin_print = __builtin__.print | |
def print(*args, **kwargs): | |
force = kwargs.pop('force', False) | |
if is_master or force: | |
builtin_print(*args, **kwargs) | |
__builtin__.print = print | |
def init_distributed_mode(args): | |
# launched with torch.distributed.launch | |
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: | |
args.rank = int(os.environ["RANK"]) | |
args.world_size = int(os.environ['WORLD_SIZE']) | |
args.gpu = int(os.environ['LOCAL_RANK']) | |
# launched with submitit on a slurm cluster | |
elif 'SLURM_PROCID' in os.environ: | |
args.rank = int(os.environ['SLURM_PROCID']) | |
args.gpu = args.rank % torch.cuda.device_count() | |
# launched naively with `python main_dino.py` | |
# we manually add MASTER_ADDR and MASTER_PORT to env variables | |
elif torch.cuda.is_available(): | |
print('Will run the code on one GPU.') | |
args.rank, args.gpu, args.world_size = 0, 0, 1 | |
os.environ['MASTER_ADDR'] = '127.0.0.1' | |
os.environ['MASTER_PORT'] = '29500' | |
else: | |
print('Does not support training without GPU.') | |
sys.exit(1) | |
dist.init_process_group( | |
backend="nccl", | |
init_method=args.dist_url, | |
world_size=args.world_size, | |
rank=args.rank, | |
) | |
torch.cuda.set_device(args.gpu) | |
print('| distributed init (rank {}): {}'.format(args.rank, args.dist_url), | |
flush=True) | |
dist.barrier() | |
setup_for_distributed(args.rank == 0) | |
def accuracy(output, target, topk=(1, )): | |
"""Computes the accuracy over the k top predictions for the specified values of k""" | |
maxk = max(topk) | |
batch_size = target.size(0) | |
_, pred = output.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(target.reshape(1, -1).expand_as(pred)) | |
return [ | |
correct[:k].reshape(-1).float().sum(0) * 100. / batch_size | |
for k in topk | |
] | |
def _no_grad_trunc_normal_(tensor, mean, std, a, b): | |
# Cut & paste from PyTorch official master until it's in a few official releases - RW | |
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | |
def norm_cdf(x): | |
# Computes standard normal cumulative distribution function | |
return (1. + math.erf(x / math.sqrt(2.))) / 2. | |
if (mean < a - 2 * std) or (mean > b + 2 * std): | |
warnings.warn( | |
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " | |
"The distribution of values may be incorrect.", | |
stacklevel=2) | |
with torch.no_grad(): | |
# Values are generated by using a truncated uniform distribution and | |
# then using the inverse CDF for the normal distribution. | |
# Get upper and lower cdf values | |
l = norm_cdf((a - mean) / std) | |
u = norm_cdf((b - mean) / std) | |
# Uniformly fill tensor with values from [l, u], then translate to | |
# [2l-1, 2u-1]. | |
tensor.uniform_(2 * l - 1, 2 * u - 1) | |
# Use inverse cdf transform for normal distribution to get truncated | |
# standard normal | |
tensor.erfinv_() | |
# Transform to proper mean, std | |
tensor.mul_(std * math.sqrt(2.)) | |
tensor.add_(mean) | |
# Clamp to ensure it's in the proper range | |
tensor.clamp_(min=a, max=b) | |
return tensor | |
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): | |
# type: (Tensor, float, float, float, float) -> Tensor | |
return _no_grad_trunc_normal_(tensor, mean, std, a, b) | |
class LARS(torch.optim.Optimizer): | |
""" | |
Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py | |
""" | |
def __init__(self, | |
params, | |
lr=0, | |
weight_decay=0, | |
momentum=0.9, | |
eta=0.001, | |
weight_decay_filter=None, | |
lars_adaptation_filter=None): | |
defaults = dict(lr=lr, | |
weight_decay=weight_decay, | |
momentum=momentum, | |
eta=eta, | |
weight_decay_filter=weight_decay_filter, | |
lars_adaptation_filter=lars_adaptation_filter) | |
super().__init__(params, defaults) | |
def step(self): | |
for g in self.param_groups: | |
for p in g['params']: | |
dp = p.grad | |
if dp is None: | |
continue | |
if p.ndim != 1: | |
dp = dp.add(p, alpha=g['weight_decay']) | |
if p.ndim != 1: | |
param_norm = torch.norm(p) | |
update_norm = torch.norm(dp) | |
one = torch.ones_like(param_norm) | |
q = torch.where( | |
param_norm > 0., | |
torch.where(update_norm > 0, | |
(g['eta'] * param_norm / update_norm), | |
one), one) | |
dp = dp.mul(q) | |
param_state = self.state[p] | |
if 'mu' not in param_state: | |
param_state['mu'] = torch.zeros_like(p) | |
mu = param_state['mu'] | |
mu.mul_(g['momentum']).add_(dp) | |
p.add_(mu, alpha=-g['lr']) | |
class MultiCropWrapper(nn.Module): | |
""" | |
Perform forward pass separately on each resolution input. | |
The inputs corresponding to a single resolution are clubbed and single | |
forward is run on the same resolution inputs. Hence we do several | |
forward passes = number of different resolutions used. We then | |
concatenate all the output features and run the head forward on these | |
concatenated features. | |
""" | |
def __init__(self, backbone, head): | |
super(MultiCropWrapper, self).__init__() | |
# disable layers dedicated to ImageNet labels classification | |
backbone.fc, backbone.head = nn.Identity(), nn.Identity() | |
self.backbone = backbone | |
self.head = head | |
def forward(self, x): | |
# convert to list | |
if not isinstance(x, list): | |
x = [x] | |
idx_crops = torch.cumsum( | |
torch.unique_consecutive( | |
torch.tensor([inp.shape[-1] for inp in x]), | |
return_counts=True, | |
)[1], 0) | |
start_idx, output = 0, torch.empty(0).to(x[0].device) | |
for end_idx in idx_crops: | |
_out = self.backbone(torch.cat(x[start_idx:end_idx])) | |
# The output is a tuple with XCiT model. See: | |
# https://github.com/facebookresearch/xcit/blob/master/xcit.py#L404-L405 | |
if isinstance(_out, tuple): | |
_out = _out[0] | |
# accumulate outputs | |
output = torch.cat((output, _out)) | |
start_idx = end_idx | |
# Run the head forward on the concatenated features. | |
return self.head(output) | |
def get_params_groups(model): | |
regularized = [] | |
not_regularized = [] | |
for name, param in model.named_parameters(): | |
if not param.requires_grad: | |
continue | |
# we do not regularize biases nor Norm parameters | |
if name.endswith(".bias") or len(param.shape) == 1: | |
not_regularized.append(param) | |
else: | |
regularized.append(param) | |
return [{ | |
'params': regularized | |
}, { | |
'params': not_regularized, | |
'weight_decay': 0. | |
}] | |
def has_batchnorms(model): | |
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, | |
nn.SyncBatchNorm) | |
for name, module in model.named_modules(): | |
if isinstance(module, bn_types): | |
return True | |
return False | |
class PCA(): | |
""" | |
Class to compute and apply PCA. | |
""" | |
def __init__(self, dim=256, whit=0.5): | |
self.dim = dim | |
self.whit = whit | |
self.mean = None | |
def train_pca(self, cov): | |
""" | |
Takes a covariance matrix (np.ndarray) as input. | |
""" | |
d, v = np.linalg.eigh(cov) | |
eps = d.max() * 1e-5 | |
n_0 = (d < eps).sum() | |
if n_0 > 0: | |
d[d < eps] = eps | |
# total energy | |
totenergy = d.sum() | |
# sort eigenvectors with eigenvalues order | |
idx = np.argsort(d)[::-1][:self.dim] | |
d = d[idx] | |
v = v[:, idx] | |
print("keeping %.2f %% of the energy" % (d.sum() / totenergy * 100.0)) | |
# for the whitening | |
d = np.diag(1. / d**self.whit) | |
# principal components | |
self.dvt = np.dot(d, v.T) | |
def apply(self, x): | |
# input is from numpy | |
if isinstance(x, np.ndarray): | |
if self.mean is not None: | |
x -= self.mean | |
return np.dot(self.dvt, x.T).T | |
# input is from torch and is on GPU | |
if x.is_cuda: | |
if self.mean is not None: | |
x -= torch.cuda.FloatTensor(self.mean) | |
return torch.mm(torch.cuda.FloatTensor(self.dvt), | |
x.transpose(0, 1)).transpose(0, 1) | |
# input if from torch, on CPU | |
if self.mean is not None: | |
x -= torch.FloatTensor(self.mean) | |
return torch.mm(torch.FloatTensor(self.dvt), | |
x.transpose(0, 1)).transpose(0, 1) | |
def compute_ap(ranks, nres): | |
""" | |
Computes average precision for given ranked indexes. | |
Arguments | |
--------- | |
ranks : zerro-based ranks of positive images | |
nres : number of positive images | |
Returns | |
------- | |
ap : average precision | |
""" | |
# number of images ranked by the system | |
nimgranks = len(ranks) | |
# accumulate trapezoids in PR-plot | |
ap = 0 | |
recall_step = 1. / nres | |
for j in np.arange(nimgranks): | |
rank = ranks[j] | |
if rank == 0: | |
precision_0 = 1. | |
else: | |
precision_0 = float(j) / rank | |
precision_1 = float(j + 1) / (rank + 1) | |
ap += (precision_0 + precision_1) * recall_step / 2. | |
return ap | |
def compute_map(ranks, gnd, kappas=[]): | |
""" | |
Computes the mAP for a given set of returned results. | |
Usage: | |
map = compute_map (ranks, gnd) | |
computes mean average precsion (map) only | |
map, aps, pr, prs = compute_map (ranks, gnd, kappas) | |
computes mean average precision (map), average precision (aps) for each query | |
computes mean precision at kappas (pr), precision at kappas (prs) for each query | |
Notes: | |
1) ranks starts from 0, ranks.shape = db_size X #queries | |
2) The junk results (e.g., the query itself) should be declared in the gnd stuct array | |
3) If there are no positive images for some query, that query is excluded from the evaluation | |
""" | |
map = 0. | |
nq = len(gnd) # number of queries | |
aps = np.zeros(nq) | |
pr = np.zeros(len(kappas)) | |
prs = np.zeros((nq, len(kappas))) | |
nempty = 0 | |
for i in np.arange(nq): | |
qgnd = np.array(gnd[i]['ok']) | |
# no positive images, skip from the average | |
if qgnd.shape[0] == 0: | |
aps[i] = float('nan') | |
prs[i, :] = float('nan') | |
nempty += 1 | |
continue | |
try: | |
qgndj = np.array(gnd[i]['junk']) | |
except: | |
qgndj = np.empty(0) | |
# sorted positions of positive and junk images (0 based) | |
pos = np.arange(ranks.shape[0])[np.in1d(ranks[:, i], qgnd)] | |
junk = np.arange(ranks.shape[0])[np.in1d(ranks[:, i], qgndj)] | |
k = 0 | |
ij = 0 | |
if len(junk): | |
# decrease positions of positives based on the number of | |
# junk images appearing before them | |
ip = 0 | |
while (ip < len(pos)): | |
while (ij < len(junk) and pos[ip] > junk[ij]): | |
k += 1 | |
ij += 1 | |
pos[ip] = pos[ip] - k | |
ip += 1 | |
# compute ap | |
ap = compute_ap(pos, len(qgnd)) | |
map = map + ap | |
aps[i] = ap | |
# compute precision @ k | |
pos += 1 # get it to 1-based | |
for j in np.arange(len(kappas)): | |
kq = min(max(pos), kappas[j]) | |
prs[i, j] = (pos <= kq).sum() / kq | |
pr = pr + prs[i, :] | |
map = map / (nq - nempty) | |
pr = pr / (nq - nempty) | |
return map, aps, pr, prs | |
def multi_scale(samples, model): | |
v = None | |
for s in [1, 1 / 2**(1 / 2), 1 / 2]: # we use 3 different scales | |
if s == 1: | |
inp = samples.clone() | |
else: | |
inp = nn.functional.interpolate(samples, | |
scale_factor=s, | |
mode='bilinear', | |
align_corners=False) | |
feats = model(inp).clone() | |
if v is None: | |
v = feats | |
else: | |
v += feats | |
v /= 3 | |
v /= v.norm() | |
return v |