moldenhof commited on
Commit
0f3c07e
1 Parent(s): 7cfdbf2

test with menu

Browse files
Files changed (1) hide show
  1. app.py +200 -234
app.py CHANGED
@@ -26,59 +26,12 @@ from rdkit import DataStructs
26
  from PIL import Image
27
  import matplotlib.pyplot as plt
28
 
29
- def main_page(top_n, model_path):
 
 
30
  st.markdown(
31
- """test """
32
  )
33
-
34
- #### TRYOUT MENU #####
35
-
36
- page_names_to_funcs = {
37
- # "Microscopy images from a molecule": images_from_molecule,
38
- # "Molecules from a microscopy image": molecules_from_image,
39
- "About AtomLenz": main_page,
40
-
41
- }
42
-
43
- selected_page = st.sidebar.selectbox("What would you like to retrieve?", page_names_to_funcs.keys())
44
- st.sidebar.markdown('')
45
-
46
-
47
- selected_model = st.sidebar.selectbox(
48
- "Select a AtomLenz model to load",
49
- ("AtomLenz trained on synthetic data (default)", "AtomLenz for hand-drawn images", "ChemExpert (not available yet)"))
50
-
51
- model_dict = {
52
- "AtomLenz trained on synthetic data (default)" : "atomlenz_default.pt",
53
- "AtomLenz for hand-drawn images" : "atomlenz_handdrawn.pt",
54
- "ChemExpert (not available yet)" : "atomlenz_default.pt"
55
-
56
- }
57
-
58
- model_file = model_dict[selected_model]
59
- #model_path = os.path.join(datapath, model_file)
60
-
61
- #if model_path.endswith("320).pt"):
62
- # image_resolution = 320
63
- #else:
64
- # image_resolution = 520
65
-
66
-
67
- #page_names_to_funcs[selected_page](n_objects, model_path)
68
-
69
-
70
-
71
-
72
- ######################
73
-
74
-
75
-
76
-
77
-
78
-
79
-
80
-
81
-
82
  colors = ["magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum","magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum"]
83
  def plot_bbox(bbox_XYXY, label):
84
  xmin, ymin, xmax, ymax =bbox_XYXY
@@ -88,213 +41,226 @@ def plot_bbox(bbox_XYXY, label):
88
  color=colors[label],
89
  label=str(label))
90
 
91
- model_cls = RCNN
92
- experiment_path_atoms="./models/atoms_model/"
93
- dir_list = os.listdir(experiment_path_atoms)
94
- dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list]
95
- dir_list.sort(key=os.path.getctime, reverse=True)
96
- checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
97
- model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
98
- model_atom.model.roi_heads.score_thresh = 0.65
99
- experiment_path_bonds = "./models/bonds_model/"
100
- dir_list = os.listdir(experiment_path_bonds)
101
- dir_list = [os.path.join(experiment_path_bonds,f) for f in dir_list]
102
- dir_list.sort(key=os.path.getctime, reverse=True)
103
- checkpoint_file_bonds = [f for f in dir_list if "ckpt" in f][0]
104
- model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds)
105
- model_bond.model.roi_heads.score_thresh = 0.65
106
- experiment_path_stereo = "./models/stereos_model/"
107
- dir_list = os.listdir(experiment_path_stereo)
108
- dir_list = [os.path.join(experiment_path_stereo,f) for f in dir_list]
109
- dir_list.sort(key=os.path.getctime, reverse=True)
110
- checkpoint_file_stereo = [f for f in dir_list if "ckpt" in f][0]
111
- model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo)
112
- model_stereo.model.roi_heads.score_thresh = 0.65
113
- experiment_path_charges = "./models/charges_model/"
114
- dir_list = os.listdir(experiment_path_charges)
115
- dir_list = [os.path.join(experiment_path_charges,f) for f in dir_list]
116
- dir_list.sort(key=os.path.getctime, reverse=True)
117
- checkpoint_file_charges = [f for f in dir_list if "ckpt" in f][0]
118
- model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges)
119
- model_charge.model.roi_heads.score_thresh = 0.65
 
120
 
121
- data_cls = Objects_Smiles
122
- dataset = data_cls(data_path="./uploads/", batch_size=1)
123
  # dataset.prepare_data()
