File size: 11,420 Bytes
7a223e7
6adb0d1
 
 
 
 
 
b2c3eed
6adb0d1
d0f68bc
4986573
7ab2ec8
6adb0d1
 
 
 
 
 
 
 
 
 
 
 
 
 
67828bb
648d955
 
0f3c07e
 
 
5417a1b
0f3c07e
5417a1b
4ea7e69
648d955
 
 
 
 
 
 
7a223e7
0f3c07e
 
 
663db57
0f3c07e
 
 
d88e063
0f3c07e
 
 
d88e063
0f3c07e
 
 
d88e063
0f3c07e
 
09fa344
0f3c07e
 
67828bb
0f3c07e
 
67828bb
0f3c07e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1ca99b
0f3c07e
 
 
 
 
 
 
d1ca99b
0f3c07e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c3eed
0f3c07e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c3eed
0f3c07e
b2c3eed
0f3c07e
 
 
 
 
 
4986573
 
 
0fac33f
4986573
7ab2ec8
a51d22b
0f3c07e
 
4986573
0f3c07e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663db57
 
 
0f3c07e
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import streamlit as st
import os
import torch
#import math
import numpy as np
#import matplotlib.pyplot as plt
#import pathlib
from AtomLenz import *
#from utils_graph import *
from Object_Smiles import Objects_Smiles 
from rdkit import Chem
from rdkit.Chem import Draw
#from robust_detection import wandb_config
from robust_detection import utils
from robust_detection.models.rcnn import RCNN
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN, COCO_RCNN

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from PIL import Image
import matplotlib.pyplot as plt

st.title("Atom Level Entity Detector")

def main_page(model_file):
        st.markdown(
                    """Identifying the chemical structure from a graphical representation, or image, of a molecule is a challenging pattern recognition task that would greatly benefit drug development. Yet, existing methods for chemical structure recognition do not typically generalize well, and show diminished effectiveness when confronted with domains where data is sparse, or costly to generate, such as hand-drawn molecule images. To address this limitation, we propose a new chemical structure recognition tool that delivers state-of-the-art performance and can adapt to new domains with a limited number of data samples and supervision. Unlike previous approaches, our method provides atom-level localization, and can therefore segment the image into the different atoms and bonds. Our model is the first model to perform OCSR with atom-level entity detection with only SMILES supervision. Through rigorous and extensive benchmarking, we demonstrate the preeminence of our chemical structure recognition approach in terms of data efficiency, accuracy, and atom-level entity prediction."""
                                                        )
colors = ["magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum","magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum"]
def plot_bbox(bbox_XYXY, label):
    xmin, ymin, xmax, ymax =bbox_XYXY
    plt.plot(
        [xmin, xmin, xmax, xmax, xmin],
        [ymin, ymax, ymax, ymin, ymin],
        color=colors[label],
        label=str(label))

