| import torch |
| import pandas as pd |
| import lightning.pytorch as pl |
|
|
| from omegaconf import OmegaConf |
| from datasets import load_from_disk |
| from torch.utils.data import DataLoader |
| from torch.nn.utils.rnn import pad_sequence |
| from functools import partial |
| from src.utils.model_utils import _print |
|
|
| config = OmegaConf.load('/scratch/pranamlab/sgoel/MadSBM/configs/wt_pep.yaml') |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
|
|
|
|
| def collate_fn(batch, pad_id=None): |
| input_ids = torch.tensor(batch[0]['input_ids']) |
| attention_mask = torch.tensor(batch[0]['attention_mask']) |
| return { |
| 'input_ids': input_ids, |
| 'attention_mask': attention_mask |
| } |
|
|
|
|
| class PeptideDataModule(pl.LightningDataModule): |
| def __init__(self, config, train_dataset, val_dataset, test_dataset, tokenizer, collate_fn=collate_fn): |
| super().__init__() |
| self.train_dataset = train_dataset |
| self.val_dataset = val_dataset |
| self.test_dataset = test_dataset |
| self.tokenizer = tokenizer |
| self.collate_fn = collate_fn |
| self.batch_size = config.data.batch_size |
| assert self.batch_size == 1, f'Batch size = {self.batch_size}. Needs to be 1 for dynamic batching' |
|
|
| def train_dataloader(self): |
| return DataLoader(self.train_dataset, |
| batch_size=self.batch_size, |
| collate_fn=partial(self.collate_fn), |
| num_workers=8, |
| shuffle=False, |
| pin_memory=True) |
| |
| def val_dataloader(self): |
| return DataLoader(self.val_dataset, |
| batch_size=self.batch_size, |
| collate_fn=partial(self.collate_fn), |
| num_workers=8, |
| shuffle=False, |
| pin_memory=True) |
| |
| def test_dataloader(self): |
| return DataLoader(self.test_dataset, |
| batch_size=self.batch_size, |
| collate_fn=partial(self.collate_fn), |
| num_workers=8, |
| shuffle=False, |
| pin_memory=True) |
| |
|
|
| def get_datasets(config): |
| """Helper method to grab datasets to quickly init data module in main.py""" |
| train_dataset = load_from_disk(config.data.train) |
| test_dataset = load_from_disk(config.data.test) |
| val_dataset = load_from_disk(config.data.val) |
| |
| return { |
| "train": train_dataset, |
| "val": val_dataset, |
| "test": test_dataset |
| } |
|
|
|
|