124
- st.title("Atom Level Entity Detector")
125
 
126
- image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png'])
127
  #st.write('filename is', file_name)
128
- if image_file is not None:
129
  #col1, col2 = st.columns(2)
130
 
131
- image = Image.open(image_file)
132
  #col1.image(image, use_column_width=True)
133
- st.image(image, use_column_width=True)
134
- col1, col2 = st.columns(2)
135
- if not os.path.exists("uploads/images"):
136
- os.makedirs("uploads/images")
137
- with open(os.path.join("uploads/images/","0.png"),"wb") as f:
138
- f.write(image_file.getbuffer())
139
  #st.success("Saved File")
140
- dataset.prepare_data()
141
- trainer = pl.Trainer(logger=False)
142
- st.toast('Predicting atoms,bonds,charges,..., please wait')
143
- atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
144
- bond_preds = trainer.predict(model_bond, dataset.test_dataloader())
145
- stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader())
146
- charges_preds = trainer.predict(model_charge, dataset.test_dataloader())
147
- st.toast('Done')
148
  #st.write(atom_preds)
149
- plt.imshow(image, cmap="gray")
150
- for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]):
151
  # st.write(bbox)
152
  # st.write(label)
153
- plot_bbox(bbox, label)
154
- plt.axis('off')
155
- plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
156
- image_vis = Image.open("example_image.png")
157
- col1.image(image_vis, use_column_width=True)
158
- plt.clf()
159
- plt.imshow(image, cmap="gray")
160
- for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]):
161
  # st.write(bbox)
162
  # st.write(label)
163
- plot_bbox(bbox, label)
164
- plt.axis('off')
165
- plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
166
- image_vis = Image.open("example_image.png")
167
- col2.image(image_vis, use_column_width=True)
168
- mol_graphs = []
169
- count_bonds_preds = np.zeros(4)
170
- count_atoms_preds = np.zeros(15)
171
- correct=0
172
- correct_objects=0
173
- correct_both=0
174
- predictions=0
175
- tanimoto_dists=[]
176
- predictions_list = []
177
- for image_idx, bonds in enumerate(bond_preds):
178
- count_bonds_preds = np.zeros(8)
179
- count_atoms_preds = np.zeros(18)
180
- atom_boxes = atom_preds[image_idx]['boxes'][0]
181
- atom_labels = atom_preds[image_idx]['preds'][0]
182
- atom_scores = atom_preds[image_idx]['scores'][0]
183
- charge_boxes = charges_preds[image_idx]['boxes'][0]
184
- charge_labels = charges_preds[image_idx]['preds'][0]
185
- charge_mask=torch.where(charge_labels>1)
186
- filtered_ch_labels=charge_labels[charge_mask]
187
- filtered_ch_boxes=charge_boxes[charge_mask]
188
  #import ipdb; ipdb.set_trace()
189
- filtered_bboxes, filtered_labels = iou_filter_bboxes(atom_boxes, atom_labels, atom_scores)
190
  #for atom_label in filtered_labels:
191
  # count_atoms_preds[atom_label] += 1
192
  #import ipdb; ipdb.set_trace()
193
- mol_graph = np.zeros((len(filtered_bboxes),len(filtered_bboxes)))
194
- stereo_atoms = np.zeros(len(filtered_bboxes))
195
- charge_atoms = np.ones(len(filtered_bboxes))
196
- for index,box_atom in enumerate(filtered_bboxes):
197
- for box_charge,label_charge in zip(filtered_ch_boxes,filtered_ch_labels):
198
- if bb_box_intersects(box_atom,box_charge) == 1:
199
- charge_atoms[index]=label_charge
200
-
201
- for bond_idx, bond_box in enumerate(bonds['boxes'][0]):
202
- label_bond = bonds['preds'][0][bond_idx]
203
- if label_bond > 1:
204
- try:
205
- count_bonds_preds[label_bond] += 1
206
- except:
207
- count_bonds_preds=count_bonds_preds
208
  #import ipdb; ipdb.set_trace()
209
- result = []
210
- limit = 0
211
  #TODO: values of 50 and 5 should be made dependent of mean size of atom_boxes
