ribesstefano's picture
Fixed issue with duplicates + Experiments now rely on predefined datasets + Added experiments on simple embeddings
251060c
from typing import Literal, List, Tuple, Optional, Dict
from collections import defaultdict
import random
import logging
from .data_utils import (
get_fingerprint,
is_active,
load_cell2embedding,
load_protein2embedding,
)
from torch.utils.data import Dataset, DataLoader
from imblearn.over_sampling import SMOTE, ADASYN
from sklearn.preprocessing import StandardScaler, OrdinalEncoder
import numpy as np
import pandas as pd
import pytorch_lightning as pl
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
class PROTAC_Dataset(Dataset):
def __init__(
self,
protac_df: pd.DataFrame,
protein2embedding: Dict[str, np.ndarray],
cell2embedding: Dict[str, np.ndarray],
smiles2fp: Dict[str, np.ndarray],
use_smote: bool = False,
oversampler: Optional[SMOTE | ADASYN] = None,
active_label: str = 'Active',
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
use_single_scaler: Optional[bool] = None,
shuffle_embedding_prob: float = 0.0,
):
""" Initialize the PROTAC dataset
Args:
protac_df (pd.DataFrame): The PROTAC dataframe
protein2embedding (dict): Dictionary of protein embeddings
cell2embedding (dict): Dictionary of cell line embeddings
smiles2fp (dict): Dictionary of SMILES to fingerprint
use_smote (bool): Whether to use SMOTE for oversampling
oversampler (SMOTE | ADASYN): The oversampler to use
active_label (str): The column containing the active/inactive information
disabled_embeddings (list): The list of embeddings to disable, i.e., return a zero vector
scaler (StandardScaler | dict): The scaler to use for the embeddings
use_single_scaler (bool): Whether to use a single scaler for all features
shuffle_embedding_prob (float): The probability of shuffling the embeddings. Used for testing whether embeddings act as "barcodes". Defaults to 0.0, i.e., no shuffling.
"""
# Filter out examples with NaN in active_label column
self.data = protac_df # [~protac_df[active_label].isna()]
self.protein2embedding = protein2embedding
self.cell2embedding = cell2embedding
self.smiles2fp = smiles2fp
self.active_label = active_label
self.disabled_embeddings = disabled_embeddings
# Scaling parameters
self.scaler = scaler
self.use_single_scaler = use_single_scaler
self.smiles_emb_dim = smiles2fp[list(smiles2fp.keys())[0]].shape[0]
self.protein_emb_dim = protein2embedding[list(
protein2embedding.keys())[0]].shape[0]
self.cell_emb_dim = cell2embedding[list(
cell2embedding.keys())[0]].shape[0]
self.default_smiles_emb = np.zeros(self.smiles_emb_dim)
self.default_protein_emb = np.zeros(self.protein_emb_dim)
self.default_cell_emb = np.zeros(self.cell_emb_dim)
# Look up the embeddings
self.data = pd.DataFrame({
'Smiles': self.data['Smiles'].apply(lambda x: smiles2fp.get(x, self.default_smiles_emb).astype(np.float32)).tolist(),
'Uniprot': self.data['Uniprot'].apply(lambda x: protein2embedding.get(x, self.default_protein_emb).astype(np.float32)).tolist(),
'E3 Ligase Uniprot': self.data['E3 Ligase Uniprot'].apply(lambda x: protein2embedding.get(x, self.default_protein_emb).astype(np.float32)).tolist(),
'Cell Line Identifier': self.data['Cell Line Identifier'].apply(lambda x: cell2embedding.get(x, self.default_cell_emb).astype(np.float32)).tolist(),
self.active_label: self.data[self.active_label].astype(np.float32).tolist(),
})
# Apply SMOTE
self.use_smote = use_smote
self.oversampler = oversampler
if self.use_smote:
self.apply_smote()
self.shuffle_embedding_prob = shuffle_embedding_prob
if shuffle_embedding_prob > 0.0:
# Set random seed
random.seed(42)
if self.protein_emb_dim != self.cell_emb_dim:
logging.warning('Protein and cell embeddings have different dimensions. Shuffling will be on POI and E3 embeddings only.')
def get_smiles_emb_dim(self):
return self.smiles_emb_dim
def get_protein_emb_dim(self):
return self.protein_emb_dim
def get_cell_emb_dim(self):
return self.cell_emb_dim
def apply_smote(self):
# Prepare the dataset for SMOTE
features = []
labels = []
for _, row in self.data.iterrows():
features.append(np.hstack([
row['Smiles'],
row['Uniprot'],
row['E3 Ligase Uniprot'],
row['Cell Line Identifier'],
]))
labels.append(row[self.active_label])
# Convert to numpy array
features = np.array(features).astype(np.float32)
labels = np.array(labels).astype(np.float32)
# Initialize SMOTE and fit
if self.oversampler is None:
oversampler = SMOTE(random_state=42)
else:
oversampler = self.oversampler
features_smote, labels_smote = oversampler.fit_resample(features, labels)
# Separate the features back into their respective embeddings
smiles_embs = features_smote[:, :self.smiles_emb_dim]
poi_embs = features_smote[:,
self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]
e3_embs = features_smote[:, self.smiles_emb_dim +
self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]
cell_embs = features_smote[:, -self.cell_emb_dim:]
# Reconstruct the dataframe with oversampled data
df_smote = pd.DataFrame({
'Smiles': list(smiles_embs),
'Uniprot': list(poi_embs),
'E3 Ligase Uniprot': list(e3_embs),
'Cell Line Identifier': list(cell_embs),
self.active_label: labels_smote
})
self.data = df_smote
def fit_scaling(self, use_single_scaler: bool = False, **scaler_kwargs) -> dict:
""" Fit the scalers for the data and save them in the dataset class.
Args:
use_single_scaler (bool): Whether to use a single scaler for all features.
scaler_kwargs: Keyword arguments for the StandardScaler.
Returns:
dict: The fitted scalers.
"""
if use_single_scaler:
self.use_single_scaler = True
self.scaler = StandardScaler(**scaler_kwargs)
embeddings = np.hstack([
np.array(self.data['Smiles'].tolist()),
np.array(self.data['Uniprot'].tolist()),
np.array(self.data['E3 Ligase Uniprot'].tolist()),
np.array(self.data['Cell Line Identifier'].tolist()),
])
self.scaler.fit(embeddings)
return self.scaler
else:
self.use_single_scaler = False
scalers = {}
scalers['Smiles'] = StandardScaler(**scaler_kwargs)
scalers['Uniprot'] = StandardScaler(**scaler_kwargs)
scalers['E3 Ligase Uniprot'] = StandardScaler(**scaler_kwargs)
scalers['Cell Line Identifier'] = StandardScaler(**scaler_kwargs)
scalers['Smiles'].fit(np.stack(self.data['Smiles'].to_numpy()))
scalers['Uniprot'].fit(np.stack(self.data['Uniprot'].to_numpy()))
scalers['E3 Ligase Uniprot'].fit(np.stack(self.data['E3 Ligase Uniprot'].to_numpy()))
scalers['Cell Line Identifier'].fit(np.stack(self.data['Cell Line Identifier'].to_numpy()))
self.scaler = scalers
return scalers
def apply_scaling(self, scalers: dict, use_single_scaler: bool = False):
""" Apply scaling to the data.
Args:
scalers (dict): The scalers for each feature.
use_single_scaler (bool): Whether to use a single scaler for all features.
"""
if use_single_scaler:
embeddings = np.hstack([
np.array(self.data['Smiles'].tolist()),
np.array(self.data['Uniprot'].tolist()),
np.array(self.data['E3 Ligase Uniprot'].tolist()),
np.array(self.data['Cell Line Identifier'].tolist()),
])
scaled_embeddings = scalers.transform(embeddings)
self.data = pd.DataFrame({
'Smiles': list(scaled_embeddings[:, :self.smiles_emb_dim]),
'Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim:self.smiles_emb_dim+self.protein_emb_dim]),
'E3 Ligase Uniprot': list(scaled_embeddings[:, self.smiles_emb_dim+self.protein_emb_dim:self.smiles_emb_dim+2*self.protein_emb_dim]),
'Cell Line Identifier': list(scaled_embeddings[:, -self.cell_emb_dim:]),
self.active_label: self.data[self.active_label]
})
else:
# Check if the self.data[<column>] data contains only binary values
# (0 or 1). If so, do not apply scaling.
for feature in ['Smiles', 'Uniprot', 'E3 Ligase Uniprot', 'Cell Line Identifier']:
feature_array = np.array(self.data[feature].tolist())
if np.all(np.isin(feature_array, [0, 1])):
continue
self.data[feature] = self.data[feature].apply(lambda x: scalers[feature].transform(x[np.newaxis, :])[0])
def get_numpy_arrays(self, component: Optional[str] = None) -> Tuple[np.ndarray, np.ndarray]:
""" Get the numpy arrays for the dataset.
Args:
component (str): The component to get the numpy arrays for. Defaults to None, i.e., get a single stacked array.
Returns:
tuple: The numpy arrays for the dataset. The first element is the input array, and the second element is the output array.
"""
if component is not None:
X = np.array(self.data[component].tolist()).copy()
else:
X = np.hstack([
np.array(self.data['Smiles'].tolist()),
np.array(self.data['Uniprot'].tolist()),
np.array(self.data['E3 Ligase Uniprot'].tolist()),
np.array(self.data['Cell Line Identifier'].tolist()),
]).copy()
y = self.data[self.active_label].values.copy()
return X, y
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if 'smiles' in self.disabled_embeddings:
# Get a zero vector for the fingerprint
smiles_emb = np.zeros(self.smiles_emb_dim).astype(np.float32)
# TODO: Remove random sampling in the future
# # Uniformly sample a binary vector for the fingerprint
# smiles_emb = np.random.randint(0, 2, size=self.smiles_emb_dim).astype(np.float32)
# if not self.use_single_scaler and self.scaler is not None:
# smiles_emb = smiles_emb[np.newaxis, :]
# smiles_emb = self.scaler['Smiles'].transform(smiles_emb).flatten()
else:
smiles_emb = self.data['Smiles'].iloc[idx]
if 'poi' in self.disabled_embeddings:
poi_emb = np.zeros(self.protein_emb_dim).astype(np.float32)
# TODO: Remove random sampling in the future
# # Uniformly sample a vector for the protein
# poi_emb = np.random.rand(self.protein_emb_dim).astype(np.float32)
# if not self.use_single_scaler and self.scaler is not None:
# poi_emb = poi_emb[np.newaxis, :]
# poi_emb = self.scaler['Uniprot'].transform(poi_emb).flatten()
else:
poi_emb = self.data['Uniprot'].iloc[idx]
if 'e3' in self.disabled_embeddings:
e3_emb = np.zeros(self.protein_emb_dim).astype(np.float32)
# TODO: Remove random sampling in the future
# # Uniformly sample a vector for the E3 ligase
# e3_emb = np.random.rand(self.protein_emb_dim).astype(np.float32)
# if not self.use_single_scaler and self.scaler is not None:
# # Add extra dimension for compatibility with the scaler
# e3_emb = e3_emb[np.newaxis, :]
# e3_emb = self.scaler['E3 Ligase Uniprot'].transform(e3_emb)
# e3_emb = e3_emb.flatten()
else:
e3_emb = self.data['E3 Ligase Uniprot'].iloc[idx]
if 'cell' in self.disabled_embeddings:
cell_emb = np.zeros(self.cell_emb_dim).astype(np.float32)
# TODO: Remove random sampling in the future
# # Uniformly sample a vector for the cell line
# cell_emb = np.random.rand(self.cell_emb_dim).astype(np.float32)
# if not self.use_single_scaler and self.scaler is not None:
# cell_emb = cell_emb[np.newaxis, :]
# cell_emb = self.scaler['Cell Line Identifier'].transform(cell_emb).flatten()
else:
cell_emb = self.data['Cell Line Identifier'].iloc[idx]
# Shuffle the embeddings if the probability is met
if random.random() < self.shuffle_embedding_prob:
if self.protein_emb_dim == self.cell_emb_dim:
# Randomly shuffle the embeddings for POI, cell, and E3
embeddings = np.vstack([poi_emb, e3_emb, cell_emb])
np.random.shuffle(embeddings)
poi_emb, e3_emb, cell_emb = embeddings
else:
# Swap POI and E3 embeddings only, because of different dimensions
poi_emb, e3_emb = e3_emb, poi_emb
elem = {
'smiles_emb': smiles_emb,
'poi_emb': poi_emb,
'e3_emb': e3_emb,
'cell_emb': cell_emb,
'active': self.data[self.active_label].iloc[idx],
}
return elem
def get_datasets(
train_df: pd.DataFrame,
val_df: pd.DataFrame,
test_df: Optional[pd.DataFrame] = None,
protein2embedding: Dict = None,
cell2embedding: Dict = None,
smiles2fp: Dict = None,
smote_k_neighbors: int = 5,
active_label: str = 'Active',
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
use_single_scaler: Optional[bool] = None,
apply_scaling: bool = False,
shuffle_embedding_prob: float = 0.0,
) -> Tuple[PROTAC_Dataset, PROTAC_Dataset, Optional[PROTAC_Dataset]]:
""" Get the datasets for training the PROTAC model.
Args:
train_df (pd.DataFrame): The training data.
val_df (pd.DataFrame): The validation data.
test_df (pd.DataFrame): The test data.
protein2embedding (dict): Dictionary of protein embeddings.
cell2embedding (dict): Dictionary of cell line embeddings.
smiles2fp (dict): Dictionary of SMILES to fingerprint.
use_smote (bool): Whether to use SMOTE for oversampling.
smote_k_neighbors (int): The number of neighbors to use for SMOTE.
active_label (str): The active label column.
disabled_embeddings (list): The list of embeddings to disable.
scaler (StandardScaler | dict): The scaler to use for the embeddings.
use_single_scaler (bool): Whether to use a single scaler for all features.
apply_scaling (bool): Whether to apply scaling to the data now. Defaults to False (the Pytorch Lightning model does that).
"""
if smote_k_neighbors:
oversampler = SMOTE(k_neighbors=smote_k_neighbors, random_state=42)
else:
oversampler = None
train_ds = PROTAC_Dataset(
train_df,
protein2embedding,
cell2embedding,
smiles2fp,
use_smote=True if smote_k_neighbors else False,
oversampler=oversampler,
active_label=active_label,
disabled_embeddings=disabled_embeddings,
scaler=scaler,
use_single_scaler=use_single_scaler,
shuffle_embedding_prob=shuffle_embedding_prob,
)
val_ds = PROTAC_Dataset(
val_df,
protein2embedding,
cell2embedding,
smiles2fp,
active_label=active_label,
disabled_embeddings=disabled_embeddings,
scaler=train_ds.scaler if train_ds.scaler is not None else scaler,
use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
)
train_scalers = None
if apply_scaling:
train_scalers = train_ds.fit_scaling(use_single_scaler=use_single_scaler)
val_ds.apply_scaling(train_scalers, use_single_scaler=use_single_scaler)
if test_df is not None:
test_ds = PROTAC_Dataset(
test_df,
protein2embedding,
cell2embedding,
smiles2fp,
active_label=active_label,
disabled_embeddings=disabled_embeddings,
scaler=train_scalers if apply_scaling else scaler,
use_single_scaler=train_ds.use_single_scaler if train_ds.use_single_scaler is not None else use_single_scaler,
)
if apply_scaling:
test_ds.apply_scaling(train_ds.scaler, use_single_scaler=use_single_scaler)
else:
test_ds = None
return train_ds, val_ds, test_ds
class PROTAC_DataModule(pl.LightningDataModule):
""" PyTorch Lightning DataModule for the PROTAC dataset.
TODO: Work in progress. It would be nice to wrap all information into a
single class, but it is not clear how to do it yet due to cross-validation
and the need to split the data into training, validation, and test sets
accordingly.
Args:
protac_csv_filepath (str): The path to the PROTAC CSV file.
protein2embedding_filepath (str): The path to the protein to embedding dictionary.
cell2embedding_filepath (str): The path to the cell line to embedding dictionary.
pDC50_threshold (float): The threshold for the pDC50 value to consider a PROTAC active.
Dmax_threshold (float): The threshold for the Dmax value to consider a PROTAC active.
use_smote (bool): Whether to use SMOTE for oversampling.
smote_k_neighbors (int): The number of neighbors to use for SMOTE.
active_label (str): The column containing the active/inactive information.
disabled_embeddings (list): The list of embeddings to disable.
scaler (StandardScaler | dict): The scaler to use for the embeddings.
use_single_scaler (bool): Whether to use a single scaler for all features.
"""
def __init__(
self,
protac_csv_filepath: str,
protein2embedding_filepath: str,
cell2embedding_filepath: str,
pDC50_threshold: float = 6.0,
Dmax_threshold: float = 0.6,
use_smote: bool = True,
smote_k_neighbors: int = 5,
active_label: str = 'Active',
disabled_embeddings: List[Literal['smiles', 'poi', 'e3', 'cell']] = [],
scaler: Optional[StandardScaler | Dict[str, StandardScaler]] = None,
use_single_scaler: Optional[bool] = None,
):
super(PROTAC_DataModule, self).__init__()
# Load the PROTAC dataset
self.protac_df = pd.read_csv('../data/PROTAC-Degradation-DB.csv')
# Map E3 Ligase Iap to IAP
self.protac_df['E3 Ligase'] = self.protac_df['E3 Ligase'].str.replace('Iap', 'IAP')
self.protac_df[active_label] = self.protac_df.apply(
lambda x: is_active(
x['DC50 (nM)'],
x['Dmax (%)'],
pDC50_threshold=pDC50_threshold,
Dmax_threshold=Dmax_threshold,
),
axis=1,
)
self.smiles2fp, self.protac_df = self.get_smiles2fp_and_avg_tanimoto(self.protac_df)
self.active_df = self.protac_df[self.protac_df[active_label].notna()].copy()
# Load embedding dictionaries
self.protein2embedding = load_protein2embedding(protein2embedding_filepath)
self.cell2embedding = load_cell2embedding(cell2embedding_filepath)
def setup(self, stage: str):
self.train_ds, self.val_ds, self.test_ds = get_datasets(
self.train_df,
self.val_df,
self.test_df,
self.protein2embedding,
self.cell2embedding,
self.smiles2fp,
use_smote=self.use_smote,
smote_k_neighbors=self.smote_k_neighbors,
active_label=self.active_label,
disabled_embeddings=self.disabled_embeddings,
scaler=self.scaler,
use_single_scaler=self.use_single_scaler,
)
def train_dataloader(self):
return DataLoader(self.train_ds, batch_size=32, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_ds, batch_size=32)
def test_dataloader(self):
return DataLoader(self.test_ds, batch_size=32)
@staticmethod
def get_random_split_indices(active_df: pd.DataFrame, test_split: float) -> pd.Index:
""" Get the indices of the test set using a random split.
Args:
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
test_split (float): The percentage of the active PROTACs to use as the test set.
Returns:
pd.Index: The indices of the test set.
"""
return active_df.sample(frac=test_split, random_state=42).index
@staticmethod
def get_e3_ligase_split_indices(active_df: pd.DataFrame) -> pd.Index:
""" Get the indices of the test set using the E3 ligase split.
Args:
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
Returns:
pd.Index: The indices of the test set.
"""
encoder = OrdinalEncoder()
active_df['E3 Group'] = encoder.fit_transform(active_df[['E3 Ligase']]).astype(int)
test_df = active_df[(active_df['E3 Ligase'] != 'VHL') & (active_df['E3 Ligase'] != 'CRBN')]
return test_df.index
@staticmethod
def get_smiles2fp_and_avg_tanimoto(protac_df: pd.DataFrame) -> tuple:
""" Get the SMILES to fingerprint dictionary and the average Tanimoto similarity.
Args:
protac_df (pd.DataFrame): The DataFrame containing the PROTACs.
Returns:
tuple: The SMILES to fingerprint dictionary and the average Tanimoto similarity.
"""
unique_smiles = protac_df['Smiles'].unique().tolist()
smiles2fp = {}
for smiles in unique_smiles:
smiles2fp[smiles] = get_fingerprint(smiles)
tanimoto_matrix = defaultdict(list)
fps = list(smiles2fp.values())
# Compute all-against-all Tanimoto similarity using BulkTanimotoSimilarity
for i, (smiles1, fp1) in enumerate(zip(unique_smiles, fps)):
similarities = DataStructs.BulkTanimotoSimilarity(fp1, fps[i:]) # Only compute for i to end, avoiding duplicates
for j, similarity in enumerate(similarities):
distance = 1 - similarity
tanimoto_matrix[smiles1].append(distance) # Store as distance
if i != i + j:
tanimoto_matrix[unique_smiles[i + j]].append(distance) # Symmetric filling
# Calculate average Tanimoto distance for each unique SMILES
avg_tanimoto = {k: np.mean(v) for k, v in tanimoto_matrix.items()}
protac_df['Avg Tanimoto'] = protac_df['Smiles'].map(avg_tanimoto)
smiles2fp = {s: np.array(fp) for s, fp in smiles2fp.items()}
return smiles2fp, protac_df
@staticmethod
def get_tanimoto_split_indices(
active_df: pd.DataFrame,
active_label: str,
test_split: float,
n_bins_tanimoto: int = 200,
) -> pd.Index:
""" Get the indices of the test set using the Tanimoto-based split.
Args:
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
n_bins_tanimoto (int): The number of bins to use for the Tanimoto similarity.
Returns:
pd.Index: The indices of the test set.
"""
tanimoto_groups = pd.cut(active_df['Avg Tanimoto'], bins=n_bins_tanimoto).copy()
encoder = OrdinalEncoder()
active_df['Tanimoto Group'] = encoder.fit_transform(tanimoto_groups.values.reshape(-1, 1)).astype(int)
# Sort the groups so that samples with the highest tanimoto similarity,
# i.e., the "less similar" ones, are placed in the test set first
tanimoto_groups = active_df.groupby('Tanimoto Group')['Avg Tanimoto'].mean().sort_values(ascending=False).index
test_df = []
# For each group, get the number of active and inactive entries. Then, add those
# entries to the test_df if: 1) the test_df lenght + the group entries is less
# 20% of the active_df lenght, and 2) the percentage of True and False entries
# in the active_label in test_df is roughly 50%.
for group in tanimoto_groups:
group_df = active_df[active_df['Tanimoto Group'] == group]
if test_df == []:
test_df.append(group_df)
continue
num_entries = len(group_df)
num_active_group = group_df[active_label].sum()
num_inactive_group = num_entries - num_active_group
tmp_test_df = pd.concat(test_df)
num_entries_test = len(tmp_test_df)
num_active_test = tmp_test_df[active_label].sum()
num_inactive_test = num_entries_test - num_active_test
# Check if the group entries can be added to the test_df
if num_entries_test + num_entries < test_split * len(active_df):
# Add anything at the beggining
if num_entries_test + num_entries < test_split / 2 * len(active_df):
test_df.append(group_df)
continue
# Be more selective and make sure that the percentage of active and
# inactive is balanced
if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
test_df.append(group_df)
test_df = pd.concat(test_df)
return test_df.index
@staticmethod
def get_target_split_indices(active_df: pd.DataFrame, active_label: str, test_split: float) -> pd.Index:
""" Get the indices of the test set using the target-based split.
Args:
active_df (pd.DataFrame): The DataFrame containing the active PROTACs.
active_label (str): The column containing the active/inactive information.
test_split (float): The percentage of the active PROTACs to use as the test set.
Returns:
pd.Index: The indices of the test set.
"""
encoder = OrdinalEncoder()
active_df['Uniprot Group'] = encoder.fit_transform(active_df[['Uniprot']]).astype(int)
test_df = []
# For each group, get the number of active and inactive entries. Then, add those
# entries to the test_df if: 1) the test_df lenght + the group entries is less
# 20% of the active_df lenght, and 2) the percentage of True and False entries
# in the active_label in test_df is roughly 50%.
# Start the loop from the groups containing the smallest number of entries.
for group in reversed(active_df['Uniprot'].value_counts().index):
group_df = active_df[active_df['Uniprot'] == group]
if test_df == []:
test_df.append(group_df)
continue
num_entries = len(group_df)
num_active_group = group_df[active_label].sum()
num_inactive_group = num_entries - num_active_group
tmp_test_df = pd.concat(test_df)
num_entries_test = len(tmp_test_df)
num_active_test = tmp_test_df[active_label].sum()
num_inactive_test = num_entries_test - num_active_test
# Check if the group entries can be added to the test_df
if num_entries_test + num_entries < test_split * len(active_df):
# Add anything at the beggining
if num_entries_test + num_entries < test_split / 2 * len(active_df):
test_df.append(group_df)
continue
# Be more selective and make sure that the percentage of active and
# inactive is balanced
if (num_active_group + num_active_test) / (num_entries_test + num_entries) < 0.6:
if (num_inactive_group + num_inactive_test) / (num_entries_test + num_entries) < 0.6:
test_df.append(group_df)
test_df = pd.concat(test_df)
return test_df.index