Annonymous commited on
Commit
b157c29
1 Parent(s): 2c29c28

Upload 4 files

Browse files
ssl_models/barlow_twins.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+
6
+ """from https://github.com/facebookresearch/barlowtwins"""
7
+
8
+ def off_diagonal(x):
9
+ # return a flattened view of the off-diagonal elements of a square matrix
10
+ n, m = x.shape
11
+ assert n == m
12
+ return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
13
+
14
+ class BarlowTwins(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ self.backbone = torchvision.models.resnet50(zero_init_residual=True)
19
+ self.backbone.fc = nn.Identity()
20
+
21
+ # projector
22
+ sizes = [2048] + list(map(int, '8192-8192-8192'.split('-')))
23
+ layers = []
24
+ for i in range(len(sizes) - 2):
25
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
26
+ layers.append(nn.BatchNorm1d(sizes[i + 1]))
27
+ layers.append(nn.ReLU(inplace=True))
28
+
29
+ layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
30
+ self.projector = nn.Sequential(*layers)
31
+
32
+ # normalization layer for the representations z1 and z2
33
+ self.bn = nn.BatchNorm1d(sizes[-1], affine=False)
34
+
35
+ def forward(self, y1, y2):
36
+ z1 = self.projector(self.backbone(y1))
37
+ z2 = self.projector(self.backbone(y2))
38
+
39
+ # empirical cross-correlation matrix
40
+ c = self.bn(z1).T @ self.bn(z2)
41
+
42
+ on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
43
+ off_diag = off_diagonal(c).pow_(2).sum()
44
+ loss = on_diag + 0.0051 * off_diag
45
+ return loss
46
+
47
+ class ResNet(nn.Module):
48
+ def __init__(self, backbone):
49
+ super().__init__()
50
+
51
+ modules = list(backbone.children())[:-2]
52
+ self.net = nn.Sequential(*modules)
53
+
54
+ def forward(self, x):
55
+ return self.net(x).mean(dim=[2, 3])
56
+
57
+ class RestructuredBarlowTwins(nn.Module):
58
+ def __init__(self, model):
59
+ super().__init__()
60
+
61
+ self.encoder = ResNet(model.backbone)
62
+ self.contrastive_head = model.projector
63
+
64
+ def forward(self, x):
65
+ x = self.encoder(x)
66
+ x = self.contrastive_head(x)
67
+ return x
68
+
69
+
70
+ def get_barlow_twins_model(ckpt_path = 'barlow_twins.pth'):
71
+ model = BarlowTwins()
72
+ state_dict = torch.load('pretrained_models/barlow_models/' + ckpt_path, map_location='cpu')
73
+ state_dict = state_dict['model']
74
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
75
+ model.load_state_dict(state_dict)
76
+ restructured_model = RestructuredBarlowTwins(model)
77
+ return restructured_model.to(device)
ssl_models/dino.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import pathlib
7
+ temp = pathlib.PosixPath
8
+ pathlib.PosixPath = pathlib.WindowsPath
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ """ from https://github.com/facebookresearch/dino"""
12
+
13
+ class DINOHead(nn.Module):
14
+
15
+ def __init__(self, in_dim, out_dim, use_bn, norm_last_layer, nlayers, hidden_dim, bottleneck_dim):
16
+ super().__init__()
17
+
18
+ nlayers = max(nlayers, 1)
19
+ if nlayers == 1:
20
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
21
+ else:
22
+ layers = [nn.Linear(in_dim, hidden_dim)]
23
+ if use_bn:
24
+ layers.append(nn.BatchNorm1d(hidden_dim))
25
+ layers.append(nn.GELU())
26
+ for _ in range(nlayers - 2):
27
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
28
+ if use_bn:
29
+ layers.append(nn.BatchNorm1d(hidden_dim))
30
+ layers.append(nn.GELU())
31
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
32
+ self.mlp = nn.Sequential(*layers)
33
+
34
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
35
+ self.last_layer.weight_g.data.fill_(1)
36
+ if norm_last_layer:
37
+ self.last_layer.weight_g.requires_grad = False
38
+
39
+ def forward(self, x):
40
+ x = self.mlp(x)
41
+ x = F.normalize(x, dim=-1, p=2)
42
+ x = self.last_layer(x)
43
+ return x
44
+
45
+ class MultiCropWrapper(nn.Module):
46
+ def __init__(self, backbone, head):
47
+ super(MultiCropWrapper, self).__init__()
48
+ backbone.fc, backbone.head = nn.Identity(), nn.Identity()
49
+ self.backbone = backbone
50
+ self.head = head
51
+
52
+ def forward(self, x):
53
+ return self.head(self.backbone(x))
54
+
55
+ class DINOLoss(nn.Module):
56
+ def __init__(self, out_dim, warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs, nepochs,
57
+ student_temp=0.1, center_momentum=0.9):
58
+ super().__init__()
59
+
60
+ self.student_temp = student_temp
61
+ self.center_momentum = center_momentum
62
+ self.register_buffer("center", torch.zeros(1, out_dim))
63
+ self.nepochs = nepochs
64
+ self.teacher_temp_schedule = np.concatenate((np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
65
+ np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp))
66
+
67
+ def forward(self, student_output, teacher_output):
68
+ student_out = student_output / self.student_temp
69
+ temp = self.teacher_temp_schedule[self.nepochs - 1] # last one
70
+ teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
71
+ teacher_out = teacher_out.detach()
72
+ loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1).mean()
73
+ return loss
74
+
75
+
76
+ class ResNet(nn.Module):
77
+ def __init__(self, backbone):
78
+ super().__init__()
79
+
80
+ modules = list(backbone.children())[:-2]
81
+ self.net = nn.Sequential(*modules)
82
+
83
+ def forward(self, x):
84
+ return self.net(x).mean(dim=[2, 3])
85
+
86
+ class RestructuredDINO(nn.Module):
87
+
88
+ def __init__(self, student, teacher):
89
+ super().__init__()
90
+
91
+ self.encoder_student = ResNet(student.backbone)
92
+ self.encoder = ResNet(teacher.backbone)
93
+
94
+ self.contrastive_head_student = student.head
95
+ self.contrastive_head = teacher.head
96
+
97
+
98
+ def forward(self, x, run_teacher):
99
+
100
+ if run_teacher:
101
+ x = self.encoder(x)
102
+ x = self.contrastive_head(x)
103
+ else:
104
+ x = self.encoder_student(x)
105
+ x = self.contrastive_head_student(x)
106
+
107
+ return x
108
+
109
+
110
+ def get_dino_model_without_loss(ckpt_path = 'dino_resnet50_pretrain_full_checkpoint.pth'):
111
+ state_dict = torch.load('pretrained_models/dino_models/' + ckpt_path, map_location='cpu')
112
+ state_dict_student = state_dict['student']
113
+ state_dict_teacher = state_dict['teacher']
114
+
115
+ state_dict_student = {k.replace("module.", ""): v for k, v in state_dict_student.items()}
116
+ state_dict_teacher = {k.replace("module.", ""): v for k, v in state_dict_teacher.items()}
117
+
118
+ student_backbone = torchvision.models.resnet50()
119
+ teacher_backbone = torchvision.models.resnet50()
120
+ embed_dim = student_backbone.fc.weight.shape[1]
121
+
122
+ student_head = DINOHead(in_dim = embed_dim, out_dim = 60000, use_bn=True, norm_last_layer=True, nlayers=2, hidden_dim=4096, bottleneck_dim=256)
123
+ teacher_head = DINOHead(in_dim = embed_dim, out_dim = 60000, use_bn =True, norm_last_layer=True, nlayers=2, hidden_dim=4096, bottleneck_dim=256)
124
+ student_head.last_layer = nn.Linear(256, 60000, bias = False)
125
+ teacher_head.last_layer = nn.Linear(256, 60000, bias = False)
126
+
127
+ student = MultiCropWrapper(student_backbone, student_head)
128
+ teacher = MultiCropWrapper(teacher_backbone, teacher_head)
129
+
130
+ student.load_state_dict(state_dict_student)
131
+ teacher.load_state_dict(state_dict_teacher)
132
+
133
+ restructured_model = RestructuredDINO(student, teacher)
134
+
135
+ return restructured_model.to(device)
136
+
137
+
138
+ def get_dino_model_with_loss(ckpt_path = 'dino_rn50_checkpoint.pth'):
139
+ state_dict = torch.load('pretrained_models/dino_models/' + ckpt_path, map_location='cpu')
140
+
141
+ state_dict_student = state_dict['student']
142
+ state_dict_teacher = state_dict['teacher']
143
+ state_dict_args = vars(state_dict['args'])
144
+ state_dic_dino_loss = state_dict['dino_loss']
145
+
146
+ state_dict_student = {k.replace("module.", ""): v for k, v in state_dict_student.items()}
147
+ state_dict_teacher = {k.replace("module.", ""): v for k, v in state_dict_teacher.items()}
148
+
149
+ student_backbone = torchvision.models.resnet50()
150
+ teacher_backbone = torchvision.models.resnet50()
151
+ embed_dim = student_backbone.fc.weight.shape[1]
152
+
153
+ student_head = DINOHead(in_dim = embed_dim,
154
+ out_dim = state_dict_args['out_dim'],
155
+ use_bn = state_dict_args['use_bn_in_head'],
156
+ norm_last_layer = state_dict_args['norm_last_layer'],
157
+ nlayers = 3,
158
+ hidden_dim = 2048,
159
+ bottleneck_dim = 256)
160
+
161
+ teacher_head = DINOHead(in_dim = embed_dim,
162
+ out_dim = state_dict_args['out_dim'],
163
+ use_bn = state_dict_args['use_bn_in_head'],
164
+ norm_last_layer = state_dict_args['norm_last_layer'],
165
+ nlayers = 3,
166
+ hidden_dim = 2048,
167
+ bottleneck_dim = 256)
168
+
169
+ loss = DINOLoss(out_dim = state_dict_args['out_dim'],
170
+ warmup_teacher_temp = state_dict_args['warmup_teacher_temp'],
171
+ teacher_temp = state_dict_args['teacher_temp'],
172
+ warmup_teacher_temp_epochs = state_dict_args['warmup_teacher_temp_epochs'],
173
+ nepochs = state_dict_args['epochs'])
174
+
175
+ student = MultiCropWrapper(student_backbone, student_head)
176
+ teacher = MultiCropWrapper(teacher_backbone, teacher_head)
177
+
178
+ student.load_state_dict(state_dict_student)
179
+ teacher.load_state_dict(state_dict_teacher)
180
+ loss.load_state_dict(state_dic_dino_loss)
181
+
182
+ restructured_model = RestructuredDINO(student, teacher)
183
+
184
+ return restructured_model.to(device), loss.to(device)
ssl_models/simclr2.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+
6
+ """
7
+ from https://github.com/Separius/SimCLRv2-Pytorch
8
+ """
9
+
10
+ BATCH_NORM_EPSILON = 1e-5
11
+ BATCH_NORM_DECAY = 0.9 # == pytorch's default value as well
12
+
13
+ class BatchNormRelu(nn.Sequential):
14
+
15
+ def __init__(self, num_channels, relu=True):
16
+ super().__init__(nn.BatchNorm2d(num_channels, eps=BATCH_NORM_EPSILON),
17
+ nn.ReLU() if relu else nn.Identity())
18
+
19
+
20
+ def conv(in_channels, out_channels, kernel_size=3, stride=1, bias=False):
21
+ return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
22
+ stride=stride, padding=(kernel_size - 1) // 2, bias=bias)
23
+
24
+
25
+ class SelectiveKernel(nn.Module):
26
+
27
+ def __init__(self, in_channels, out_channels, stride, sk_ratio, min_dim=32):
28
+ super().__init__()
29
+ assert sk_ratio > 0.0
30
+ self.main_conv = nn.Sequential(conv(in_channels, 2 * out_channels, stride=stride),
31
+ BatchNormRelu(2 * out_channels))
32
+ mid_dim = max(int(out_channels * sk_ratio), min_dim)
33
+ self.mixing_conv = nn.Sequential(conv(out_channels, mid_dim, kernel_size=1),
34
+ BatchNormRelu(mid_dim),
35
+ conv(mid_dim, 2 * out_channels, kernel_size=1))
36
+
37
+ def forward(self, x):
38
+ x = self.main_conv(x)
39
+ x = torch.stack(torch.chunk(x, 2, dim=1), dim=0) # 2, B, C, H, W
40
+ g = x.sum(dim=0).mean(dim=[2, 3], keepdim=True)
41
+ m = self.mixing_conv(g)
42
+ m = torch.stack(torch.chunk(m, 2, dim=1), dim=0) # 2, B, C, 1, 1
43
+ return (x * F.softmax(m, dim=0)).sum(dim=0)
44
+
45
+
46
+ class Projection(nn.Module):
47
+ def __init__(self, in_channels, out_channels, stride, sk_ratio=0):
48
+ super().__init__()
49
+ if sk_ratio > 0:
50
+ self.shortcut = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)),
51
+ nn.AvgPool2d(kernel_size=2, stride=stride, padding=0),
52
+ conv(in_channels, out_channels, kernel_size=1))
53
+ else:
54
+ self.shortcut = conv(in_channels, out_channels, kernel_size=1, stride=stride)
55
+
56
+ self.bn = BatchNormRelu(out_channels, relu=False)
57
+
58
+ def forward(self, x):
59
+ return self.bn(self.shortcut(x))
60
+
61
+
62
+ class BottleneckBlock(nn.Module):
63
+ expansion = 4
64
+
65
+ def __init__(self, in_channels, out_channels, stride, sk_ratio=0, use_projection=False):
66
+ super().__init__()
67
+ if use_projection:
68
+ self.projection = Projection(in_channels, out_channels * 4, stride, sk_ratio)
69
+ else:
70
+ self.projection = nn.Identity()
71
+
72
+ ops = [conv(in_channels, out_channels, kernel_size=1), BatchNormRelu(out_channels)]
73
+ if sk_ratio > 0:
74
+ ops.append(SelectiveKernel(out_channels, out_channels, stride, sk_ratio))
75
+ else:
76
+ ops.append(conv(out_channels, out_channels, stride=stride))
77
+ ops.append(BatchNormRelu(out_channels))
78
+
79
+ ops.append(conv(out_channels, out_channels * 4, kernel_size=1))
80
+ ops.append(BatchNormRelu(out_channels * 4, relu=False))
81
+ self.net = nn.Sequential(*ops)
82
+
83
+ def forward(self, x):
84
+ shortcut = self.projection(x)
85
+ return F.relu(shortcut + self.net(x))
86
+
87
+
88
+ class Blocks(nn.Module):
89
+ def __init__(self, num_blocks, in_channels, out_channels, stride, sk_ratio=0):
90
+ super().__init__()
91
+ self.blocks = nn.ModuleList([BottleneckBlock(in_channels, out_channels, stride, sk_ratio, True)])
92
+ self.channels_out = out_channels * BottleneckBlock.expansion
93
+ for _ in range(num_blocks - 1):
94
+ self.blocks.append(BottleneckBlock(self.channels_out, out_channels, 1, sk_ratio))
95
+
96
+ def forward(self, x):
97
+ for b in self.blocks:
98
+ x = b(x)
99
+ return x
100
+
101
+
102
+ class Stem(nn.Sequential):
103
+ def __init__(self, sk_ratio, width_multiplier):
104
+ ops = []
105
+ channels = 64 * width_multiplier // 2
106
+ if sk_ratio > 0:
107
+ ops.append(conv(3, channels, stride=2))
108
+ ops.append(BatchNormRelu(channels))
109
+ ops.append(conv(channels, channels))
110
+ ops.append(BatchNormRelu(channels))
111
+ ops.append(conv(channels, channels * 2))
112
+ else:
113
+ ops.append(conv(3, channels * 2, kernel_size=7, stride=2))
114
+ ops.append(BatchNormRelu(channels * 2))
115
+ ops.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
116
+ super().__init__(*ops)
117
+
118
+
119
+ class ResNet(nn.Module):
120
+ def __init__(self, layers, width_multiplier, sk_ratio):
121
+ super().__init__()
122
+ ops = [Stem(sk_ratio, width_multiplier)]
123
+ channels_in = 64 * width_multiplier
124
+ ops.append(Blocks(layers[0], channels_in, 64 * width_multiplier, 1, sk_ratio))
125
+ channels_in = ops[-1].channels_out
126
+ ops.append(Blocks(layers[1], channels_in, 128 * width_multiplier, 2, sk_ratio))
127
+ channels_in = ops[-1].channels_out
128
+ ops.append(Blocks(layers[2], channels_in, 256 * width_multiplier, 2, sk_ratio))
129
+ channels_in = ops[-1].channels_out
130
+ ops.append(Blocks(layers[3], channels_in, 512 * width_multiplier, 2, sk_ratio))
131
+ channels_in = ops[-1].channels_out
132
+ self.channels_out = channels_in
133
+ self.net = nn.Sequential(*ops)
134
+ self.fc = nn.Linear(channels_in, 1000)
135
+
136
+ def forward(self, x, apply_fc=False):
137
+ h = self.net(x).mean(dim=[2, 3])
138
+ if apply_fc:
139
+ h = self.fc(h)
140
+ return h
141
+
142
+
143
+ class ContrastiveHead(nn.Module):
144
+ def __init__(self, channels_in, out_dim=128, num_layers=3):
145
+ super().__init__()
146
+ self.layers = nn.ModuleList()
147
+ for i in range(num_layers):
148
+ if i != num_layers - 1:
149
+ dim, relu = channels_in, True
150
+ else:
151
+ dim, relu = out_dim, False
152
+ self.layers.append(nn.Linear(channels_in, dim, bias=False))
153
+ bn = nn.BatchNorm1d(dim, eps=BATCH_NORM_EPSILON, affine=True)
154
+ if i == num_layers - 1:
155
+ nn.init.zeros_(bn.bias)
156
+ self.layers.append(bn)
157
+ if relu:
158
+ self.layers.append(nn.ReLU())
159
+
160
+ def forward(self, x):
161
+ for b in self.layers:
162
+ x = b(x)
163
+ return x
164
+
165
+
166
+ def get_resnet(depth=50, width_multiplier=1, sk_ratio=0): # sk_ratio=0.0625 is recommended
167
+ layers = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}[depth]
168
+ resnet = ResNet(layers, width_multiplier, sk_ratio)
169
+ return resnet, ContrastiveHead(resnet.channels_out)
170
+
171
+
172
+ def name_to_params(checkpoint):
173
+ sk_ratio = 0.0625 if '_sk1' in checkpoint else 0
174
+ if 'r50_' in checkpoint:
175
+ depth = 50
176
+ elif 'r101_' in checkpoint:
177
+ depth = 101
178
+ elif 'r152_' in checkpoint:
179
+ depth = 152
180
+ else:
181
+ raise NotImplementedError
182
+
183
+ if '_1x_' in checkpoint:
184
+ width = 1
185
+ elif '_2x_' in checkpoint:
186
+ width = 2
187
+ elif '_3x_' in checkpoint:
188
+ width = 3
189
+ else:
190
+ raise NotImplementedError
191
+
192
+ return depth, width, sk_ratio
193
+
194
+ class SimCLRv2(nn.Module):
195
+ def __init__(self, model, head):
196
+ super(SimCLRv2, self).__init__()
197
+
198
+ self.encoder = model
199
+ self.contrastive_head = head
200
+
201
+ def forward(self, x):
202
+ x = self.encoder(x)
203
+ x = self.contrastive_head(x)
204
+ return x
205
+
206
+ def get_simclr2_model(ckpt_path):
207
+ depth, width, sk_ratio = name_to_params(ckpt_path)
208
+ model, head = get_resnet(depth, width, sk_ratio)
209
+ checkpoint = torch.load('pretrained_models/simclr2_models/' + ckpt_path)
210
+ model.load_state_dict(checkpoint['resnet'])
211
+ head.load_state_dict(checkpoint['head'])
212
+ del model.fc
213
+ simclr2 = SimCLRv2(model, head)
214
+ return simclr2.to(device)
ssl_models/simsiam.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+
6
+ """from https://github.com/facebookresearch/simsiam"""
7
+
8
+ class SimSiam(nn.Module):
9
+
10
+ def __init__(self, base_encoder, dim, pred_dim):
11
+ """
12
+ dim: feature dimension (default: 2048)
13
+ pred_dim: hidden dimension of the predictor (default: 512)
14
+ symetric is True only when training
15
+ """
16
+ super(SimSiam, self).__init__()
17
+
18
+ # create the encoder
19
+ # num_classes is the output fc dimension, zero-initialize last BNs
20
+ self.encoder = base_encoder(num_classes=dim, zero_init_residual=True)
21
+
22
+ # build a 3-layer projector
23
+ prev_dim = self.encoder.fc.weight.shape[1]
24
+ self.encoder.fc = nn.Sequential(nn.Linear(prev_dim, prev_dim, bias=False),
25
+ nn.BatchNorm1d(prev_dim),
26
+ nn.ReLU(inplace=True), # first layer
27
+ nn.Linear(prev_dim, prev_dim, bias=False),
28
+ nn.BatchNorm1d(prev_dim),
29
+ nn.ReLU(inplace=True), # second layer
30
+ self.encoder.fc,
31
+ nn.BatchNorm1d(dim, affine=False)) # output layer
32
+ self.encoder.fc[6].bias.requires_grad = False # hack: not use bias as it is followed by BN
33
+
34
+ # build a 2-layer predictor
35
+ self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False),
36
+ nn.BatchNorm1d(pred_dim),
37
+ nn.ReLU(inplace=True), # hidden layer
38
+ nn.Linear(pred_dim, dim)) # output layer
39
+
40
+ def forward(self, x1, x2):
41
+ z1 = self.encoder(x1).detach() # NxC
42
+ z2 = self.encoder(x2).detach() # NxC
43
+
44
+ p1 = self.predictor(z1) # NxC
45
+ p2 = self.predictor(z2) # NxC
46
+
47
+ loss = -(nn.CosineSimilarity(dim=1)(p1, z2).mean() + nn.CosineSimilarity(dim=1)(p2, z1).mean()) * 0.5
48
+
49
+ return loss
50
+
51
+ class ResNet(nn.Module):
52
+ def __init__(self, backbone):
53
+ super().__init__()
54
+
55
+ modules = list(backbone.children())[:-2]
56
+ self.net = nn.Sequential(*modules)
57
+
58
+ def forward(self, x):
59
+ return self.net(x).mean(dim=[2, 3])
60
+
61
+ class RestructuredSimSiam(nn.Module):
62
+ def __init__(self, model):
63
+ super().__init__()
64
+
65
+ self.encoder = ResNet(model.encoder)
66
+ self.mlp_encoder = model.encoder.fc
67
+ self.mlp_encoder[6].bias.requires_grad = False
68
+ self.contrastive_head = model.predictor
69
+
70
+ def forward(self, x, run_head = True):
71
+
72
+ x = self.mlp_encoder(self.encoder(x)) # don't detach since we will do backprop for explainability
73
+
74
+ if run_head:
75
+ x = self.contrastive_head(x)
76
+
77
+ return x
78
+
79
+
80
+ def get_simsiam(ckpt_path = 'checkpoint_0099.pth.tar'):
81
+
82
+ model = SimSiam(base_encoder = torchvision.models.resnet50,
83
+ dim = 2048,
84
+ pred_dim = 512)
85
+
86
+ checkpoint = torch.load('pretrained_models/simsiam_models/'+ ckpt_path, map_location='cpu')
87
+ state_dic = checkpoint['state_dict']
88
+ state_dic = {k.replace("module.", ""): v for k, v in state_dic.items()}
89
+ model.load_state_dict(state_dic)
90
+ restructured_model = RestructuredSimSiam(model)
91
+ return restructured_model.to(device)