Spaces:
Sleeping
Sleeping
| # Importing Libaries | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| def get_network(name, pretrained=False): | |
| network = { | |
| "VGG16": torchvision.models.vgg16(pretrained=pretrained), | |
| "VGG16_bn": torchvision.models.vgg16_bn(pretrained=pretrained), | |
| "resnet18": torchvision.models.resnet18(pretrained=pretrained), | |
| "resnet34": torchvision.models.resnet34(pretrained=pretrained), | |
| "resnet50": torchvision.models.resnet50(pretrained=pretrained), | |
| "resnet101": torchvision.models.resnet101(pretrained=pretrained), | |
| "resnet152": torchvision.models.resnet152(pretrained=pretrained), | |
| } | |
| if name not in network.keys(): | |
| raise KeyError(f"{name} is not a valid network architecture") | |
| return network[name] | |
| class CONTRIQUE_Model(nn.Module): | |
| def __init__(self, | |
| encoder, | |
| n_features, | |
| patch_dim = (2,2), | |
| normalize = True, | |
| projection_dim = 128 | |
| ): | |
| """ | |
| ResNet50 model for Projector | |
| """ | |
| super(CONTRIQUE_Model, self).__init__() | |
| self.normalize = normalize | |
| self.encoder = nn.Sequential(*list(encoder.children())[:-2]) | |
| self.n_features = n_features | |
| self.patch_dim = patch_dim | |
| self.avgpool = nn.AdaptiveAvgPool2d((1,1)) | |
| self.avgpool_patch = nn.AdaptiveAvgPool2d(patch_dim) | |
| # MLP for Projector | |
| self.projector = nn.Sequential( | |
| nn.Linear(self.n_features, self.n_features, bias=False), | |
| nn.BatchNorm1d(self.n_features), | |
| nn.ReLU(), | |
| nn.Linear(self.n_features, projection_dim, bias=False), | |
| nn.BatchNorm1d(projection_dim), | |
| ) | |
| def forward(self, x_i, x_j): | |
| # global features | |
| h_i = self.encoder(x_i) | |
| h_j = self.encoder(x_j) | |
| # local features | |
| h_i_patch = self.avgpool_patch(h_i) | |
| h_j_patch = self.avgpool_patch(h_j) | |
| h_i_patch = h_i_patch.reshape(-1,self.n_features, self.patch_dim[0]*self.patch_dim[1]) | |
| h_j_patch = h_j_patch.reshape(-1,self.n_features, self.patch_dim[0]*self.patch_dim[1]) | |
| h_i_patch = torch.transpose(h_i_patch,2,1) | |
| h_i_patch = h_i_patch.reshape(-1, self.n_features) | |
| h_j_patch = torch.transpose(h_j_patch,2,1) | |
| h_j_patch = h_j_patch.reshape(-1, self.n_features) | |
| h_i = self.avgpool(h_i) | |
| h_j = self.avgpool(h_j) | |
| h_i = h_i.view(-1, self.n_features) | |
| h_j = h_j.view(-1, self.n_features) | |
| if self.normalize: | |
| h_i = nn.functional.normalize(h_i, dim=1) | |
| h_j = nn.functional.normalize(h_j, dim=1) | |
| h_i_patch = nn.functional.normalize(h_i_patch, dim=1) | |
| h_j_patch = nn.functional.normalize(h_j_patch, dim=1) | |
| # Global Projections | |
| z_i = self.projector(h_i) | |
| z_j = self.projector(h_j) | |
| # Local Projections | |
| z_i_patch = self.projector(h_i_patch) | |
| z_j_patch = self.projector(h_j_patch) | |
| return z_i, z_j, z_i_patch, z_j_patch, h_i, h_j, h_i_patch, h_j_patch |