File size: 5,103 Bytes
1e6fe0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""
🎯 core-dino | DINO-style Loss Functions πŸ’₯

Defines the cross-view contrastive loss used in DINO setups, 
including temperature scaling, centering, and teacher-student divergence.

Includes:
- DinoSpatialLoss: Temp-scaled CE loss with center momentum πŸŒ€
- DinoSinkhornSpatialLoss: Sinkhorn-based balanced assignment loss βš–οΈ

Author: Gajesh Ladhar  
πŸ”— LinkedIn: https://www.linkedin.com/in/gajeshladhar/  
πŸ€— Hugging Face: https://huggingface.co/gajeshladhar
"""

import torch
from torch import nn 
import torch.nn.functional as F


class DinoSpatialLoss(nn.Module):
    """
    πŸŒ€ DINO loss using temperature-scaled cross-entropy over spatial tokens.

    - Aligns teacher & student spatial features (B, C, H, W)
    - Applies center momentum for teacher stability

    Args:
        teacher_temp (float): Temperature for teacher softmax
        student_temp (float): Temperature for student softmax
        center_momentum (float): EMA factor for center update
    """
    def __init__(self, teacher_temp=0.04, student_temp=0.1, center_momentum=0.9):
        super().__init__()
        self.teacher_temp = teacher_temp
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.register_buffer("center", torch.zeros(1, 1))  # lazy init

    def forward(self, student_feat, teacher_feat):
        """
        Compute loss over (B, C, H, W) features.

        Args:
            student_feat (Tensor): Student output, shape (B, C, Hs, Ws)
            teacher_feat (Tensor): Teacher output, shape (B, C, Ht, Wt)

        Returns:
            Tensor: Scalar DINO loss
        """

        # Initialize center shape based on teacher feature dim
        if self.center.shape[1] == 1:
            self.center = self.center.new_zeros(1, teacher_feat.shape[1])

        # Resize student to teacher resolution
        student_resized = F.interpolate(student_feat, size=teacher_feat.shape[2:], mode='bilinear', align_corners=False)

        # Flatten spatial dims: (B, C, H, W) β†’ (B*H*W, C)
        B, C, H, W = student_resized.shape
        student_flat = student_resized.permute(0, 2, 3, 1).reshape(-1, C)  # (BHW, C)
        teacher_flat = teacher_feat.permute(0, 2, 3, 1).reshape(-1, C)     # (BHW, C)

        # Apply softmax (teacher uses center)
        student_logits = student_flat / self.student_temp
        teacher_logits = (teacher_flat - self.center) / self.teacher_temp

        student_log_probs = F.log_softmax(student_logits, dim=-1)
        teacher_probs = F.softmax(teacher_logits, dim=-1).detach()

        # Cross-entropy loss
        loss = - (teacher_probs * student_log_probs).sum(dim=-1).mean()

        # Update center
        batch_center = teacher_probs.mean(dim=0, keepdim=True)
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

        return loss
    
    
    
class SinkhornKnopp(nn.Module):
    """
    βš–οΈ Sinkhorn-Knopp normalization for balanced assignments.

    Args:
        num_iters (int): Number of normalization iterations
        eps (float): Stabilizer to avoid div-by-zero
    """
    def __init__(self, num_iters: int = 3, eps: float = 1e-6):
        super().__init__()
        self.num_iters = num_iters
        self.eps = eps

    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        logits = logits - logits.max(dim=1, keepdim=True)[0]  # stabilize
        Q = torch.exp(logits).clone()
        Q /= Q.sum()

        for _ in range(self.num_iters):
            Q /= Q.sum(dim=1, keepdim=True) + self.eps  # row normalization
            Q /= Q.sum(dim=0, keepdim=True) + self.eps  # column normalization

        return Q

class DinoSinkhornSpatialLoss(nn.Module):
    """
    πŸŒ€ DINO loss with Sinkhorn assignment β€” no center, balanced targets.

    Args:
        student_temp (float): Temperature for student softmax
        sinkhorn_iters (int): Iterations for Sinkhorn normalization
    """
    def __init__(self, student_temp=0.1, sinkhorn_iters=3):
        super().__init__()
        self.student_temp = student_temp
        self.sinkhorn = SinkhornKnopp(sinkhorn_iters)

    def forward(self, student_feat, teacher_feat):
        """
        student_feat: (B, C, Hs, Ws)
        teacher_feat: (B, C, Ht, Wt)
        """

        # Resize student to teacher resolution
        student_resized = F.interpolate(
            student_feat, size=teacher_feat.shape[2:], mode='bilinear', align_corners=False
        )

        # Flatten spatial dims: (B, C, H, W) β†’ (BHW, C)
        B, C, H, W = student_resized.shape
        student_flat = student_resized.permute(0, 2, 3, 1).reshape(-1, C)
        teacher_flat = teacher_feat.permute(0, 2, 3, 1).reshape(-1, C)

        # Teacher: apply Sinkhorn (no temp, no center)
        teacher_probs = self.sinkhorn(teacher_flat).detach()

        # Student: softmax with temp
        student_log_probs = F.log_softmax(student_flat / self.student_temp, dim=-1)

        # Cross-entropy loss
        loss = -(teacher_probs * student_log_probs).sum(dim=-1).mean()

        return loss