Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torchvision | |
import torch.nn.functional as F | |
import numpy as np | |
import pathlib | |
temp = pathlib.PosixPath | |
pathlib.PosixPath = pathlib.WindowsPath | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
""" from https://github.com/facebookresearch/dino""" | |
class DINOHead(nn.Module): | |
def __init__(self, in_dim, out_dim, use_bn, norm_last_layer, nlayers, hidden_dim, bottleneck_dim): | |
super().__init__() | |
nlayers = max(nlayers, 1) | |
if nlayers == 1: | |
self.mlp = nn.Linear(in_dim, bottleneck_dim) | |
else: | |
layers = [nn.Linear(in_dim, hidden_dim)] | |
if use_bn: | |
layers.append(nn.BatchNorm1d(hidden_dim)) | |
layers.append(nn.GELU()) | |
for _ in range(nlayers - 2): | |
layers.append(nn.Linear(hidden_dim, hidden_dim)) | |
if use_bn: | |
layers.append(nn.BatchNorm1d(hidden_dim)) | |
layers.append(nn.GELU()) | |
layers.append(nn.Linear(hidden_dim, bottleneck_dim)) | |
self.mlp = nn.Sequential(*layers) | |
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) | |
self.last_layer.weight_g.data.fill_(1) | |
if norm_last_layer: | |
self.last_layer.weight_g.requires_grad = False | |
def forward(self, x): | |
x = self.mlp(x) | |
x = F.normalize(x, dim=-1, p=2) | |
x = self.last_layer(x) | |
return x | |
class MultiCropWrapper(nn.Module): | |
def __init__(self, backbone, head): | |
super(MultiCropWrapper, self).__init__() | |
backbone.fc, backbone.head = nn.Identity(), nn.Identity() | |
self.backbone = backbone | |
self.head = head | |
def forward(self, x): | |
return self.head(self.backbone(x)) | |
class DINOLoss(nn.Module): | |
def __init__(self, out_dim, warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs, nepochs, | |
student_temp=0.1, center_momentum=0.9): | |
super().__init__() | |
self.student_temp = student_temp | |
self.center_momentum = center_momentum | |
self.register_buffer("center", torch.zeros(1, out_dim)) | |
self.nepochs = nepochs | |
self.teacher_temp_schedule = np.concatenate((np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs), | |
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp)) | |
def forward(self, student_output, teacher_output): | |
student_out = student_output / self.student_temp | |
temp = self.teacher_temp_schedule[self.nepochs - 1] # last one | |
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) | |
teacher_out = teacher_out.detach() | |
loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1).mean() | |
return loss | |
class ResNet(nn.Module): | |
def __init__(self, backbone): | |
super().__init__() | |
modules = list(backbone.children())[:-2] | |
self.net = nn.Sequential(*modules) | |
def forward(self, x): | |
return self.net(x).mean(dim=[2, 3]) | |
class RestructuredDINO(nn.Module): | |
def __init__(self, student, teacher): | |
super().__init__() | |
self.encoder_student = ResNet(student.backbone) | |
self.encoder = ResNet(teacher.backbone) | |
self.contrastive_head_student = student.head | |
self.contrastive_head = teacher.head | |
def forward(self, x, run_teacher): | |
if run_teacher: | |
x = self.encoder(x) | |
x = self.contrastive_head(x) | |
else: | |
x = self.encoder_student(x) | |
x = self.contrastive_head_student(x) | |
return x | |
def get_dino_model_without_loss(ckpt_path = 'dino_resnet50_pretrain_full_checkpoint.pth'): | |
state_dict = torch.load('pretrained_models/dino_models/' + ckpt_path, map_location='cpu') | |
state_dict_student = state_dict['student'] | |
state_dict_teacher = state_dict['teacher'] | |
state_dict_student = {k.replace("module.", ""): v for k, v in state_dict_student.items()} | |
state_dict_teacher = {k.replace("module.", ""): v for k, v in state_dict_teacher.items()} | |
student_backbone = torchvision.models.resnet50() | |
teacher_backbone = torchvision.models.resnet50() | |
embed_dim = student_backbone.fc.weight.shape[1] | |
student_head = DINOHead(in_dim = embed_dim, out_dim = 60000, use_bn=True, norm_last_layer=True, nlayers=2, hidden_dim=4096, bottleneck_dim=256) | |
teacher_head = DINOHead(in_dim = embed_dim, out_dim = 60000, use_bn =True, norm_last_layer=True, nlayers=2, hidden_dim=4096, bottleneck_dim=256) | |
student_head.last_layer = nn.Linear(256, 60000, bias = False) | |
teacher_head.last_layer = nn.Linear(256, 60000, bias = False) | |
student = MultiCropWrapper(student_backbone, student_head) | |
teacher = MultiCropWrapper(teacher_backbone, teacher_head) | |
student.load_state_dict(state_dict_student) | |
teacher.load_state_dict(state_dict_teacher) | |
restructured_model = RestructuredDINO(student, teacher) | |
return restructured_model.to(device) | |
def get_dino_model_with_loss(ckpt_path = 'dino_rn50_checkpoint.pth'): | |
state_dict = torch.load('pretrained_models/dino_models/' + ckpt_path, map_location='cpu') | |
state_dict_student = state_dict['student'] | |
state_dict_teacher = state_dict['teacher'] | |
state_dict_args = vars(state_dict['args']) | |
state_dic_dino_loss = state_dict['dino_loss'] | |
state_dict_student = {k.replace("module.", ""): v for k, v in state_dict_student.items()} | |
state_dict_teacher = {k.replace("module.", ""): v for k, v in state_dict_teacher.items()} | |
student_backbone = torchvision.models.resnet50() | |
teacher_backbone = torchvision.models.resnet50() | |
embed_dim = student_backbone.fc.weight.shape[1] | |
student_head = DINOHead(in_dim = embed_dim, | |
out_dim = state_dict_args['out_dim'], | |
use_bn = state_dict_args['use_bn_in_head'], | |
norm_last_layer = state_dict_args['norm_last_layer'], | |
nlayers = 3, | |
hidden_dim = 2048, | |
bottleneck_dim = 256) | |
teacher_head = DINOHead(in_dim = embed_dim, | |
out_dim = state_dict_args['out_dim'], | |
use_bn = state_dict_args['use_bn_in_head'], | |
norm_last_layer = state_dict_args['norm_last_layer'], | |
nlayers = 3, | |
hidden_dim = 2048, | |
bottleneck_dim = 256) | |
loss = DINOLoss(out_dim = state_dict_args['out_dim'], | |
warmup_teacher_temp = state_dict_args['warmup_teacher_temp'], | |
teacher_temp = state_dict_args['teacher_temp'], | |
warmup_teacher_temp_epochs = state_dict_args['warmup_teacher_temp_epochs'], | |
nepochs = state_dict_args['epochs']) | |
student = MultiCropWrapper(student_backbone, student_head) | |
teacher = MultiCropWrapper(teacher_backbone, teacher_head) | |
student.load_state_dict(state_dict_student) | |
teacher.load_state_dict(state_dict_teacher) | |
loss.load_state_dict(state_dic_dino_loss) | |
restructured_model = RestructuredDINO(student, teacher) | |
return restructured_model.to(device), loss.to(device) |