File size: 3,584 Bytes
1ccc202
 
 
 
 
180681d
 
 
 
 
 
1ccc202
 
 
 
 
 
 
 
 
 
 
 
8dc3889
1ccc202
6d1b6c6
1ccc202
180681d
1ccc202
 
 
 
 
 
180681d
1ccc202
 
8dc3889
1ccc202
 
 
180681d
 
 
 
 
 
1ccc202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8dc3889
1ccc202
6d1b6c6
1ccc202
 
 
 
 
 
 
8dc3889
1ccc202
 
 
180681d
 
 
 
 
 
1ccc202
 
 
 
 
 
 
 
 
 
 
 
 
 
180681d
1ccc202
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
from torch import nn
import torch.nn.functional as F


def get_similarity_matrix(
    image_features: torch.Tensor, text_features: torch.Tensor
) -> torch.Tensor:
    return image_features @ text_features.T


def contrastive_loss(logits, dim):
    neg_ce = torch.diag(F.log_softmax(logits, dim=dim))
    return -neg_ce.mean()


def contrastive_sigmoid_loss(logits):
    return F.binary_cross_entropy_with_logits(logits, torch.eye(len(logits)), reduction="mean")


class CLIPLoss(nn.Module):
    def __init__(self, logit_temperature: float = -1.0):
        super().__init__()
        self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))

    def forward(self, similarity_matrix: torch.Tensor, *args):
        temperature = self.logit_temperature.sigmoid()

        caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
        image_loss = contrastive_loss(similarity_matrix / temperature, dim=1)

        return 0.5 * (caption_loss + image_loss)


class CyCLIPLoss(nn.Module):
    def __init__(self, logit_temperature: float = -1.0):
        super().__init__()
        self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
        self.lambda_1: float = 1.0
        self.lambda_2: float = 1.0

    def forward(
        self,
        similarity_matrix: torch.Tensor,
        image_features: torch.Tensor,
        text_features: torch.Tensor,
    ):
        temperature = self.logit_temperature.sigmoid()
        caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
        image_loss = contrastive_loss(similarity_matrix / temperature, dim=1)

        symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T)
        modality_difference_loss = F.mse_loss(
            image_features @ image_features.T, text_features @ text_features.T
        )

        return (
            0.5 * (caption_loss + image_loss)
            + self.lambda_1 * symmetry_loss
            + self.lambda_2 * modality_difference_loss
        )


class SigLIPLoss(nn.Module):
    def __init__(self, logit_temperature: float = -1.0):
        super().__init__()
        self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))

    def forward(self, similarity_matrix: torch.Tensor, *args):
        temperature = self.logit_temperature.sigmoid()
        return contrastive_sigmoid_loss(similarity_matrix / temperature)


class CySigLIPLoss(nn.Module):
    def __init__(self, logit_temperature: float = -1.0):
        super().__init__()
        self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
        self.lambda_1: float = 1.0
        self.lambda_2: float = 1.0

    def forward(
        self,
        similarity_matrix: torch.Tensor,
        image_features: torch.Tensor,
        text_features: torch.Tensor,
    ):
        temperature = self.logit_temperature.sigmoid()
        loss = contrastive_sigmoid_loss(similarity_matrix / temperature)

        symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T)
        modality_difference_loss = F.mse_loss(
            image_features @ image_features.T, text_features @ text_features.T
        )

        return loss + self.lambda_1 * symmetry_loss + self.lambda_2 * modality_difference_loss


def get_loss(loss_type: str):
    loss_functions = {
        "clip": CLIPLoss(),
        "cyclip": CyCLIPLoss(),
        "sigmoid": SigLIPLoss(),
        "cyclic_sigmoid": CySigLIPLoss(),
    }
    if loss_type in loss_functions:
        return loss_functions[loss_type]
    else:
        raise ValueError("Invalid loss type")