File size: 2,747 Bytes
9e6cbab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np

import torch
import torch.nn as nn
from torchvision import models

from scipy.optimize import root_scalar
from scipy.special import betainc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def build_backbone(path, name='resnet50'):
    """ Builds a pretrained ResNet-50 backbone. """
    model = getattr(models, name)(pretrained=False)
    model.head = nn.Identity()
    model.fc = nn.Identity()
    checkpoint = torch.load(path, map_location=device)
    state_dict = checkpoint
    for ckpt_key in ['state_dict', 'model_state_dict', 'teacher']:
        if ckpt_key in checkpoint:
            state_dict = checkpoint[ckpt_key]
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
    msg = model.load_state_dict(state_dict, strict=False)
    return model

def get_linear_layer(weight, bias):
    """ Creates a layer that performs feature whitening or centering """
    dim_out, dim_in = weight.shape
    layer = nn.Linear(dim_in, dim_out)
    layer.weight = nn.Parameter(weight)
    layer.bias = nn.Parameter(bias)
    return layer

def load_normalization_layer(path):
    """
    Loads the normalization layer from a checkpoint and returns the layer.
    """
    checkpoint = torch.load(path, map_location=device)
    if 'whitening' in path or 'out' in path:
        D = checkpoint['weight'].shape[1]
        weight = torch.nn.Parameter(D*checkpoint['weight'])
        bias = torch.nn.Parameter(D*checkpoint['bias'])
    else:
        weight = checkpoint['weight']
        bias = checkpoint['bias']
    return get_linear_layer(weight, bias).to(device, non_blocking=True)

class NormLayerWrapper(nn.Module):
    """
    Wraps backbone model and normalization layer
    """
    def __init__(self, backbone, head):
        super(NormLayerWrapper, self).__init__()
        backbone.eval(), head.eval()
        self.backbone = backbone
        self.head = head

    def forward(self, x):
        output = self.backbone(x)
        return self.head(output)

def cosine_pvalue(c, d, k=1):
    """
    Returns the probability that the absolute value of the projection
    between random unit vectors is higher than c
    Args:
        c: cosine value
        d: dimension of the features
        k: number of dimensions of the projection
    """
    assert k>0
    a = (d - k) / 2.0
    b = k / 2.0
    if c < 0:
        return 1.0
    return betainc(a, b, 1 - c ** 2)

def pvalue_angle(dim, k=1, angle=None, proba=None):
    def f(a):
        return cosine_pvalue(np.cos(a), dim, k) - proba
    a = root_scalar(f, x0=0.49*np.pi, bracket=[0, np.pi/2])
    # a = fsolve(f, x0=0.49*np.pi)[0]
    return a.root