Spaces:
Sleeping
Sleeping
DeepLearning101
commited on
Commit
•
fdc4786
1
Parent(s):
4cda815
Upload 6 files
Browse files- loss/contrastive_loss.py +88 -0
- loss/focal_loss.py +28 -0
- loss/label_smoothing.py +21 -0
- loss/rl_loss.py +122 -0
- loss/similarity_loss.py +70 -0
- loss/triplet_loss.py +103 -0
loss/contrastive_loss.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2022/03/23 14:50
|
3 |
+
# @Author : Jianing Wang
|
4 |
+
# @Email : lygwjn@gmail.com
|
5 |
+
# @File : ContrastiveLoss.py
|
6 |
+
# !/usr/bin/env python
|
7 |
+
# coding=utf-8
|
8 |
+
|
9 |
+
from enum import Enum
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch import nn, Tensor
|
13 |
+
from transformers.models.bert.modeling_bert import BertModel
|
14 |
+
from transformers import BertTokenizer, BertConfig
|
15 |
+
|
16 |
+
class SiameseDistanceMetric(Enum):
|
17 |
+
"""
|
18 |
+
The metric for the contrastive loss
|
19 |
+
"""
|
20 |
+
EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
|
21 |
+
MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
|
22 |
+
COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y)
|
23 |
+
|
24 |
+
|
25 |
+
class ContrastiveLoss(nn.Module):
|
26 |
+
"""
|
27 |
+
Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
|
28 |
+
two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.
|
29 |
+
|
30 |
+
@:param distance_metric: The distance metric function
|
31 |
+
@:param margin: (float) The margin distance
|
32 |
+
@:param size_average: (bool) Whether to get averaged loss
|
33 |
+
|
34 |
+
Input example of forward function:
|
35 |
+
rep_anchor: [[0.2, -0.1, ..., 0.6], [0.2, -0.1, ..., 0.6], ..., [0.2, -0.1, ..., 0.6]]
|
36 |
+
rep_candidate: [[0.3, 0.1, ...m -0.3], [-0.8, 1.2, ..., 0.7], ..., [-0.9, 0.1, ..., 0.4]]
|
37 |
+
label: [0, 1, ..., 1]
|
38 |
+
|
39 |
+
Return example of forward function:
|
40 |
+
0.015 (averged)
|
41 |
+
2.672 (sum)
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = False):
|
45 |
+
super(ContrastiveLoss, self).__init__()
|
46 |
+
self.distance_metric = distance_metric
|
47 |
+
self.margin = margin
|
48 |
+
self.size_average = size_average
|
49 |
+
|
50 |
+
def forward(self, rep_anchor, rep_candidate, label: Tensor):
|
51 |
+
# rep_anchor: [batch_size, hidden_dim] denotes the representations of anchors
|
52 |
+
# rep_candidate: [batch_size, hidden_dim] denotes the representations of positive / negative
|
53 |
+
# label: [batch_size, hidden_dim] denotes the label of each anchor - candidate pair
|
54 |
+
|
55 |
+
distances = self.distance_metric(rep_anchor, rep_candidate)
|
56 |
+
losses = 0.5 * (label.float() * distances.pow(2) + (1 - label).float() * F.relu(self.margin - distances).pow(2))
|
57 |
+
return losses.mean() if self.size_average else losses.sum()
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
# configure for huggingface pre-trained language models
|
62 |
+
config = BertConfig.from_pretrained("bert-base-cased")
|
63 |
+
# tokenizer for huggingface pre-trained language models
|
64 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
65 |
+
# pytorch_model.bin for huggingface pre-trained language models
|
66 |
+
model = BertModel.from_pretrained("bert-base-cased")
|
67 |
+
# obtain two batch of examples, each corresponding example is a pair
|
68 |
+
examples1 = ["This is the sentence anchor 1.", "It is the second sentence in this article named Section D."]
|
69 |
+
examples2 = ["It is the same as anchor 1.", "I think it is different with Section D."]
|
70 |
+
label = [1, 0]
|
71 |
+
# convert each example for feature
|
72 |
+
# {"input_ids": xxx, "attention_mask": xxx, "token_tuype_ids": xxx}
|
73 |
+
features1 = tokenizer(examples1, add_special_tokens=True, padding=True)
|
74 |
+
features2 = tokenizer(examples2, add_special_tokens=True, padding=True)
|
75 |
+
# padding and convert to feature batch
|
76 |
+
max_seq_lem = 16
|
77 |
+
features1 = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in features1.items()}
|
78 |
+
features2 = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in features2.items()}
|
79 |
+
label = torch.Tensor(label).long()
|
80 |
+
# obtain sentence embedding by averaged pooling
|
81 |
+
rep_anchor = model(**features1)[0] # [batch_size, max_seq_len, hidden_dim]
|
82 |
+
rep_candidate = model(**features2)[0] # [batch_size, max_seq_len, hidden_dim]
|
83 |
+
rep_anchor = torch.mean(rep_anchor, -1) # [batch_size, hidden_dim]
|
84 |
+
rep_candidate = torch.mean(rep_candidate, -1) # [batch_size, hidden_dim]
|
85 |
+
# obtain contrastive loss
|
86 |
+
loss_fn = ContrastiveLoss()
|
87 |
+
loss = loss_fn(rep_anchor=rep_anchor, rep_candidate=rep_candidate, label=label)
|
88 |
+
print(loss) # tensor(0.0869, grad_fn=<SumBackward0>)
|
loss/focal_loss.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2022/2/17 6:05 下午
|
3 |
+
# @Author : JianingWang
|
4 |
+
# @File : loss
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class FocalLoss(nn.Module):
|
11 |
+
"""Multi-class Focal loss implementation"""
|
12 |
+
|
13 |
+
def __init__(self, gamma=2, weight=None, ignore_index=-100):
|
14 |
+
super(FocalLoss, self).__init__()
|
15 |
+
self.gamma = gamma
|
16 |
+
self.weight = weight
|
17 |
+
self.ignore_index = ignore_index
|
18 |
+
|
19 |
+
def forward(self, input, target):
|
20 |
+
"""
|
21 |
+
input: [N, C]
|
22 |
+
target: [N, ]
|
23 |
+
"""
|
24 |
+
logpt = F.log_softmax(input, dim=1)
|
25 |
+
pt = torch.exp(logpt)
|
26 |
+
logpt = (1 - pt) ** self.gamma * logpt
|
27 |
+
loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index)
|
28 |
+
return loss
|
loss/label_smoothing.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
class LabelSmoothingCrossEntropy(nn.Module):
|
5 |
+
def __init__(self, eps=0.1, reduction="mean",ignore_index=-100):
|
6 |
+
super(LabelSmoothingCrossEntropy, self).__init__()
|
7 |
+
self.eps = eps
|
8 |
+
self.reduction = reduction
|
9 |
+
self.ignore_index = ignore_index
|
10 |
+
|
11 |
+
def forward(self, output, target):
|
12 |
+
c = output.size()[-1]
|
13 |
+
log_preds = F.log_softmax(output, dim=-1)
|
14 |
+
if self.reduction=="sum":
|
15 |
+
loss = -log_preds.sum()
|
16 |
+
else:
|
17 |
+
loss = -log_preds.sum(dim=-1)
|
18 |
+
if self.reduction=="mean":
|
19 |
+
loss = loss.mean()
|
20 |
+
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction,
|
21 |
+
ignore_index=self.ignore_index)
|
loss/rl_loss.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
7 |
+
tensor = tensor * mask
|
8 |
+
tensor = tensor.sum(dim=dim)
|
9 |
+
mask_sum = mask.sum(dim=dim)
|
10 |
+
mean = tensor / (mask_sum + 1e-8)
|
11 |
+
return mean
|
12 |
+
|
13 |
+
|
14 |
+
class GPTLMLoss(nn.Module):
|
15 |
+
"""
|
16 |
+
GPT Language Model Loss
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
super().__init__()
|
21 |
+
self.loss = nn.CrossEntropyLoss()
|
22 |
+
|
23 |
+
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
24 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
25 |
+
shift_labels = labels[..., 1:].contiguous()
|
26 |
+
# Flatten the tokens
|
27 |
+
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
28 |
+
|
29 |
+
|
30 |
+
class PolicyLoss(nn.Module):
|
31 |
+
"""
|
32 |
+
Policy Loss for PPO
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, clip_eps: float = 0.2) -> None:
|
36 |
+
super().__init__()
|
37 |
+
self.clip_eps = clip_eps
|
38 |
+
|
39 |
+
def forward(self,
|
40 |
+
log_probs: torch.Tensor,
|
41 |
+
old_log_probs: torch.Tensor,
|
42 |
+
advantages: torch.Tensor,
|
43 |
+
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
44 |
+
ratio = (log_probs - old_log_probs).exp()
|
45 |
+
surr1 = ratio * advantages
|
46 |
+
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
47 |
+
loss = -torch.min(surr1, surr2)
|
48 |
+
if action_mask is not None:
|
49 |
+
loss = masked_mean(loss, action_mask)
|
50 |
+
loss = loss.mean()
|
51 |
+
return loss
|
52 |
+
|
53 |
+
|
54 |
+
class ValueLoss(nn.Module):
|
55 |
+
"""
|
56 |
+
Value Loss for PPO
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, clip_eps: float = 0.4) -> None:
|
60 |
+
super().__init__()
|
61 |
+
self.clip_eps = clip_eps
|
62 |
+
|
63 |
+
def forward(self,
|
64 |
+
values: torch.Tensor,
|
65 |
+
old_values: torch.Tensor,
|
66 |
+
reward: torch.Tensor,
|
67 |
+
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
68 |
+
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
69 |
+
surr1 = (values_clipped - reward)**2
|
70 |
+
surr2 = (values - reward)**2
|
71 |
+
loss = torch.max(surr1, surr2)
|
72 |
+
loss = loss.mean()
|
73 |
+
return 0.5 * loss
|
74 |
+
|
75 |
+
|
76 |
+
class PPOPtxActorLoss(nn.Module):
|
77 |
+
"""
|
78 |
+
To Do:
|
79 |
+
|
80 |
+
PPO-ptx Actor Loss
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
|
84 |
+
super().__init__()
|
85 |
+
self.pretrain_coef = pretrain_coef
|
86 |
+
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
|
87 |
+
self.pretrain_loss_fn = pretrain_loss_fn
|
88 |
+
|
89 |
+
def forward(self,
|
90 |
+
log_probs: torch.Tensor,
|
91 |
+
old_log_probs: torch.Tensor,
|
92 |
+
advantages: torch.Tensor,
|
93 |
+
lm_logits: torch.Tensor,
|
94 |
+
lm_input_ids: torch.Tensor,
|
95 |
+
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
96 |
+
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
|
97 |
+
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
|
98 |
+
return policy_loss + self.pretrain_coef * lm_loss
|
99 |
+
|
100 |
+
|
101 |
+
class LogSigLoss(nn.Module):
|
102 |
+
"""
|
103 |
+
Pairwise Loss for Reward Model
|
104 |
+
Details: https://arxiv.org/abs/2203.02155
|
105 |
+
"""
|
106 |
+
|
107 |
+
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
108 |
+
probs = torch.sigmoid(chosen_reward - reject_reward)
|
109 |
+
log_probs = torch.log(probs)
|
110 |
+
loss = -log_probs.mean()
|
111 |
+
return loss
|
112 |
+
|
113 |
+
|
114 |
+
class LogExpLoss(nn.Module):
|
115 |
+
"""
|
116 |
+
Pairwise Loss for Reward Model
|
117 |
+
Details: https://arxiv.org/abs/2204.05862
|
118 |
+
"""
|
119 |
+
|
120 |
+
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
121 |
+
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
|
122 |
+
return loss
|
loss/similarity_loss.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2022/03/23 16:55
|
3 |
+
# @Author : Jianing Wang
|
4 |
+
# @Email : lygwjn@gmail.com
|
5 |
+
# @File : SimilarityLoss.py
|
6 |
+
# !/usr/bin/env python
|
7 |
+
# coding=utf-8
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import nn, Tensor
|
11 |
+
from transformers.models.bert.modeling_bert import BertModel
|
12 |
+
from transformers import BertTokenizer, BertConfig
|
13 |
+
|
14 |
+
|
15 |
+
class CosineSimilarityLoss(nn.Module):
|
16 |
+
"""
|
17 |
+
CosineSimilarityLoss expects, that the InputExamples consists of two texts and a float label.
|
18 |
+
|
19 |
+
It computes the vectors u = model(input_text[0]) and v = model(input_text[1]) and measures the cosine-similarity between the two.
|
20 |
+
By default, it minimizes the following loss: ||input_label - cos_score_transformation(cosine_sim(u,v))||_2.
|
21 |
+
|
22 |
+
:param loss_fct: Which pytorch loss function should be used to compare the cosine_similartiy(u,v) with the input_label? By default, MSE: ||input_label - cosine_sim(u,v)||_2
|
23 |
+
:param cos_score_transformation: The cos_score_transformation function is applied on top of cosine_similarity. By default, the identify function is used (i.e. no change).
|
24 |
+
|
25 |
+
|
26 |
+
"""
|
27 |
+
def __init__(self, loss_fct = nn.MSELoss(), cos_score_transformation=nn.Identity()):
|
28 |
+
super(CosineSimilarityLoss, self).__init__()
|
29 |
+
self.loss_fct = loss_fct
|
30 |
+
self.cos_score_transformation = cos_score_transformation
|
31 |
+
|
32 |
+
|
33 |
+
def forward(self, rep_a, rep_b, label: Tensor):
|
34 |
+
# rep_a: [batch_size, hidden_dim]
|
35 |
+
# rep_b: [batch_size, hidden_dim]
|
36 |
+
output = self.cos_score_transformation(torch.cosine_similarity(rep_a, rep_b))
|
37 |
+
# print(output) # tensor([0.9925, 0.5846], grad_fn=<DivBackward0>), tensor(0.1709, grad_fn=<MseLossBackward0>)
|
38 |
+
return self.loss_fct(output, label.view(-1))
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
if __name__ == "__main__":
|
43 |
+
# configure for huggingface pre-trained language models
|
44 |
+
config = BertConfig.from_pretrained("bert-base-cased")
|
45 |
+
# tokenizer for huggingface pre-trained language models
|
46 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
47 |
+
# pytorch_model.bin for huggingface pre-trained language models
|
48 |
+
model = BertModel.from_pretrained("bert-base-cased")
|
49 |
+
# obtain two batch of examples, each corresponding example is a pair
|
50 |
+
examples1 = ["Beijing is one of the biggest city in China.", "Disney film is well seeing for us."]
|
51 |
+
examples2 = ["Shanghai is the largest city in east of China.", "ACL 2021 will be held in line due to COVID-19."]
|
52 |
+
label = [1, 0]
|
53 |
+
# convert each example for feature
|
54 |
+
# {"input_ids": xxx, "attention_mask": xxx, "token_tuype_ids": xxx}
|
55 |
+
features1 = tokenizer(examples1, add_special_tokens=True, padding=True)
|
56 |
+
features2 = tokenizer(examples2, add_special_tokens=True, padding=True)
|
57 |
+
# padding and convert to feature batch
|
58 |
+
max_seq_lem = 24
|
59 |
+
features1 = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in features1.items()}
|
60 |
+
features2 = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in features2.items()}
|
61 |
+
label = torch.Tensor(label).long()
|
62 |
+
# obtain sentence embedding by averaged pooling
|
63 |
+
rep_a = model(**features1)[0] # [batch_size, max_seq_len, hidden_dim]
|
64 |
+
rep_b = model(**features2)[0] # [batch_size, max_seq_len, hidden_dim]
|
65 |
+
rep_a = torch.mean(rep_a, -1) # [batch_size, hidden_dim]
|
66 |
+
rep_b = torch.mean(rep_b, -1) # [batch_size, hidden_dim]
|
67 |
+
# obtain contrastive loss
|
68 |
+
loss_fn = CosineSimilarityLoss()
|
69 |
+
loss = loss_fn(rep_a=rep_a, rep_b=rep_b, label=label)
|
70 |
+
print(loss) # tensor(0.1709, grad_fn=<SumBackward0>)
|
loss/triplet_loss.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2022/03/23 15:25
|
3 |
+
# @Author : Jianing Wang
|
4 |
+
# @Email : lygwjn@gmail.com
|
5 |
+
# @File : TripletLoss.py
|
6 |
+
# !/usr/bin/env python
|
7 |
+
# coding=utf-8
|
8 |
+
|
9 |
+
from enum import Enum
|
10 |
+
import torch
|
11 |
+
from torch import nn, Tensor
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from transformers.models.bert.modeling_bert import BertModel
|
14 |
+
from transformers import BertTokenizer, BertConfig
|
15 |
+
|
16 |
+
class TripletDistanceMetric(Enum):
|
17 |
+
"""
|
18 |
+
The metric for the triplet loss
|
19 |
+
"""
|
20 |
+
COSINE = lambda x, y: 1 - F.cosine_similarity(x, y)
|
21 |
+
EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
|
22 |
+
MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
|
23 |
+
|
24 |
+
class TripletLoss(nn.Module):
|
25 |
+
"""
|
26 |
+
This class implements triplet loss. Given a triplet of (anchor, positive, negative),
|
27 |
+
the loss minimizes the distance between anchor and positive while it maximizes the distance
|
28 |
+
between anchor and negative. It compute the following loss function:
|
29 |
+
|
30 |
+
loss = max(||anchor - positive|| - ||anchor - negative|| + margin, 0).
|
31 |
+
|
32 |
+
Margin is an important hyperparameter and needs to be tuned respectively.
|
33 |
+
|
34 |
+
@:param distance_metric: The distance metric function
|
35 |
+
@:param triplet_margin: (float) The margin distance
|
36 |
+
|
37 |
+
Input example of forward function:
|
38 |
+
rep_anchor: [[0.2, -0.1, ..., 0.6], [0.2, -0.1, ..., 0.6], ..., [0.2, -0.1, ..., 0.6]]
|
39 |
+
rep_candidate: [[0.3, 0.1, ...m -0.3], [-0.8, 1.2, ..., 0.7], ..., [-0.9, 0.1, ..., 0.4]]
|
40 |
+
label: [0, 1, ..., 1]
|
41 |
+
|
42 |
+
Return example of forward function:
|
43 |
+
0.015 (averged)
|
44 |
+
2.672 (sum)
|
45 |
+
|
46 |
+
"""
|
47 |
+
def __init__(self, distance_metric=TripletDistanceMetric.EUCLIDEAN, triplet_margin: float = 0.5):
|
48 |
+
super(TripletLoss, self).__init__()
|
49 |
+
self.distance_metric = distance_metric
|
50 |
+
self.triplet_margin = triplet_margin
|
51 |
+
|
52 |
+
|
53 |
+
def forward(self, rep_anchor, rep_positive, rep_negative):
|
54 |
+
# rep_anchor: [batch_size, hidden_dim] denotes the representations of anchors
|
55 |
+
# rep_positive: [batch_size, hidden_dim] denotes the representations of positive, sometimes, it canbe dropout
|
56 |
+
# rep_negative: [batch_size, hidden_dim] denotes the representations of negative
|
57 |
+
# label: [batch_size, hidden_dim] denotes the label of each anchor - candidate pair
|
58 |
+
distance_pos = self.distance_metric(rep_anchor, rep_positive)
|
59 |
+
distance_neg = self.distance_metric(rep_anchor, rep_negative)
|
60 |
+
|
61 |
+
losses = F.relu(distance_pos - distance_neg + self.triplet_margin)
|
62 |
+
return losses.mean()
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
# configure for huggingface pre-trained language models
|
67 |
+
config = BertConfig.from_pretrained("bert-base-cased")
|
68 |
+
# tokenizer for huggingface pre-trained language models
|
69 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
|
70 |
+
# pytorch_model.bin for huggingface pre-trained language models
|
71 |
+
model = BertModel.from_pretrained("bert-base-cased")
|
72 |
+
# obtain two batch of examples, each corresponding example is a pair
|
73 |
+
anchor_example = ["I am an anchor, which is the source example sampled from corpora."] # anchor sentence
|
74 |
+
positive_example = [
|
75 |
+
"I am an anchor, which is the source example.",
|
76 |
+
"I am the source example sampled from corpora."
|
77 |
+
] # positive, which randomly dropout or noise from anchor
|
78 |
+
negative_example = [
|
79 |
+
"It is different with the anchor.",
|
80 |
+
"My name is Jianing Wang, please give me some stars, thank you!"
|
81 |
+
] # negative, which randomly sampled from corpora
|
82 |
+
# convert each example for feature
|
83 |
+
# {"input_ids": xxx, "attention_mask": xxx, "token_tuype_ids": xxx}
|
84 |
+
anchor_feature = tokenizer(anchor_example, add_special_tokens=True, padding=True)
|
85 |
+
positive_feature = tokenizer(positive_example, add_special_tokens=True, padding=True)
|
86 |
+
negative_feature = tokenizer(negative_example, add_special_tokens=True, padding=True)
|
87 |
+
# padding and convert to feature batch
|
88 |
+
max_seq_lem = 24
|
89 |
+
anchor_feature = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in anchor_feature.items()}
|
90 |
+
positive_feature = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in positive_feature.items()}
|
91 |
+
negative_feature = {key: torch.Tensor([value + [0] * (max_seq_lem - len(value)) for value in values]).long() for key, values in negative_feature.items()}
|
92 |
+
# obtain sentence embedding by averaged pooling
|
93 |
+
rep_anchor = model(**anchor_feature)[0] # [1, max_seq_len, hidden_dim]
|
94 |
+
rep_positive = model(**positive_feature)[0] # [batch_size, max_seq_len, hidden_dim]
|
95 |
+
rep_negative = model(**negative_feature)[0] # [batch_size, max_seq_len, hidden_dim]
|
96 |
+
# repeat
|
97 |
+
rep_anchor = torch.mean(rep_anchor, -1) # [1, hidden_dim]
|
98 |
+
rep_positive = torch.mean(rep_positive, -1) # [batch_size, hidden_dim]
|
99 |
+
rep_negative = torch.mean(rep_negative, -1) # [batch_size, hidden_dim]
|
100 |
+
# obtain contrastive loss
|
101 |
+
loss_fn = TripletLoss()
|
102 |
+
loss = loss_fn(rep_anchor=rep_anchor, rep_positive=rep_positive, rep_negative=rep_negative)
|
103 |
+
print(loss) # tensor(0.5001, grad_fn=<MeanBackward0>)
|