DeepLearning101 commited on
Commit
fdc4786
1 Parent(s): 4cda815

Upload 6 files

Browse files
loss/contrastive_loss.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/03/23 14:50
3
+ # @Author : Jianing Wang
4
+ # @Email : lygwjn@gmail.com
5
+ # @File : ContrastiveLoss.py
6
+ # !/usr/bin/env python
7
+ # coding=utf-8
8
+
9
+ from enum import Enum
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn, Tensor
13
+ from transformers.models.bert.modeling_bert import BertModel
14
+ from transformers import BertTokenizer, BertConfig
15
+
16
+ class SiameseDistanceMetric(Enum):
17
+ """
18
+ The metric for the contrastive loss
19
+ """
20
+ EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
21
+ MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
22
+ COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y)
23
+
24
+
25
+ class ContrastiveLoss(nn.Module):
26
+ """
27
+ Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
28
+ two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.
29
+
30
+ @:param distance_metric: The distance metric function
31
+ @:param margin: (float) The margin distance
32
+ @:param size_average: (bool) Whether to get averaged loss
33
+
34
+ Input example of forward function:
35
+ rep_anchor: [[0.2, -0.1, ..., 0.6], [0.2, -0.1, ..., 0.6], ..., [0.2, -0.1, ..., 0.6]]
36
+ rep_candidate: [[0.3, 0.1, ...m -0.3], [-0.8, 1.2, ..., 0.7], ..., [-0.9, 0.1, ..., 0.4]]
37
+ label: [0, 1, ..., 1]
38
+
39
+ Return example of forward function:
40
+ 0.015 (averged)
41
+ 2.672 (sum)
42
+ """
43
+
44
+ def __init__(self, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = False):
45
+ super(ContrastiveLoss, self).__init__()
46
+ self.distance_metric = distance_metric
47
+ self.margin = margin
48
+ self.size_average = size_average
49
+
50
+ def forward(self, rep_anchor, rep_candidate, label: Tensor):
51
+ # rep_anchor: [batch_size, hidden_dim] denotes the representations of anchors
52
+ # rep_candidate: [batch_size, hidden_dim] denotes the representations of positive / negative
53
+ # label: [batch_size, hidden_dim] denotes the label of each anchor - candidate pair
54
+
55
+ distances = self.distance_metric(rep_anchor, rep_candidate)
56
+ losses = 0.5 * (label.float() * distances.pow(2) + (1 - label).float() * F.relu(self.margin - distances).pow(2))
57
+ return losses.mean() if self.size_average else losses.sum()
58
+
59
+
60
+ if __name__ == "__main__":
61
+ # configure for huggingface pre-trained language models
62
+ config = BertConfig.from_pretrained("bert-base-cased")
63
+ # tokenizer for huggingface pre-trained language models
64
+ tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
65
+ # pytorch_model.bin for huggingface pre-trained language models
66
+ model = BertModel.from_pretrained("bert-base-cased")
67
+ # obtain two batch of examples, each corresponding example is a pair
68
+ examples1 = ["This is the sentence anchor 1.", "It is the second sentence in this article named Section D."]
69
+ examples2 = ["It is the same as anchor 1.", "I think it is different with Section D."]
70
+ label = [1, 0]
71
+ # convert each example for feature
72
+ # {"input_ids": xxx, "attention_mask": xxx, "token_tuype_ids": xxx}
73
+ features1 = tokenizer(examples1, add_special_tokens=True, padding=True)
74
+ features2 = tokenizer(examples2, add_special_tokens=True, padding=True)
75
+ # padding and convert to feature batch
76
+ max_seq_lem = 16
77
+ features1 = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in features1.items()}
78
+ features2 = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in features2.items()}
79
+ label = torch.Tensor(label).long()
80
+ # obtain sentence embedding by averaged pooling
81
+ rep_anchor = model(**features1)[0] # [batch_size, max_seq_len, hidden_dim]
82
+ rep_candidate = model(**features2)[0] # [batch_size, max_seq_len, hidden_dim]
83
+ rep_anchor = torch.mean(rep_anchor, -1) # [batch_size, hidden_dim]
84
+ rep_candidate = torch.mean(rep_candidate, -1) # [batch_size, hidden_dim]
85
+ # obtain contrastive loss
86
+ loss_fn = ContrastiveLoss()
87
+ loss = loss_fn(rep_anchor=rep_anchor, rep_candidate=rep_candidate, label=label)
88
+ print(loss) # tensor(0.0869, grad_fn=<SumBackward0>)
loss/focal_loss.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/2/17 6:05 下午
3
+ # @Author : JianingWang
4
+ # @File : loss
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class FocalLoss(nn.Module):
11
+ """Multi-class Focal loss implementation"""
12
+
13
+ def __init__(self, gamma=2, weight=None, ignore_index=-100):
14
+ super(FocalLoss, self).__init__()
15
+ self.gamma = gamma
16
+ self.weight = weight
17
+ self.ignore_index = ignore_index
18
+
19
+ def forward(self, input, target):
20
+ """
21
+ input: [N, C]
22
+ target: [N, ]
23
+ """
24
+ logpt = F.log_softmax(input, dim=1)
25
+ pt = torch.exp(logpt)
26
+ logpt = (1 - pt) ** self.gamma * logpt
27
+ loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index)
28
+ return loss
loss/label_smoothing.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+ class LabelSmoothingCrossEntropy(nn.Module):
5
+ def __init__(self, eps=0.1, reduction="mean",ignore_index=-100):
6
+ super(LabelSmoothingCrossEntropy, self).__init__()
7
+ self.eps = eps
8
+ self.reduction = reduction
9
+ self.ignore_index = ignore_index
10
+
11
+ def forward(self, output, target):
12
+ c = output.size()[-1]
13
+ log_preds = F.log_softmax(output, dim=-1)
14
+ if self.reduction=="sum":
15
+ loss = -log_preds.sum()
16
+ else:
17
+ loss = -log_preds.sum(dim=-1)
18
+ if self.reduction=="mean":
19
+ loss = loss.mean()
20
+ return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction,
21
+ ignore_index=self.ignore_index)
loss/rl_loss.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
7
+ tensor = tensor * mask
8
+ tensor = tensor.sum(dim=dim)
9
+ mask_sum = mask.sum(dim=dim)
10
+ mean = tensor / (mask_sum + 1e-8)
11
+ return mean
12
+
13
+
14
+ class GPTLMLoss(nn.Module):
15
+ """
16
+ GPT Language Model Loss
17
+ """
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.loss = nn.CrossEntropyLoss()
22
+
23
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
24
+ shift_logits = logits[..., :-1, :].contiguous()
25
+ shift_labels = labels[..., 1:].contiguous()
26
+ # Flatten the tokens
27
+ return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
28
+
29
+
30
+ class PolicyLoss(nn.Module):
31
+ """
32
+ Policy Loss for PPO
33
+ """
34
+
35
+ def __init__(self, clip_eps: float = 0.2) -> None:
36
+ super().__init__()
37
+ self.clip_eps = clip_eps
38
+
39
+ def forward(self,
40
+ log_probs: torch.Tensor,
41
+ old_log_probs: torch.Tensor,
42
+ advantages: torch.Tensor,
43
+ action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
44
+ ratio = (log_probs - old_log_probs).exp()
45
+ surr1 = ratio * advantages
46
+ surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
47
+ loss = -torch.min(surr1, surr2)
48
+ if action_mask is not None:
49
+ loss = masked_mean(loss, action_mask)
50
+ loss = loss.mean()
51
+ return loss
52
+
53
+
54
+ class ValueLoss(nn.Module):
55
+ """
56
+ Value Loss for PPO
57
+ """
58
+
59
+ def __init__(self, clip_eps: float = 0.4) -> None:
60
+ super().__init__()
61
+ self.clip_eps = clip_eps
62
+
63
+ def forward(self,
64
+ values: torch.Tensor,
65
+ old_values: torch.Tensor,
66
+ reward: torch.Tensor,
67
+ action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
68
+ values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
69
+ surr1 = (values_clipped - reward)**2
70
+ surr2 = (values - reward)**2
71
+ loss = torch.max(surr1, surr2)
72
+ loss = loss.mean()
73
+ return 0.5 * loss
74
+
75
+
76
+ class PPOPtxActorLoss(nn.Module):
77
+ """
78
+ To Do:
79
+
80
+ PPO-ptx Actor Loss
81
+ """
82
+
83
+ def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
84
+ super().__init__()
85
+ self.pretrain_coef = pretrain_coef
86
+ self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
87
+ self.pretrain_loss_fn = pretrain_loss_fn
88
+
89
+ def forward(self,
90
+ log_probs: torch.Tensor,
91
+ old_log_probs: torch.Tensor,
92
+ advantages: torch.Tensor,
93
+ lm_logits: torch.Tensor,
94
+ lm_input_ids: torch.Tensor,
95
+ action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
96
+ policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
97
+ lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
98
+ return policy_loss + self.pretrain_coef * lm_loss
99
+
100
+
101
+ class LogSigLoss(nn.Module):
102
+ """
103
+ Pairwise Loss for Reward Model
104
+ Details: https://arxiv.org/abs/2203.02155
105
+ """
106
+
107
+ def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
108
+ probs = torch.sigmoid(chosen_reward - reject_reward)
109
+ log_probs = torch.log(probs)
110
+ loss = -log_probs.mean()
111
+ return loss
112
+
113
+
114
+ class LogExpLoss(nn.Module):
115
+ """
116
+ Pairwise Loss for Reward Model
117
+ Details: https://arxiv.org/abs/2204.05862
118
+ """
119
+
120
+ def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
121
+ loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
122
+ return loss
loss/similarity_loss.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/03/23 16:55
3
+ # @Author : Jianing Wang
4
+ # @Email : lygwjn@gmail.com
5
+ # @File : SimilarityLoss.py
6
+ # !/usr/bin/env python
7
+ # coding=utf-8
8
+
9
+ import torch
10
+ from torch import nn, Tensor
11
+ from transformers.models.bert.modeling_bert import BertModel
12
+ from transformers import BertTokenizer, BertConfig
13
+
14
+
15
+ class CosineSimilarityLoss(nn.Module):
16
+ """
17
+ CosineSimilarityLoss expects, that the InputExamples consists of two texts and a float label.
18
+
19
+ It computes the vectors u = model(input_text[0]) and v = model(input_text[1]) and measures the cosine-similarity between the two.
20
+ By default, it minimizes the following loss: ||input_label - cos_score_transformation(cosine_sim(u,v))||_2.
21
+
22
+ :param loss_fct: Which pytorch loss function should be used to compare the cosine_similartiy(u,v) with the input_label? By default, MSE: ||input_label - cosine_sim(u,v)||_2
23
+ :param cos_score_transformation: The cos_score_transformation function is applied on top of cosine_similarity. By default, the identify function is used (i.e. no change).
24
+
25
+
26
+ """
27
+ def __init__(self, loss_fct = nn.MSELoss(), cos_score_transformation=nn.Identity()):
28
+ super(CosineSimilarityLoss, self).__init__()
29
+ self.loss_fct = loss_fct
30
+ self.cos_score_transformation = cos_score_transformation
31
+
32
+
33
+ def forward(self, rep_a, rep_b, label: Tensor):
34
+ # rep_a: [batch_size, hidden_dim]
35
+ # rep_b: [batch_size, hidden_dim]
36
+ output = self.cos_score_transformation(torch.cosine_similarity(rep_a, rep_b))
37
+ # print(output) # tensor([0.9925, 0.5846], grad_fn=<DivBackward0>), tensor(0.1709, grad_fn=<MseLossBackward0>)
38
+ return self.loss_fct(output, label.view(-1))
39
+
40
+
41
+
42
+ if __name__ == "__main__":
43
+ # configure for huggingface pre-trained language models
44
+ config = BertConfig.from_pretrained("bert-base-cased")
45
+ # tokenizer for huggingface pre-trained language models
46
+ tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
47
+ # pytorch_model.bin for huggingface pre-trained language models
48
+ model = BertModel.from_pretrained("bert-base-cased")
49
+ # obtain two batch of examples, each corresponding example is a pair
50
+ examples1 = ["Beijing is one of the biggest city in China.", "Disney film is well seeing for us."]
51
+ examples2 = ["Shanghai is the largest city in east of China.", "ACL 2021 will be held in line due to COVID-19."]
52
+ label = [1, 0]
53
+ # convert each example for feature
54
+ # {"input_ids": xxx, "attention_mask": xxx, "token_tuype_ids": xxx}
55
+ features1 = tokenizer(examples1, add_special_tokens=True, padding=True)
56
+ features2 = tokenizer(examples2, add_special_tokens=True, padding=True)
57
+ # padding and convert to feature batch
58
+ max_seq_lem = 24
59
+ features1 = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in features1.items()}
60
+ features2 = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in features2.items()}
61
+ label = torch.Tensor(label).long()
62
+ # obtain sentence embedding by averaged pooling
63
+ rep_a = model(**features1)[0] # [batch_size, max_seq_len, hidden_dim]
64
+ rep_b = model(**features2)[0] # [batch_size, max_seq_len, hidden_dim]
65
+ rep_a = torch.mean(rep_a, -1) # [batch_size, hidden_dim]
66
+ rep_b = torch.mean(rep_b, -1) # [batch_size, hidden_dim]
67
+ # obtain contrastive loss
68
+ loss_fn = CosineSimilarityLoss()
69
+ loss = loss_fn(rep_a=rep_a, rep_b=rep_b, label=label)
70
+ print(loss) # tensor(0.1709, grad_fn=<SumBackward0>)
loss/triplet_loss.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/03/23 15:25
3
+ # @Author : Jianing Wang
4
+ # @Email : lygwjn@gmail.com
5
+ # @File : TripletLoss.py
6
+ # !/usr/bin/env python
7
+ # coding=utf-8
8
+
9
+ from enum import Enum
10
+ import torch
11
+ from torch import nn, Tensor
12
+ import torch.nn.functional as F
13
+ from transformers.models.bert.modeling_bert import BertModel
14
+ from transformers import BertTokenizer, BertConfig
15
+
16
+ class TripletDistanceMetric(Enum):
17
+ """
18
+ The metric for the triplet loss
19
+ """
20
+ COSINE = lambda x, y: 1 - F.cosine_similarity(x, y)
21
+ EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
22
+ MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
23
+
24
+ class TripletLoss(nn.Module):
25
+ """
26
+ This class implements triplet loss. Given a triplet of (anchor, positive, negative),
27
+ the loss minimizes the distance between anchor and positive while it maximizes the distance
28
+ between anchor and negative. It compute the following loss function:
29
+
30
+ loss = max(||anchor - positive|| - ||anchor - negative|| + margin, 0).
31
+
32
+ Margin is an important hyperparameter and needs to be tuned respectively.
33
+
34
+ @:param distance_metric: The distance metric function
35
+ @:param triplet_margin: (float) The margin distance
36
+
37
+ Input example of forward function:
38
+ rep_anchor: [[0.2, -0.1, ..., 0.6], [0.2, -0.1, ..., 0.6], ..., [0.2, -0.1, ..., 0.6]]
39
+ rep_candidate: [[0.3, 0.1, ...m -0.3], [-0.8, 1.2, ..., 0.7], ..., [-0.9, 0.1, ..., 0.4]]
40
+ label: [0, 1, ..., 1]
41
+
42
+ Return example of forward function:
43
+ 0.015 (averged)
44
+ 2.672 (sum)
45
+
46
+ """
47
+ def __init__(self, distance_metric=TripletDistanceMetric.EUCLIDEAN, triplet_margin: float = 0.5):
48
+ super(TripletLoss, self).__init__()
49
+ self.distance_metric = distance_metric
50
+ self.triplet_margin = triplet_margin
51
+
52
+
53
+ def forward(self, rep_anchor, rep_positive, rep_negative):
54
+ # rep_anchor: [batch_size, hidden_dim] denotes the representations of anchors
55
+ # rep_positive: [batch_size, hidden_dim] denotes the representations of positive, sometimes, it canbe dropout
56
+ # rep_negative: [batch_size, hidden_dim] denotes the representations of negative
57
+ # label: [batch_size, hidden_dim] denotes the label of each anchor - candidate pair
58
+ distance_pos = self.distance_metric(rep_anchor, rep_positive)
59
+ distance_neg = self.distance_metric(rep_anchor, rep_negative)
60
+
61
+ losses = F.relu(distance_pos - distance_neg + self.triplet_margin)
62
+ return losses.mean()
63
+
64
+
65
+ if __name__ == "__main__":
66
+ # configure for huggingface pre-trained language models
67
+ config = BertConfig.from_pretrained("bert-base-cased")
68
+ # tokenizer for huggingface pre-trained language models
69
+ tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
70
+ # pytorch_model.bin for huggingface pre-trained language models
71
+ model = BertModel.from_pretrained("bert-base-cased")
72
+ # obtain two batch of examples, each corresponding example is a pair
73
+ anchor_example = ["I am an anchor, which is the source example sampled from corpora."] # anchor sentence
74
+ positive_example = [
75
+ "I am an anchor, which is the source example.",
76
+ "I am the source example sampled from corpora."
77
+ ] # positive, which randomly dropout or noise from anchor
78
+ negative_example = [
79
+ "It is different with the anchor.",
80
+ "My name is Jianing Wang, please give me some stars, thank you!"
81
+ ] # negative, which randomly sampled from corpora
82
+ # convert each example for feature
83
+ # {"input_ids": xxx, "attention_mask": xxx, "token_tuype_ids": xxx}
84
+ anchor_feature = tokenizer(anchor_example, add_special_tokens=True, padding=True)
85
+ positive_feature = tokenizer(positive_example, add_special_tokens=True, padding=True)
86
+ negative_feature = tokenizer(negative_example, add_special_tokens=True, padding=True)
87
+ # padding and convert to feature batch
88
+ max_seq_lem = 24
89
+ anchor_feature = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in anchor_feature.items()}
90
+ positive_feature = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in positive_feature.items()}
91
+ negative_feature = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in negative_feature.items()}
92
+ # obtain sentence embedding by averaged pooling
93
+ rep_anchor = model(**anchor_feature)[0] # [1, max_seq_len, hidden_dim]
94
+ rep_positive = model(**positive_feature)[0] # [batch_size, max_seq_len, hidden_dim]
95
+ rep_negative = model(**negative_feature)[0] # [batch_size, max_seq_len, hidden_dim]
96
+ # repeat
97
+ rep_anchor = torch.mean(rep_anchor, -1) # [1, hidden_dim]
98
+ rep_positive = torch.mean(rep_positive, -1) # [batch_size, hidden_dim]
99
+ rep_negative = torch.mean(rep_negative, -1) # [batch_size, hidden_dim]
100
+ # obtain contrastive loss
101
+ loss_fn = TripletLoss()
102
+ loss = loss_fn(rep_anchor=rep_anchor, rep_positive=rep_positive, rep_negative=rep_negative)
103
+ print(loss) # tensor(0.5001, grad_fn=<MeanBackward0>)