|
import streamlit as st |
|
import os |
|
import torch |
|
|
|
import numpy as np |
|
|
|
|
|
from AtomLenz import * |
|
|
|
from Object_Smiles import Objects_Smiles |
|
from rdkit import Chem |
|
from rdkit.Chem import Draw |
|
|
|
from robust_detection import utils |
|
from robust_detection.models.rcnn import RCNN |
|
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN |
|
|
|
import pytorch_lightning as pl |
|
from pytorch_lightning.loggers import WandbLogger |
|
from pytorch_lightning.loggers import CSVLogger |
|
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint |
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping |
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
from rdkit import Chem |
|
from rdkit.Chem import AllChem |
|
from rdkit import DataStructs |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
|
|
st.title("Atom Level Entity Detector") |
|
|
|
def main_page(model_file): |
|
st.markdown( |
|
"""Identifying the chemical structure from a graphical representation, or image, of a molecule is a challenging pattern recognition task that would greatly benefit drug development. Yet, existing methods for chemical structure recognition do not typically generalize well, and show diminished effectiveness when confronted with domains where data is sparse, or costly to generate, such as hand-drawn molecule images. To address this limitation, we propose a new chemical structure recognition tool that delivers state-of-the-art performance and can adapt to new domains with a limited number of data samples and supervision. Unlike previous approaches, our method provides atom-level localization, and can therefore segment the image into the different atoms and bonds. Our model is the first model to perform OCSR with atom-level entity detection with only SMILES supervision. Through rigorous and extensive benchmarking, we demonstrate the preeminence of our chemical structure recognition approach in terms of data efficiency, accuracy, and atom-level entity prediction.""" |
|
) |
|
colors = ["magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum","magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum"] |
|
def plot_bbox(bbox_XYXY, label): |
|
xmin, ymin, xmax, ymax =bbox_XYXY |
|
plt.plot( |
|
[xmin, xmin, xmax, xmax, xmin], |
|
[ymin, ymax, ymax, ymin, ymin], |
|
color=colors[label], |
|
label=str(label)) |
|
|
|
def atomlenz(modelfile): |
|
model_cls = RCNN |
|
experiment_path_atoms="./models/atoms_model/" |
|
checkpoint_file_atoms=os.path.join(experiment_path_atoms,modelfile) |
|
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms) |
|
model_atom.model.roi_heads.score_thresh = 0.65 |
|
experiment_path_bonds = "./models/bonds_model/" |
|
checkpoint_file_bonds=os.path.join(experiment_path_bonds,modelfile) |
|
model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds) |
|
model_bond.model.roi_heads.score_thresh = 0.65 |
|
experiment_path_stereo = "./models/stereos_model/" |
|
checkpoint_file_stereo=os.path.join(experiment_path_stereo,modelfile) |
|
model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo) |
|
model_stereo.model.roi_heads.score_thresh = 0.65 |
|
experiment_path_charges = "./models/charges_model/" |
|
checkpoint_file_charges=os.path.join(experiment_path_charges,modelfile) |
|
model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges) |
|
model_charge.model.roi_heads.score_thresh = 0.65 |
|
|
|
data_cls = Objects_Smiles |
|
dataset = data_cls(data_path="./uploads/", batch_size=1) |
|
|
|
image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png']) |
|
if image_file is not None: |
|
|
|
image = Image.open(image_file) |
|
st.image(image, use_column_width=True) |
|
col1, col2 = st.columns(2) |
|
if not os.path.exists("uploads/images"): |
|
os.makedirs("uploads/images") |
|
with open(os.path.join("uploads/images/","0.png"),"wb") as f: |
|
f.write(image_file.getbuffer()) |
|
dataset.prepare_data() |
|
trainer = pl.Trainer(logger=False) |
|
st.toast('Predicting atoms,bonds,charges,..., please wait') |
|
atom_preds = trainer.predict(model_atom, dataset.test_dataloader()) |
|
bond_preds = trainer.predict(model_bond, dataset.test_dataloader()) |
|
stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader()) |
|
charges_preds = trainer.predict(model_charge, dataset.test_dataloader()) |
|
st.toast('Done') |
|
plt.imshow(image, cmap="gray") |
|
for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]): |
|
plot_bbox(bbox, label) |
|
plt.axis('off') |
|
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0) |
|
image_vis = Image.open("example_image.png") |
|
col1.image(image_vis, caption=f"Atom entities", use_column_width=True) |
|
plt.clf() |
|
plt.imshow(image, cmap="gray") |
|
for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]): |
|
plot_bbox(bbox, label) |
|
plt.axis('off') |
|
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0) |
|
image_vis = Image.open("example_image.png") |
|
col2.image(image_vis, caption=f"Bond entities", use_column_width=True) |
|
mol_graphs = [] |
|
count_bonds_preds = np.zeros(4) |
|
count_atoms_preds = np.zeros(15) |
|
correct=0 |
|
correct_objects=0 |
|
correct_both=0 |
|
predictions=0 |
|
tanimoto_dists=[] |
|
predictions_list = [] |
|
for image_idx, bonds in enumerate(bond_preds): |
|
count_bonds_preds = np.zeros(8) |
|
count_atoms_preds = np.zeros(18) |
|
atom_boxes = atom_preds[image_idx]['boxes'][0] |
|
atom_labels = atom_preds[image_idx]['preds'][0] |
|
atom_scores = atom_preds[image_idx]['scores'][0] |
|
charge_boxes = charges_preds[image_idx]['boxes'][0] |
|
charge_labels = charges_preds[image_idx]['preds'][0] |
|
charge_mask=torch.where(charge_labels>1) |
|
filtered_ch_labels=charge_labels[charge_mask] |
|
filtered_ch_boxes=charge_boxes[charge_mask] |
|
filtered_bboxes, filtered_labels = iou_filter_bboxes(atom_boxes, atom_labels, atom_scores) |
|
mol_graph = np.zeros((len(filtered_bboxes),len(filtered_bboxes))) |
|
stereo_atoms = np.zeros(len(filtered_bboxes)) |
|
charge_atoms = np.ones(len(filtered_bboxes)) |
|
for index,box_atom in enumerate(filtered_bboxes): |
|
for box_charge,label_charge in zip(filtered_ch_boxes,filtered_ch_labels): |
|
if bb_box_intersects(box_atom,box_charge) == 1: |
|
charge_atoms[index]=label_charge |
|
|
|
for bond_idx, bond_box in enumerate(bonds['boxes'][0]): |
|
label_bond = bonds['preds'][0][bond_idx] |
|
if label_bond > 1: |
|
try: |
|
count_bonds_preds[label_bond] += 1 |
|
except: |
|
count_bonds_preds=count_bonds_preds |
|
result = [] |
|
limit = 0 |
|
while result.count(1) < 2 and limit < 80: |
|
result=[] |
|
bigger_bond_box = [bond_box[0]-limit,bond_box[1]-limit,bond_box[2]+limit,bond_box[3]+limit] |
|
for atom_box in filtered_bboxes: |
|
result.append(bb_box_intersects(atom_box,bigger_bond_box)) |
|
limit+=5 |
|
indices = [i for i, x in enumerate(result) if x == 1] |
|
if len(indices) == 2: |
|
mol_graph[indices[0],indices[1]]=label_bond |
|
mol_graph[indices[1],indices[0]]=label_bond |
|
if len(indices) > 2: |
|
|
|
cand_bboxes = filtered_bboxes[indices,:] |
|
cand_indices = dist_filter_bboxes(cand_bboxes) |
|
mol_graph[indices[cand_indices[0]],indices[cand_indices[1]]]=label_bond |
|
mol_graph[indices[cand_indices[1]],indices[cand_indices[0]]]=label_bond |
|
stereo_bonds = np.where(mol_graph>4, True, False) |
|
if np.any(stereo_bonds): |
|
stereo_boxes = stereo_preds[image_idx]['boxes'][0] |
|
stereo_labels= stereo_preds[image_idx]['preds'][0] |
|
for stereo_box in stereo_boxes: |
|
result=[] |
|
for atom_box in filtered_bboxes: |
|
result.append(bb_box_intersects(atom_box,stereo_box)) |
|
indices = [i for i, x in enumerate(result) if x == 1] |
|
if len(indices) == 1: |
|
stereo_atoms[indices[0]]=1 |
|
|
|
molecule = dict() |
|
molecule['graph'] = mol_graph |
|
molecule['atom_labels'] = filtered_labels |
|
molecule['atom_boxes'] = filtered_bboxes |
|
molecule['stereo_atoms'] = stereo_atoms |
|
molecule['charge_atoms'] = charge_atoms |
|
mol_graphs.append(molecule) |
|
save_mol_to_file(molecule,'molfile') |
|
mol = Chem.MolFromMolFile('molfile',sanitize=False) |
|
problematic = 0 |
|
try: |
|
problems = Chem.DetectChemistryProblems(mol) |
|
if len(problems) > 0: |
|
mol = solve_mol_problems(mol,problems) |
|
problematic = 1 |
|
try: |
|
Chem.SanitizeMol(mol) |
|
except: |
|
problems = Chem.DetectChemistryProblems(mol) |
|
if len(problems) > 0: |
|
mol = solve_mol_problems(mol,problems) |
|
try: |
|
Chem.SanitizeMol(mol) |
|
except: |
|
pass |
|
except: |
|
problematic = 1 |
|
try: |
|
pred_smiles = Chem.MolToSmiles(mol) |
|
except: |
|
pred_smiles = "" |
|
problematic = 1 |
|
predictions+=1 |
|
predictions_list.append([image_idx,pred_smiles,problematic]) |
|
file_preds = open('preds_atomlenz','w') |
|
for pred in predictions_list: |
|
smiles = pred[1] |
|
m = Chem.MolFromSmiles(smiles) |
|
if m is None: |
|
st.write("No valid chemical structure recognized") |
|
else: |
|
img = Draw.MolToImage(m) |
|
st.image(img, caption=f"Recognized structure: {smiles}", use_column_width=False, width=400) |
|
print(pred) |
|
|
|
|
|
|
|
|
|
page_to_funcs = { |
|
"Predict Atom-Level Entities": atomlenz, |
|
"About AtomLenz": main_page, |
|
|
|
} |
|
|
|
sel_page = st.sidebar.selectbox("Select task", page_to_funcs.keys()) |
|
st.sidebar.markdown('') |
|
|
|
|
|
selected_model = st.sidebar.selectbox( |
|
"Select the AtomLenz model to load", |
|
("AtomLenz trained on synthetic data (default)", "AtomLenz for hand-drawn images", "ChemExpert (not available yet)")) |
|
|
|
model_dict = { |
|
"AtomLenz trained on synthetic data (default)" : "synthetic.ckpt", |
|
"AtomLenz for hand-drawn images" : "real.ckpt", |
|
"ChemExpert (not available yet)" : "synthetic.ckpt" |
|
|
|
} |
|
|
|
model_file = model_dict[selected_model] |
|
|
|
page_to_funcs[sel_page](model_file) |
|
|
|
|
|
|
|
|