moldenhof commited on
Commit
09fa344
1 Parent(s): c4f48f9

implementing app

Browse files
Files changed (1) hide show
  1. app.py +35 -0
app.py CHANGED
@@ -43,6 +43,28 @@ dir_list.sort(key=os.path.getctime, reverse=True)
43
  checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
44
  model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
45
  model_atom.model.roi_heads.score_thresh = 0.65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  data_cls = Objects_Smiles
47
  dataset = data_cls(data_path="./uploads/", batch_size=1)
48
  # dataset.prepare_data()
@@ -65,6 +87,9 @@ if image_file is not None:
65
  dataset.prepare_data()
66
  trainer = pl.Trainer(logger=False)
67
  atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
 
 
 
68
  #st.write(atom_preds)
69
  plt.imshow(image, cmap="gray")
70
  for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]):
@@ -75,5 +100,15 @@ if image_file is not None:
75
  plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
76
  image_vis = Image.open("example_image.png")
77
  col1.image(image_vis, use_column_width=True)
 
 
 
 
 
 
 
 
 
 
78
  #x = st.slider('Select a value')
79
  #st.write(x, 'squared is', x * x)
 
43
  checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
44
  model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
45
  model_atom.model.roi_heads.score_thresh = 0.65
46
+ experiment_path_bonds = "./models/bonds_model/"
47
+ dir_list = os.listdir(experiment_path_bonds)
48
+ dir_list = [os.path.join(experiment_path_bonds,f) for f in dir_list]
49
+ dir_list.sort(key=os.path.getctime, reverse=True)
50
+ checkpoint_file_bonds = [f for f in dir_list if "ckpt" in f][0]
51
+ model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds)
52
+ model_bond.model.roi_heads.score_thresh = 0.65
53
+ experiment_path_stereo = "./models/stereos_model/"
54
+ dir_list = os.listdir(experiment_path_stereo)
55
+ dir_list = [os.path.join(experiment_path_stereo,f) for f in dir_list]
56
+ dir_list.sort(key=os.path.getctime, reverse=True)
57
+ checkpoint_file_stereo = [f for f in dir_list if "ckpt" in f][0]
58
+ model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo)
59
+ model_stereo.model.roi_heads.score_thresh = 0.65
60
+ experiment_path_charges = "./models/charges_model/"
61
+ dir_list = os.listdir(experiment_path_charges)
62
+ dir_list = [os.path.join(experiment_path_charges,f) for f in dir_list]
63
+ dir_list.sort(key=os.path.getctime, reverse=True)
64
+ checkpoint_file_charges = [f for f in dir_list if "ckpt" in f][0]
65
+ model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges)
66
+ model_charge.model.roi_heads.score_thresh = 0.65
67
+
68
  data_cls = Objects_Smiles
69
  dataset = data_cls(data_path="./uploads/", batch_size=1)
70
  # dataset.prepare_data()
 
87
  dataset.prepare_data()
88
  trainer = pl.Trainer(logger=False)
89
  atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
90
+ bond_preds = trainer.predict(model_bond, dataset.test_dataloader())
91
+ stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader())
92
+ charge_preds = trainer.predict(model_charge, dataset.test_dataloader())
93
  #st.write(atom_preds)
94
  plt.imshow(image, cmap="gray")
95
  for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]):
 
100
  plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
101
  image_vis = Image.open("example_image.png")
102
  col1.image(image_vis, use_column_width=True)
103
+ plt.clf()
104
+ plt.imshow(image, cmap="gray")
105
+ for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]):
106
+ # st.write(bbox)
107
+ # st.write(label)
108
+ plot_bbox(bbox, label)
109
+ plt.axis('off')
110
+ plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
111
+ image_vis = Image.open("example_image.png")
112
+ col2.image(image_vis, use_column_width=True)
113
  #x = st.slider('Select a value')
114
  #st.write(x, 'squared is', x * x)