sachin commited on
Commit
180681d
1 Parent(s): 8dc3889

refactoring loss functions

Browse files
Files changed (1) hide show
  1. src/loss.py +34 -10
src/loss.py CHANGED
@@ -3,6 +3,23 @@ from torch import nn
3
  import torch.nn.functional as F
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def contrastive_loss(logits, dim):
7
  neg_ce = torch.diag(F.log_softmax(logits, dim=dim))
8
  return -neg_ce.mean()
@@ -17,25 +34,29 @@ class CLIPLoss(nn.Module):
17
  super().__init__()
18
  self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
19
 
20
- def forward(self, image_features: torch.Tensor, text_features: torch.Tensor):
21
  temperature = self.logit_temperature.sigmoid()
22
- similarity_matrix = image_features @ text_features.T
23
  caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
24
  image_loss = contrastive_loss(similarity_matrix / temperature, dim=1)
25
 
26
  return 0.5 * (caption_loss + image_loss)
27
 
28
 
29
- class CyCLIP(nn.Module):
30
  def __init__(self, logit_temperature: float = -1.0):
31
  super().__init__()
32
  self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
33
  self.lambda_1: float = 1.0
34
  self.lambda_2: float = 1.0
35
 
36
- def forward(self, image_features: torch.Tensor, text_features: torch.Tensor):
 
 
 
 
 
37
  temperature = self.logit_temperature.sigmoid()
38
- similarity_matrix = image_features @ text_features.T
39
  caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
40
  image_loss = contrastive_loss(similarity_matrix / temperature, dim=1)
41
 
@@ -56,9 +77,8 @@ class SigLIPLoss(nn.Module):
56
  super().__init__()
57
  self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
58
 
59
- def forward(self, image_features: torch.Tensor, text_features: torch.Tensor):
60
  temperature = self.logit_temperature.sigmoid()
61
- similarity_matrix = image_features @ text_features.T
62
  return contrastive_sigmoid_loss(similarity_matrix / temperature)
63
 
64
 
@@ -69,9 +89,13 @@ class CySigLIPLoss(nn.Module):
69
  self.lambda_1: float = 1.0
70
  self.lambda_2: float = 1.0
71
 
72
- def forward(self, image_features: torch.Tensor, text_features: torch.Tensor):
 
 
 
 
 
73
  temperature = self.logit_temperature.sigmoid()
74
- similarity_matrix = image_features @ text_features.T
75
  loss = contrastive_sigmoid_loss(similarity_matrix / temperature)
76
 
77
  symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T)
@@ -85,7 +109,7 @@ class CySigLIPLoss(nn.Module):
85
  def get_loss(loss_type: str):
86
  loss_functions = {
87
  "clip": CLIPLoss(),
88
- "cyclip": CyCLIP(),
89
  "sigmoid": SigLIPLoss(),
90
  "cyclic_sigmoid": CySigLIPLoss(),
91
  }
 
3
  import torch.nn.functional as F
4
 
5
 
6
+ def metrics(similarity: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
7
+ y = torch.arange(len(similarity)).to(similarity.device)
8
+ img2cap_match_idx = similarity.argmax(dim=1)
9
+ cap2img_match_idx = similarity.argmax(dim=0)
10
+
11
+ img_acc = (img2cap_match_idx == y).float().mean()
12
+ cap_acc = (cap2img_match_idx == y).float().mean()
13
+
14
+ return img_acc, cap_acc
15
+
16
+
17
+ def get_similarity_matrix(
18
+ image_features: torch.Tensor, text_features: torch.Tensor
19
+ ) -> torch.Tensor:
20
+ return image_features @ text_features.T
21
+
22
+
23
  def contrastive_loss(logits, dim):
24
  neg_ce = torch.diag(F.log_softmax(logits, dim=dim))
25
  return -neg_ce.mean()
 
34
  super().__init__()
35
  self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
36
 
37
+ def forward(self, similarity_matrix: torch.Tensor):
38
  temperature = self.logit_temperature.sigmoid()
39
+
40
  caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
41
  image_loss = contrastive_loss(similarity_matrix / temperature, dim=1)
42
 
43
  return 0.5 * (caption_loss + image_loss)
44
 
45
 
46
+ class CyCLIPLoss(nn.Module):
47
  def __init__(self, logit_temperature: float = -1.0):
48
  super().__init__()
49
  self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
50
  self.lambda_1: float = 1.0
51
  self.lambda_2: float = 1.0
52
 
53
+ def forward(
54
+ self,
55
+ similarity_matrix: torch.Tensor,
56
+ image_features: torch.Tensor,
57
+ text_features: torch.Tensor,
58
+ ):
59
  temperature = self.logit_temperature.sigmoid()
 
60
  caption_loss = contrastive_loss(similarity_matrix / temperature, dim=0)
61
  image_loss = contrastive_loss(similarity_matrix / temperature, dim=1)
62
 
 
77
  super().__init__()
78
  self.logit_temperature = nn.Parameter(torch.tensor(logit_temperature))
79
 
80
+ def forward(self, similarity_matrix: torch.Tensor):
81
  temperature = self.logit_temperature.sigmoid()
 
82
  return contrastive_sigmoid_loss(similarity_matrix / temperature)
83
 
84
 
 
89
  self.lambda_1: float = 1.0
90
  self.lambda_2: float = 1.0
91
 
92
+ def forward(
93
+ self,
94
+ similarity_matrix: torch.Tensor,
95
+ image_features: torch.Tensor,
96
+ text_features: torch.Tensor,
97
+ ):
98
  temperature = self.logit_temperature.sigmoid()
 
99
  loss = contrastive_sigmoid_loss(similarity_matrix / temperature)
100
 
101
  symmetry_loss = F.mse_loss(similarity_matrix, similarity_matrix.T)
 
109
  def get_loss(loss_type: str):
110
  loss_functions = {
111
  "clip": CLIPLoss(),
112
+ "cyclip": CyCLIPLoss(),
113
  "sigmoid": SigLIPLoss(),
114
  "cyclic_sigmoid": CySigLIPLoss(),
115
  }