File size: 2,280 Bytes
7f4f2d3
 
 
 
 
 
d09e211
7f4f2d3
 
 
d09e211
7f4f2d3
 
 
 
d09e211
 
7f4f2d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d09e211
7f4f2d3
 
 
 
 
d09e211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f4f2d3
 
 
 
 
 
d09e211
7f4f2d3
 
 
 
 
 
 
 
 
d09e211
7f4f2d3
d09e211
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn as nn
import torch.nn.functional as F
import config


class CL_loss(nn.Module):
    """Supervised contrastive loss without weighting."""

    def __init__(self):
        super(CL_loss, self).__init__()
        self.temperature = config.temperature

    def forward(self, feature_vectors, labels):
        normalized_features = F.normalize(
            feature_vectors, p=2, dim=1
        )  # normalize by row, each row euc is approximately 1

        # Identify indices for each label
        anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
        positive_indices = (labels == 1).nonzero().squeeze(dim=1)
        negative_indices = (labels == 2).nonzero().squeeze(dim=1)

        # Extract tensors based on labels
        anchor = normalized_features[anchor_indices]
        positives = normalized_features[positive_indices]
        negatives = normalized_features[negative_indices]
        pos_and_neg = torch.cat([positives, negatives])

        pos_cardinal = positives.shape[0]

        denominator = torch.sum(
            torch.exp(
                torch.div(
                    F.cosine_similarity(anchor, pos_and_neg, dim=1),
                    self.temperature,
                )
            )
        )

        # if not torch.isfinite(denominator):
        #     print("Denominator is Inf!")

        # if not torch.isfinite(
        #     torch.exp(
        #         torch.div(F.cosine_similarity(anchor, pos_and_neg, dim=1)),
        #         self.temperature,
        #     )
        # ).all():
        #     print("Exp is Inf!")
        #     print(
        #         torch.exp(
        #             torch.div(F.cosine_similarity(anchor, pos_and_neg, dim=1)),
        #             self.temperature,
        #         )
        #     )

        sum_log_ent = torch.sum(
            torch.log(
                torch.div(
                    torch.exp(
                        torch.div(
                            F.cosine_similarity(anchor, positives, dim=1),
                            self.temperature,
                        )
                    ),
                    denominator,
                )
            )
        )

        scale = -1 / pos_cardinal
        out = scale * sum_log_ent

        return out