Spaces:
Sleeping
Sleeping
Annonymous
commited on
Commit
•
b157c29
1
Parent(s):
2c29c28
Upload 4 files
Browse files- ssl_models/barlow_twins.py +77 -0
- ssl_models/dino.py +184 -0
- ssl_models/simclr2.py +214 -0
- ssl_models/simsiam.py +91 -0
ssl_models/barlow_twins.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision
|
4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
+
|
6 |
+
"""from https://github.com/facebookresearch/barlowtwins"""
|
7 |
+
|
8 |
+
def off_diagonal(x):
|
9 |
+
# return a flattened view of the off-diagonal elements of a square matrix
|
10 |
+
n, m = x.shape
|
11 |
+
assert n == m
|
12 |
+
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
|
13 |
+
|
14 |
+
class BarlowTwins(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
self.backbone = torchvision.models.resnet50(zero_init_residual=True)
|
19 |
+
self.backbone.fc = nn.Identity()
|
20 |
+
|
21 |
+
# projector
|
22 |
+
sizes = [2048] + list(map(int, '8192-8192-8192'.split('-')))
|
23 |
+
layers = []
|
24 |
+
for i in range(len(sizes) - 2):
|
25 |
+
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
|
26 |
+
layers.append(nn.BatchNorm1d(sizes[i + 1]))
|
27 |
+
layers.append(nn.ReLU(inplace=True))
|
28 |
+
|
29 |
+
layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
|
30 |
+
self.projector = nn.Sequential(*layers)
|
31 |
+
|
32 |
+
# normalization layer for the representations z1 and z2
|
33 |
+
self.bn = nn.BatchNorm1d(sizes[-1], affine=False)
|
34 |
+
|
35 |
+
def forward(self, y1, y2):
|
36 |
+
z1 = self.projector(self.backbone(y1))
|
37 |
+
z2 = self.projector(self.backbone(y2))
|
38 |
+
|
39 |
+
# empirical cross-correlation matrix
|
40 |
+
c = self.bn(z1).T @ self.bn(z2)
|
41 |
+
|
42 |
+
on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
|
43 |
+
off_diag = off_diagonal(c).pow_(2).sum()
|
44 |
+
loss = on_diag + 0.0051 * off_diag
|
45 |
+
return loss
|
46 |
+
|
47 |
+
class ResNet(nn.Module):
|
48 |
+
def __init__(self, backbone):
|
49 |
+
super().__init__()
|
50 |
+
|
51 |
+
modules = list(backbone.children())[:-2]
|
52 |
+
self.net = nn.Sequential(*modules)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
return self.net(x).mean(dim=[2, 3])
|
56 |
+
|
57 |
+
class RestructuredBarlowTwins(nn.Module):
|
58 |
+
def __init__(self, model):
|
59 |
+
super().__init__()
|
60 |
+
|
61 |
+
self.encoder = ResNet(model.backbone)
|
62 |
+
self.contrastive_head = model.projector
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
x = self.encoder(x)
|
66 |
+
x = self.contrastive_head(x)
|
67 |
+
return x
|
68 |
+
|
69 |
+
|
70 |
+
def get_barlow_twins_model(ckpt_path = 'barlow_twins.pth'):
|
71 |
+
model = BarlowTwins()
|
72 |
+
state_dict = torch.load('pretrained_models/barlow_models/' + ckpt_path, map_location='cpu')
|
73 |
+
state_dict = state_dict['model']
|
74 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
75 |
+
model.load_state_dict(state_dict)
|
76 |
+
restructured_model = RestructuredBarlowTwins(model)
|
77 |
+
return restructured_model.to(device)
|
ssl_models/dino.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
import pathlib
|
7 |
+
temp = pathlib.PosixPath
|
8 |
+
pathlib.PosixPath = pathlib.WindowsPath
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
+
|
11 |
+
""" from https://github.com/facebookresearch/dino"""
|
12 |
+
|
13 |
+
class DINOHead(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self, in_dim, out_dim, use_bn, norm_last_layer, nlayers, hidden_dim, bottleneck_dim):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
nlayers = max(nlayers, 1)
|
19 |
+
if nlayers == 1:
|
20 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
21 |
+
else:
|
22 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
23 |
+
if use_bn:
|
24 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
25 |
+
layers.append(nn.GELU())
|
26 |
+
for _ in range(nlayers - 2):
|
27 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
28 |
+
if use_bn:
|
29 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
30 |
+
layers.append(nn.GELU())
|
31 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
32 |
+
self.mlp = nn.Sequential(*layers)
|
33 |
+
|
34 |
+
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
35 |
+
self.last_layer.weight_g.data.fill_(1)
|
36 |
+
if norm_last_layer:
|
37 |
+
self.last_layer.weight_g.requires_grad = False
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x = self.mlp(x)
|
41 |
+
x = F.normalize(x, dim=-1, p=2)
|
42 |
+
x = self.last_layer(x)
|
43 |
+
return x
|
44 |
+
|
45 |
+
class MultiCropWrapper(nn.Module):
|
46 |
+
def __init__(self, backbone, head):
|
47 |
+
super(MultiCropWrapper, self).__init__()
|
48 |
+
backbone.fc, backbone.head = nn.Identity(), nn.Identity()
|
49 |
+
self.backbone = backbone
|
50 |
+
self.head = head
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
return self.head(self.backbone(x))
|
54 |
+
|
55 |
+
class DINOLoss(nn.Module):
|
56 |
+
def __init__(self, out_dim, warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs, nepochs,
|
57 |
+
student_temp=0.1, center_momentum=0.9):
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
self.student_temp = student_temp
|
61 |
+
self.center_momentum = center_momentum
|
62 |
+
self.register_buffer("center", torch.zeros(1, out_dim))
|
63 |
+
self.nepochs = nepochs
|
64 |
+
self.teacher_temp_schedule = np.concatenate((np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
|
65 |
+
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp))
|
66 |
+
|
67 |
+
def forward(self, student_output, teacher_output):
|
68 |
+
student_out = student_output / self.student_temp
|
69 |
+
temp = self.teacher_temp_schedule[self.nepochs - 1] # last one
|
70 |
+
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
|
71 |
+
teacher_out = teacher_out.detach()
|
72 |
+
loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1).mean()
|
73 |
+
return loss
|
74 |
+
|
75 |
+
|
76 |
+
class ResNet(nn.Module):
|
77 |
+
def __init__(self, backbone):
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
modules = list(backbone.children())[:-2]
|
81 |
+
self.net = nn.Sequential(*modules)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
return self.net(x).mean(dim=[2, 3])
|
85 |
+
|
86 |
+
class RestructuredDINO(nn.Module):
|
87 |
+
|
88 |
+
def __init__(self, student, teacher):
|
89 |
+
super().__init__()
|
90 |
+
|
91 |
+
self.encoder_student = ResNet(student.backbone)
|
92 |
+
self.encoder = ResNet(teacher.backbone)
|
93 |
+
|
94 |
+
self.contrastive_head_student = student.head
|
95 |
+
self.contrastive_head = teacher.head
|
96 |
+
|
97 |
+
|
98 |
+
def forward(self, x, run_teacher):
|
99 |
+
|
100 |
+
if run_teacher:
|
101 |
+
x = self.encoder(x)
|
102 |
+
x = self.contrastive_head(x)
|
103 |
+
else:
|
104 |
+
x = self.encoder_student(x)
|
105 |
+
x = self.contrastive_head_student(x)
|
106 |
+
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
def get_dino_model_without_loss(ckpt_path = 'dino_resnet50_pretrain_full_checkpoint.pth'):
|
111 |
+
state_dict = torch.load('pretrained_models/dino_models/' + ckpt_path, map_location='cpu')
|
112 |
+
state_dict_student = state_dict['student']
|
113 |
+
state_dict_teacher = state_dict['teacher']
|
114 |
+
|
115 |
+
state_dict_student = {k.replace("module.", ""): v for k, v in state_dict_student.items()}
|
116 |
+
state_dict_teacher = {k.replace("module.", ""): v for k, v in state_dict_teacher.items()}
|
117 |
+
|
118 |
+
student_backbone = torchvision.models.resnet50()
|
119 |
+
teacher_backbone = torchvision.models.resnet50()
|
120 |
+
embed_dim = student_backbone.fc.weight.shape[1]
|
121 |
+
|
122 |
+
student_head = DINOHead(in_dim = embed_dim, out_dim = 60000, use_bn=True, norm_last_layer=True, nlayers=2, hidden_dim=4096, bottleneck_dim=256)
|
123 |
+
teacher_head = DINOHead(in_dim = embed_dim, out_dim = 60000, use_bn =True, norm_last_layer=True, nlayers=2, hidden_dim=4096, bottleneck_dim=256)
|
124 |
+
student_head.last_layer = nn.Linear(256, 60000, bias = False)
|
125 |
+
teacher_head.last_layer = nn.Linear(256, 60000, bias = False)
|
126 |
+
|
127 |
+
student = MultiCropWrapper(student_backbone, student_head)
|
128 |
+
teacher = MultiCropWrapper(teacher_backbone, teacher_head)
|
129 |
+
|
130 |
+
student.load_state_dict(state_dict_student)
|
131 |
+
teacher.load_state_dict(state_dict_teacher)
|
132 |
+
|
133 |
+
restructured_model = RestructuredDINO(student, teacher)
|
134 |
+
|
135 |
+
return restructured_model.to(device)
|
136 |
+
|
137 |
+
|
138 |
+
def get_dino_model_with_loss(ckpt_path = 'dino_rn50_checkpoint.pth'):
|
139 |
+
state_dict = torch.load('pretrained_models/dino_models/' + ckpt_path, map_location='cpu')
|
140 |
+
|
141 |
+
state_dict_student = state_dict['student']
|
142 |
+
state_dict_teacher = state_dict['teacher']
|
143 |
+
state_dict_args = vars(state_dict['args'])
|
144 |
+
state_dic_dino_loss = state_dict['dino_loss']
|
145 |
+
|
146 |
+
state_dict_student = {k.replace("module.", ""): v for k, v in state_dict_student.items()}
|
147 |
+
state_dict_teacher = {k.replace("module.", ""): v for k, v in state_dict_teacher.items()}
|
148 |
+
|
149 |
+
student_backbone = torchvision.models.resnet50()
|
150 |
+
teacher_backbone = torchvision.models.resnet50()
|
151 |
+
embed_dim = student_backbone.fc.weight.shape[1]
|
152 |
+
|
153 |
+
student_head = DINOHead(in_dim = embed_dim,
|
154 |
+
out_dim = state_dict_args['out_dim'],
|
155 |
+
use_bn = state_dict_args['use_bn_in_head'],
|
156 |
+
norm_last_layer = state_dict_args['norm_last_layer'],
|
157 |
+
nlayers = 3,
|
158 |
+
hidden_dim = 2048,
|
159 |
+
bottleneck_dim = 256)
|
160 |
+
|
161 |
+
teacher_head = DINOHead(in_dim = embed_dim,
|
162 |
+
out_dim = state_dict_args['out_dim'],
|
163 |
+
use_bn = state_dict_args['use_bn_in_head'],
|
164 |
+
norm_last_layer = state_dict_args['norm_last_layer'],
|
165 |
+
nlayers = 3,
|
166 |
+
hidden_dim = 2048,
|
167 |
+
bottleneck_dim = 256)
|
168 |
+
|
169 |
+
loss = DINOLoss(out_dim = state_dict_args['out_dim'],
|
170 |
+
warmup_teacher_temp = state_dict_args['warmup_teacher_temp'],
|
171 |
+
teacher_temp = state_dict_args['teacher_temp'],
|
172 |
+
warmup_teacher_temp_epochs = state_dict_args['warmup_teacher_temp_epochs'],
|
173 |
+
nepochs = state_dict_args['epochs'])
|
174 |
+
|
175 |
+
student = MultiCropWrapper(student_backbone, student_head)
|
176 |
+
teacher = MultiCropWrapper(teacher_backbone, teacher_head)
|
177 |
+
|
178 |
+
student.load_state_dict(state_dict_student)
|
179 |
+
teacher.load_state_dict(state_dict_teacher)
|
180 |
+
loss.load_state_dict(state_dic_dino_loss)
|
181 |
+
|
182 |
+
restructured_model = RestructuredDINO(student, teacher)
|
183 |
+
|
184 |
+
return restructured_model.to(device), loss.to(device)
|
ssl_models/simclr2.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
+
|
6 |
+
"""
|
7 |
+
from https://github.com/Separius/SimCLRv2-Pytorch
|
8 |
+
"""
|
9 |
+
|
10 |
+
BATCH_NORM_EPSILON = 1e-5
|
11 |
+
BATCH_NORM_DECAY = 0.9 # == pytorch's default value as well
|
12 |
+
|
13 |
+
class BatchNormRelu(nn.Sequential):
|
14 |
+
|
15 |
+
def __init__(self, num_channels, relu=True):
|
16 |
+
super().__init__(nn.BatchNorm2d(num_channels, eps=BATCH_NORM_EPSILON),
|
17 |
+
nn.ReLU() if relu else nn.Identity())
|
18 |
+
|
19 |
+
|
20 |
+
def conv(in_channels, out_channels, kernel_size=3, stride=1, bias=False):
|
21 |
+
return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
|
22 |
+
stride=stride, padding=(kernel_size - 1) // 2, bias=bias)
|
23 |
+
|
24 |
+
|
25 |
+
class SelectiveKernel(nn.Module):
|
26 |
+
|
27 |
+
def __init__(self, in_channels, out_channels, stride, sk_ratio, min_dim=32):
|
28 |
+
super().__init__()
|
29 |
+
assert sk_ratio > 0.0
|
30 |
+
self.main_conv = nn.Sequential(conv(in_channels, 2 * out_channels, stride=stride),
|
31 |
+
BatchNormRelu(2 * out_channels))
|
32 |
+
mid_dim = max(int(out_channels * sk_ratio), min_dim)
|
33 |
+
self.mixing_conv = nn.Sequential(conv(out_channels, mid_dim, kernel_size=1),
|
34 |
+
BatchNormRelu(mid_dim),
|
35 |
+
conv(mid_dim, 2 * out_channels, kernel_size=1))
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
x = self.main_conv(x)
|
39 |
+
x = torch.stack(torch.chunk(x, 2, dim=1), dim=0) # 2, B, C, H, W
|
40 |
+
g = x.sum(dim=0).mean(dim=[2, 3], keepdim=True)
|
41 |
+
m = self.mixing_conv(g)
|
42 |
+
m = torch.stack(torch.chunk(m, 2, dim=1), dim=0) # 2, B, C, 1, 1
|
43 |
+
return (x * F.softmax(m, dim=0)).sum(dim=0)
|
44 |
+
|
45 |
+
|
46 |
+
class Projection(nn.Module):
|
47 |
+
def __init__(self, in_channels, out_channels, stride, sk_ratio=0):
|
48 |
+
super().__init__()
|
49 |
+
if sk_ratio > 0:
|
50 |
+
self.shortcut = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)),
|
51 |
+
nn.AvgPool2d(kernel_size=2, stride=stride, padding=0),
|
52 |
+
conv(in_channels, out_channels, kernel_size=1))
|
53 |
+
else:
|
54 |
+
self.shortcut = conv(in_channels, out_channels, kernel_size=1, stride=stride)
|
55 |
+
|
56 |
+
self.bn = BatchNormRelu(out_channels, relu=False)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
return self.bn(self.shortcut(x))
|
60 |
+
|
61 |
+
|
62 |
+
class BottleneckBlock(nn.Module):
|
63 |
+
expansion = 4
|
64 |
+
|
65 |
+
def __init__(self, in_channels, out_channels, stride, sk_ratio=0, use_projection=False):
|
66 |
+
super().__init__()
|
67 |
+
if use_projection:
|
68 |
+
self.projection = Projection(in_channels, out_channels * 4, stride, sk_ratio)
|
69 |
+
else:
|
70 |
+
self.projection = nn.Identity()
|
71 |
+
|
72 |
+
ops = [conv(in_channels, out_channels, kernel_size=1), BatchNormRelu(out_channels)]
|
73 |
+
if sk_ratio > 0:
|
74 |
+
ops.append(SelectiveKernel(out_channels, out_channels, stride, sk_ratio))
|
75 |
+
else:
|
76 |
+
ops.append(conv(out_channels, out_channels, stride=stride))
|
77 |
+
ops.append(BatchNormRelu(out_channels))
|
78 |
+
|
79 |
+
ops.append(conv(out_channels, out_channels * 4, kernel_size=1))
|
80 |
+
ops.append(BatchNormRelu(out_channels * 4, relu=False))
|
81 |
+
self.net = nn.Sequential(*ops)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
shortcut = self.projection(x)
|
85 |
+
return F.relu(shortcut + self.net(x))
|
86 |
+
|
87 |
+
|
88 |
+
class Blocks(nn.Module):
|
89 |
+
def __init__(self, num_blocks, in_channels, out_channels, stride, sk_ratio=0):
|
90 |
+
super().__init__()
|
91 |
+
self.blocks = nn.ModuleList([BottleneckBlock(in_channels, out_channels, stride, sk_ratio, True)])
|
92 |
+
self.channels_out = out_channels * BottleneckBlock.expansion
|
93 |
+
for _ in range(num_blocks - 1):
|
94 |
+
self.blocks.append(BottleneckBlock(self.channels_out, out_channels, 1, sk_ratio))
|
95 |
+
|
96 |
+
def forward(self, x):
|
97 |
+
for b in self.blocks:
|
98 |
+
x = b(x)
|
99 |
+
return x
|
100 |
+
|
101 |
+
|
102 |
+
class Stem(nn.Sequential):
|
103 |
+
def __init__(self, sk_ratio, width_multiplier):
|
104 |
+
ops = []
|
105 |
+
channels = 64 * width_multiplier // 2
|
106 |
+
if sk_ratio > 0:
|
107 |
+
ops.append(conv(3, channels, stride=2))
|
108 |
+
ops.append(BatchNormRelu(channels))
|
109 |
+
ops.append(conv(channels, channels))
|
110 |
+
ops.append(BatchNormRelu(channels))
|
111 |
+
ops.append(conv(channels, channels * 2))
|
112 |
+
else:
|
113 |
+
ops.append(conv(3, channels * 2, kernel_size=7, stride=2))
|
114 |
+
ops.append(BatchNormRelu(channels * 2))
|
115 |
+
ops.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
116 |
+
super().__init__(*ops)
|
117 |
+
|
118 |
+
|
119 |
+
class ResNet(nn.Module):
|
120 |
+
def __init__(self, layers, width_multiplier, sk_ratio):
|
121 |
+
super().__init__()
|
122 |
+
ops = [Stem(sk_ratio, width_multiplier)]
|
123 |
+
channels_in = 64 * width_multiplier
|
124 |
+
ops.append(Blocks(layers[0], channels_in, 64 * width_multiplier, 1, sk_ratio))
|
125 |
+
channels_in = ops[-1].channels_out
|
126 |
+
ops.append(Blocks(layers[1], channels_in, 128 * width_multiplier, 2, sk_ratio))
|
127 |
+
channels_in = ops[-1].channels_out
|
128 |
+
ops.append(Blocks(layers[2], channels_in, 256 * width_multiplier, 2, sk_ratio))
|
129 |
+
channels_in = ops[-1].channels_out
|
130 |
+
ops.append(Blocks(layers[3], channels_in, 512 * width_multiplier, 2, sk_ratio))
|
131 |
+
channels_in = ops[-1].channels_out
|
132 |
+
self.channels_out = channels_in
|
133 |
+
self.net = nn.Sequential(*ops)
|
134 |
+
self.fc = nn.Linear(channels_in, 1000)
|
135 |
+
|
136 |
+
def forward(self, x, apply_fc=False):
|
137 |
+
h = self.net(x).mean(dim=[2, 3])
|
138 |
+
if apply_fc:
|
139 |
+
h = self.fc(h)
|
140 |
+
return h
|
141 |
+
|
142 |
+
|
143 |
+
class ContrastiveHead(nn.Module):
|
144 |
+
def __init__(self, channels_in, out_dim=128, num_layers=3):
|
145 |
+
super().__init__()
|
146 |
+
self.layers = nn.ModuleList()
|
147 |
+
for i in range(num_layers):
|
148 |
+
if i != num_layers - 1:
|
149 |
+
dim, relu = channels_in, True
|
150 |
+
else:
|
151 |
+
dim, relu = out_dim, False
|
152 |
+
self.layers.append(nn.Linear(channels_in, dim, bias=False))
|
153 |
+
bn = nn.BatchNorm1d(dim, eps=BATCH_NORM_EPSILON, affine=True)
|
154 |
+
if i == num_layers - 1:
|
155 |
+
nn.init.zeros_(bn.bias)
|
156 |
+
self.layers.append(bn)
|
157 |
+
if relu:
|
158 |
+
self.layers.append(nn.ReLU())
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
for b in self.layers:
|
162 |
+
x = b(x)
|
163 |
+
return x
|
164 |
+
|
165 |
+
|
166 |
+
def get_resnet(depth=50, width_multiplier=1, sk_ratio=0): # sk_ratio=0.0625 is recommended
|
167 |
+
layers = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}[depth]
|
168 |
+
resnet = ResNet(layers, width_multiplier, sk_ratio)
|
169 |
+
return resnet, ContrastiveHead(resnet.channels_out)
|
170 |
+
|
171 |
+
|
172 |
+
def name_to_params(checkpoint):
|
173 |
+
sk_ratio = 0.0625 if '_sk1' in checkpoint else 0
|
174 |
+
if 'r50_' in checkpoint:
|
175 |
+
depth = 50
|
176 |
+
elif 'r101_' in checkpoint:
|
177 |
+
depth = 101
|
178 |
+
elif 'r152_' in checkpoint:
|
179 |
+
depth = 152
|
180 |
+
else:
|
181 |
+
raise NotImplementedError
|
182 |
+
|
183 |
+
if '_1x_' in checkpoint:
|
184 |
+
width = 1
|
185 |
+
elif '_2x_' in checkpoint:
|
186 |
+
width = 2
|
187 |
+
elif '_3x_' in checkpoint:
|
188 |
+
width = 3
|
189 |
+
else:
|
190 |
+
raise NotImplementedError
|
191 |
+
|
192 |
+
return depth, width, sk_ratio
|
193 |
+
|
194 |
+
class SimCLRv2(nn.Module):
|
195 |
+
def __init__(self, model, head):
|
196 |
+
super(SimCLRv2, self).__init__()
|
197 |
+
|
198 |
+
self.encoder = model
|
199 |
+
self.contrastive_head = head
|
200 |
+
|
201 |
+
def forward(self, x):
|
202 |
+
x = self.encoder(x)
|
203 |
+
x = self.contrastive_head(x)
|
204 |
+
return x
|
205 |
+
|
206 |
+
def get_simclr2_model(ckpt_path):
|
207 |
+
depth, width, sk_ratio = name_to_params(ckpt_path)
|
208 |
+
model, head = get_resnet(depth, width, sk_ratio)
|
209 |
+
checkpoint = torch.load('pretrained_models/simclr2_models/' + ckpt_path)
|
210 |
+
model.load_state_dict(checkpoint['resnet'])
|
211 |
+
head.load_state_dict(checkpoint['head'])
|
212 |
+
del model.fc
|
213 |
+
simclr2 = SimCLRv2(model, head)
|
214 |
+
return simclr2.to(device)
|
ssl_models/simsiam.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision
|
4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
+
|
6 |
+
"""from https://github.com/facebookresearch/simsiam"""
|
7 |
+
|
8 |
+
class SimSiam(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, base_encoder, dim, pred_dim):
|
11 |
+
"""
|
12 |
+
dim: feature dimension (default: 2048)
|
13 |
+
pred_dim: hidden dimension of the predictor (default: 512)
|
14 |
+
symetric is True only when training
|
15 |
+
"""
|
16 |
+
super(SimSiam, self).__init__()
|
17 |
+
|
18 |
+
# create the encoder
|
19 |
+
# num_classes is the output fc dimension, zero-initialize last BNs
|
20 |
+
self.encoder = base_encoder(num_classes=dim, zero_init_residual=True)
|
21 |
+
|
22 |
+
# build a 3-layer projector
|
23 |
+
prev_dim = self.encoder.fc.weight.shape[1]
|
24 |
+
self.encoder.fc = nn.Sequential(nn.Linear(prev_dim, prev_dim, bias=False),
|
25 |
+
nn.BatchNorm1d(prev_dim),
|
26 |
+
nn.ReLU(inplace=True), # first layer
|
27 |
+
nn.Linear(prev_dim, prev_dim, bias=False),
|
28 |
+
nn.BatchNorm1d(prev_dim),
|
29 |
+
nn.ReLU(inplace=True), # second layer
|
30 |
+
self.encoder.fc,
|
31 |
+
nn.BatchNorm1d(dim, affine=False)) # output layer
|
32 |
+
self.encoder.fc[6].bias.requires_grad = False # hack: not use bias as it is followed by BN
|
33 |
+
|
34 |
+
# build a 2-layer predictor
|
35 |
+
self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False),
|
36 |
+
nn.BatchNorm1d(pred_dim),
|
37 |
+
nn.ReLU(inplace=True), # hidden layer
|
38 |
+
nn.Linear(pred_dim, dim)) # output layer
|
39 |
+
|
40 |
+
def forward(self, x1, x2):
|
41 |
+
z1 = self.encoder(x1).detach() # NxC
|
42 |
+
z2 = self.encoder(x2).detach() # NxC
|
43 |
+
|
44 |
+
p1 = self.predictor(z1) # NxC
|
45 |
+
p2 = self.predictor(z2) # NxC
|
46 |
+
|
47 |
+
loss = -(nn.CosineSimilarity(dim=1)(p1, z2).mean() + nn.CosineSimilarity(dim=1)(p2, z1).mean()) * 0.5
|
48 |
+
|
49 |
+
return loss
|
50 |
+
|
51 |
+
class ResNet(nn.Module):
|
52 |
+
def __init__(self, backbone):
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
modules = list(backbone.children())[:-2]
|
56 |
+
self.net = nn.Sequential(*modules)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
return self.net(x).mean(dim=[2, 3])
|
60 |
+
|
61 |
+
class RestructuredSimSiam(nn.Module):
|
62 |
+
def __init__(self, model):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.encoder = ResNet(model.encoder)
|
66 |
+
self.mlp_encoder = model.encoder.fc
|
67 |
+
self.mlp_encoder[6].bias.requires_grad = False
|
68 |
+
self.contrastive_head = model.predictor
|
69 |
+
|
70 |
+
def forward(self, x, run_head = True):
|
71 |
+
|
72 |
+
x = self.mlp_encoder(self.encoder(x)) # don't detach since we will do backprop for explainability
|
73 |
+
|
74 |
+
if run_head:
|
75 |
+
x = self.contrastive_head(x)
|
76 |
+
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
def get_simsiam(ckpt_path = 'checkpoint_0099.pth.tar'):
|
81 |
+
|
82 |
+
model = SimSiam(base_encoder = torchvision.models.resnet50,
|
83 |
+
dim = 2048,
|
84 |
+
pred_dim = 512)
|
85 |
+
|
86 |
+
checkpoint = torch.load('pretrained_models/simsiam_models/'+ ckpt_path, map_location='cpu')
|
87 |
+
state_dic = checkpoint['state_dict']
|
88 |
+
state_dic = {k.replace("module.", ""): v for k, v in state_dic.items()}
|
89 |
+
model.load_state_dict(state_dic)
|
90 |
+
restructured_model = RestructuredSimSiam(model)
|
91 |
+
return restructured_model.to(device)
|