import torch.nn as nn import torch.nn.functional as F import torchvision.models as models class ResNetSimCLR(nn.Module): def __init__(self, base_model, out_dim): super(ResNetSimCLR, self).__init__() self.resnet_dict = {"resnet18": models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d), "resnet50": models.resnet50(pretrained=False)} resnet = self._get_basemodel(base_model) num_ftrs = resnet.fc.in_features self.features = nn.Sequential(*list(resnet.children())[:-1]) # projection MLP self.l1 = nn.Linear(num_ftrs, num_ftrs) self.l2 = nn.Linear(num_ftrs, out_dim) def _get_basemodel(self, model_name): try: model = self.resnet_dict[model_name] print("Feature extractor:", model_name) return model except: raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") def forward(self, x): h = self.features(x) h = h.squeeze() x = self.l1(h) x = F.relu(x) x = self.l2(x) return h, x