|
import pytorch_lightning as pl |
|
import torch |
|
from transformers.optimization import AdamW |
|
import torchmetrics |
|
|
|
|
|
class DualEncoderModule(pl.LightningModule): |
|
|
|
def __init__(self, tokenizer, model, learning_rate=1e-3): |
|
super().__init__() |
|
self.tokenizer = tokenizer |
|
self.model = model |
|
self.learning_rate = learning_rate |
|
|
|
self.train_acc = torchmetrics.Accuracy( |
|
task="multiclass", num_classes=model.num_labels |
|
) |
|
self.val_acc = torchmetrics.Accuracy( |
|
task="multiclass", num_classes=model.num_labels |
|
) |
|
self.test_acc = torchmetrics.Accuracy( |
|
task="multiclass", num_classes=model.num_labels |
|
) |
|
|
|
def forward(self, input_ids, **kwargs): |
|
return self.model(input_ids, **kwargs) |
|
|
|
def configure_optimizers(self): |
|
optimizer = AdamW(self.parameters(), lr=self.learning_rate) |
|
return optimizer |
|
|
|
def training_step(self, batch, batch_idx): |
|
pos_ids, pos_mask, neg_ids, neg_mask = batch |
|
|
|
neg_ids = neg_ids.view(-1, neg_ids.shape[-1]) |
|
neg_mask = neg_mask.view(-1, neg_mask.shape[-1]) |
|
|
|
pos_outputs = self( |
|
pos_ids, |
|
attention_mask=pos_mask, |
|
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to( |
|
pos_ids.get_device() |
|
), |
|
) |
|
neg_outputs = self( |
|
neg_ids, |
|
attention_mask=neg_mask, |
|
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to( |
|
neg_ids.get_device() |
|
), |
|
) |
|
|
|
loss_scale = 1.0 |
|
loss = pos_outputs.loss + loss_scale * neg_outputs.loss |
|
|
|
pos_logits = pos_outputs.logits |
|
pos_preds = torch.argmax(pos_logits, axis=1) |
|
self.train_acc( |
|
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu() |
|
) |
|
|
|
neg_logits = neg_outputs.logits |
|
neg_preds = torch.argmax(neg_logits, axis=1) |
|
self.train_acc( |
|
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu() |
|
) |
|
|
|
return {"loss": loss} |
|
|
|
def validation_step(self, batch, batch_idx): |
|
pos_ids, pos_mask, neg_ids, neg_mask = batch |
|
|
|
neg_ids = neg_ids.view(-1, neg_ids.shape[-1]) |
|
neg_mask = neg_mask.view(-1, neg_mask.shape[-1]) |
|
|
|
pos_outputs = self( |
|
pos_ids, |
|
attention_mask=pos_mask, |
|
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to( |
|
pos_ids.get_device() |
|
), |
|
) |
|
neg_outputs = self( |
|
neg_ids, |
|
attention_mask=neg_mask, |
|
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to( |
|
neg_ids.get_device() |
|
), |
|
) |
|
|
|
loss_scale = 1.0 |
|
loss = pos_outputs.loss + loss_scale * neg_outputs.loss |
|
|
|
pos_logits = pos_outputs.logits |
|
pos_preds = torch.argmax(pos_logits, axis=1) |
|
self.val_acc( |
|
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu() |
|
) |
|
|
|
neg_logits = neg_outputs.logits |
|
neg_preds = torch.argmax(neg_logits, axis=1) |
|
self.val_acc( |
|
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu() |
|
) |
|
|
|
self.log("val_acc", self.val_acc) |
|
|
|
return {"loss": loss} |
|
|
|
def test_step(self, batch, batch_idx): |
|
pos_ids, pos_mask, neg_ids, neg_mask = batch |
|
|
|
neg_ids = neg_ids.view(-1, neg_ids.shape[-1]) |
|
neg_mask = neg_mask.view(-1, neg_mask.shape[-1]) |
|
|
|
pos_outputs = self( |
|
pos_ids, |
|
attention_mask=pos_mask, |
|
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to( |
|
pos_ids.get_device() |
|
), |
|
) |
|
neg_outputs = self( |
|
neg_ids, |
|
attention_mask=neg_mask, |
|
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to( |
|
neg_ids.get_device() |
|
), |
|
) |
|
|
|
pos_logits = pos_outputs.logits |
|
pos_preds = torch.argmax(pos_logits, axis=1) |
|
self.test_acc( |
|
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu() |
|
) |
|
|
|
neg_logits = neg_outputs.logits |
|
neg_preds = torch.argmax(neg_logits, axis=1) |
|
self.test_acc( |
|
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu() |
|
) |
|
|
|
self.log("test_acc", self.test_acc) |
|
|