File size: 5,103 Bytes
1e6fe0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
"""
π― core-dino | DINO-style Loss Functions π₯
Defines the cross-view contrastive loss used in DINO setups,
including temperature scaling, centering, and teacher-student divergence.
Includes:
- DinoSpatialLoss: Temp-scaled CE loss with center momentum π
- DinoSinkhornSpatialLoss: Sinkhorn-based balanced assignment loss βοΈ
Author: Gajesh Ladhar
π LinkedIn: https://www.linkedin.com/in/gajeshladhar/
π€ Hugging Face: https://huggingface.co/gajeshladhar
"""
import torch
from torch import nn
import torch.nn.functional as F
class DinoSpatialLoss(nn.Module):
"""
π DINO loss using temperature-scaled cross-entropy over spatial tokens.
- Aligns teacher & student spatial features (B, C, H, W)
- Applies center momentum for teacher stability
Args:
teacher_temp (float): Temperature for teacher softmax
student_temp (float): Temperature for student softmax
center_momentum (float): EMA factor for center update
"""
def __init__(self, teacher_temp=0.04, student_temp=0.1, center_momentum=0.9):
super().__init__()
self.teacher_temp = teacher_temp
self.student_temp = student_temp
self.center_momentum = center_momentum
self.register_buffer("center", torch.zeros(1, 1)) # lazy init
def forward(self, student_feat, teacher_feat):
"""
Compute loss over (B, C, H, W) features.
Args:
student_feat (Tensor): Student output, shape (B, C, Hs, Ws)
teacher_feat (Tensor): Teacher output, shape (B, C, Ht, Wt)
Returns:
Tensor: Scalar DINO loss
"""
# Initialize center shape based on teacher feature dim
if self.center.shape[1] == 1:
self.center = self.center.new_zeros(1, teacher_feat.shape[1])
# Resize student to teacher resolution
student_resized = F.interpolate(student_feat, size=teacher_feat.shape[2:], mode='bilinear', align_corners=False)
# Flatten spatial dims: (B, C, H, W) β (B*H*W, C)
B, C, H, W = student_resized.shape
student_flat = student_resized.permute(0, 2, 3, 1).reshape(-1, C) # (BHW, C)
teacher_flat = teacher_feat.permute(0, 2, 3, 1).reshape(-1, C) # (BHW, C)
# Apply softmax (teacher uses center)
student_logits = student_flat / self.student_temp
teacher_logits = (teacher_flat - self.center) / self.teacher_temp
student_log_probs = F.log_softmax(student_logits, dim=-1)
teacher_probs = F.softmax(teacher_logits, dim=-1).detach()
# Cross-entropy loss
loss = - (teacher_probs * student_log_probs).sum(dim=-1).mean()
# Update center
batch_center = teacher_probs.mean(dim=0, keepdim=True)
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
return loss
class SinkhornKnopp(nn.Module):
"""
βοΈ Sinkhorn-Knopp normalization for balanced assignments.
Args:
num_iters (int): Number of normalization iterations
eps (float): Stabilizer to avoid div-by-zero
"""
def __init__(self, num_iters: int = 3, eps: float = 1e-6):
super().__init__()
self.num_iters = num_iters
self.eps = eps
def forward(self, logits: torch.Tensor) -> torch.Tensor:
logits = logits - logits.max(dim=1, keepdim=True)[0] # stabilize
Q = torch.exp(logits).clone()
Q /= Q.sum()
for _ in range(self.num_iters):
Q /= Q.sum(dim=1, keepdim=True) + self.eps # row normalization
Q /= Q.sum(dim=0, keepdim=True) + self.eps # column normalization
return Q
class DinoSinkhornSpatialLoss(nn.Module):
"""
π DINO loss with Sinkhorn assignment β no center, balanced targets.
Args:
student_temp (float): Temperature for student softmax
sinkhorn_iters (int): Iterations for Sinkhorn normalization
"""
def __init__(self, student_temp=0.1, sinkhorn_iters=3):
super().__init__()
self.student_temp = student_temp
self.sinkhorn = SinkhornKnopp(sinkhorn_iters)
def forward(self, student_feat, teacher_feat):
"""
student_feat: (B, C, Hs, Ws)
teacher_feat: (B, C, Ht, Wt)
"""
# Resize student to teacher resolution
student_resized = F.interpolate(
student_feat, size=teacher_feat.shape[2:], mode='bilinear', align_corners=False
)
# Flatten spatial dims: (B, C, H, W) β (BHW, C)
B, C, H, W = student_resized.shape
student_flat = student_resized.permute(0, 2, 3, 1).reshape(-1, C)
teacher_flat = teacher_feat.permute(0, 2, 3, 1).reshape(-1, C)
# Teacher: apply Sinkhorn (no temp, no center)
teacher_probs = self.sinkhorn(teacher_flat).detach()
# Student: softmax with temp
student_log_probs = F.log_softmax(student_flat / self.student_temp, dim=-1)
# Cross-entropy loss
loss = -(teacher_probs * student_log_probs).sum(dim=-1).mean()
return loss |