|
import copy |
|
|
|
import torch |
|
from pytorch_lightning import LightningModule |
|
from torch import Tensor |
|
from torch.optim import SGD |
|
from torch.nn import Identity |
|
from torchvision.models import resnet50 |
|
|
|
from lightly.loss import DINOLoss |
|
from lightly.models.modules import DINOProjectionHead |
|
from lightly.models.utils import ( |
|
activate_requires_grad, |
|
deactivate_requires_grad, |
|
get_weight_decay_parameters, |
|
update_momentum, |
|
) |
|
from lightly.utils.benchmarking import OnlineLinearClassifier |
|
from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule |
|
|
|
from typing import Union, Tuple, List |
|
|
|
|
|
class DINO(LightningModule): |
|
def __init__(self, batch_size_per_device: int, num_classes: int) -> None: |
|
super().__init__() |
|
self.save_hyperparameters() |
|
self.batch_size_per_device = batch_size_per_device |
|
|
|
resnet = resnet50() |
|
resnet.fc = Identity() |
|
self.backbone = resnet |
|
self.projection_head = DINOProjectionHead(freeze_last_layer=1) |
|
self.student_backbone = copy.deepcopy(self.backbone) |
|
self.student_projection_head = DINOProjectionHead() |
|
self.criterion = DINOLoss(output_dim=65536) |
|
|
|
self.online_classifier = OnlineLinearClassifier(num_classes=num_classes) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
return self.backbone(x) |
|
|
|
def forward_student(self, x: Tensor) -> Tensor: |
|
features = self.student_backbone(x).flatten(start_dim=1) |
|
projections = self.student_projection_head(features) |
|
return projections |
|
|
|
def on_train_start(self) -> None: |
|
deactivate_requires_grad(self.backbone) |
|
deactivate_requires_grad(self.projection_head) |
|
|
|
def on_train_end(self) -> None: |
|
activate_requires_grad(self.backbone) |
|
activate_requires_grad(self.projection_head) |
|
|
|
def training_step( |
|
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int |
|
) -> Tensor: |
|
|
|
momentum = cosine_schedule( |
|
step=self.trainer.global_step, |
|
max_steps=self.trainer.estimated_stepping_batches, |
|
start_value=0.996, |
|
end_value=1.0, |
|
) |
|
update_momentum(self.student_backbone, self.backbone, m=momentum) |
|
update_momentum(self.student_projection_head, self.projection_head, m=momentum) |
|
|
|
views, targets = batch[0], batch[1] |
|
global_views = torch.cat(views[:2]) |
|
local_views = torch.cat(views[2:]) |
|
|
|
teacher_features = self.forward(global_views).flatten(start_dim=1) |
|
teacher_projections = self.projection_head(teacher_features) |
|
student_projections = torch.cat( |
|
[self.forward_student(global_views), self.forward_student(local_views)] |
|
) |
|
|
|
loss = self.criterion( |
|
teacher_out=teacher_projections.chunk(2), |
|
student_out=student_projections.chunk(len(views)), |
|
epoch=self.current_epoch, |
|
) |
|
self.log_dict( |
|
{"train_loss": loss, "ema_momentum": momentum}, |
|
prog_bar=True, |
|
sync_dist=True, |
|
batch_size=len(targets), |
|
) |
|
|
|
|
|
cls_loss, cls_log = self.online_classifier.training_step( |
|
(teacher_features.chunk(2)[0].detach(), targets), batch_idx |
|
) |
|
self.log_dict(cls_log, sync_dist=True, batch_size=len(targets)) |
|
return loss + cls_loss |
|
|
|
def validation_step( |
|
self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int |
|
) -> Tensor: |
|
images, targets = batch[0], batch[1] |
|
features = self.forward(images).flatten(start_dim=1) |
|
cls_loss, cls_log = self.online_classifier.validation_step( |
|
(features.detach(), targets), batch_idx |
|
) |
|
self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets)) |
|
return cls_loss |
|
|
|
def configure_optimizers(self): |
|
|
|
|
|
params, params_no_weight_decay = get_weight_decay_parameters( |
|
[self.student_backbone, self.student_projection_head] |
|
) |
|
|
|
|
|
optimizer = SGD( |
|
[ |
|
{"name": "dino", "params": params}, |
|
{ |
|
"name": "dino_no_weight_decay", |
|
"params": params_no_weight_decay, |
|
"weight_decay": 0.0, |
|
}, |
|
{ |
|
"name": "online_classifier", |
|
"params": self.online_classifier.parameters(), |
|
"weight_decay": 0.0, |
|
}, |
|
], |
|
lr=0.03 * self.batch_size_per_device * self.trainer.world_size / 256, |
|
momentum=0.9, |
|
weight_decay=1e-4, |
|
) |
|
scheduler = { |
|
"scheduler": CosineWarmupScheduler( |
|
optimizer=optimizer, |
|
warmup_epochs=int( |
|
self.trainer.estimated_stepping_batches |
|
/ self.trainer.max_epochs |
|
* 10 |
|
), |
|
max_epochs=int(self.trainer.estimated_stepping_batches), |
|
), |
|
"interval": "step", |
|
} |
|
return [optimizer], [scheduler] |
|
|
|
def configure_gradient_clipping( |
|
self, |
|
optimizer, |
|
gradient_clip_val: Union[int, float, None] = None, |
|
gradient_clip_algorithm: Union[str, None] = None, |
|
) -> None: |
|
self.clip_gradients( |
|
optimizer=optimizer, |
|
gradient_clip_val=3.0, |
|
gradient_clip_algorithm="norm", |
|
) |
|
self.student_projection_head.cancel_last_layer_gradients(self.current_epoch) |
|
|