Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torchvision | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
"""from https://github.com/facebookresearch/simsiam""" | |
class SimSiam(nn.Module): | |
def __init__(self, base_encoder, dim, pred_dim): | |
""" | |
dim: feature dimension (default: 2048) | |
pred_dim: hidden dimension of the predictor (default: 512) | |
symetric is True only when training | |
""" | |
super(SimSiam, self).__init__() | |
# create the encoder | |
# num_classes is the output fc dimension, zero-initialize last BNs | |
self.encoder = base_encoder(num_classes=dim, zero_init_residual=True) | |
# build a 3-layer projector | |
prev_dim = self.encoder.fc.weight.shape[1] | |
self.encoder.fc = nn.Sequential(nn.Linear(prev_dim, prev_dim, bias=False), | |
nn.BatchNorm1d(prev_dim), | |
nn.ReLU(inplace=True), # first layer | |
nn.Linear(prev_dim, prev_dim, bias=False), | |
nn.BatchNorm1d(prev_dim), | |
nn.ReLU(inplace=True), # second layer | |
self.encoder.fc, | |
nn.BatchNorm1d(dim, affine=False)) # output layer | |
self.encoder.fc[6].bias.requires_grad = False # hack: not use bias as it is followed by BN | |
# build a 2-layer predictor | |
self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False), | |
nn.BatchNorm1d(pred_dim), | |
nn.ReLU(inplace=True), # hidden layer | |
nn.Linear(pred_dim, dim)) # output layer | |
def forward(self, x1, x2): | |
z1 = self.encoder(x1).detach() # NxC | |
z2 = self.encoder(x2).detach() # NxC | |
p1 = self.predictor(z1) # NxC | |
p2 = self.predictor(z2) # NxC | |
loss = -(nn.CosineSimilarity(dim=1)(p1, z2).mean() + nn.CosineSimilarity(dim=1)(p2, z1).mean()) * 0.5 | |
return loss | |
class ResNet(nn.Module): | |
def __init__(self, backbone): | |
super().__init__() | |
modules = list(backbone.children())[:-2] | |
self.net = nn.Sequential(*modules) | |
def forward(self, x): | |
return self.net(x).mean(dim=[2, 3]) | |
class RestructuredSimSiam(nn.Module): | |
def __init__(self, model): | |
super().__init__() | |
self.encoder = ResNet(model.encoder) | |
self.mlp_encoder = model.encoder.fc | |
self.mlp_encoder[6].bias.requires_grad = False | |
self.contrastive_head = model.predictor | |
def forward(self, x, run_head = True): | |
x = self.mlp_encoder(self.encoder(x)) # don't detach since we will do backprop for explainability | |
if run_head: | |
x = self.contrastive_head(x) | |
return x | |
def get_simsiam(ckpt_path = 'checkpoint_0099.pth.tar'): | |
model = SimSiam(base_encoder = torchvision.models.resnet50, | |
dim = 2048, | |
pred_dim = 512) | |
checkpoint = torch.load('pretrained_models/simsiam_models/'+ ckpt_path, map_location='cpu') | |
state_dic = checkpoint['state_dict'] | |
state_dic = {k.replace("module.", ""): v for k, v in state_dic.items()} | |
model.load_state_dict(state_dic) | |
restructured_model = RestructuredSimSiam(model) | |
return restructured_model.to(device) | |