from huggingface_hub import from_pretrained_keras import gradio as gr from rdkit import Chem, RDLogger from rdkit.Chem.Draw import MolsToGridImage import numpy as np import tensorflow as tf from tensorflow import keras import pandas as pd # Config class Featurizer: def __init__(self, allowable_sets): self.dim = 0 self.features_mapping = {} for k, s in allowable_sets.items(): s = sorted(list(s)) self.features_mapping[k] = dict(zip(s, range(self.dim, len(s) + self.dim))) self.dim += len(s) def encode(self, inputs): output = np.zeros((self.dim,)) for name_feature, feature_mapping in self.features_mapping.items(): feature = getattr(self, name_feature)(inputs) if feature not in feature_mapping: continue output[feature_mapping[feature]] = 1.0 return output class AtomFeaturizer(Featurizer): def __init__(self, allowable_sets): super().__init__(allowable_sets) def symbol(self, atom): return atom.GetSymbol() def n_valence(self, atom): return atom.GetTotalValence() def n_hydrogens(self, atom): return atom.GetTotalNumHs() def hybridization(self, atom): return atom.GetHybridization().name.lower() class BondFeaturizer(Featurizer): def __init__(self, allowable_sets): super().__init__(allowable_sets) self.dim += 1 def encode(self, bond): output = np.zeros((self.dim,)) if bond is None: output[-1] = 1.0 return output output = super().encode(bond) return output def bond_type(self, bond): return bond.GetBondType().name.lower() def conjugated(self, bond): return bond.GetIsConjugated() atom_featurizer = AtomFeaturizer( allowable_sets={ "symbol": {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"}, "n_valence": {0, 1, 2, 3, 4, 5, 6}, "n_hydrogens": {0, 1, 2, 3, 4}, "hybridization": {"s", "sp", "sp2", "sp3"}, } ) bond_featurizer = BondFeaturizer( allowable_sets={ "bond_type": {"single", "double", "triple", "aromatic"}, "conjugated": {True, False}, } ) def molecule_from_smiles(smiles): # MolFromSmiles(m, sanitize=True) should be equivalent to # MolFromSmiles(m, sanitize=False) -> SanitizeMol(m) -> AssignStereochemistry(m, ...) molecule = Chem.MolFromSmiles(smiles, sanitize=False) # If sanitization is unsuccessful, catch the error, and try again without # the sanitization step that caused the error flag = Chem.SanitizeMol(molecule, catchErrors=True) if flag != Chem.SanitizeFlags.SANITIZE_NONE: Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag) Chem.AssignStereochemistry(molecule, cleanIt=True, force=True) return molecule def graph_from_molecule(molecule): # Initialize graph atom_features = [] bond_features = [] pair_indices = [] for atom in molecule.GetAtoms(): atom_features.append(atom_featurizer.encode(atom)) # Add self-loops pair_indices.append([atom.GetIdx(), atom.GetIdx()]) bond_features.append(bond_featurizer.encode(None)) for neighbor in atom.GetNeighbors(): bond = molecule.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()) pair_indices.append([atom.GetIdx(), neighbor.GetIdx()]) bond_features.append(bond_featurizer.encode(bond)) return np.array(atom_features), np.array(bond_features), np.array(pair_indices) def graphs_from_smiles(smiles_list): # Initialize graphs atom_features_list = [] bond_features_list = [] pair_indices_list = [] for smiles in smiles_list: molecule = molecule_from_smiles(smiles) atom_features, bond_features, pair_indices = graph_from_molecule(molecule) atom_features_list.append(atom_features) bond_features_list.append(bond_features) pair_indices_list.append(pair_indices) # Convert lists to ragged tensors for tf.data.Dataset later on return ( tf.ragged.constant(atom_features_list, dtype=tf.float32), tf.ragged.constant(bond_features_list, dtype=tf.float32), tf.ragged.constant(pair_indices_list, dtype=tf.int64), ) def prepare_batch(x_batch, y_batch): """Merges (sub)graphs of batch into a single global (disconnected) graph """ atom_features, bond_features, pair_indices = x_batch # Obtain number of atoms and bonds for each graph (molecule) num_atoms = atom_features.row_lengths() num_bonds = bond_features.row_lengths() # Obtain partition indices (molecule_indicator), which will be used to # gather (sub)graphs from global graph in model later on molecule_indices = tf.range(len(num_atoms)) molecule_indicator = tf.repeat(molecule_indices, num_atoms) # Merge (sub)graphs into a global (disconnected) graph. Adding 'increment' to # 'pair_indices' (and merging ragged tensors) actualizes the global graph gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:]) increment = tf.cumsum(num_atoms[:-1]) increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)]) pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor() pair_indices = pair_indices + increment[:, tf.newaxis] atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor() bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor() return (atom_features, bond_features, pair_indices, molecule_indicator), y_batch def MPNNDataset(X, y, batch_size=32, shuffle=False): dataset = tf.data.Dataset.from_tensor_slices((X, (y))) if shuffle: dataset = dataset.shuffle(1024) return dataset.batch(batch_size).map(prepare_batch, -1).prefetch(-1) model = from_pretrained_keras("keras-io/MPNN-for-molecular-property-prediction") def predict(smiles, label): molecules = [molecule_from_smiles(smiles)] input = graphs_from_smiles([smiles]) label = pd.Series([label]) test_dataset = MPNNDataset(input, label) y_pred = tf.squeeze(model.predict(test_dataset), axis=1) legends = [f"y_true/y_pred = {label[i]}/{y_pred[i]:.2f}" for i in range(len(label))] MolsToGridImage(molecules, molsPerRow=1, legends=legends, returnPNG=False, subImgSize=(650, 650)).save("img.png") return 'img.png' inputs = [ gr.Textbox(label='Smiles of molecular'), gr.Textbox(label='Molecular permeability') ] examples = [ ["CO/N=C(C(=O)N[C@H]1[C@H]2SCC(=C(N2C1=O)C(O)=O)C)/c3csc(N)n3", 0], ["[C@H]37[C@H]2[C@@]([C@](C(COC(C1=CC(=CC=C1)[S](O)(=O)=O)=O)=O)(O)[C@@H](C2)C)(C[C@@H]([C@@H]3[C@@]4(C(=CC5=C(C4)C=N[N]5C6=CC=CC=C6)C(=C7)C)C)O)C", 1], ["CNCCCC2(C)C(=O)N(c1ccccc1)c3ccccc23", 1], ["O.N[C@@H](C(=O)NC1C2CCC(=C(N2C1=O)C(O)=O)Cl)c3ccccc3", 0], ["[C@@]4([C@@]3([C@H]([C@H]2[C@@H]([C@@]1(C(=CC(=O)CC1)CC2)C)[C@H](C3)O)CC4)C)(C(COC(C)=O)=O)OC(CC)=O", 1], ["[C@]34([C@H](C2[C@@](F)([C@@]1(C(=CC(=O)C=C1)[C@@H](F)C2)C)[C@@H](O)C3)C[C@H]5OC(O[C@@]45C(=O)COC(=O)C6CC6)(C)C)C", 1] ] gr.Interface( fn=predict, title="Predict blood-brain barrier permeability of molecular", description = "Message-passing neural network (MPNN) for molecular property prediction", inputs=inputs, examples=examples, outputs="image", article = "Author: Vu Minh Chien. Based on the keras example from Alexander Kensert", ).launch(debug=False, enable_queue=True)