Spaces:
Build error
Build error
import torch | |
import torchmetrics | |
from torch.optim import AdamW | |
from pytorch_lightning import LightningModule | |
from transformers import AutoConfig, AutoModelForSequenceClassification, get_linear_schedule_with_warmup | |
class LightningModel(LightningModule): | |
def __init__( | |
self, | |
model_name_or_path: str, | |
num_labels: int = 2, | |
lr: float = 5e-6, | |
train_batch_size: int = 32, | |
adam_epsilon=1e-8, | |
warmup_steps: int = 0, | |
weight_decay: float = 0.0, | |
**kwargs | |
): | |
super().__init__() | |
self.save_hyperparameters() | |
self.num_labels = num_labels | |
self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=self.num_labels) | |
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config) | |
self.model.gradient_checkpointing_enable() | |
self.lr = lr | |
self.train_batch_size = train_batch_size | |
self.accuracy = torchmetrics.Accuracy() | |
self.f1score = torchmetrics.F1Score(num_classes=2) | |
self.mcc = torchmetrics.MatthewsCorrCoef(num_classes=2) | |
def forward(self, input_ids, attention_mask, labels=None): | |
return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) | |
def training_step(self, batch, batch_idx): | |
outputs = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) | |
loss = outputs[0] | |
return loss | |
def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
outputs = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) | |
val_loss, logits = outputs[:2] | |
preds = torch.argmax(logits, axis=1) | |
labels = batch["labels"] | |
return {"loss": val_loss, "preds": preds, "labels": labels} | |
def predict_step(self, batch, batch_idx, dataloader_idx=0): | |
batch = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]} | |
outputs = self(**batch) | |
return torch.nn.functional.softmax(outputs.logits, dim=1)[:, 1] | |
def validation_epoch_end(self, outputs): | |
preds = torch.cat([x["preds"] for x in outputs]) | |
labels = torch.cat([x["labels"] for x in outputs]) | |
loss = torch.stack([x["loss"] for x in outputs]).mean() | |
self.log("val_loss", loss, prog_bar=True) | |
self.log("val_accuracy", self.accuracy(preds, labels.squeeze()), prog_bar=True) | |
self.log("val_f1", self.f1score(preds, labels.squeeze()), prog_bar=True) | |
self.log("val_mcc", self.mcc(preds, labels.squeeze()), prog_bar=True) | |
return loss | |
def setup(self, stage=None): | |
if stage != "fit": | |
return None | |
# Get dataloader by calling it - train_dataloader() is called after setup() by default | |
train_loader = self.trainer.datamodule.train_dataloader() | |
# Calculate total steps | |
tb_size = self.train_batch_size * max(1, self.trainer.gpus) | |
ab_size = tb_size * self.trainer.accumulate_grad_batches | |
self.total_steps = int((len(train_loader.dataset) / ab_size) * float(self.trainer.max_epochs)) | |
def configure_optimizers(self): | |
"""Prepare optimizer and schedule (linear warmup and decay)""" | |
model = self.model | |
no_decay = ["bias", "LayerNorm.weight"] | |
optimizer_grouped_parameters = [ | |
{ | |
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], | |
"weight_decay": self.hparams.weight_decay, | |
}, | |
{ | |
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], | |
"weight_decay": 0.0, | |
}, | |
] | |
optimizer = AdamW( | |
optimizer_grouped_parameters, | |
lr=self.lr, | |
eps=self.hparams.adam_epsilon, | |
) | |
scheduler = get_linear_schedule_with_warmup( | |
optimizer, | |
num_warmup_steps=self.hparams.warmup_steps, | |
num_training_steps=self.total_steps, | |
) | |
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} | |
return [optimizer], [scheduler] | |