implementing app
Browse files- Object_Smiles.py +157 -0
- app.py +8 -4
Object_Smiles.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset, DataLoader, Subset
|
2 |
+
from robust_detection.data_utils.rcnn_data_utils import *
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import robust_detection.transforms as T
|
5 |
+
|
6 |
+
DATA_FOLDER = os.path.join(os.path.dirname(__file__))
|
7 |
+
def get_transform():
|
8 |
+
transforms = []
|
9 |
+
transforms.append(T.ToTensor())
|
10 |
+
return T.Compose(transforms)
|
11 |
+
|
12 |
+
class Objects_Smiles(pl.LightningDataModule):
|
13 |
+
def __init__(self, data_path, **kwargs):
|
14 |
+
super().__init__()
|
15 |
+
self.batch_size = 1
|
16 |
+
self.num_workers = 4
|
17 |
+
self.data_path = data_path
|
18 |
+
self.transforms = get_transform()
|
19 |
+
self.base_class = Objects_Detection_Predictor_Dataset
|
20 |
+
def prepare_data(self):
|
21 |
+
dataset = self.base_class(os.path.join(DATA_FOLDER, self.data_path), self.transforms)
|
22 |
+
self.train = dataset
|
23 |
+
self.test = dataset
|
24 |
+
self.val = dataset
|
25 |
+
|
26 |
+
self.test_ood = dataset
|
27 |
+
|
28 |
+
def train_dataloader(self):
|
29 |
+
return DataLoader(
|
30 |
+
self.train,
|
31 |
+
batch_size=self.batch_size,
|
32 |
+
shuffle=True,
|
33 |
+
num_workers=self.num_workers,
|
34 |
+
drop_last=False,
|
35 |
+
pin_memory=True,
|
36 |
+
collate_fn=collate_tuple
|
37 |
+
)
|
38 |
+
|
39 |
+
def val_dataloader(self):
|
40 |
+
return DataLoader(
|
41 |
+
self.val,
|
42 |
+
batch_size=self.batch_size,
|
43 |
+
shuffle=False,
|
44 |
+
num_workers=self.num_workers,
|
45 |
+
drop_last=False,
|
46 |
+
pin_memory=True,
|
47 |
+
collate_fn=collate_tuple
|
48 |
+
)
|
49 |
+
|
50 |
+
def test_dataloader(self):
|
51 |
+
return DataLoader(
|
52 |
+
self.test,
|
53 |
+
batch_size=self.batch_size,
|
54 |
+
shuffle=False,
|
55 |
+
num_workers=self.num_workers,
|
56 |
+
drop_last=False,
|
57 |
+
pin_memory=True,
|
58 |
+
collate_fn=collate_tuple
|
59 |
+
)
|
60 |
+
|
61 |
+
def test_ood_dataloader(self, shuffle=False):
|
62 |
+
return DataLoader(
|
63 |
+
self.test_ood,
|
64 |
+
batch_size=self.batch_size,
|
65 |
+
shuffle=shuffle,
|
66 |
+
num_workers=self.num_workers,
|
67 |
+
drop_last=False,
|
68 |
+
pin_memory=True,
|
69 |
+
collate_fn=collate_tuple
|
70 |
+
)
|
71 |
+
|
72 |
+
@classmethod
|
73 |
+
def add_dataset_specific_args(cls, parent):
|
74 |
+
import argparse
|
75 |
+
parser = argparse.ArgumentParser(parents=[parent], add_help=False)
|
76 |
+
parser.add_argument('--data_path', type=str,
|
77 |
+
default="mnist/alldigits/")
|
78 |
+
return parser
|
79 |
+
|
80 |
+
|
81 |
+
class Objects_fold_Smiles(pl.LightningDataModule):
|
82 |
+
def __init__(self, data_path, fold, **kwargs):
|
83 |
+
super().__init__()
|
84 |
+
self.batch_size = 1
|
85 |
+
self.num_workers = 4
|
86 |
+
self.data_path = data_path
|
87 |
+
self.fold = fold
|
88 |
+
self.transforms = get_transform()
|
89 |
+
# self.base_class = Objects_Detection_Predictor_Dataset
|
90 |
+
self.base_class = Objects_Detection_Dataset
|
91 |
+
def prepare_data(self):
|
92 |
+
dataset = self.base_class(os.path.join(DATA_FOLDER, self.data_path), self.transforms)
|
93 |
+
if self.fold > -1:
|
94 |
+
train_idx = np.load(os.path.join(DATA_FOLDER, f"{self.data_path}", "../folds", str(self.fold), "train_idx.npy"))
|
95 |
+
self.train = Subset(dataset, train_idx)
|
96 |
+
val_idx = np.load(os.path.join(DATA_FOLDER, f"{self.data_path}", "../folds", str(self.fold), "val_idx.npy"))
|
97 |
+
|
98 |
+
self.val = Subset(dataset, val_idx)
|
99 |
+
else:
|
100 |
+
self.train = dataset
|
101 |
+
self.val = dataset
|
102 |
+
self.test = self.val
|
103 |
+
self.test_ood = self.test
|
104 |
+
|
105 |
+
def train_dataloader(self):
|
106 |
+
return DataLoader(
|
107 |
+
self.train,
|
108 |
+
batch_size=self.batch_size,
|
109 |
+
shuffle=True,
|
110 |
+
num_workers=self.num_workers,
|
111 |
+
drop_last=False,
|
112 |
+
pin_memory=True,
|
113 |
+
collate_fn=collate_tuple
|
114 |
+
)
|
115 |
+
|
116 |
+
def val_dataloader(self):
|
117 |
+
return DataLoader(
|
118 |
+
self.val,
|
119 |
+
batch_size=self.batch_size,
|
120 |
+
shuffle=False,
|
121 |
+
num_workers=self.num_workers,
|
122 |
+
drop_last=False,
|
123 |
+
pin_memory=True,
|
124 |
+
collate_fn=collate_tuple
|
125 |
+
)
|
126 |
+
|
127 |
+
def test_dataloader(self):
|
128 |
+
return DataLoader(
|
129 |
+
self.test,
|
130 |
+
batch_size=self.batch_size,
|
131 |
+
shuffle=False,
|
132 |
+
num_workers=self.num_workers,
|
133 |
+
drop_last=False,
|
134 |
+
pin_memory=True,
|
135 |
+
collate_fn=collate_tuple
|
136 |
+
)
|
137 |
+
|
138 |
+
def test_ood_dataloader(self, shuffle=False):
|
139 |
+
return DataLoader(
|
140 |
+
self.test_ood,
|
141 |
+
batch_size=self.batch_size,
|
142 |
+
shuffle=shuffle,
|
143 |
+
num_workers=self.num_workers,
|
144 |
+
drop_last=False,
|
145 |
+
pin_memory=True,
|
146 |
+
collate_fn=collate_tuple
|
147 |
+
)
|
148 |
+
|
149 |
+
@classmethod
|
150 |
+
def add_dataset_specific_args(cls, parent):
|
151 |
+
import argparse
|
152 |
+
parser = argparse.ArgumentParser(parents=[parent], add_help=False)
|
153 |
+
parser.add_argument('--data_path', type=str,
|
154 |
+
default="mnist/alldigits/")
|
155 |
+
parser.add_argument('--fold', type=int,
|
156 |
+
default=0)
|
157 |
+
return parser
|
app.py
CHANGED
@@ -7,7 +7,7 @@ import numpy as np
|
|
7 |
#import pathlib
|
8 |
#from AtomLenz import *
|
9 |
#from utils_graph import *
|
10 |
-
|
11 |
|
12 |
#from robust_detection import wandb_config
|
13 |
from robust_detection import utils
|
@@ -33,8 +33,9 @@ dir_list.sort(key=os.path.getctime, reverse=True)
|
|
33 |
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
|
34 |
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
|
35 |
model_atom.model.roi_heads.score_thresh = 0.65
|
36 |
-
|
37 |
-
|
|
|
38 |
st.title("Atom Level Entity Detector")
|
39 |
|
40 |
image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png','jpeg','jpg'])
|
@@ -45,7 +46,10 @@ if image_file is not None:
|
|
45 |
image = Image.open(image_file)
|
46 |
col1.image(image, use_column_width=True)
|
47 |
with open(os.path.join("uploads",image_file.name),"wb") as f:
|
48 |
-
f.write(image_file.getbuffer())
|
49 |
st.success("Saved File")
|
|
|
|
|
|
|
50 |
x = st.slider('Select a value')
|
51 |
st.write(x, 'squared is', x * x)
|
|
|
7 |
#import pathlib
|
8 |
#from AtomLenz import *
|
9 |
#from utils_graph import *
|
10 |
+
from Object_Smiles import Objects_Smiles
|
11 |
|
12 |
#from robust_detection import wandb_config
|
13 |
from robust_detection import utils
|
|
|
33 |
checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
|
34 |
model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
|
35 |
model_atom.model.roi_heads.score_thresh = 0.65
|
36 |
+
data_cls = Objects_Smiles
|
37 |
+
dataset = data_cls(datapath="./uploads/")
|
38 |
+
# dataset.prepare_data()
|
39 |
st.title("Atom Level Entity Detector")
|
40 |
|
41 |
image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png','jpeg','jpg'])
|
|
|
46 |
image = Image.open(image_file)
|
47 |
col1.image(image, use_column_width=True)
|
48 |
with open(os.path.join("uploads",image_file.name),"wb") as f:
|
49 |
+
f.write(image_file.getbuffer())
|
50 |
st.success("Saved File")
|
51 |
+
dataset.prepare_data()
|
52 |
+
trainer = pl.Trainer(logger=False)
|
53 |
+
atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
|
54 |
x = st.slider('Select a value')
|
55 |
st.write(x, 'squared is', x * x)
|