File size: 11,420 Bytes
7a223e7 6adb0d1 b2c3eed 6adb0d1 d0f68bc 4986573 7ab2ec8 6adb0d1 67828bb 648d955 0f3c07e 5417a1b 0f3c07e 5417a1b 4ea7e69 648d955 7a223e7 0f3c07e 663db57 0f3c07e d88e063 0f3c07e d88e063 0f3c07e d88e063 0f3c07e 09fa344 0f3c07e 67828bb 0f3c07e 67828bb 0f3c07e d1ca99b 0f3c07e d1ca99b 0f3c07e b2c3eed 0f3c07e b2c3eed 0f3c07e b2c3eed 0f3c07e 4986573 0fac33f 4986573 7ab2ec8 a51d22b 0f3c07e 4986573 0f3c07e 663db57 0f3c07e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
import streamlit as st
import os
import torch
#import math
import numpy as np
#import matplotlib.pyplot as plt
#import pathlib
from AtomLenz import *
#from utils_graph import *
from Object_Smiles import Objects_Smiles
from rdkit import Chem
from rdkit.Chem import Draw
#from robust_detection import wandb_config
from robust_detection import utils
from robust_detection.models.rcnn import RCNN
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from PIL import Image
import matplotlib.pyplot as plt
st.title("Atom Level Entity Detector")
def main_page(model_file):
st.markdown(
"""Identifying the chemical structure from a graphical representation, or image, of a molecule is a challenging pattern recognition task that would greatly benefit drug development. Yet, existing methods for chemical structure recognition do not typically generalize well, and show diminished effectiveness when confronted with domains where data is sparse, or costly to generate, such as hand-drawn molecule images. To address this limitation, we propose a new chemical structure recognition tool that delivers state-of-the-art performance and can adapt to new domains with a limited number of data samples and supervision. Unlike previous approaches, our method provides atom-level localization, and can therefore segment the image into the different atoms and bonds. Our model is the first model to perform OCSR with atom-level entity detection with only SMILES supervision. Through rigorous and extensive benchmarking, we demonstrate the preeminence of our chemical structure recognition approach in terms of data efficiency, accuracy, and atom-level entity prediction."""
)
colors = ["magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum","magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum"]
def plot_bbox(bbox_XYXY, label):
xmin, ymin, xmax, ymax =bbox_XYXY
plt.plot(
[xmin, xmin, xmax, xmax, xmin],
[ymin, ymax, ymax, ymin, ymin],
color=colors[label],
label=str(label))
def atomlenz(modelfile):
model_cls = RCNN
experiment_path_atoms="./models/atoms_model/"
checkpoint_file_atoms=os.path.join(experiment_path_atoms,modelfile)
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
model_atom.model.roi_heads.score_thresh = 0.65
experiment_path_bonds = "./models/bonds_model/"
checkpoint_file_bonds=os.path.join(experiment_path_bonds,modelfile)
model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds)
model_bond.model.roi_heads.score_thresh = 0.65
experiment_path_stereo = "./models/stereos_model/"
checkpoint_file_stereo=os.path.join(experiment_path_stereo,modelfile)
model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo)
model_stereo.model.roi_heads.score_thresh = 0.65
experiment_path_charges = "./models/charges_model/"
checkpoint_file_charges=os.path.join(experiment_path_charges,modelfile)
model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges)
model_charge.model.roi_heads.score_thresh = 0.65
data_cls = Objects_Smiles
dataset = data_cls(data_path="./uploads/", batch_size=1)
image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png'])
if image_file is not None:
image = Image.open(image_file)
st.image(image, use_column_width=True)
col1, col2 = st.columns(2)
if not os.path.exists("uploads/images"):
os.makedirs("uploads/images")
with open(os.path.join("uploads/images/","0.png"),"wb") as f:
f.write(image_file.getbuffer())
dataset.prepare_data()
trainer = pl.Trainer(logger=False)
st.toast('Predicting atoms,bonds,charges,..., please wait')
atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
bond_preds = trainer.predict(model_bond, dataset.test_dataloader())
stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader())
charges_preds = trainer.predict(model_charge, dataset.test_dataloader())
st.toast('Done')
plt.imshow(image, cmap="gray")
for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]):
plot_bbox(bbox, label)
plt.axis('off')
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
image_vis = Image.open("example_image.png")
col1.image(image_vis, caption=f"Atom entities", use_column_width=True)
plt.clf()
plt.imshow(image, cmap="gray")
for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]):
plot_bbox(bbox, label)
plt.axis('off')
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
image_vis = Image.open("example_image.png")
col2.image(image_vis, caption=f"Bond entities", use_column_width=True)
mol_graphs = []
count_bonds_preds = np.zeros(4)
count_atoms_preds = np.zeros(15)
correct=0
correct_objects=0
correct_both=0
predictions=0
tanimoto_dists=[]
predictions_list = []
for image_idx, bonds in enumerate(bond_preds):
count_bonds_preds = np.zeros(8)
count_atoms_preds = np.zeros(18)
atom_boxes = atom_preds[image_idx]['boxes'][0]
atom_labels = atom_preds[image_idx]['preds'][0]
atom_scores = atom_preds[image_idx]['scores'][0]
charge_boxes = charges_preds[image_idx]['boxes'][0]
charge_labels = charges_preds[image_idx]['preds'][0]
charge_mask=torch.where(charge_labels>1)
filtered_ch_labels=charge_labels[charge_mask]
filtered_ch_boxes=charge_boxes[charge_mask]
filtered_bboxes, filtered_labels = iou_filter_bboxes(atom_boxes, atom_labels, atom_scores)
mol_graph = np.zeros((len(filtered_bboxes),len(filtered_bboxes)))
stereo_atoms = np.zeros(len(filtered_bboxes))
charge_atoms = np.ones(len(filtered_bboxes))
for index,box_atom in enumerate(filtered_bboxes):
for box_charge,label_charge in zip(filtered_ch_boxes,filtered_ch_labels):
if bb_box_intersects(box_atom,box_charge) == 1:
charge_atoms[index]=label_charge
for bond_idx, bond_box in enumerate(bonds['boxes'][0]):
label_bond = bonds['preds'][0][bond_idx]
if label_bond > 1:
try:
count_bonds_preds[label_bond] += 1
except:
count_bonds_preds=count_bonds_preds
result = []
limit = 0
while result.count(1) < 2 and limit < 80:
result=[]
bigger_bond_box = [bond_box[0]-limit,bond_box[1]-limit,bond_box[2]+limit,bond_box[3]+limit]
for atom_box in filtered_bboxes:
result.append(bb_box_intersects(atom_box,bigger_bond_box))
limit+=5
indices = [i for i, x in enumerate(result) if x == 1]
if len(indices) == 2:
mol_graph[indices[0],indices[1]]=label_bond
mol_graph[indices[1],indices[0]]=label_bond
if len(indices) > 2:
#we have more then two canidate atoms for one bond, we filter ...
cand_bboxes = filtered_bboxes[indices,:]
cand_indices = dist_filter_bboxes(cand_bboxes)
mol_graph[indices[cand_indices[0]],indices[cand_indices[1]]]=label_bond
mol_graph[indices[cand_indices[1]],indices[cand_indices[0]]]=label_bond
stereo_bonds = np.where(mol_graph>4, True, False)
if np.any(stereo_bonds):
stereo_boxes = stereo_preds[image_idx]['boxes'][0]
stereo_labels= stereo_preds[image_idx]['preds'][0]
for stereo_box in stereo_boxes:
result=[]
for atom_box in filtered_bboxes:
result.append(bb_box_intersects(atom_box,stereo_box))
indices = [i for i, x in enumerate(result) if x == 1]
if len(indices) == 1:
stereo_atoms[indices[0]]=1
molecule = dict()
molecule['graph'] = mol_graph
molecule['atom_labels'] = filtered_labels
molecule['atom_boxes'] = filtered_bboxes
molecule['stereo_atoms'] = stereo_atoms
molecule['charge_atoms'] = charge_atoms
mol_graphs.append(molecule)
save_mol_to_file(molecule,'molfile')
mol = Chem.MolFromMolFile('molfile',sanitize=False)
problematic = 0
try:
problems = Chem.DetectChemistryProblems(mol)
if len(problems) > 0:
mol = solve_mol_problems(mol,problems)
problematic = 1
try:
Chem.SanitizeMol(mol)
except:
problems = Chem.DetectChemistryProblems(mol)
if len(problems) > 0:
mol = solve_mol_problems(mol,problems)
try:
Chem.SanitizeMol(mol)
except:
pass
except:
problematic = 1
try:
pred_smiles = Chem.MolToSmiles(mol)
except:
pred_smiles = ""
problematic = 1
predictions+=1
predictions_list.append([image_idx,pred_smiles,problematic])
file_preds = open('preds_atomlenz','w')
for pred in predictions_list:
smiles = pred[1]
m = Chem.MolFromSmiles(smiles)
if m is None:
st.write("No valid chemical structure recognized.")
else:
img = Draw.MolToImage(m)
st.image(img, caption=f"Recognized structure: {smiles}", use_column_width=False, width=400)
print(pred)
#### TRYOUT MENU #####
page_to_funcs = {
"Predict Atom-Level Entities": atomlenz,
"About AtomLenz": main_page,
}
sel_page = st.sidebar.selectbox("Select task", page_to_funcs.keys())
st.sidebar.markdown('')
selected_model = st.sidebar.selectbox(
"Select the AtomLenz model to load",
("AtomLenz trained on synthetic data (default)", "AtomLenz for hand-drawn images", "ChemExpert (not available yet)"))
model_dict = {
"AtomLenz trained on synthetic data (default)" : "synthetic.ckpt",
"AtomLenz for hand-drawn images" : "real.ckpt",
"ChemExpert (not available yet)" : "synthetic.ckpt"
}
model_file = model_dict[selected_model]
page_to_funcs[sel_page](model_file)
######################
|