import time import torch import numpy as np import pandas as pd from datasets import load_metric from transformers import ( AdamW, T5ForConditionalGeneration, T5TokenizerFast as T5Tokenizer, ) from torch.utils.data import Dataset, DataLoader import pytorch_lightning as pl from pytorch_lightning.loggers import MLFlowLogger from pytorch_lightning import Trainer from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning import LightningDataModule from pytorch_lightning import LightningModule torch.cuda.empty_cache() pl.seed_everything(42) class DataModule(Dataset): """ Data Module for pytorch """ def __init__( self, data: pd.DataFrame, tokenizer: T5Tokenizer, source_max_token_len: int = 512,