File size: 4,473 Bytes
afdeeca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import pytorch_lightning as pl
import torch
from transformers.optimization import AdamW
import torchmetrics


class DualEncoderModule(pl.LightningModule):

    def __init__(self, tokenizer, model, learning_rate=1e-3):
        super().__init__()
        self.tokenizer = tokenizer
        self.model = model
        self.learning_rate = learning_rate

        self.train_acc = torchmetrics.Accuracy(
            task="multiclass", num_classes=model.num_labels
        )
        self.val_acc = torchmetrics.Accuracy(
            task="multiclass", num_classes=model.num_labels
        )
        self.test_acc = torchmetrics.Accuracy(
            task="multiclass", num_classes=model.num_labels
        )

    def forward(self, input_ids, **kwargs):
        return self.model(input_ids, **kwargs)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        pos_ids, pos_mask, neg_ids, neg_mask = batch

        neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
        neg_mask = neg_mask.view(-1, neg_mask.shape[-1])

        pos_outputs = self(
            pos_ids,
            attention_mask=pos_mask,
            labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
                pos_ids.get_device()
            ),
        )
        neg_outputs = self(
            neg_ids,
            attention_mask=neg_mask,
            labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
                neg_ids.get_device()
            ),
        )

        loss_scale = 1.0
        loss = pos_outputs.loss + loss_scale * neg_outputs.loss

        pos_logits = pos_outputs.logits
        pos_preds = torch.argmax(pos_logits, axis=1)
        self.train_acc(
            pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
        )

        neg_logits = neg_outputs.logits
        neg_preds = torch.argmax(neg_logits, axis=1)
        self.train_acc(
            neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
        )

        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        pos_ids, pos_mask, neg_ids, neg_mask = batch

        neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
        neg_mask = neg_mask.view(-1, neg_mask.shape[-1])

        pos_outputs = self(
            pos_ids,
            attention_mask=pos_mask,
            labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
                pos_ids.get_device()
            ),
        )
        neg_outputs = self(
            neg_ids,
            attention_mask=neg_mask,
            labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
                neg_ids.get_device()
            ),
        )

        loss_scale = 1.0
        loss = pos_outputs.loss + loss_scale * neg_outputs.loss

        pos_logits = pos_outputs.logits
        pos_preds = torch.argmax(pos_logits, axis=1)
        self.val_acc(
            pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
        )

        neg_logits = neg_outputs.logits
        neg_preds = torch.argmax(neg_logits, axis=1)
        self.val_acc(
            neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
        )

        self.log("val_acc", self.val_acc)

        return {"loss": loss}

    def test_step(self, batch, batch_idx):
        pos_ids, pos_mask, neg_ids, neg_mask = batch

        neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
        neg_mask = neg_mask.view(-1, neg_mask.shape[-1])

        pos_outputs = self(
            pos_ids,
            attention_mask=pos_mask,
            labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
                pos_ids.get_device()
            ),
        )
        neg_outputs = self(
            neg_ids,
            attention_mask=neg_mask,
            labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
                neg_ids.get_device()
            ),
        )

        pos_logits = pos_outputs.logits
        pos_preds = torch.argmax(pos_logits, axis=1)
        self.test_acc(
            pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
        )

        neg_logits = neg_outputs.logits
        neg_preds = torch.argmax(neg_logits, axis=1)
        self.test_acc(
            neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
        )

        self.log("test_acc", self.test_acc)