212
- while result.count(1) < 2 and limit < 80:
213
- result=[]
214
- bigger_bond_box = [bond_box[0]-limit,bond_box[1]-limit,bond_box[2]+limit,bond_box[3]+limit]
215
- for atom_box in filtered_bboxes:
216
- result.append(bb_box_intersects(atom_box,bigger_bond_box))
217
- limit+=5
218
- indices = [i for i, x in enumerate(result) if x == 1]
219
- if len(indices) == 2:
220
  #import ipdb; ipdb.set_trace()
221
- mol_graph[indices[0],indices[1]]=label_bond
222
- mol_graph[indices[1],indices[0]]=label_bond
223
- if len(indices) > 2:
224
  #we have more then two canidate atoms for one bond, we filter ...
225
- cand_bboxes = filtered_bboxes[indices,:]
226
- cand_indices = dist_filter_bboxes(cand_bboxes)
227
  #import ipdb; ipdb.set_trace()
228
- mol_graph[indices[cand_indices[0]],indices[cand_indices[1]]]=label_bond
229
- mol_graph[indices[cand_indices[1]],indices[cand_indices[0]]]=label_bond
230
- #print("more than 2 indices")
231
- #if len(indices) < 2:
232
- # print("less than 2 indices")
233
- #import ipdb; ipdb.set_trace()
234
- # else:
235
- # result=[]
236
- # for atom_box in filtered_bboxes:
237
- # result.append(bb_box_intersects(atom_box,bond_box))
238
- # indices = [i for i, x in enumerate(result) if x == 1]
239
- # if len(indices) == 1:
240
- # stereo_atoms[indices[0]]=label_bond
241
- stereo_bonds = np.where(mol_graph>4, True, False)
242
- if np.any(stereo_bonds):
243
- stereo_boxes = stereo_preds[image_idx]['boxes'][0]
244
- stereo_labels= stereo_preds[image_idx]['preds'][0]
245
- for stereo_box in stereo_boxes:
246
- result=[]
247
- for atom_box in filtered_bboxes:
248
- result.append(bb_box_intersects(atom_box,stereo_box))
249
- indices = [i for i, x in enumerate(result) if x == 1]
250
- if len(indices) == 1:
251
- stereo_atoms[indices[0]]=1
252
-
253
- molecule = dict()
254
- molecule['graph'] = mol_graph
255
  #molecule['atom_labels'] = atom_preds[image_idx]['preds'][0]
256
- molecule['atom_labels'] = filtered_labels
257
- molecule['atom_boxes'] = filtered_bboxes
258
- molecule['stereo_atoms'] = stereo_atoms
259
- molecule['charge_atoms'] = charge_atoms
260
- mol_graphs.append(molecule)
261
- #base_path="./"
262
- #base_path = pathlib.Path(args.data_path)
263
- #image_dir = base_path.joinpath("images")
264
- #smiles_dir = base_path.joinpath("smiles")
265
- #impath = image_dir.joinpath(f"{image_idx}.png")
266
- #smilespath = smiles_dir.joinpath(f"{image_idx}.txt")
267
- save_mol_to_file(molecule,'molfile')
268
- mol = Chem.MolFromMolFile('molfile',sanitize=False)
269
- problematic = 0
270
- try:
271
- problems = Chem.DetectChemistryProblems(mol)
272
- if len(problems) > 0:
273
- mol = solve_mol_problems(mol,problems)
274
- problematic = 1
275
  #import ipdb; ipdb.set_trace()
276
- try:
277
- Chem.SanitizeMol(mol)
278
- except:
279
- problems = Chem.DetectChemistryProblems(mol)
280
- if len(problems) > 0:
281
- mol = solve_mol_problems(mol,problems)
 
 
 
 
 
 
282
  try:
283
- Chem.SanitizeMol(mol)
284
  except:
285
- pass
286
- except:
287
- problematic = 1
288
- try:
289
- pred_smiles = Chem.MolToSmiles(mol)
290
- except:
291
- pred_smiles = ""
292
- problematic = 1
293
- predictions+=1
294
- predictions_list.append([image_idx,pred_smiles,problematic])
295
  #import ipdb; ipdb.set_trace()