def atomlenz(modelfile):
    model_cls = RCNN
    experiment_path_atoms="./models/atoms_model/"
    checkpoint_file_atoms=os.path.join(experiment_path_atoms,modelfile)
    model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
    model_atom.model.roi_heads.score_thresh = 0.65
    experiment_path_bonds = "./models/bonds_model/"
    checkpoint_file_bonds=os.path.join(experiment_path_bonds,modelfile)
    model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds)
    model_bond.model.roi_heads.score_thresh = 0.65
    experiment_path_stereo = "./models/stereos_model/"
    checkpoint_file_stereo=os.path.join(experiment_path_stereo,modelfile)
    model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo)
    model_stereo.model.roi_heads.score_thresh = 0.65
    experiment_path_charges = "./models/charges_model/"
    checkpoint_file_charges=os.path.join(experiment_path_charges,modelfile)
    model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges)
    model_charge.model.roi_heads.score_thresh = 0.65

    data_cls = Objects_Smiles
    dataset = data_cls(data_path="./uploads/", batch_size=1)

    image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png'])
    if image_file is not None:

       image = Image.open(image_file)
       st.image(image, use_column_width=True)
       col1, col2 = st.columns(2)
       if not os.path.exists("uploads/images"):
           os.makedirs("uploads/images")
       with open(os.path.join("uploads/images/","0.png"),"wb") as f:
            f.write(image_file.getbuffer())
       dataset.prepare_data()
       trainer = pl.Trainer(logger=False)
       st.toast('Predicting atoms,bonds,charges,..., please wait')
       atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
       bond_preds = trainer.predict(model_bond, dataset.test_dataloader())
       stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader())
       charges_preds = trainer.predict(model_charge, dataset.test_dataloader())
       st.toast('Done')
       plt.imshow(image, cmap="gray")
       for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]):
           plot_bbox(bbox, label)
       plt.axis('off')
       plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
       image_vis = Image.open("example_image.png")
       col1.image(image_vis, caption=f"Atom entities", use_column_width=True)
       plt.clf()
       plt.imshow(image, cmap="gray")
       for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]):
           plot_bbox(bbox, label)
       plt.axis('off')
       plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
       image_vis = Image.open("example_image.png")
       col2.image(image_vis, caption=f"Bond entities", use_column_width=True)
       mol_graphs = []
       count_bonds_preds = np.zeros(4)
       count_atoms_preds = np.zeros(15)
       correct=0
       correct_objects=0
       correct_both=0
       predictions=0
       tanimoto_dists=[]
       predictions_list = []
       for image_idx, bonds in enumerate(bond_preds):
            count_bonds_preds = np.zeros(8)
            count_atoms_preds = np.zeros(18)
            atom_boxes = atom_preds[image_idx]['boxes'][0]
            atom_labels = atom_preds[image_idx]['preds'][0]
            atom_scores = atom_preds[image_idx]['scores'][0]
            charge_boxes = charges_preds[image_idx]['boxes'][0]
            charge_labels = charges_preds[image_idx]['preds'][0]
            charge_mask=torch.where(charge_labels>1)
            filtered_ch_labels=charge_labels[charge_mask]
            filtered_ch_boxes=charge_boxes[charge_mask]
            filtered_bboxes, filtered_labels = iou_filter_bboxes(atom_boxes, atom_labels, atom_scores)
            mol_graph = np.zeros((len(filtered_bboxes),len(filtered_bboxes)))
            stereo_atoms = np.zeros(len(filtered_bboxes))
            charge_atoms = np.ones(len(filtered_bboxes))
            for index,box_atom in enumerate(filtered_bboxes):
                for box_charge,label_charge in zip(filtered_ch_boxes,filtered_ch_labels):
                    if bb_box_intersects(box_atom,box_charge) == 1:
                        charge_atoms[index]=label_charge

            for bond_idx, bond_box in enumerate(bonds['boxes'][0]):
                label_bond = bonds['preds'][0][bond_idx]
                if label_bond > 1:
                  try:
                     count_bonds_preds[label_bond] += 1
                  except:
                     count_bonds_preds=count_bonds_preds
                  result = []
                  limit = 0
                  while result.count(1) < 2 and limit < 80:
                     result=[]
                     bigger_bond_box = [bond_box[0]-limit,bond_box[1]-limit,bond_box[2]+limit,bond_box[3]+limit]
                     for atom_box in filtered_bboxes:
                         result.append(bb_box_intersects(atom_box,bigger_bond_box))
                     limit+=5
                  indices = [i for i, x in enumerate(result) if x == 1]
                  if len(indices) == 2:
                     mol_graph[indices[0],indices[1]]=label_bond
                     mol_graph[indices[1],indices[0]]=label_bond
                  if len(indices) > 2:
                #we have more then two canidate atoms for one bond, we filter ...
                      cand_bboxes = filtered_bboxes[indices,:]
                      cand_indices = dist_filter_bboxes(cand_bboxes)
                      mol_graph[indices[cand_indices[0]],indices[cand_indices[1]]]=label_bond
                      mol_graph[indices[cand_indices[1]],indices[cand_indices[0]]]=label_bond
            stereo_bonds = np.where(mol_graph>4, True, False)
            if np.any(stereo_bonds):
               stereo_boxes = stereo_preds[image_idx]['boxes'][0]
               stereo_labels= stereo_preds[image_idx]['preds'][0]
               for stereo_box in stereo_boxes:
                   result=[]
                   for atom_box in filtered_bboxes:
                       result.append(bb_box_intersects(atom_box,stereo_box))
                   indices = [i for i, x in enumerate(result) if x == 1]
                   if len(indices) == 1:
                       stereo_atoms[indices[0]]=1

            molecule = dict()
            molecule['graph'] = mol_graph
            molecule['atom_labels'] = filtered_labels
            molecule['atom_boxes'] = filtered_bboxes
            molecule['stereo_atoms'] = stereo_atoms
            molecule['charge_atoms'] = charge_atoms
            mol_graphs.append(molecule)
            save_mol_to_file(molecule,'molfile')
            mol =  Chem.MolFromMolFile('molfile',sanitize=False)
            problematic = 0
            try:
              problems = Chem.DetectChemistryProblems(mol)
              if len(problems) > 0:
                 mol = solve_mol_problems(mol,problems)
                 problematic = 1
              try:
                Chem.SanitizeMol(mol)
              except:
                problems = Chem.DetectChemistryProblems(mol)
                if len(problems) > 0:
                  mol = solve_mol_problems(mol,problems)
                try:
                  Chem.SanitizeMol(mol)
                except:
                  pass
            except:
              problematic = 1
            try:
              pred_smiles = Chem.MolToSmiles(mol)
            except:
              pred_smiles = ""
              problematic = 1
            predictions+=1
            predictions_list.append([image_idx,pred_smiles,problematic])
       file_preds = open('preds_atomlenz','w')
       for pred in predictions_list:
            smiles = pred[1]
            m = Chem.MolFromSmiles(smiles)
            if m is None:
               st.write("No valid chemical structure recognized.")
            else:
                img = Draw.MolToImage(m)
                st.image(img, caption=f"Recognized structure: {smiles}", use_column_width=False, width=400)
            print(pred)


#### TRYOUT MENU #####

page_to_funcs = {
    "Predict Atom-Level Entities": atomlenz,
    "About AtomLenz": main_page,
   
}

sel_page = st.sidebar.selectbox("Select task", page_to_funcs.keys())
st.sidebar.markdown('')


selected_model = st.sidebar.selectbox(
                "Select the AtomLenz model to load",
                ("AtomLenz trained on synthetic data (default)", "AtomLenz for hand-drawn images", "ChemExpert (not available yet)"))

model_dict = {
    "AtomLenz trained on synthetic data (default)" : "synthetic.ckpt",
    "AtomLenz for hand-drawn images" : "real.ckpt",
    "ChemExpert (not available yet)" : "synthetic.ckpt"

}

model_file = model_dict[selected_model]

page_to_funcs[sel_page](model_file)


######################