fact-checking-api / averitec /models /DualEncoderModule.py
zhenyundeng
add files
afdeeca
raw
history blame
4.47 kB
import pytorch_lightning as pl
import torch
from transformers.optimization import AdamW
import torchmetrics
class DualEncoderModule(pl.LightningModule):
def __init__(self, tokenizer, model, learning_rate=1e-3):
super().__init__()
self.tokenizer = tokenizer
self.model = model
self.learning_rate = learning_rate
self.train_acc = torchmetrics.Accuracy(
task="multiclass", num_classes=model.num_labels
)
self.val_acc = torchmetrics.Accuracy(
task="multiclass", num_classes=model.num_labels
)
self.test_acc = torchmetrics.Accuracy(
task="multiclass", num_classes=model.num_labels
)
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=self.learning_rate)
return optimizer
def training_step(self, batch, batch_idx):
pos_ids, pos_mask, neg_ids, neg_mask = batch
neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
pos_outputs = self(
pos_ids,
attention_mask=pos_mask,
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
pos_ids.get_device()
),
)
neg_outputs = self(
neg_ids,
attention_mask=neg_mask,
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
neg_ids.get_device()
),
)
loss_scale = 1.0
loss = pos_outputs.loss + loss_scale * neg_outputs.loss
pos_logits = pos_outputs.logits
pos_preds = torch.argmax(pos_logits, axis=1)
self.train_acc(
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
)
neg_logits = neg_outputs.logits
neg_preds = torch.argmax(neg_logits, axis=1)
self.train_acc(
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
pos_ids, pos_mask, neg_ids, neg_mask = batch
neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
pos_outputs = self(
pos_ids,
attention_mask=pos_mask,
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
pos_ids.get_device()
),
)
neg_outputs = self(
neg_ids,
attention_mask=neg_mask,
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
neg_ids.get_device()
),
)
loss_scale = 1.0
loss = pos_outputs.loss + loss_scale * neg_outputs.loss
pos_logits = pos_outputs.logits
pos_preds = torch.argmax(pos_logits, axis=1)
self.val_acc(
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
)
neg_logits = neg_outputs.logits
neg_preds = torch.argmax(neg_logits, axis=1)
self.val_acc(
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
)
self.log("val_acc", self.val_acc)
return {"loss": loss}
def test_step(self, batch, batch_idx):
pos_ids, pos_mask, neg_ids, neg_mask = batch
neg_ids = neg_ids.view(-1, neg_ids.shape[-1])
neg_mask = neg_mask.view(-1, neg_mask.shape[-1])
pos_outputs = self(
pos_ids,
attention_mask=pos_mask,
labels=torch.ones(pos_ids.shape[0], dtype=torch.uint8).to(
pos_ids.get_device()
),
)
neg_outputs = self(
neg_ids,
attention_mask=neg_mask,
labels=torch.zeros(neg_ids.shape[0], dtype=torch.uint8).to(
neg_ids.get_device()
),
)
pos_logits = pos_outputs.logits
pos_preds = torch.argmax(pos_logits, axis=1)
self.test_acc(
pos_preds.cpu(), torch.ones(pos_ids.shape[0], dtype=torch.uint8).cpu()
)
neg_logits = neg_outputs.logits
neg_preds = torch.argmax(neg_logits, axis=1)
self.test_acc(
neg_preds.cpu(), torch.zeros(neg_ids.shape[0], dtype=torch.uint8).cpu()
)
self.log("test_acc", self.test_acc)