import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import lightning as L from pathlib import Path import pandas as pd from models.plm import get_model from models.polybert import PolyEncoder, polymer2psmiles from argparse import Namespace as Args from sklearn.model_selection import KFold from tqdm import tqdm from torch.utils.data import WeightedRandomSampler class EnzymeDataset(Dataset): def __init__(self, csv_file: str, plm: str): self.data_list = [] for i, row in pd.read_csv(csv_file).iterrows(): self.data_list.append( (row['category'], row['sequence'].upper(), row['degradation'], row['sequence_id'], row['polymer_id'])) (cache_dir := Path('cache')).mkdir(parents=True, exist_ok=True) Path(cache_dir, 'protein').mkdir(parents=True, exist_ok=True) Path(cache_dir, 'protein', plm).mkdir(parents=True, exist_ok=True) Path(cache_dir, 'polymer').mkdir(parents=True, exist_ok=True) if not all(Path(cache_dir, 'protein', plm, f"{seqid}.pt").exists() for _, _, _, seqid, _ in self.data_list): plm_func = get_model(plm, 'cuda') for _, seq, _, seqid, _ in tqdm(self.data_list, desc='Encoding enzyme sequences'): seq_path = Path(cache_dir, 'protein', plm, f'{seqid}.pt') if not seq_path.exists(): seq_tensor = plm_func([seq]) torch.save(seq_tensor, seq_path) def __len__(self): return len(self.data_list) def __getitem__(self, idx): return self.data_list[idx] class EnzymeDataModule(L.LightningDataModule): def __init__(self, args: Args): super().__init__() self.args = args self.train_csv = args.train_csv self.test_csv = args.test_csv self.batch_size = args.batch_size self.num_workers = args.num_workers self.plm = args.plm self.train_val_set = EnzymeDataset(self.train_csv, self.plm) self.test_set = EnzymeDataset(self.test_csv, self.plm) self.kfold = KFold( n_splits=args.nfolds, shuffle=True, random_state=self.args.seed) self.indices = list(range(len(self.train_val_set))) self.splits = list(self.kfold.split(self.indices)) def setup_k_fold(self, fold_idx): train_idx, val_idx = self.splits[fold_idx] self.train_set = torch.utils.data.Subset( self.train_val_set, train_idx) self.val_set = torch.utils.data.Subset( self.train_val_set, val_idx) self.sampler = self.data_sampler() def data_sampler(self): # Get labels for train_set if hasattr(self, 'train_set'): # train_set is a Subset, get indices indices = self.train_set.indices if hasattr( self.train_set, 'indices') else range(len(self.train_set)) labels = [self.train_val_set[i][2] for i in indices] # Compute class weights label_counts = pd.Series(labels).value_counts() weights = [1.0 / label_counts[label] for label in labels] sampler = WeightedRandomSampler( weights, num_samples=len(weights), replacement=True) return sampler else: raise AttributeError( 'train_set not initialized. Call setup_k_fold first.') def train_dataloader(self): return DataLoader( self.train_set, batch_size=self.batch_size, # shuffle=True, num_workers=self.num_workers, sampler=self.sampler, ) def val_dataloader(self): return DataLoader( self.val_set, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers,) def test_dataloader(self): return DataLoader( self.test_set, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)