| import os | |
| import torch | |
| import pytorch_lightning as pl | |
| from torch import nn | |
| from transformers import AdamW | |
| from transformers import T5ForConditionalGeneration | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from pytorch_lightning.loggers import TensorBoardLogger | |
| class SummarizerModel(pl.LightningModule): | |
| def __init__(self, model_name = None): | |
| super().__init__() | |
| self.model = T5ForConditionalGeneration.from_pretrained(model_name, return_dict = True) | |
| def forward(self, | |
| input_ids, | |
| attention_mask, | |
| decoder_attention_mask, | |
| labels = None): | |
| output = self.model( | |
| input_ids, | |
| attention_mask = attention_mask, | |
| labels = labels, | |
| decoder_attention_mask = decoder_attention_mask | |
| ) | |
| return output.loss, output.logits | |
| def training_step(self, batch, batch_idx): | |
| input_ids = batch['text_input_ids'] | |
| attention_mask = batch['text_attention_mask'] | |
| labels = batch['labels'] | |
| decoder_attention_mask = batch['labels_attention_mask'] | |
| loss, outputs = self.forward( | |
| input_ids = input_ids, | |
| attention_mask = attention_mask, | |
| decoder_attention_mask = decoder_attention_mask, | |
| labels = labels | |
| ) | |
| self.log("train_loss", loss, prog_bar = True, logger = True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| input_ids = batch['text_input_ids'] | |
| attention_mask = batch['text_attention_mask'] | |
| labels = batch['labels'] | |
| decoder_attention_mask = batch['labels_attention_mask'] | |
| loss, outputs = self.forward( | |
| input_ids = input_ids, | |
| attention_mask = attention_mask, | |
| decoder_attention_mask = decoder_attention_mask, | |
| labels = labels | |
| ) | |
| self.log("val_loss", loss, prog_bar = True, logger = True) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| input_ids = batch['text_input_ids'] | |
| attention_mask = batch['text_attention_mask'] | |
| labels = batch['labels'] | |
| decoder_attention_mask = batch['labels_attention_mask'] | |
| loss, outputs = self.forward( | |
| input_ids = input_ids, | |
| attention_mask = attention_mask, | |
| decoder_attention_mask = decoder_attention_mask, | |
| labels = labels | |
| ) | |
| self.log("test_loss", loss, prog_bar = True, logger = True) | |
| return loss | |
| def configure_optimizers(self): | |
| return AdamW(self.model.parameters(), lr = 0.0001) | 