296
- file_preds = open('preds_atomlenz','w')
297
- for pred in predictions_list:
298
- print(pred)
299
- #x = st.slider('Select a value')
300
- #st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  from PIL import Image
27
  import matplotlib.pyplot as plt
28
 
29
+ st.title("Atom Level Entity Detector")
30
+
31
+ def main_page(model_file):
32
  st.markdown(
33
+ """Identifying the chemical structure from a graphical representation, or image, of a molecule is a challenging pattern recognition task that would greatly benefit drug development. Yet, existing methods for chemical structure recognition do not typically generalize well, and show diminished effectiveness when confronted with domains where data is sparse, or costly to generate, such as hand-drawn molecule images. To address this limitation, we propose a new chemical structure recognition tool that delivers state-of-the-art performance and can adapt to new domains with a limited number of data samples and supervision. Unlike previous approaches, our method provides atom-level localization, and can therefore segment the image into the different atoms and bonds. Our model is the first model to perform OCSR with atom-level entity detection with only SMILES supervision. Through rigorous and extensive benchmarking, we demonstrate the preeminence of our chemical structure recognition approach in terms of data efficiency, accuracy, and atom-level entity prediction."""
34
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  colors = ["magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum","magenta", "green", "blue", "red", "orange", "magenta", "peru", "azure", "slateblue", "plum"]
36
  def plot_bbox(bbox_XYXY, label):
37
  xmin, ymin, xmax, ymax =bbox_XYXY
 
41
  color=colors[label],
42
  label=str(label))
43
 
44
+ def atomlenz(modelfile):
45
+ model_cls = RCNN
46
+ experiment_path_atoms="./models/atoms_model/"
47
+ dir_list = os.listdir(experiment_path_atoms)
48
+ dir_list = [os.path.join(experiment_path_atoms,f) for f in dir_list]
49
+ dir_list.sort(key=os.path.getctime, reverse=True)
50
+ checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
51
+ model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
52
+ model_atom.model.roi_heads.score_thresh = 0.65
53
+ experiment_path_bonds = "./models/bonds_model/"
54
+ dir_list = os.listdir(experiment_path_bonds)
55
+ dir_list = [os.path.join(experiment_path_bonds,f) for f in dir_list]
56
+ dir_list.sort(key=os.path.getctime, reverse=True)
57
+ checkpoint_file_bonds = [f for f in dir_list if "ckpt" in f][0]
58
+ model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds)
59
+ model_bond.model.roi_heads.score_thresh = 0.65
60
+ experiment_path_stereo = "./models/stereos_model/"
61
+ dir_list = os.listdir(experiment_path_stereo)
62
+ dir_list = [os.path.join(experiment_path_stereo,f) for f in dir_list]
63
+ dir_list.sort(key=os.path.getctime, reverse=True)
64
+ checkpoint_file_stereo = [f for f in dir_list if "ckpt" in f][0]
65
+ model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo)
66
+ model_stereo.model.roi_heads.score_thresh = 0.65
67
+ experiment_path_charges = "./models/charges_model/"
68
+ dir_list = os.listdir(experiment_path_charges)
69
+ dir_list = [os.path.join(experiment_path_charges,f) for f in dir_list]
70
+ dir_list.sort(key=os.path.getctime, reverse=True)
71
+ checkpoint_file_charges = [f for f in dir_list if "ckpt" in f][0]
72
+ model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges)
73
+ model_charge.model.roi_heads.score_thresh = 0.65
74
 
75
+ data_cls = Objects_Smiles
76
+ dataset = data_cls(data_path="./uploads/", batch_size=1)
77
  # dataset.prepare_data()
 
78
 
79
+ image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png'])
80
  #st.write('filename is', file_name)
81
+ if image_file is not None:
82
  #col1, col2 = st.columns(2)
83
 
84
+ image = Image.open(image_file)
85
  #col1.image(image, use_column_width=True)
86
+ st.image(image, use_column_width=True)
87
+ col1, col2 = st.columns(2)
88
+ if not os.path.exists("uploads/images"):
89
+ os.makedirs("uploads/images")
90
+ with open(os.path.join("uploads/images/","0.png"),"wb") as f:
91
+ f.write(image_file.getbuffer())
92
  #st.success("Saved File")
