import torch import torch.nn as nn import torchvision device = torch.device("cuda" if torch.cuda.is_available() else "cpu") """from https://github.com/facebookresearch/barlowtwins""" def off_diagonal(x): # return a flattened view of the off-diagonal elements of a square matrix n, m = x.shape assert n == m return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() class BarlowTwins(nn.Module): def __init__(self): super().__init__() self.backbone = torchvision.models.resnet50(zero_init_residual=True) self.backbone.fc = nn.Identity() # projector sizes = [2048] + list(map(int, '8192-8192-8192'.split('-'))) layers = [] for i in range(len(sizes) - 2): layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False)) layers.append(nn.BatchNorm1d(sizes[i + 1])) layers.append(nn.ReLU(inplace=True)) layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False)) self.projector = nn.Sequential(*layers) # normalization layer for the representations z1 and z2 self.bn = nn.BatchNorm1d(sizes[-1], affine=False) def forward(self, y1, y2): z1 = self.projector(self.backbone(y1)) z2 = self.projector(self.backbone(y2)) # empirical cross-correlation matrix c = self.bn(z1).T @ self.bn(z2) on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() off_diag = off_diagonal(c).pow_(2).sum() loss = on_diag + 0.0051 * off_diag return loss class ResNet(nn.Module): def __init__(self, backbone): super().__init__() modules = list(backbone.children())[:-2] self.net = nn.Sequential(*modules) def forward(self, x): return self.net(x).mean(dim=[2, 3]) class RestructuredBarlowTwins(nn.Module): def __init__(self, model): super().__init__() self.encoder = ResNet(model.backbone) self.contrastive_head = model.projector def forward(self, x): x = self.encoder(x) x = self.contrastive_head(x) return x def get_barlow_twins_model(ckpt_path = 'barlow_twins.pth'): model = BarlowTwins() state_dict = torch.load('pretrained_models/barlow_models/' + ckpt_path, map_location='cpu') state_dict = state_dict['model'] state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} model.load_state_dict(state_dict) restructured_model = RestructuredBarlowTwins(model) return restructured_model.to(device)