import torch import torchmetrics from torch.optim import AdamW from pytorch_lightning import LightningModule from transformers import AutoConfig, AutoModelForSequenceClassification, get_linear_schedule_with_warmup class LightningModel(LightningModule): def __init__( self, model_name_or_path: str, num_labels: int = 2, lr: float = 5e-6, train_batch_size: int = 32, adam_epsilon=1e-8, warmup_steps: int = 0, weight_decay: float = 0.0, **kwargs ): super().__init__() self.save_hyperparameters() self.num_labels = num_labels self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=self.num_labels) self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config) self.model.gradient_checkpointing_enable() self.lr = lr self.train_batch_size = train_batch_size self.accuracy = torchmetrics.Accuracy() self.f1score = torchmetrics.F1Score(num_classes=2) self.mcc = torchmetrics.MatthewsCorrCoef(num_classes=2) def forward(self, input_ids, attention_mask, labels=None): return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) def training_step(self, batch, batch_idx): outputs = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) loss = outputs[0] return loss def validation_step(self, batch, batch_idx, dataloader_idx=0): outputs = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) val_loss, logits = outputs[:2] preds = torch.argmax(logits, axis=1) labels = batch["labels"] return {"loss": val_loss, "preds": preds, "labels": labels} def predict_step(self, batch, batch_idx, dataloader_idx=0): batch = {"input_ids": batch["input_ids"], "attention_mask": batch["attention_mask"]} outputs = self(**batch) return torch.nn.functional.softmax(outputs.logits, dim=1)[:, 1] def validation_epoch_end(self, outputs): preds = torch.cat([x["preds"] for x in outputs]) labels = torch.cat([x["labels"] for x in outputs]) loss = torch.stack([x["loss"] for x in outputs]).mean() self.log("val_loss", loss, prog_bar=True) self.log("val_accuracy", self.accuracy(preds, labels.squeeze()), prog_bar=True) self.log("val_f1", self.f1score(preds, labels.squeeze()), prog_bar=True) self.log("val_mcc", self.mcc(preds, labels.squeeze()), prog_bar=True) return loss def setup(self, stage=None): if stage != "fit": return None # Get dataloader by calling it - train_dataloader() is called after setup() by default train_loader = self.trainer.datamodule.train_dataloader() # Calculate total steps tb_size = self.train_batch_size * max(1, self.trainer.gpus) ab_size = tb_size * self.trainer.accumulate_grad_batches self.total_steps = int((len(train_loader.dataset) / ab_size) * float(self.trainer.max_epochs)) def configure_optimizers(self): """Prepare optimizer and schedule (linear warmup and decay)""" model = self.model no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": self.hparams.weight_decay, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] optimizer = AdamW( optimizer_grouped_parameters, lr=self.lr, eps=self.hparams.adam_epsilon, ) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps, ) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} return [optimizer], [scheduler]