93
+ dataset.prepare_data()
94
+ trainer = pl.Trainer(logger=False)
95
+ st.toast('Predicting atoms,bonds,charges,..., please wait')
96
+ atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
97
+ bond_preds = trainer.predict(model_bond, dataset.test_dataloader())
98
+ stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader())
99
+ charges_preds = trainer.predict(model_charge, dataset.test_dataloader())
100
+ st.toast('Done')
101
  #st.write(atom_preds)
102
+ plt.imshow(image, cmap="gray")
103
+ for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]):
104
  # st.write(bbox)
105
  # st.write(label)
106
+ plot_bbox(bbox, label)
107
+ plt.axis('off')
108
+ plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
109
+ image_vis = Image.open("example_image.png")
110
+ col1.image(image_vis, use_column_width=True)
111
+ plt.clf()
112
+ plt.imshow(image, cmap="gray")
113
+ for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]):
114
  # st.write(bbox)
115
  # st.write(label)
116
+ plot_bbox(bbox, label)
117
+ plt.axis('off')
118
+ plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
119
+ image_vis = Image.open("example_image.png")
120
+ col2.image(image_vis, use_column_width=True)
121
+ mol_graphs = []
122
+ count_bonds_preds = np.zeros(4)
123
+ count_atoms_preds = np.zeros(15)
124
+ correct=0
125
+ correct_objects=0
126
+ correct_both=0
127
+ predictions=0
128
+ tanimoto_dists=[]
129
+ predictions_list = []
130
+ for image_idx, bonds in enumerate(bond_preds):
131
+ count_bonds_preds = np.zeros(8)
132
+ count_atoms_preds = np.zeros(18)
133
+ atom_boxes = atom_preds[image_idx]['boxes'][0]
134
+ atom_labels = atom_preds[image_idx]['preds'][0]
135
+ atom_scores = atom_preds[image_idx]['scores'][0]
136
+ charge_boxes = charges_preds[image_idx]['boxes'][0]
137
+ charge_labels = charges_preds[image_idx]['preds'][0]
138
+ charge_mask=torch.where(charge_labels>1)
139
+ filtered_ch_labels=charge_labels[charge_mask]
140
+ filtered_ch_boxes=charge_boxes[charge_mask]
141
  #import ipdb; ipdb.set_trace()
142
+ filtered_bboxes, filtered_labels = iou_filter_bboxes(atom_boxes, atom_labels, atom_scores)
143
  #for atom_label in filtered_labels:
144
  # count_atoms_preds[atom_label] += 1
145
  #import ipdb; ipdb.set_trace()
146
+ mol_graph = np.zeros((len(filtered_bboxes),len(filtered_bboxes)))
147
+ stereo_atoms = np.zeros(len(filtered_bboxes))
148
+ charge_atoms = np.ones(len(filtered_bboxes))
149
+ for index,box_atom in enumerate(filtered_bboxes):
150
+ for box_charge,label_charge in zip(filtered_ch_boxes,filtered_ch_labels):
151
+ if bb_box_intersects(box_atom,box_charge) == 1:
152
+ charge_atoms[index]=label_charge
153
+
154
+ for bond_idx, bond_box in enumerate(bonds['boxes'][0]):
155
+ label_bond = bonds['preds'][0][bond_idx]
156
+ if label_bond > 1:
157
+ try:
158
+ count_bonds_preds[label_bond] += 1
159
+ except:
160
+ count_bonds_preds=count_bonds_preds
161
  #import ipdb; ipdb.set_trace()
162
+ result = []
163
+ limit = 0
164
  #TODO: values of 50 and 5 should be made dependent of mean size of atom_boxes
165
+ while result.count(1) < 2 and limit < 80:
166
+ result=[]
167
+ bigger_bond_box = [bond_box[0]-limit,bond_box[1]-limit,bond_box[2]+limit,bond_box[3]+limit]
168
+ for atom_box in filtered_bboxes:
169
+ result.append(bb_box_intersects(atom_box,bigger_bond_box))
170
+ limit+=5
171
+ indices = [i for i, x in enumerate(result) if x == 1]
172
+ if len(indices) == 2:
173
  #import ipdb; ipdb.set_trace()
