File size: 3,496 Bytes
b157c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn as nn
import torchvision
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

"""from https://github.com/facebookresearch/simsiam"""

class SimSiam(nn.Module):

    def __init__(self, base_encoder, dim, pred_dim):
        """
        dim: feature dimension (default: 2048)
        pred_dim: hidden dimension of the predictor (default: 512)
        symetric is True only when training
        """
        super(SimSiam, self).__init__()

        # create the encoder
        # num_classes is the output fc dimension, zero-initialize last BNs
        self.encoder = base_encoder(num_classes=dim, zero_init_residual=True)

        # build a 3-layer projector
        prev_dim = self.encoder.fc.weight.shape[1]
        self.encoder.fc = nn.Sequential(nn.Linear(prev_dim, prev_dim, bias=False),
                                        nn.BatchNorm1d(prev_dim),
                                        nn.ReLU(inplace=True), # first layer
                                        nn.Linear(prev_dim, prev_dim, bias=False),
                                        nn.BatchNorm1d(prev_dim),
                                        nn.ReLU(inplace=True), # second layer
                                        self.encoder.fc,
                                        nn.BatchNorm1d(dim, affine=False)) # output layer
        self.encoder.fc[6].bias.requires_grad = False # hack: not use bias as it is followed by BN

        # build a 2-layer predictor
        self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False),
                                        nn.BatchNorm1d(pred_dim),
                                        nn.ReLU(inplace=True), # hidden layer
                                        nn.Linear(pred_dim, dim)) # output layer

    def forward(self, x1, x2):
        z1 = self.encoder(x1).detach() # NxC
        z2 = self.encoder(x2).detach() # NxC

        p1 = self.predictor(z1) # NxC
        p2 = self.predictor(z2) # NxC
        
        loss = -(nn.CosineSimilarity(dim=1)(p1, z2).mean() + nn.CosineSimilarity(dim=1)(p2, z1).mean()) * 0.5
        
        return loss
    
class ResNet(nn.Module):
    def __init__(self, backbone):
        super().__init__()

        modules = list(backbone.children())[:-2]
        self.net = nn.Sequential(*modules)

    def forward(self, x):
        return self.net(x).mean(dim=[2, 3])
    
class RestructuredSimSiam(nn.Module):
    def __init__(self, model):
        super().__init__()

        self.encoder = ResNet(model.encoder)
        self.mlp_encoder = model.encoder.fc
        self.mlp_encoder[6].bias.requires_grad = False
        self.contrastive_head = model.predictor

    def forward(self, x, run_head = True):
        
        x = self.mlp_encoder(self.encoder(x))   # don't detach since we will do backprop for explainability
        
        if run_head:
            x = self.contrastive_head(x) 
            
        return x

    
def get_simsiam(ckpt_path = 'checkpoint_0099.pth.tar'):

    model = SimSiam(base_encoder = torchvision.models.resnet50, 
                    dim = 2048, 
                    pred_dim = 512)
    
    checkpoint = torch.load('pretrained_models/simsiam_models/'+ ckpt_path, map_location='cpu')
    state_dic = checkpoint['state_dict']
    state_dic = {k.replace("module.", ""): v for k, v in state_dic.items()}
    model.load_state_dict(state_dic)
    restructured_model = RestructuredSimSiam(model)
    return restructured_model.to(device)