import os import torch import pytorch_lightning as pl from torch import nn from transformers import AdamW from transformers import T5ForConditionalGeneration from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger class SummarizerModel(pl.LightningModule): def __init__(self, model_name = None): super().__init__() self.model = T5ForConditionalGeneration.from_pretrained(model_name, return_dict = True) def forward(self, input_ids, attention_mask, decoder_attention_mask, labels = None): output = self.model( input_ids, attention_mask = attention_mask, labels = labels, decoder_attention_mask = decoder_attention_mask ) return output.loss, output.logits def training_step(self, batch, batch_idx): input_ids = batch['text_input_ids'] attention_mask = batch['text_attention_mask'] labels = batch['labels'] decoder_attention_mask = batch['labels_attention_mask'] loss, outputs = self.forward( input_ids = input_ids, attention_mask = attention_mask, decoder_attention_mask = decoder_attention_mask, labels = labels ) self.log("train_loss", loss, prog_bar = True, logger = True) return loss def validation_step(self, batch, batch_idx): input_ids = batch['text_input_ids'] attention_mask = batch['text_attention_mask'] labels = batch['labels'] decoder_attention_mask = batch['labels_attention_mask'] loss, outputs = self.forward( input_ids = input_ids, attention_mask = attention_mask, decoder_attention_mask = decoder_attention_mask, labels = labels ) self.log("val_loss", loss, prog_bar = True, logger = True) return loss def test_step(self, batch, batch_idx): input_ids = batch['text_input_ids'] attention_mask = batch['text_attention_mask'] labels = batch['labels'] decoder_attention_mask = batch['labels_attention_mask'] loss, outputs = self.forward( input_ids = input_ids, attention_mask = attention_mask, decoder_attention_mask = decoder_attention_mask, labels = labels ) self.log("test_loss", loss, prog_bar = True, logger = True) return loss def configure_optimizers(self): return AdamW(self.model.parameters(), lr = 0.0001)