import pandas as pd from transformers import AutoTokenizer from torch.utils.data import Dataset, DataLoader from pytorch_lightning import LightningDataModule class PyTorchDataModule(Dataset): """PyTorch Dataset class""" def __init__(self, model_name_or_path: str, data: pd.DataFrame, max_seq_length: int = 256): """ Initiates a PyTorch Dataset Module for input data """ self.model_name_or_path = model_name_or_path self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True) self.data = data self.max_seq_length = max_seq_length def __len__(self): """returns length of data""" return len(self.data) def __getitem__(self, index: int): """returns dictionary of input tensors to feed into the model""" data_row = self.data.iloc[index] sentence = data_row["sentence"] sentence_encoding = self.tokenizer( sentence, max_length=self.max_seq_length, padding="max_length", truncation=True, add_special_tokens=True, return_tensors="pt", ) out = dict( sentence=sentence, input_ids=sentence_encoding["input_ids"].flatten(), attention_mask=sentence_encoding["attention_mask"].flatten(), ) if "label" in self.data.columns: out.update(dict(labels=data_row["label"].flatten())) return out class DataModule(LightningDataModule): def __init__( self, model_name_or_path: str, train_df: pd.DataFrame, eval_df: pd.DataFrame, test_df: pd.DataFrame, max_seq_length: int = 256, train_batch_size: int = 32, eval_batch_size: int = 32, num_workers: int = 4, **kwargs ): super().__init__() self.model_name_or_path = model_name_or_path self.train_df = train_df self.eval_df = eval_df self.test_df = test_df self.max_seq_length = max_seq_length self.train_batch_size = train_batch_size self.eval_batch_size = eval_batch_size self.num_workers = num_workers def setup(self, stage=None): self.train_dataset = PyTorchDataModule(self.model_name_or_path, self.train_df, self.max_seq_length) self.eval_dataset = PyTorchDataModule(self.model_name_or_path, self.eval_df, self.max_seq_length) if isinstance(self.test_df, pd.DataFrame): self.test_dataset = PyTorchDataModule(self.model_name_or_path, self.test_df, self.max_seq_length) def train_dataloader(self) -> DataLoader: return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True) def val_dataloader(self) -> DataLoader: return DataLoader(self.eval_dataset, batch_size=self.eval_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True) def predict_dataloader(self) -> DataLoader: return DataLoader(self.test_dataset, batch_size=len(self.test_dataset), shuffle=False, num_workers=self.num_workers, pin_memory=True)