antoniaebner's picture
update code and requirements
9b98ab0
"""
This files includes functions to create molecular descriptors.
As an input it takes a list of SMILES and it outputs a numpy array of descriptors.
"""
import json
import argparse
import numpy as np
from datasets import load_dataset
from rdkit import Chem, DataStructs
from rdkit.Chem import Descriptors, rdFingerprintGenerator, MACCSkeys
from rdkit.Chem.rdchem import Mol
from .utils import (
TASKS,
KNOWN_DESCR,
HF_TOKEN,
USED_200_DESCR,
Standardizer,
)
parser = argparse.ArgumentParser(
description="Data preprocessing script for the Tox21 dataset"
)
parser.add_argument(
"--save_folder",
type=str,
default="data/",
)
parser.add_argument(
"--use_hf",
type=int,
default=0,
)
parser.add_argument(
"--tox_smarts_filepath",
type=str,
default="assets/tox_smarts.json",
)
def create_cleaned_mol_objects(smiles: list[str]) -> tuple[list[Mol], np.ndarray]:
"""This function creates cleaned RDKit mol objects from a list of SMILES.
Args:
smiles (list[str]): list of SMILES
Returns:
list[Mol]: list of cleaned molecules
np.ndarray[bool]: mask that contains False at index `i`, if molecule in `smiles` atindex `i` could not be cleaned and was removed.
"""
sm = Standardizer(canon_taut=True)
clean_mol_mask = list()
mols = list()
for i, smile in enumerate(smiles):
mol = Chem.MolFromSmiles(smile)
standardized_mol, _ = sm.standardize_mol(mol)
is_cleaned = standardized_mol is not None
clean_mol_mask.append(is_cleaned)
if not is_cleaned:
continue
can_mol = Chem.MolFromSmiles(Chem.MolToSmiles(standardized_mol))
mols.append(can_mol)
return mols, np.array(clean_mol_mask)
def create_ecfp_fps(mols: list[Mol], radius=None, fpsize=None) -> np.ndarray:
"""This function ECFP fingerprints for a list of molecules.
Args:
mols (list[Mol]): list of molecules
Returns:
np.ndarray: ECFP fingerprints of molecules
"""
ecfps = list()
kwargs = {}
if not fpsize is None:
kwargs["fpSize"] = fpsize
if not radius is None:
kwargs["radius"] = radius
for mol in mols:
gen = rdFingerprintGenerator.GetMorganGenerator(countSimulation=True, **kwargs)
fp_sparse_vec = gen.GetCountFingerprint(mol)
fp = np.zeros((0,), np.int8)
DataStructs.ConvertToNumpyArray(fp_sparse_vec, fp)
ecfps.append(fp)
return np.array(ecfps)
def create_maccs_keys(mols: list[Mol]) -> np.ndarray:
maccs = [MACCSkeys.GenMACCSKeys(x) for x in mols]
return np.array(maccs)
def get_tox_patterns(filepath: str):
"""This calculates tox features defined in tox_smarts.json.
Args:
mols: A list of Mol
n_jobs: If >1 multiprocessing is used
"""
# load patterns
with open(filepath) as f:
smarts_list = [s[1] for s in json.load(f)]
# Code does not work for this case
assert len([s for s in smarts_list if ("AND" in s) and ("OR" in s)]) == 0
# Chem.MolFromSmarts takes a long time so it pays of to parse all the smarts first
# and then use them for all molecules. This gives a huge speedup over existing code.
# a list of patterns, whether to negate the match result and how to join them to obtain one boolean value
all_patterns = []
for smarts in smarts_list:
patterns = [] # list of smarts-patterns
# value for each of the patterns above. Negates the values of the above later.
negations = []
if " AND " in smarts:
smarts = smarts.split(" AND ")
merge_any = False # If an ' AND ' is found all 'subsmarts' have to match
else:
# If there is an ' OR ' present it's enough is any of the 'subsmarts' match.
# This also accumulates smarts where neither ' OR ' nor ' AND ' occur
smarts = smarts.split(" OR ")
merge_any = True
# for all subsmarts check if they are preceded by 'NOT '
for s in smarts:
neg = s.startswith("NOT ")
if neg:
s = s[4:]
patterns.append(Chem.MolFromSmarts(s))
negations.append(neg)
all_patterns.append((patterns, negations, merge_any))
return all_patterns
def create_tox_features(mols: list[Mol], patterns: list) -> np.ndarray:
"""Matches the tox patterns against a molecule. Returns a boolean array"""
tox_data = []
for mol in mols:
mol_features = []
for patts, negations, merge_any in patterns:
matches = [mol.HasSubstructMatch(p) for p in patts]
matches = [m != n for m, n in zip(matches, negations)]
if merge_any:
pres = any(matches)
else:
pres = all(matches)
mol_features.append(pres)
tox_data.append(np.array(mol_features))
return np.array(tox_data)
def create_rdkit_descriptors(mols: list[Mol]) -> np.ndarray:
"""This function creates RDKit descriptors for a list of molecules.
Args:
mols (list[Mol]): list of molecules
Returns:
np.ndarray: RDKit descriptors of molecules
"""
rdkit_descriptors = list()
for mol in mols:
descrs = []
for _, descr_calc_fn in Descriptors._descList:
descrs.append(descr_calc_fn(mol))
descrs = np.array(descrs)
descrs = descrs[USED_200_DESCR]
rdkit_descriptors.append(descrs)
return np.array(rdkit_descriptors)
def create_descriptors(
smiles,
):
print(f"Preprocess {len(smiles)} molecules")
# Create cleanded rdkit mol objects
mols, clean_mol_mask = create_cleaned_mol_objects(smiles)
print("Cleaned molecules")
tox_patterns = get_tox_patterns("assets/tox_smarts.json")
# Create fingerprints and descriptors
ecfps = create_ecfp_fps(mols, radius=3, fpsize=8192)
print("Created ECFP fingerprints")
tox = create_tox_features(mols, tox_patterns)
print("Created Tox features")
maccs = create_maccs_keys(mols)
print("Created MACCS keys")
rdkit_descrs = create_rdkit_descriptors(mols)
print("Created RDKit descriptors")
features = np.concatenate((ecfps, tox, maccs, rdkit_descrs), axis=1)
return features, clean_mol_mask
def fill(features, mask, value=np.nan):
n_mols = len(mask)
n_features = features.shape[1]
data = np.zeros(shape=(n_mols, n_features))
data.fill(value)
data[~mask] = features
return data
def preprocess_tox21():
splits = ["train", "validation"]
ds = load_dataset("tschouis/tox21", token=HF_TOKEN)
all_features, all_labels, all_split = [], [], []
for split in splits:
print(f"Preprocess {split} molecules")
smiles = list(ds[split]["smiles"])
features, mol_mask = create_descriptors(
smiles,
)
print(f"Created {features.shape[1]} descriptors for {len(smiles)} molecules.")
print(f"{len(mol_mask) - sum(mol_mask)} molecules removed during cleaning.")
labels = []
for task in TASKS:
datasplit = ds[split].to_pandas() if args.use_hf else ds[split]
labels.append(datasplit[task].to_numpy())
labels = np.stack(labels, axis=1)
all_features.append(features)
all_labels.append(labels)
all_split.append([split] * len(smiles))
save_path = f"{args.save_folder}/tox21_data.npz"
with open(save_path, "wb") as f:
np.savez_compressed(
f,
features=all_features,
labels=all_labels,
splits=all_split,
)
print(f"Saved preprocessed data to {save_path}")
if __name__ == "__main__":
args = parser.parse_args()
preprocess_tox21()