xai-cl / ssl_models /dino.py
Annonymous
Upload 4 files
b157c29
raw
history blame
7.52 kB
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import numpy as np
import pathlib
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
""" from https://github.com/facebookresearch/dino"""
class DINOHead(nn.Module):
def __init__(self, in_dim, out_dim, use_bn, norm_last_layer, nlayers, hidden_dim, bottleneck_dim):
super().__init__()
nlayers = max(nlayers, 1)
if nlayers == 1:
self.mlp = nn.Linear(in_dim, bottleneck_dim)
else:
layers = [nn.Linear(in_dim, hidden_dim)]
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
for _ in range(nlayers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = nn.Sequential(*layers)
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def forward(self, x):
x = self.mlp(x)
x = F.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x
class MultiCropWrapper(nn.Module):
def __init__(self, backbone, head):
super(MultiCropWrapper, self).__init__()
backbone.fc, backbone.head = nn.Identity(), nn.Identity()
self.backbone = backbone
self.head = head
def forward(self, x):
return self.head(self.backbone(x))
class DINOLoss(nn.Module):
def __init__(self, out_dim, warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs, nepochs,
student_temp=0.1, center_momentum=0.9):
super().__init__()
self.student_temp = student_temp
self.center_momentum = center_momentum
self.register_buffer("center", torch.zeros(1, out_dim))
self.nepochs = nepochs
self.teacher_temp_schedule = np.concatenate((np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp))
def forward(self, student_output, teacher_output):
student_out = student_output / self.student_temp
temp = self.teacher_temp_schedule[self.nepochs - 1] # last one
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
teacher_out = teacher_out.detach()
loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1).mean()
return loss
class ResNet(nn.Module):
def __init__(self, backbone):
super().__init__()
modules = list(backbone.children())[:-2]
self.net = nn.Sequential(*modules)
def forward(self, x):
return self.net(x).mean(dim=[2, 3])
class RestructuredDINO(nn.Module):
def __init__(self, student, teacher):
super().__init__()
self.encoder_student = ResNet(student.backbone)
self.encoder = ResNet(teacher.backbone)
self.contrastive_head_student = student.head
self.contrastive_head = teacher.head
def forward(self, x, run_teacher):
if run_teacher:
x = self.encoder(x)
x = self.contrastive_head(x)
else:
x = self.encoder_student(x)
x = self.contrastive_head_student(x)
return x
def get_dino_model_without_loss(ckpt_path = 'dino_resnet50_pretrain_full_checkpoint.pth'):
state_dict = torch.load('pretrained_models/dino_models/' + ckpt_path, map_location='cpu')
state_dict_student = state_dict['student']
state_dict_teacher = state_dict['teacher']
state_dict_student = {k.replace("module.", ""): v for k, v in state_dict_student.items()}
state_dict_teacher = {k.replace("module.", ""): v for k, v in state_dict_teacher.items()}
student_backbone = torchvision.models.resnet50()
teacher_backbone = torchvision.models.resnet50()
embed_dim = student_backbone.fc.weight.shape[1]
student_head = DINOHead(in_dim = embed_dim, out_dim = 60000, use_bn=True, norm_last_layer=True, nlayers=2, hidden_dim=4096, bottleneck_dim=256)
teacher_head = DINOHead(in_dim = embed_dim, out_dim = 60000, use_bn =True, norm_last_layer=True, nlayers=2, hidden_dim=4096, bottleneck_dim=256)
student_head.last_layer = nn.Linear(256, 60000, bias = False)
teacher_head.last_layer = nn.Linear(256, 60000, bias = False)
student = MultiCropWrapper(student_backbone, student_head)
teacher = MultiCropWrapper(teacher_backbone, teacher_head)
student.load_state_dict(state_dict_student)
teacher.load_state_dict(state_dict_teacher)
restructured_model = RestructuredDINO(student, teacher)
return restructured_model.to(device)
def get_dino_model_with_loss(ckpt_path = 'dino_rn50_checkpoint.pth'):
state_dict = torch.load('pretrained_models/dino_models/' + ckpt_path, map_location='cpu')
state_dict_student = state_dict['student']
state_dict_teacher = state_dict['teacher']
state_dict_args = vars(state_dict['args'])
state_dic_dino_loss = state_dict['dino_loss']
state_dict_student = {k.replace("module.", ""): v for k, v in state_dict_student.items()}
state_dict_teacher = {k.replace("module.", ""): v for k, v in state_dict_teacher.items()}
student_backbone = torchvision.models.resnet50()
teacher_backbone = torchvision.models.resnet50()
embed_dim = student_backbone.fc.weight.shape[1]
student_head = DINOHead(in_dim = embed_dim,
out_dim = state_dict_args['out_dim'],
use_bn = state_dict_args['use_bn_in_head'],
norm_last_layer = state_dict_args['norm_last_layer'],
nlayers = 3,
hidden_dim = 2048,
bottleneck_dim = 256)
teacher_head = DINOHead(in_dim = embed_dim,
out_dim = state_dict_args['out_dim'],
use_bn = state_dict_args['use_bn_in_head'],
norm_last_layer = state_dict_args['norm_last_layer'],
nlayers = 3,
hidden_dim = 2048,
bottleneck_dim = 256)
loss = DINOLoss(out_dim = state_dict_args['out_dim'],
warmup_teacher_temp = state_dict_args['warmup_teacher_temp'],
teacher_temp = state_dict_args['teacher_temp'],
warmup_teacher_temp_epochs = state_dict_args['warmup_teacher_temp_epochs'],
nepochs = state_dict_args['epochs'])
student = MultiCropWrapper(student_backbone, student_head)
teacher = MultiCropWrapper(teacher_backbone, teacher_head)
student.load_state_dict(state_dict_student)
teacher.load_state_dict(state_dict_teacher)
loss.load_state_dict(state_dic_dino_loss)
restructured_model = RestructuredDINO(student, teacher)
return restructured_model.to(device), loss.to(device)