implementing app
Browse files
app.py
CHANGED
@@ -43,6 +43,28 @@ dir_list.sort(key=os.path.getctime, reverse=True)
|
|
43 |
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
|
44 |
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
|
45 |
model_atom.model.roi_heads.score_thresh = 0.65
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
data_cls = Objects_Smiles
|
47 |
dataset = data_cls(data_path="./uploads/", batch_size=1)
|
48 |
# dataset.prepare_data()
|
@@ -65,6 +87,9 @@ if image_file is not None:
|
|
65 |
dataset.prepare_data()
|
66 |
trainer = pl.Trainer(logger=False)
|
67 |
atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
|
|
|
|
|
|
|
68 |
#st.write(atom_preds)
|
69 |
plt.imshow(image, cmap="gray")
|
70 |
for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]):
|
@@ -75,5 +100,15 @@ if image_file is not None:
|
|
75 |
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
|
76 |
image_vis = Image.open("example_image.png")
|
77 |
col1.image(image_vis, use_column_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
#x = st.slider('Select a value')
|
79 |
#st.write(x, 'squared is', x * x)
|
|
|
43 |
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
|
44 |
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
|
45 |
model_atom.model.roi_heads.score_thresh = 0.65
|
46 |
+
experiment_path_bonds = "./models/bonds_model/"
|
47 |
+
dir_list = os.listdir(experiment_path_bonds)
|
48 |
+
dir_list = [os.path.join(experiment_path_bonds,f) for f in dir_list]
|
49 |
+
dir_list.sort(key=os.path.getctime, reverse=True)
|
50 |
+
checkpoint_file_bonds = [f for f in dir_list if "ckpt" in f][0]
|
51 |
+
model_bond = model_cls.load_from_checkpoint(checkpoint_file_bonds)
|
52 |
+
model_bond.model.roi_heads.score_thresh = 0.65
|
53 |
+
experiment_path_stereo = "./models/stereos_model/"
|
54 |
+
dir_list = os.listdir(experiment_path_stereo)
|
55 |
+
dir_list = [os.path.join(experiment_path_stereo,f) for f in dir_list]
|
56 |
+
dir_list.sort(key=os.path.getctime, reverse=True)
|
57 |
+
checkpoint_file_stereo = [f for f in dir_list if "ckpt" in f][0]
|
58 |
+
model_stereo = model_cls.load_from_checkpoint(checkpoint_file_stereo)
|
59 |
+
model_stereo.model.roi_heads.score_thresh = 0.65
|
60 |
+
experiment_path_charges = "./models/charges_model/"
|
61 |
+
dir_list = os.listdir(experiment_path_charges)
|
62 |
+
dir_list = [os.path.join(experiment_path_charges,f) for f in dir_list]
|
63 |
+
dir_list.sort(key=os.path.getctime, reverse=True)
|
64 |
+
checkpoint_file_charges = [f for f in dir_list if "ckpt" in f][0]
|
65 |
+
model_charge = model_cls.load_from_checkpoint(checkpoint_file_charges)
|
66 |
+
model_charge.model.roi_heads.score_thresh = 0.65
|
67 |
+
|
68 |
data_cls = Objects_Smiles
|
69 |
dataset = data_cls(data_path="./uploads/", batch_size=1)
|
70 |
# dataset.prepare_data()
|
|
|
87 |
dataset.prepare_data()
|
88 |
trainer = pl.Trainer(logger=False)
|
89 |
atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
|
90 |
+
bond_preds = trainer.predict(model_bond, dataset.test_dataloader())
|
91 |
+
stereo_preds = trainer.predict(model_stereo, dataset.test_dataloader())
|
92 |
+
charge_preds = trainer.predict(model_charge, dataset.test_dataloader())
|
93 |
#st.write(atom_preds)
|
94 |
plt.imshow(image, cmap="gray")
|
95 |
for bbox, label in zip(atom_preds[0]['boxes'][0], atom_preds[0]['preds'][0]):
|
|
|
100 |
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
|
101 |
image_vis = Image.open("example_image.png")
|
102 |
col1.image(image_vis, use_column_width=True)
|
103 |
+
plt.clf()
|
104 |
+
plt.imshow(image, cmap="gray")
|
105 |
+
for bbox, label in zip(bond_preds[0]['boxes'][0], bond_preds[0]['preds'][0]):
|
106 |
+
# st.write(bbox)
|
107 |
+
# st.write(label)
|
108 |
+
plot_bbox(bbox, label)
|
109 |
+
plt.axis('off')
|
110 |
+
plt.savefig("example_image.png",bbox_inches='tight', pad_inches=0)
|
111 |
+
image_vis = Image.open("example_image.png")
|
112 |
+
col2.image(image_vis, use_column_width=True)
|
113 |
#x = st.slider('Select a value')
|
114 |
#st.write(x, 'squared is', x * x)
|