mhnfs / src /data_preprocessing /create_descriptors.py
Tschoui's picture
move project from private to public space
cf004a6
"""
This file includes all necessary code to preprocess molecules (assumed to be in SMILES
format) and create descriptors which can be fed into MHNfs.
"""
#---------------------------------------------------------------------------------------
# Dependencies
import numpy as np
import pandas as pd
import pickle
from typing import List
from rdkit import Chem, DataStructs
from rdkit.Chem.rdchem import Mol
from rdkit.Chem import Descriptors, rdFingerprintGenerator
from src.data_preprocessing.constants import USED_200_DESCR
from src.data_preprocessing.utils import Standardizer
#---------------------------------------------------------------------------------------
# Define main function
def preprocess_molecules(input_molecules: [str, List[str], pd.DataFrame]):
"""
This function preprocesses molecules (assumed to be in SMILES format) and creates
descriptors which can be fed into MHNfs.
"""
# Load needed objects
current_loc = __file__.rsplit("/",3)[0]
with open(current_loc + "/assets/data_preprocessing_objects/scaler_fitted.pkl",
"rb") as fl:
scaler = pickle.load(fl)
with open(current_loc + "/assets/data_preprocessing_objects/ecdfs.pkl", "rb") as fl:
ecdfs = pickle.load(fl)
# Ensure that input_molecules is an Iterable with strs
input_smiles = handle_inputs(input_molecules)
# Create cleanded rdkit mol objects
input_molecules = create_cleaned_mol_objects(input_smiles)
# Create fingerprints and descriptors
ecfps = create_ecfp_fps(input_molecules)
rdkit_descrs = create_rdkit_descriptors(input_molecules)
# Create quantils
rdkit_descr_quantils = create_quantils(rdkit_descrs, ecdfs)
# Concatenate features
raw_features = np.concatenate((ecfps, rdkit_descr_quantils), axis=1)
# Normalize feature vectors
normalized_features = scaler.transform(raw_features)
# Return feature vectors
return normalized_features
#---------------------------------------------------------------------------------------
# Define helper functions
def handle_inputs(input_molecules: [str, List[str], pd.DataFrame]):
"""
This function handles the input molecules.
"""
if isinstance(input_molecules, list):
return input_molecules
elif isinstance(input_molecules, pd.DataFrame):
input_molecules.columns = [c.lower() for c in input_molecules.columns]
if "smiles" not in input_molecules.columns:
raise ValueError(("Input DataFrame must have a column named 'Smiles'."))
iterable = list(input_molecules["smiles"].values)
return iterable
elif isinstance(input_molecules, str):
smiles_list = input_molecules.split(",")
smiles_list_cleaned = [smiles.strip() for smiles in smiles_list]
smiles_list_cleaned = [smiles for smiles in smiles_list_cleaned if smiles != ""]
return smiles_list_cleaned
else:
raise TypeError(("Input molecules must be a string,a list of strings or a "
"pandas DataFrame."))
def create_ecfp_fps(mols: List[Mol]) -> np.ndarray:
"""
This function ECFP fingerprints for a list of molecules.
"""
ecfps = list()
for mol in mols:
fp_sparse_vec = rdFingerprintGenerator.GetCountFPs(
[mol], fpType=rdFingerprintGenerator.MorganFP
)[0]
fp = np.zeros((0,), np.int8)
DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
ecfps.append(fp)
return np.array(ecfps)
def create_rdkit_descriptors(mols: List[Mol]) -> np.ndarray:
"""
This function creates RDKit descriptors for a list of molecules.
"""
rdkit_descriptors = list()
for mol in mols:
descrs = []
for _, descr_calc_fn in Descriptors._descList:
descrs.append(descr_calc_fn(mol))
descrs = np.array(descrs)
descrs = descrs[USED_200_DESCR]
rdkit_descriptors.append(descrs)
return np.array(rdkit_descriptors)
def create_quantils(raw_features: np.ndarray, ecdfs: list) -> np.ndarray:
quantils = np.zeros_like(raw_features)
for column in range(raw_features.shape[1]):
raw_values = raw_features[:, column].reshape(-1)
ecdf = ecdfs[column]
q = ecdf(raw_values)
quantils[:, column] = q
return quantils
def create_cleaned_mol_objects(smiles: List[str]) -> List[Mol]:
"""
This function creates cleaned RDKit mol objects from a list of SMILES.
"""
sm = Standardizer(canon_taut=True)
mols = list()
for smile in smiles:
#try:
mol = Chem.MolFromSmiles(smile)
standardized_mol, _ = sm.standardize_mol(mol)
can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
mols.append(can_mol)
return mols
#---------------------------------------------------------------------------------------