File size: 2,476 Bytes
d60982d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import torch
import numpy as np
class NTXentLoss(torch.nn.Module):
def __init__(self, device, batch_size, temperature, use_cosine_similarity):
super(NTXentLoss, self).__init__()
self.batch_size = batch_size
self.temperature = temperature
self.device = device
self.softmax = torch.nn.Softmax(dim=-1)
self.mask_samples_from_same_repr = self._get_correlated_mask().type(torch.bool)
self.similarity_function = self._get_similarity_function(use_cosine_similarity)
self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")
def _get_similarity_function(self, use_cosine_similarity):
if use_cosine_similarity:
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
return self._cosine_simililarity
else:
return self._dot_simililarity
def _get_correlated_mask(self):
diag = np.eye(2 * self.batch_size)
l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
mask = torch.from_numpy((diag + l1 + l2))
mask = (1 - mask).type(torch.bool)
return mask.to(self.device)
@staticmethod
def _dot_simililarity(x, y):
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
# x shape: (N, 1, C)
# y shape: (1, C, 2N)
# v shape: (N, 2N)
return v
def _cosine_simililarity(self, x, y):
# x shape: (N, 1, C)
# y shape: (1, 2N, C)
# v shape: (N, 2N)
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
return v
def forward(self, zis, zjs):
representations = torch.cat([zjs, zis], dim=0)
similarity_matrix = self.similarity_function(representations, representations)
# filter out the scores from the positive samples
l_pos = torch.diag(similarity_matrix, self.batch_size)
r_pos = torch.diag(similarity_matrix, -self.batch_size)
positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)
negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)
logits = torch.cat((positives, negatives), dim=1)
logits /= self.temperature
labels = torch.zeros(2 * self.batch_size).to(self.device).long()
loss = self.criterion(logits, labels)
return loss / (2 * self.batch_size)
|