import copy import torch from pytorch_lightning import LightningModule from torch import Tensor from torch.optim import SGD from torch.nn import Identity from torchvision.models import resnet50 from lightly.loss import DINOLoss from lightly.models.modules import DINOProjectionHead from lightly.models.utils import ( activate_requires_grad, deactivate_requires_grad, get_weight_decay_parameters, update_momentum, ) from lightly.utils.benchmarking import OnlineLinearClassifier from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule from typing import Union, Tuple, List class DINO(LightningModule): def __init__(self, batch_size_per_device: int, num_classes: int) -> None: super().__init__() self.save_hyperparameters() self.batch_size_per_device = batch_size_per_device resnet = resnet50() resnet.fc = Identity() # Ignore classification head self.backbone = resnet self.projection_head = DINOProjectionHead(freeze_last_layer=1) self.student_backbone = copy.deepcopy(self.backbone) self.student_projection_head = DINOProjectionHead() self.criterion = DINOLoss(output_dim=65536) self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) def forward(self, x: Tensor) -> Tensor: return self.backbone(x) def forward_student(self, x: Tensor) -> Tensor: features = self.student_backbone(x).flatten(start_dim=1) projections = self.student_projection_head(features) return projections def on_train_start(self) -> None: deactivate_requires_grad(self.backbone) deactivate_requires_grad(self.projection_head) def on_train_end(self) -> None: activate_requires_grad(self.backbone) activate_requires_grad(self.projection_head) def training_step( self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int ) -> Tensor: # Momentum update teacher. momentum = cosine_schedule( step=self.trainer.global_step, max_steps=self.trainer.estimated_stepping_batches, start_value=0.996, end_value=1.0, ) update_momentum(self.student_backbone, self.backbone, m=momentum) update_momentum(self.student_projection_head, self.projection_head, m=momentum) views, targets = batch[0], batch[1] global_views = torch.cat(views[:2]) local_views = torch.cat(views[2:]) teacher_features = self.forward(global_views).flatten(start_dim=1) teacher_projections = self.projection_head(teacher_features) student_projections = torch.cat( [self.forward_student(global_views), self.forward_student(local_views)] ) loss = self.criterion( teacher_out=teacher_projections.chunk(2), student_out=student_projections.chunk(len(views)), epoch=self.current_epoch, ) self.log_dict( {"train_loss": loss, "ema_momentum": momentum}, prog_bar=True, sync_dist=True, batch_size=len(targets), ) # Online classification. cls_loss, cls_log = self.online_classifier.training_step( (teacher_features.chunk(2)[0].detach(), targets), batch_idx ) self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) return loss + cls_loss def validation_step( self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int ) -> Tensor: images, targets = batch[0], batch[1] features = self.forward(images).flatten(start_dim=1) cls_loss, cls_log = self.online_classifier.validation_step( (features.detach(), targets), batch_idx ) self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) return cls_loss def configure_optimizers(self): # Don't use weight decay for batch norm, bias parameters, and classification # head to improve performance. params, params_no_weight_decay = get_weight_decay_parameters( [self.student_backbone, self.student_projection_head] ) # For ResNet50 we use SGD instead of AdamW/LARS as recommended by the authors: # https://github.com/facebookresearch/dino#resnet-50-and-other-convnets-trainings optimizer = SGD( [ {"name": "dino", "params": params}, { "name": "dino_no_weight_decay", "params": params_no_weight_decay, "weight_decay": 0.0, }, { "name": "online_classifier", "params": self.online_classifier.parameters(), "weight_decay": 0.0, }, ], lr=0.03 * self.batch_size_per_device * self.trainer.world_size / 256, momentum=0.9, weight_decay=1e-4, ) scheduler = { "scheduler": CosineWarmupScheduler( optimizer=optimizer, warmup_epochs=int( self.trainer.estimated_stepping_batches / self.trainer.max_epochs * 10 ), max_epochs=int(self.trainer.estimated_stepping_batches), ), "interval": "step", } return [optimizer], [scheduler] def configure_gradient_clipping( self, optimizer, gradient_clip_val: Union[int, float, None] = None, gradient_clip_algorithm: Union[str, None] = None, ) -> None: self.clip_gradients( optimizer=optimizer, gradient_clip_val=3.0, gradient_clip_algorithm="norm", ) self.student_projection_head.cancel_last_layer_gradients(self.current_epoch)