vumichien's picture
Update app.py
a8a8c18
from huggingface_hub import from_pretrained_keras
import gradio as gr
from rdkit import Chem, RDLogger
from rdkit.Chem.Draw import MolsToGridImage
import numpy as np
import tensorflow as tf
from tensorflow import keras
import pandas as pd
# Config
class Featurizer:
def __init__(self, allowable_sets):
self.dim = 0
self.features_mapping = {}
for k, s in allowable_sets.items():
s = sorted(list(s))
self.features_mapping[k] = dict(zip(s, range(self.dim, len(s) + self.dim)))
self.dim += len(s)
def encode(self, inputs):
output = np.zeros((self.dim,))
for name_feature, feature_mapping in self.features_mapping.items():
feature = getattr(self, name_feature)(inputs)
if feature not in feature_mapping:
continue
output[feature_mapping[feature]] = 1.0
return output
class AtomFeaturizer(Featurizer):
def __init__(self, allowable_sets):
super().__init__(allowable_sets)
def symbol(self, atom):
return atom.GetSymbol()
def n_valence(self, atom):
return atom.GetTotalValence()
def n_hydrogens(self, atom):
return atom.GetTotalNumHs()
def hybridization(self, atom):
return atom.GetHybridization().name.lower()
class BondFeaturizer(Featurizer):
def __init__(self, allowable_sets):
super().__init__(allowable_sets)
self.dim += 1
def encode(self, bond):
output = np.zeros((self.dim,))
if bond is None:
output[-1] = 1.0
return output
output = super().encode(bond)
return output
def bond_type(self, bond):
return bond.GetBondType().name.lower()
def conjugated(self, bond):
return bond.GetIsConjugated()
atom_featurizer = AtomFeaturizer(
allowable_sets={
"symbol": {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"},
"n_valence": {0, 1, 2, 3, 4, 5, 6},
"n_hydrogens": {0, 1, 2, 3, 4},
"hybridization": {"s", "sp", "sp2", "sp3"},
}
)
bond_featurizer = BondFeaturizer(
allowable_sets={
"bond_type": {"single", "double", "triple", "aromatic"},
"conjugated": {True, False},
}
)
def molecule_from_smiles(smiles):
# MolFromSmiles(m, sanitize=True) should be equivalent to
# MolFromSmiles(m, sanitize=False) -> SanitizeMol(m) -> AssignStereochemistry(m, ...)
molecule = Chem.MolFromSmiles(smiles, sanitize=False)
# If sanitization is unsuccessful, catch the error, and try again without
# the sanitization step that caused the error
flag = Chem.SanitizeMol(molecule, catchErrors=True)
if flag != Chem.SanitizeFlags.SANITIZE_NONE:
Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag)
Chem.AssignStereochemistry(molecule, cleanIt=True, force=True)
return molecule
def graph_from_molecule(molecule):
# Initialize graph
atom_features = []
bond_features = []
pair_indices = []
for atom in molecule.GetAtoms():
atom_features.append(atom_featurizer.encode(atom))
# Add self-loops
pair_indices.append([atom.GetIdx(), atom.GetIdx()])
bond_features.append(bond_featurizer.encode(None))
for neighbor in atom.GetNeighbors():
bond = molecule.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx())
pair_indices.append([atom.GetIdx(), neighbor.GetIdx()])
bond_features.append(bond_featurizer.encode(bond))
return np.array(atom_features), np.array(bond_features), np.array(pair_indices)
def graphs_from_smiles(smiles_list):
# Initialize graphs
atom_features_list = []
bond_features_list = []
pair_indices_list = []
for smiles in smiles_list:
molecule = molecule_from_smiles(smiles)
atom_features, bond_features, pair_indices = graph_from_molecule(molecule)
atom_features_list.append(atom_features)
bond_features_list.append(bond_features)
pair_indices_list.append(pair_indices)
# Convert lists to ragged tensors for tf.data.Dataset later on
return (
tf.ragged.constant(atom_features_list, dtype=tf.float32),
tf.ragged.constant(bond_features_list, dtype=tf.float32),
tf.ragged.constant(pair_indices_list, dtype=tf.int64),
)
def prepare_batch(x_batch, y_batch):
"""Merges (sub)graphs of batch into a single global (disconnected) graph
"""
atom_features, bond_features, pair_indices = x_batch
# Obtain number of atoms and bonds for each graph (molecule)
num_atoms = atom_features.row_lengths()
num_bonds = bond_features.row_lengths()
# Obtain partition indices (molecule_indicator), which will be used to
# gather (sub)graphs from global graph in model later on
molecule_indices = tf.range(len(num_atoms))
molecule_indicator = tf.repeat(molecule_indices, num_atoms)
# Merge (sub)graphs into a global (disconnected) graph. Adding 'increment' to
# 'pair_indices' (and merging ragged tensors) actualizes the global graph
gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:])
increment = tf.cumsum(num_atoms[:-1])
increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)])
pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
pair_indices = pair_indices + increment[:, tf.newaxis]
atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
return (atom_features, bond_features, pair_indices, molecule_indicator), y_batch
def MPNNDataset(X, y, batch_size=32, shuffle=False):
dataset = tf.data.Dataset.from_tensor_slices((X, (y)))
if shuffle:
dataset = dataset.shuffle(1024)
return dataset.batch(batch_size).map(prepare_batch, -1).prefetch(-1)
model = from_pretrained_keras("keras-io/MPNN-for-molecular-property-prediction")
def predict(smiles, label):
molecules = [molecule_from_smiles(smiles)]
input = graphs_from_smiles([smiles])
label = pd.Series([label])
test_dataset = MPNNDataset(input, label)
y_pred = tf.squeeze(model.predict(test_dataset), axis=1)
legends = [f"y_true/y_pred = {label[i]}/{y_pred[i]:.2f}" for i in range(len(label))]
MolsToGridImage(molecules, molsPerRow=1, legends=legends, returnPNG=False, subImgSize=(650, 650)).save("img.png")
return 'img.png'
inputs = [
gr.Textbox(label='Smiles of molecular'),
gr.Textbox(label='Molecular permeability')
]
examples = [
["CO/N=C(C(=O)N[C@H]1[C@H]2SCC(=C(N2C1=O)C(O)=O)C)/c3csc(N)n3", 0],
["[C@H]37[C@H]2[C@@]([C@](C(COC(C1=CC(=CC=C1)[S](O)(=O)=O)=O)=O)(O)[C@@H](C2)C)(C[C@@H]([C@@H]3[C@@]4(C(=CC5=C(C4)C=N[N]5C6=CC=CC=C6)C(=C7)C)C)O)C", 1],
["CNCCCC2(C)C(=O)N(c1ccccc1)c3ccccc23", 1],
["O.N[C@@H](C(=O)NC1C2CCC(=C(N2C1=O)C(O)=O)Cl)c3ccccc3", 0],
["[C@@]4([C@@]3([C@H]([C@H]2[C@@H]([C@@]1(C(=CC(=O)CC1)CC2)C)[C@H](C3)O)CC4)C)(C(COC(C)=O)=O)OC(CC)=O", 1],
["[C@]34([C@H](C2[C@@](F)([C@@]1(C(=CC(=O)C=C1)[C@@H](F)C2)C)[C@@H](O)C3)C[C@H]5OC(O[C@@]45C(=O)COC(=O)C6CC6)(C)C)C", 1]
]
gr.Interface(
fn=predict,
title="Predict blood-brain barrier permeability of molecular",
description = "Message-passing neural network (MPNN) for molecular property prediction",
inputs=inputs,
examples=examples,
outputs="image",
article = "Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the keras example from <a href=\"https://keras.io/examples/graph/mpnn-molecular-graphs/\">Alexander Kensert</a>",
).launch(debug=False, enable_queue=True)