| from torch.utils.data import DataLoader | |
| import torch | |
| from torch.utils.data import DataLoader, DistributedSampler | |
| from src.interface.data_interface import DInterface_base | |
| from src.model.pretrain_model_interface import PretrainModelInterface | |
| from src.data.proteingym_dataset import ProteinGYMDataset | |
| from src.data.msa_dataset import MSADataset | |
| class DInterface(DInterface_base): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.save_hyperparameters() | |
| def setup(self, stage=None): | |
| pass | |
| def data_setup(self, type="proteingym"): | |
| if type == "proteingym": | |
| self.mut_dataset = ProteinGYMDataset( | |
| dms_csv_dir = self.hparams.dms_csv_dir, | |
| dms_pdb_dir = self.hparams.dms_pdb_dir, | |
| dms_reference_csv_path = self.hparams.dms_reference_csv_path, | |
| ) | |
| elif type == "msa": | |
| self.msa_dataset = MSADataset( | |
| msa_csv_path = msa_csv_path | |
| ) | |
| def train_dataloader(self): | |
| return DataLoader(self.mut_dataset, batch_size=1, shuffle=True, num_workers=self.hparams.num_workers, pin_memory=True) | |
| def val_dataloader(self): | |
| return DataLoader(self.mut_dataset, batch_size=1, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True) | |
| def test_dataloader(self): | |
| return DataLoader(self.mut_dataset, batch_size=1, shuffle=False, num_workers=self.hparams.num_workers, pin_memory=True) | |