174
+ mol_graph[indices[0],indices[1]]=label_bond
175
+ mol_graph[indices[1],indices[0]]=label_bond
176
+ if len(indices) > 2:
177
  #we have more then two canidate atoms for one bond, we filter ...
178
+ cand_bboxes = filtered_bboxes[indices,:]
179
+ cand_indices = dist_filter_bboxes(cand_bboxes)
180
  #import ipdb; ipdb.set_trace()
181
+ mol_graph[indices[cand_indices[0]],indices[cand_indices[1]]]=label_bond
182
+ mol_graph[indices[cand_indices[1]],indices[cand_indices[0]]]=label_bond
183
+ stereo_bonds = np.where(mol_graph>4, True, False)
184
+ if np.any(stereo_bonds):
185
+ stereo_boxes = stereo_preds[image_idx]['boxes'][0]
186
+ stereo_labels= stereo_preds[image_idx]['preds'][0]
187
+ for stereo_box in stereo_boxes:
188
+ result=[]
189
+ for atom_box in filtered_bboxes:
190
+ result.append(bb_box_intersects(atom_box,stereo_box))
191
+ indices = [i for i, x in enumerate(result) if x == 1]
192
+ if len(indices) == 1:
193
+ stereo_atoms[indices[0]]=1
194
+
195
+ molecule = dict()
196
+ molecule['graph'] = mol_graph
 
 
 
 
 
 
 
 
 
 
 
197
  #molecule['atom_labels'] = atom_preds[image_idx]['preds'][0]
198
+ molecule['atom_labels'] = filtered_labels
199
+ molecule['atom_boxes'] = filtered_bboxes
200
+ molecule['stereo_atoms'] = stereo_atoms
201
+ molecule['charge_atoms'] = charge_atoms
202
+ mol_graphs.append(molecule)
203
+ save_mol_to_file(molecule,'molfile')
204
+ mol = Chem.MolFromMolFile('molfile',sanitize=False)
205
+ problematic = 0
206
+ try:
207
+ problems = Chem.DetectChemistryProblems(mol)
208
+ if len(problems) > 0:
209
+ mol = solve_mol_problems(mol,problems)
210
+ problematic = 1
 
 
 
 
 
 
211
  #import ipdb; ipdb.set_trace()
212
+ try:
213
+ Chem.SanitizeMol(mol)
214
+ except:
215
+ problems = Chem.DetectChemistryProblems(mol)
216
+ if len(problems) > 0:
217
+ mol = solve_mol_problems(mol,problems)
218
+ try:
219
+ Chem.SanitizeMol(mol)
220
+ except:
221
+ pass
222
+ except:
223
+ problematic = 1
224
  try:
225
+ pred_smiles = Chem.MolToSmiles(mol)
226
  except:
227
+ pred_smiles = ""
228
+ problematic = 1
229
+ predictions+=1
230
+ predictions_list.append([image_idx,pred_smiles,problematic])
 
 
 
 
 
 
231
  #import ipdb; ipdb.set_trace()
232
+ file_preds = open('preds_atomlenz','w')
233
+ for pred in predictions_list:
234
+ print(pred)
235
+
236
+ #### TRYOUT MENU #####
237
+
238
+ page_to_funcs = {
239
+ "Predict Atom-Level Entities": atomlenz,
240
+ "About AtomLenz": main_page,
241
+
242
+ }
243
+
244
+ sel_page = st.sidebar.selectbox("Select task", page_to_funcs.keys())
245
+ st.sidebar.markdown('')
246
+
247
+
248
+ selected_model = st.sidebar.selectbox(
249
+ "Select the AtomLenz model to load",
250
+ ("AtomLenz trained on synthetic data (default)", "AtomLenz for hand-drawn images", "ChemExpert (not available yet)"))
251
+
252
+ model_dict = {
253
+ "AtomLenz trained on synthetic data (default)" : "synthetic",
254
+ "AtomLenz for hand-drawn images" : "real",
255
+ "ChemExpert (not available yet)" : "synthetic"
256
+
257
+ }
258
+
259
+ model_file = model_dict[selected_model]
260
+
261
+ page_to_funcs[sel_page](model_file)
262
+
263
+
264
+
265
+
266
+ ######################