from typing import Literal, List, Tuple, Optional, Dict from collections import defaultdict import random import logging from .data_utils import ( get_fingerprint, is_active, load_cell2embedding, load_protein2embedding, ) from torch.utils.data import Dataset, DataLoader from imblearn.over_sampling import SMOTE, ADASYN from sklearn.preprocessing import StandardScaler, OrdinalEncoder import numpy as np import pandas as pd import pytorch_lightning as pl from rdkit import Chem from rdkit.Chem import AllChem from rdkit import DataStructs class PROTAC_Dataset(Dataset): def __init__( self, protac_df: pd.DataFrame, protein2embedding: Dict[str, np.ndarray], cell2embedding: Dict[str, np.ndarray], smiles2fp: Dict[str, np.ndarray], use_smote: bool = False, oversampler: Optional[SMOTE | ADASYN] = None, active_label: str = 'Active', disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [], scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None, use_single_scaler: Optional[bool] = None, shuffle_embedding_prob: float = 0.0, ): """ Initialize the PROTAC dataset Args: protac_df (pd.DataFrame): The PROTAC dataframe protein2embedding (dict): Dictionary of protein embeddings cell2embedding (dict): Dictionary of cell line embeddings smiles2fp (dict): Dictionary of SMILES to fingerprint use_smote (bool): Whether to use SMOTE for oversampling oversampler (SMOTE | ADASYN): The oversampler to use active_label (str): The column containing the active/inactive information disabled_embeddings (list): The list of embeddings to disable, i.e., return a zero vector scaler (StandardScaler | dict): The scaler to use for the embeddings use_single_scaler (bool): Whether to use a single scaler for all features shuffle_embedding_prob (float): The probability of shuffling the embeddings. Used for testing whether embeddings act as "barcodes". Defaults to 0.0, i.e., no shuffling. """ # Filter out examples with NaN in active_label column self.data = protac_df # [~protac_df[active_label].isna()] self.protein2embedding = protein2embedding self.cell2embedding = cell2embedding self.smiles2fp = smiles2fp self.active_label = active_label self.disabled_embeddings = disabled_embeddings # Scaling parameters self.scaler = scaler self.use_single_scaler = use_single_scaler self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0] self.protein_emb_dim = protein2embedding[list( protein2embedding.keys())[0]].shape[0] self.cell_emb_dim = cell2embedding[list( cell2embedding.keys())[0]].shape[0] self.default_smiles_emb = np.zeros(self.smiles_emb_dim) self.default_protein_emb = np.zeros(self.protein_emb_dim) self.default_cell_emb = np.zeros(self.cell_emb_dim) # Look up the embeddings self.data = pd.DataFrame({ 'Smiles': self.data['Smiles'].apply(lambda x: smiles2fp.get(x, self.default_smiles_emb).astype(np.float32)).tolist(), 'Uniprot': self.data['Uniprot'].apply(lambda x: protein2embedding.get(x, self.default_protein_emb).astype(np.float32)).tolist(), 'E3 Ligase Uniprot': self.data['E3 Ligase Uniprot'].apply(lambda x: protein2embedding.get(x, self.default_protein_emb).astype(np.float32)).tolist(), 'Cell Line Identifier': self.data['Cell Line Identifier'].apply(lambda x: cell2embedding.get(x, self.default_cell_emb).astype(np.float32)).tolist(), self.active_label: self.data[self.active_label].astype(np.float32).tolist(), }) # Apply SMOTE self.use_smote = use_smote self.oversampler = oversampler if self.use_smote: self.apply_smote() self.shuffle_embedding_prob = shuffle_embedding_prob if shuffle_embedding_prob > 0.0: # Set random seed random.seed(42) if self.protein_emb_dim != self.cell_emb_dim: logging.warning('Protein and cell embeddings have different dimensions. Shuffling will be on POI and E3 embeddings only.') def get_smiles_emb_dim(self): return self.smiles_emb_dim def get_protein_emb_dim(self): return self.protein_emb_dim def get_cell_emb_dim(self): return self.cell_emb_dim def apply_smote(self): # Prepare the dataset for SMOTE features = [] labels = [] for _, row in self.data.iterrows(): features.append(np.hstack([ row['Smiles'], row['Uniprot'], row['E3 Ligase Uniprot'], row['Cell Line Identifier'], ])) labels.append(row[self.active_label]) # Convert to numpy array features = np.array(features).astype(np.float32) labels = np.array(labels).astype(np.float32) # Initialize SMOTE and fit if self.oversampler is None: oversampler = SMOTE(random_state=42) else: oversampler = self.oversampler features_smote, labels_smote = oversampler.fit_resample(features, labels) # Separate the features back into their respective embeddings smiles_embs = features_smote[:, :self.smiles_emb_dim] poi_embs = features_smote[:, self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim] e3_embs = features_smote[:, self.smiles_emb_dim + self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim] cell_embs = features_smote[:, -self.cell_emb_dim:] # Reconstruct the dataframe with oversampled data df_smote = pd.DataFrame({ 'Smiles': list(smiles_embs), 'Uniprot': list(poi_embs), 'E3 Ligase Uniprot': list(e3_embs), 'Cell Line Identifier': list(cell_embs), self.active_label: labels_smote }) self.data = df_smote def fit_scaling(self, use_single_scaler: bool = False, **scaler_kwargs) -> dict: """ Fit the scalers for the data and save them in the dataset class. Args: use_single_scaler (bool): Whether to use a single scaler for all features. scaler_kwargs: Keyword arguments for the StandardScaler. Returns: dict: The fitted scalers. """ if use_single_scaler: self.use_single_scaler = True self.scaler = StandardScaler(**scaler_kwargs) embeddings = np.hstack([ np.array(self.data['Smiles'].tolist()), np.array(self.data['Uniprot'].tolist()), np.array(self.data['E3 Ligase Uniprot'].tolist()), np.array(self.data['Cell Line Identifier'].tolist()), ]) self.scaler.fit(embeddings) return self.scaler else: self.use_single_scaler = False scalers = {} scalers['Smiles'] = StandardScaler(**scaler_kwargs) scalers['Uniprot'] = StandardScaler(**scaler_kwargs) scalers['E3 Ligase Uniprot'] = StandardScaler(**scaler_kwargs) scalers['Cell Line Identifier'] = StandardScaler(**scaler_kwargs) scalers['Smiles'].fit(np.stack(self.data['Smiles'].to_numpy())) scalers['Uniprot'].fit(np.stack(self.data['Uniprot'].to_numpy())) scalers['E3 Ligase Uniprot'].fit(np.stack(self.data['E3 Ligase Uniprot'].to_numpy())) scalers['Cell Line Identifier'].fit(np.stack(self.data['Cell Line Identifier'].to_numpy())) self.scaler = scalers return scalers def apply_scaling(self, scalers: dict, use_single_scaler: bool = False): """ Apply scaling to the data. Args: scalers (dict): The scalers for each feature. use_single_scaler (bool): Whether to use a single scaler for all features. """ if use_single_scaler: embeddings = np.hstack([ np.array(self.data['Smiles'].tolist()), np.array(self.data['Uniprot'].tolist()), np.array(self.data['E3 Ligase Uniprot'].tolist()), np.array(self.data['Cell Line Identifier'].tolist()), ]) scaled_embeddings = scalers.transform(embeddings) self.data = pd.DataFrame({ 'Smiles': list(scaled_embeddings[:, :self.smiles_emb_dim]), 'Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]), 'E3 Ligase Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim+self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]), 'Cell Line Identifier': list(scaled_embeddings[:, -self.cell_emb_dim:]), self.active_label: self.data[self.active_label] }) else: # Check if the self.data[] data contains only binary values # (0 or 1). If so, do not apply scaling. for feature in ['Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier']: feature_array = np.array(self.data[feature].tolist()) if np.all(np.isin(feature_array, [0, 1])): continue self.data[feature] = self.data[feature].apply(lambda x: scalers[feature].transform(x[np.newaxis, :])[0]) def get_numpy_arrays(self, component: Optional[str] = None) -> Tuple[np.ndarray, np.ndarray]: """ Get the numpy arrays for the dataset. Args: component (str): The component to get the numpy arrays for. Defaults to None, i.e., get a single stacked array. Returns: tuple: The numpy arrays for the dataset. The first element is the input array, and the second element is the output array. """ if component is not None: X = np.array(self.data[component].tolist()).copy() else: X = np.hstack([ np.array(self.data['Smiles'].tolist()), np.array(self.data['Uniprot'].tolist()), np.array(self.data['E3 Ligase Uniprot'].tolist()), np.array(self.data['Cell Line Identifier'].tolist()), ]).copy() y = self.data[self.active_label].values.copy() return X, y def __len__(self): return len(self.data) def __getitem__(self, idx): if 'smiles' in self.disabled_embeddings: # Get a zero vector for the fingerprint smiles_emb = np.zeros(self.smiles_emb_dim).astype(np.float32) # TODO: Remove random sampling in the future # # Uniformly sample a binary vector for the fingerprint # smiles_emb = np.random.randint(0, 2, size=self.smiles_emb_dim).astype(np.float32) # if not self.use_single_scaler and self.scaler is not None: # smiles_emb = smiles_emb[np.newaxis, :] # smiles_emb = self.scaler['Smiles'].transform(smiles_emb).flatten() else: smiles_emb = self.data['Smiles'].iloc[idx] if 'poi' in self.disabled_embeddings: poi_emb = np.zeros(self.protein_emb_dim).astype(np.float32) # TODO: Remove random sampling in the future # # Uniformly sample a vector for the protein # poi_emb = np.random.rand(self.protein_emb_dim).astype(np.float32) # if not self.use_single_scaler and self.scaler is not None: # poi_emb = poi_emb[np.newaxis, :] # poi_emb = self.scaler['Uniprot'].transform(poi_emb).flatten() else: poi_emb = self.data['Uniprot'].iloc[idx] if 'e3' in self.disabled_embeddings: e3_emb = np.zeros(self.protein_emb_dim).astype(np.float32) # TODO: Remove random sampling in the future # # Uniformly sample a vector for the E3 ligase # e3_emb = np.random.rand(self.protein_emb_dim).astype(np.float32) # if not self.use_single_scaler and self.scaler is not None: # # Add extra dimension for compatibility with the scaler # e3_emb = e3_emb[np.newaxis, :] # e3_emb = self.scaler['E3 Ligase Uniprot'].transform(e3_emb) # e3_emb = e3_emb.flatten() else: e3_emb = self.data['E3 Ligase Uniprot'].iloc[idx] if 'cell' in self.disabled_embeddings: cell_emb = np.zeros(self.cell_emb_dim).astype(np.float32) # TODO: Remove random sampling in the future # # Uniformly sample a vector for the cell line # cell_emb = np.random.rand(self.cell_emb_dim).astype(np.float32) # if not self.use_single_scaler and self.scaler is not None: # cell_emb = cell_emb[np.newaxis, :] # cell_emb = self.scaler['Cell Line Identifier'].transform(cell_emb).flatten() else: cell_emb = self.data['Cell Line Identifier'].iloc[idx] # Shuffle the embeddings if the probability is met if random.random() < self.shuffle_embedding_prob: if self.protein_emb_dim == self.cell_emb_dim: # Randomly shuffle the embeddings for POI, cell, and E3 embeddings = np.vstack([poi_emb, e3_emb, cell_emb]) np.random.shuffle(embeddings) poi_emb, e3_emb, cell_emb = embeddings else: # Swap POI and E3 embeddings only, because of different dimensions poi_emb, e3_emb = e3_emb, poi_emb elem = { 'smiles_emb': smiles_emb, 'poi_emb': poi_emb, 'e3_emb': e3_emb, 'cell_emb': cell_emb, 'active': self.data[self.active_label].iloc[idx], } return elem def get_datasets( train_df: pd.DataFrame, val_df: pd.DataFrame, test_df: Optional[pd.DataFrame] = None, protein2embedding: Dict = None, cell2embedding: Dict = None, smiles2fp: Dict = None, smote_k_neighbors: int = 5, active_label: str = 'Active', disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [], scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None, use_single_scaler: Optional[bool] = None, apply_scaling: bool = False, shuffle_embedding_prob: float = 0.0, ) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]: """ Get the datasets for training the PROTAC model. Args: train_df (pd.DataFrame): The training data. val_df (pd.DataFrame): The validation data. test_df (pd.DataFrame): The test data. protein2embedding (dict): Dictionary of protein embeddings. cell2embedding (dict): Dictionary of cell line embeddings. smiles2fp (dict): Dictionary of SMILES to fingerprint. use_smote (bool): Whether to use SMOTE for oversampling. smote_k_neighbors (int): The number of neighbors to use for SMOTE. active_label (str): The active label column. disabled_embeddings (list): The list of embeddings to disable. scaler (StandardScaler | dict): The scaler to use for the embeddings. use_single_scaler (bool): Whether to use a single scaler for all features. apply_scaling (bool): Whether to apply scaling to the data now. Defaults to False (the Pytorch Lightning model does that). """ if smote_k_neighbors: oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42) else: oversampler = None train_ds = PROTAC_Dataset( train_df, protein2embedding, cell2embedding, smiles2fp, use_smote=True if smote_k_neighbors else False, oversampler=oversampler, active_label=active_label, disabled_embeddings=disabled_embeddings, scaler=scaler, use_single_scaler=use_single_scaler, shuffle_embedding_prob=shuffle_embedding_prob, ) val_ds = PROTAC_Dataset( val_df, protein2embedding, cell2embedding, smiles2fp, active_label=active_label, disabled_embeddings=disabled_embeddings, scaler=train_ds.scaler if train_ds.scaler is not None else scaler, use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler, ) train_scalers = None if apply_scaling: train_scalers = train_ds.fit_scaling(use_single_scaler=use_single_scaler) val_ds.apply_scaling(train_scalers, use_single_scaler=use_single_scaler) if test_df is not None: test_ds = PROTAC_Dataset( test_df, protein2embedding, cell2embedding, smiles2fp, active_label=active_label, disabled_embeddings=disabled_embeddings, scaler=train_scalers if apply_scaling else scaler, use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler, ) if apply_scaling: test_ds.apply_scaling(train_ds.scaler, use_single_scaler=use_single_scaler) else: test_ds = None return train_ds, val_ds, test_ds class PROTAC_DataModule(pl.LightningDataModule): """ PyTorch Lightning DataModule for the PROTAC dataset. TODO: Work in progress. It would be nice to wrap all information into a single class, but it is not clear how to do it yet due to cross-validation and the need to split the data into training, validation, and test sets accordingly. Args: protac_csv_filepath (str): The path to the PROTAC CSV file. protein2embedding_filepath (str): The path to the protein to embedding dictionary. cell2embedding_filepath (str): The path to the cell line to embedding dictionary. pDC50_threshold (float): The threshold for the pDC50 value to consider a PROTAC active. Dmax_threshold (float): The threshold for the Dmax value to consider a PROTAC active. use_smote (bool): Whether to use SMOTE for oversampling. smote_k_neighbors (int): The number of neighbors to use for SMOTE. active_label (str): The column containing the active/inactive information. disabled_embeddings (list): The list of embeddings to disable. scaler (StandardScaler | dict): The scaler to use for the embeddings. use_single_scaler (bool): Whether to use a single scaler for all features. """ def __init__( self, protac_csv_filepath: str, protein2embedding_filepath: str, cell2embedding_filepath: str, pDC50_threshold: float = 6.0, Dmax_threshold: float = 0.6, use_smote: bool = True, smote_k_neighbors: int = 5, active_label: str = 'Active', disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [], scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None, use_single_scaler: Optional[bool] = None, ): super(PROTAC_DataModule, self).__init__() # Load the PROTAC dataset self.protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv') # Map E3 Ligase Iap to IAP self.protac_df['E3 Ligase'] = self.protac_df['E3 Ligase'].str.replace('Iap', 'IAP') self.protac_df[active_label] = self.protac_df.apply( lambda x: is_active( x['DC50 (nM)'], x['Dmax (%)'], pDC50_threshold=pDC50_threshold, Dmax_threshold=Dmax_threshold, ), axis=1, ) self.smiles2fp, self.protac_df = self.get_smiles2fp_and_avg_tanimoto(self.protac_df) self.active_df = self.protac_df[self.protac_df[active_label].notna()].copy() # Load embedding dictionaries self.protein2embedding = load_protein2embedding(protein2embedding_filepath) self.cell2embedding = load_cell2embedding(cell2embedding_filepath) def setup(self, stage: str): self.train_ds, self.val_ds, self.test_ds = get_datasets( self.train_df, self.val_df, self.test_df, self.protein2embedding, self.cell2embedding, self.smiles2fp, use_smote=self.use_smote, smote_k_neighbors=self.smote_k_neighbors, active_label=self.active_label, disabled_embeddings=self.disabled_embeddings, scaler=self.scaler, use_single_scaler=self.use_single_scaler, ) def train_dataloader(self): return DataLoader(self.train_ds, batch_size=32, shuffle=True) def val_dataloader(self): return DataLoader(self.val_ds, batch_size=32) def test_dataloader(self): return DataLoader(self.test_ds, batch_size=32) @staticmethod def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index: """ Get the indices of the test set using a random split. Args: active_df (pd.DataFrame): The DataFrame containing the active PROTACs. test_split (float): The percentage of the active PROTACs to use as the test set. Returns: pd.Index: The indices of the test set. """ return active_df.sample(frac=test_split, random_state=42).index @staticmethod def get_e3_ligase_split_indices(active_df: pd.DataFrame) -> pd.Index: """ Get the indices of the test set using the E3 ligase split. Args: active_df (pd.DataFrame): The DataFrame containing the active PROTACs. Returns: pd.Index: The indices of the test set. """ encoder = OrdinalEncoder() active_df['E3 Group'] = encoder.fit_transform(active_df[['E3 Ligase']]).astype(int) test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')] return test_df.index @staticmethod def get_smiles2fp_and_avg_tanimoto(protac_df: pd.DataFrame) -> tuple: """ Get the SMILES to fingerprint dictionary and the average Tanimoto similarity. Args: protac_df (pd.DataFrame): The DataFrame containing the PROTACs. Returns: tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity. """ unique_smiles = protac_df['Smiles'].unique().tolist() smiles2fp = {} for smiles in unique_smiles: smiles2fp[smiles] = get_fingerprint(smiles) tanimoto_matrix = defaultdict(list) fps = list(smiles2fp.values()) # Compute all-against-all Tanimoto similarity using BulkTanimotoSimilarity for i, (smiles1, fp1) in enumerate(zip(unique_smiles, fps)): similarities = DataStructs.BulkTanimotoSimilarity(fp1, fps[i:]) # Only compute for i to end, avoiding duplicates for j, similarity in enumerate(similarities): distance = 1 - similarity tanimoto_matrix[smiles1].append(distance) # Store as distance if i != i + j: tanimoto_matrix[unique_smiles[i + j]].append(distance) # Symmetric filling # Calculate average Tanimoto distance for each unique SMILES avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()} protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto) smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()} return smiles2fp, protac_df @staticmethod def get_tanimoto_split_indices( active_df: pd.DataFrame, active_label: str, test_split: float, n_bins_tanimoto: int = 200, ) -> pd.Index: """ Get the indices of the test set using the Tanimoto-based split. Args: active_df (pd.DataFrame): The DataFrame containing the active PROTACs. n_bins_tanimoto (int): The number of bins to use for the Tanimoto similarity. Returns: pd.Index: The indices of the test set. """ tanimoto_groups = pd.cut(active_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy() encoder = OrdinalEncoder() active_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int) # Sort the groups so that samples with the highest tanimoto similarity, # i.e., the "less similar" ones, are placed in the test set first tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index test_df = [] # For each group, get the number of active and inactive entries. Then, add those # entries to the test_df if: 1) the test_df lenght + the group entries is less # 20% of the active_df lenght, and 2) the percentage of True and False entries # in the active_label in test_df is roughly 50%. for group in tanimoto_groups: group_df = active_df[active_df['Tanimoto Group'] == group] if test_df == []: test_df.append(group_df) continue num_entries = len(group_df) num_active_group = group_df[active_label].sum() num_inactive_group = num_entries - num_active_group tmp_test_df = pd.concat(test_df) num_entries_test = len(tmp_test_df) num_active_test = tmp_test_df[active_label].sum() num_inactive_test = num_entries_test - num_active_test # Check if the group entries can be added to the test_df if num_entries_test + num_entries < test_split * len(active_df): # Add anything at the beggining if num_entries_test + num_entries < test_split / 2 * len(active_df): test_df.append(group_df) continue # Be more selective and make sure that the percentage of active and # inactive is balanced if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6: if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6: test_df.append(group_df) test_df = pd.concat(test_df) return test_df.index @staticmethod def get_target_split_indices(active_df: pd.DataFrame, active_label: str, test_split: float) -> pd.Index: """ Get the indices of the test set using the target-based split. Args: active_df (pd.DataFrame): The DataFrame containing the active PROTACs. active_label (str): The column containing the active/inactive information. test_split (float): The percentage of the active PROTACs to use as the test set. Returns: pd.Index: The indices of the test set. """ encoder = OrdinalEncoder() active_df['Uniprot Group'] = encoder.fit_transform(active_df[['Uniprot']]).astype(int) test_df = [] # For each group, get the number of active and inactive entries. Then, add those # entries to the test_df if: 1) the test_df lenght + the group entries is less # 20% of the active_df lenght, and 2) the percentage of True and False entries # in the active_label in test_df is roughly 50%. # Start the loop from the groups containing the smallest number of entries. for group in reversed(active_df['Uniprot'].value_counts().index): group_df = active_df[active_df['Uniprot'] == group] if test_df == []: test_df.append(group_df) continue num_entries = len(group_df) num_active_group = group_df[active_label].sum() num_inactive_group = num_entries - num_active_group tmp_test_df = pd.concat(test_df) num_entries_test = len(tmp_test_df) num_active_test = tmp_test_df[active_label].sum() num_inactive_test = num_entries_test - num_active_test # Check if the group entries can be added to the test_df if num_entries_test + num_entries < test_split * len(active_df): # Add anything at the beggining if num_entries_test + num_entries < test_split / 2 * len(active_df): test_df.append(group_df) continue # Be more selective and make sure that the percentage of active and # inactive is balanced if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6: if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6: test_df.append(group_df) test_df = pd.concat(test_df) return test_df.index