sxtforreal commited on
Commit
7f4f2d3
1 Parent(s): 504db9e

Create loss.py

Browse files

This file holds 4 loss functions for the 4 models respectively.

Files changed (1) hide show
  1. loss.py +228 -0
loss.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import config
5
+
6
+
7
+ class ContrastiveLoss_simcse(nn.Module):
8
+ """SimCSE loss"""
9
+
10
+ def __init__(self):
11
+ super(ContrastiveLoss_simcse, self).__init__()
12
+ self.temperature = config.temperature
13
+
14
+ def forward(self, feature_vectors, labels):
15
+ normalized_features = F.normalize(
16
+ feature_vectors, p=2, dim=0
17
+ ) # normalize along columns
18
+
19
+ # Identify indices for each label
20
+ anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
21
+ positive_indices = (labels == 1).nonzero().squeeze(dim=1)
22
+ negative_indices = (labels == 2).nonzero().squeeze(dim=1)
23
+
24
+ # Extract tensors based on labels
25
+ anchor = normalized_features[anchor_indices]
26
+ positives = normalized_features[positive_indices]
27
+ negatives = normalized_features[negative_indices]
28
+ pos_and_neg = torch.cat([positives, negatives])
29
+
30
+ denominator = torch.sum(
31
+ torch.exp(
32
+ torch.div(
33
+ torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
34
+ self.temperature,
35
+ )
36
+ )
37
+ )
38
+
39
+ numerator = torch.exp(
40
+ torch.div(
41
+ torch.matmul(anchor, torch.transpose(positives, 0, 1)),
42
+ self.temperature,
43
+ )
44
+ )
45
+
46
+ loss = -torch.log(
47
+ torch.div(
48
+ numerator,
49
+ denominator,
50
+ )
51
+ )
52
+
53
+ return loss
54
+
55
+
56
+ class ContrastiveLoss_simcse_w(nn.Module):
57
+ """SimCSE loss with weighting."""
58
+
59
+ def __init__(self):
60
+ super(ContrastiveLoss_simcse_w, self).__init__()
61
+ self.temperature = config.temperature
62
+
63
+ def forward(self, feature_vectors, labels, scores):
64
+ normalized_features = F.normalize(
65
+ feature_vectors, p=2, dim=0
66
+ ) # normalize along columns
67
+
68
+ # Identify indices for each label
69
+ anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
70
+ positive_indices = (labels == 1).nonzero().squeeze(dim=1)
71
+ negative_indices = (labels == 2).nonzero().squeeze(dim=1)
72
+
73
+ pos_scores = scores[positive_indices].float()
74
+ normalized_neg_scores = F.normalize(
75
+ scores[negative_indices].float(), p=2, dim=0
76
+ ) # l2-norm
77
+ normalized_neg_scores += 1
78
+ scores = torch.cat([pos_scores, normalized_neg_scores])
79
+
80
+ # Extract tensors based on labels
81
+ anchor = normalized_features[anchor_indices]
82
+ positives = normalized_features[positive_indices]
83
+ negatives = normalized_features[negative_indices]
84
+ pos_and_neg = torch.cat([positives, negatives])
85
+
86
+ denominator = torch.sum(
87
+ torch.exp(
88
+ scores
89
+ * torch.div(
90
+ torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
91
+ self.temperature,
92
+ )
93
+ )
94
+ )
95
+
96
+ numerator = torch.exp(
97
+ torch.div(
98
+ torch.matmul(anchor, torch.transpose(positives, 0, 1)),
99
+ self.temperature,
100
+ )
101
+ )
102
+
103
+ loss = -torch.log(
104
+ torch.div(
105
+ numerator,
106
+ denominator,
107
+ )
108
+ )
109
+
110
+ return loss
111
+
112
+
113
+ class ContrastiveLoss_samp(nn.Module):
114
+ """Supervised contrastive loss without weighting."""
115
+
116
+ def __init__(self):
117
+ super(ContrastiveLoss_samp, self).__init__()
118
+ self.temperature = config.temperature
119
+
120
+ def forward(self, feature_vectors, labels):
121
+ # Normalize feature vectors
122
+ normalized_features = F.normalize(
123
+ feature_vectors, p=2, dim=0
124
+ ) # normalize along columns
125
+
126
+ # Identify indices for each label
127
+ anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
128
+ positive_indices = (labels == 1).nonzero().squeeze(dim=1)
129
+ negative_indices = (labels == 2).nonzero().squeeze(dim=1)
130
+
131
+ # Extract tensors based on labels
132
+ anchor = normalized_features[anchor_indices]
133
+ positives = normalized_features[positive_indices]
134
+ negatives = normalized_features[negative_indices]
135
+ pos_and_neg = torch.cat([positives, negatives])
136
+
137
+ pos_cardinal = positives.shape[0]
138
+
139
+ denominator = torch.sum(
140
+ torch.exp(
141
+ torch.div(
142
+ torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
143
+ self.temperature,
144
+ )
145
+ )
146
+ )
147
+
148
+ sum_log_ent = torch.sum(
149
+ torch.log(
150
+ torch.div(
151
+ torch.exp(
152
+ torch.div(
153
+ torch.matmul(anchor, torch.transpose(positives, 0, 1)),
154
+ self.temperature,
155
+ )
156
+ ),
157
+ denominator,
158
+ )
159
+ )
160
+ )
161
+
162
+ scale = -1 / pos_cardinal
163
+
164
+ return scale * sum_log_ent
165
+
166
+
167
+ class ContrastiveLoss_samp_w(nn.Module):
168
+ """Supervised contrastive loss with weighting."""
169
+
170
+ def __init__(self):
171
+ super(ContrastiveLoss_samp_w, self).__init__()
172
+ self.temperature = config.temperature
173
+
174
+ def forward(self, feature_vectors, labels, scores):
175
+ # Normalize feature vectors
176
+ normalized_features = F.normalize(
177
+ feature_vectors, p=2, dim=0
178
+ ) # normalize along columns
179
+
180
+ # Identify indices for each label
181
+ anchor_indices = (labels == 0).nonzero().squeeze(dim=1)
182
+ positive_indices = (labels == 1).nonzero().squeeze(dim=1)
183
+ negative_indices = (labels == 2).nonzero().squeeze(dim=1)
184
+
185
+ # Normalize score vector
186
+ num_skip = len(positive_indices) + 1
187
+ pos_scores = scores[: (num_skip - 1)].float() # exclude anchor
188
+ normalized_neg_scores = F.normalize(
189
+ scores[num_skip:].float(), p=2, dim=0
190
+ ) # l2-norm
191
+ normalized_neg_scores += 1
192
+ scores = torch.cat([pos_scores, normalized_neg_scores])
193
+
194
+ # Extract tensors based on labels
195
+ anchor = normalized_features[anchor_indices]
196
+ positives = normalized_features[positive_indices]
197
+ negatives = normalized_features[negative_indices]
198
+ pos_and_neg = torch.cat([positives, negatives])
199
+
200
+ pos_cardinal = positives.shape[0]
201
+
202
+ denominator = torch.sum(
203
+ torch.exp(
204
+ scores
205
+ * torch.div(
206
+ torch.matmul(anchor, torch.transpose(pos_and_neg, 0, 1)),
207
+ self.temperature,
208
+ )
209
+ )
210
+ )
211
+
212
+ sum_log_ent = torch.sum(
213
+ torch.log(
214
+ torch.div(
215
+ torch.exp(
216
+ torch.div(
217
+ torch.matmul(anchor, torch.transpose(positives, 0, 1)),
218
+ self.temperature,
219
+ )
220
+ ),
221
+ denominator,
222
+ )
223
+ )
224
+ )
225
+
226
+ scale = -1 / pos_cardinal
227
+
228
+ return scale * sum_log_ent