moldenhof commited on
Commit
b2c3eed
1 Parent(s): 08c8b06

implementing predict smiles

Browse files
Files changed (1) hide show
  1. app.py +122 -1
app.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import numpy as np
6
  #import matplotlib.pyplot as plt
7
  #import pathlib
8
- #from AtomLenz import *
9
  #from utils_graph import *
10
  from Object_Smiles import Objects_Smiles
11
 
@@ -112,5 +112,126 @@ if image_file is not None:
112
  plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
113
  image_vis = Image.open("example_image.png")
114
  col2.image(image_vis, use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  #x = st.slider('Select a value')
116
  #st.write(x, 'squared is', x * x)
 
5
  import numpy as np
6
  #import matplotlib.pyplot as plt
7
  #import pathlib
8
+ from AtomLenz import *
9
  #from utils_graph import *
10
  from Object_Smiles import Objects_Smiles
11
 
 
112
  plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
113
  image_vis = Image.open("example_image.png")
114
  col2.image(image_vis, use_column_width=True)
115
+ for image_idx, bonds in enumerate(bond_preds):
116
+ count_bonds_preds = np.zeros(8)
117
+ count_atoms_preds = np.zeros(18)
118
+ atom_boxes = atom_preds[image_idx]['boxes'][0]
119
+ atom_labels = atom_preds[image_idx]['preds'][0]
120
+ atom_scores = atom_preds[image_idx]['scores'][0]
121
+ charge_boxes = charges_preds[image_idx]['boxes'][0]
122
+ charge_labels = charges_preds[image_idx]['preds'][0]
123
+ charge_mask=torch.where(charge_labels>1)
124
+ filtered_ch_labels=charge_labels[charge_mask]
125
+ filtered_ch_boxes=charge_boxes[charge_mask]
126
+ #import ipdb; ipdb.set_trace()
127
+ filtered_bboxes, filtered_labels = iou_filter_bboxes(atom_boxes, atom_labels, atom_scores)
128
+ #for atom_label in filtered_labels:
129
+ # count_atoms_preds[atom_label] += 1
130
+ #import ipdb; ipdb.set_trace()
131
+ mol_graph = np.zeros((len(filtered_bboxes),len(filtered_bboxes)))
132
+ stereo_atoms = np.zeros(len(filtered_bboxes))
133
+ charge_atoms = np.ones(len(filtered_bboxes))
134
+ for index,box_atom in enumerate(filtered_bboxes):
135
+ for box_charge,label_charge in zip(filtered_ch_boxes,filtered_ch_labels):
136
+ if bb_box_intersects(box_atom,box_charge) == 1:
137
+ charge_atoms[index]=label_charge
138
+
139
+ for bond_idx, bond_box in enumerate(bonds['boxes'][0]):
140
+ label_bond = bonds['preds'][0][bond_idx]
141
+ if label_bond > 1:
142
+ try:
143
+ count_bonds_preds[label_bond] += 1
144
+ except:
145
+ count_bonds_preds=count_bonds_preds
146
+ #import ipdb; ipdb.set_trace()
147
+ result = []
148
+ limit = 0
149
+ #TODO: values of 50 and 5 should be made dependent of mean size of atom_boxes
150
+ while result.count(1) < 2 and limit < 80:
151
+ result=[]
152
+ bigger_bond_box = [bond_box[0]-limit,bond_box[1]-limit,bond_box[2]+limit,bond_box[3]+limit]
153
+ for atom_box in filtered_bboxes:
154
+ result.append(bb_box_intersects(atom_box,bigger_bond_box))
155
+ limit+=5
156
+ indices = [i for i, x in enumerate(result) if x == 1]
157
+ if len(indices) == 2:
158
+ #import ipdb; ipdb.set_trace()
159
+ mol_graph[indices[0],indices[1]]=label_bond
160
+ mol_graph[indices[1],indices[0]]=label_bond
161
+ if len(indices) > 2:
162
+ #we have more then two canidate atoms for one bond, we filter ...
163
+ cand_bboxes = filtered_bboxes[indices,:]
164
+ cand_indices = dist_filter_bboxes(cand_bboxes)
165
+ #import ipdb; ipdb.set_trace()
166
+ mol_graph[indices[cand_indices[0]],indices[cand_indices[1]]]=label_bond
167
+ mol_graph[indices[cand_indices[1]],indices[cand_indices[0]]]=label_bond
168
+ #print("more than 2 indices")
169
+ #if len(indices) < 2:
170
+ # print("less than 2 indices")
171
+ #import ipdb; ipdb.set_trace()
172
+ # else:
173
+ # result=[]
174
+ # for atom_box in filtered_bboxes:
175
+ # result.append(bb_box_intersects(atom_box,bond_box))
176
+ # indices = [i for i, x in enumerate(result) if x == 1]
177
+ # if len(indices) == 1:
178
+ # stereo_atoms[indices[0]]=label_bond
179
+ stereo_bonds = np.where(mol_graph>4, True, False)
180
+ if np.any(stereo_bonds):
181
+ stereo_boxes = stereo_preds[image_idx]['boxes'][0]
182
+ stereo_labels= stereo_preds[image_idx]['preds'][0]
183
+ for stereo_box in stereo_boxes:
184
+ result=[]
185
+ for atom_box in filtered_bboxes:
186
+ result.append(bb_box_intersects(atom_box,stereo_box))
187
+ indices = [i for i, x in enumerate(result) if x == 1]
188
+ if len(indices) == 1:
189
+ stereo_atoms[indices[0]]=1
190
+
191
+ molecule = dict()
192
+ molecule['graph'] = mol_graph
193
+ #molecule['atom_labels'] = atom_preds[image_idx]['preds'][0]
194
+ molecule['atom_labels'] = filtered_labels
195
+ molecule['atom_boxes'] = filtered_bboxes
196
+ molecule['stereo_atoms'] = stereo_atoms
197
+ molecule['charge_atoms'] = charge_atoms
198
+ mol_graphs.append(molecule)
199
+ base_path = pathlib.Path(args.data_path)
200
+ image_dir = base_path.joinpath("images")
201
+ smiles_dir = base_path.joinpath("smiles")
202
+ impath = image_dir.joinpath(f"{image_idx}.png")
203
+ smilespath = smiles_dir.joinpath(f"{image_idx}.txt")
204
+ save_mol_to_file(molecule,'molfile')
205
+ mol = Chem.MolFromMolFile('molfile',sanitize=False)
206
+ problematic = 0
207
+ try:
208
+ problems = Chem.DetectChemistryProblems(mol)
209
+ if len(problems) > 0:
210
+ mol = solve_mol_problems(mol,problems)
211
+ problematic = 1
212
+ #import ipdb; ipdb.set_trace()
213
+ try:
214
+ Chem.SanitizeMol(mol)
215
+ except:
216
+ problems = Chem.DetectChemistryProblems(mol)
217
+ if len(problems) > 0:
218
+ mol = solve_mol_problems(mol,problems)
219
+ try:
220
+ Chem.SanitizeMol(mol)
221
+ except:
222
+ pass
223
+ except:
224
+ problematic = 1
225
+ try:
226
+ pred_smiles = Chem.MolToSmiles(mol)
227
+ except:
228
+ pred_smiles = ""
229
+ problematic = 1
230
+ predictions+=1
231
+ predictions_list.append([image_idx,pred_smiles,problematic])
232
+ #import ipdb; ipdb.set_trace()
233
+ file_preds = open('preds_atomlenz','w')
234
+ for pred in predictions_list:
235
+ print(pred)
236
  #x = st.slider('Select a value')
237
  #st.write(x, 'squared is', x * x)