moldenhof commited on
Commit
d0f68bc
1 Parent(s): 71c5aac

implementing app

Browse files
Files changed (2) hide show
  1. Object_Smiles.py +157 -0
  2. app.py +8 -4
Object_Smiles.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset, DataLoader, Subset
2
+ from robust_detection.data_utils.rcnn_data_utils import *
3
+ import pytorch_lightning as pl
4
+ import robust_detection.transforms as T
5
+
6
+ DATA_FOLDER = os.path.join(os.path.dirname(__file__))
7
+ def get_transform():
8
+ transforms = []
9
+ transforms.append(T.ToTensor())
10
+ return T.Compose(transforms)
11
+
12
+ class Objects_Smiles(pl.LightningDataModule):
13
+ def __init__(self, data_path, **kwargs):
14
+ super().__init__()
15
+ self.batch_size = 1
16
+ self.num_workers = 4
17
+ self.data_path = data_path
18
+ self.transforms = get_transform()
19
+ self.base_class = Objects_Detection_Predictor_Dataset
20
+ def prepare_data(self):
21
+ dataset = self.base_class(os.path.join(DATA_FOLDER, self.data_path), self.transforms)
22
+ self.train = dataset
23
+ self.test = dataset
24
+ self.val = dataset
25
+
26
+ self.test_ood = dataset
27
+
28
+ def train_dataloader(self):
29
+ return DataLoader(
30
+ self.train,
31
+ batch_size=self.batch_size,
32
+ shuffle=True,
33
+ num_workers=self.num_workers,
34
+ drop_last=False,
35
+ pin_memory=True,
36
+ collate_fn=collate_tuple
37
+ )
38
+
39
+ def val_dataloader(self):
40
+ return DataLoader(
41
+ self.val,
42
+ batch_size=self.batch_size,
43
+ shuffle=False,
44
+ num_workers=self.num_workers,
45
+ drop_last=False,
46
+ pin_memory=True,
47
+ collate_fn=collate_tuple
48
+ )
49
+
50
+ def test_dataloader(self):
51
+ return DataLoader(
52
+ self.test,
53
+ batch_size=self.batch_size,
54
+ shuffle=False,
55
+ num_workers=self.num_workers,
56
+ drop_last=False,
57
+ pin_memory=True,
58
+ collate_fn=collate_tuple
59
+ )
60
+
61
+ def test_ood_dataloader(self, shuffle=False):
62
+ return DataLoader(
63
+ self.test_ood,
64
+ batch_size=self.batch_size,
65
+ shuffle=shuffle,
66
+ num_workers=self.num_workers,
67
+ drop_last=False,
68
+ pin_memory=True,
69
+ collate_fn=collate_tuple
70
+ )
71
+
72
+ @classmethod
73
+ def add_dataset_specific_args(cls, parent):
74
+ import argparse
75
+ parser = argparse.ArgumentParser(parents=[parent], add_help=False)
76
+ parser.add_argument('--data_path', type=str,
77
+ default="mnist/alldigits/")
78
+ return parser
79
+
80
+
81
+ class Objects_fold_Smiles(pl.LightningDataModule):
82
+ def __init__(self, data_path, fold, **kwargs):
83
+ super().__init__()
84
+ self.batch_size = 1
85
+ self.num_workers = 4
86
+ self.data_path = data_path
87
+ self.fold = fold
88
+ self.transforms = get_transform()
89
+ # self.base_class = Objects_Detection_Predictor_Dataset
90
+ self.base_class = Objects_Detection_Dataset
91
+ def prepare_data(self):
92
+ dataset = self.base_class(os.path.join(DATA_FOLDER, self.data_path), self.transforms)
93
+ if self.fold > -1:
94
+ train_idx = np.load(os.path.join(DATA_FOLDER, f"{self.data_path}", "../folds", str(self.fold), "train_idx.npy"))
95
+ self.train = Subset(dataset, train_idx)
96
+ val_idx = np.load(os.path.join(DATA_FOLDER, f"{self.data_path}", "../folds", str(self.fold), "val_idx.npy"))
97
+
98
+ self.val = Subset(dataset, val_idx)
99
+ else:
100
+ self.train = dataset
101
+ self.val = dataset
102
+ self.test = self.val
103
+ self.test_ood = self.test
104
+
105
+ def train_dataloader(self):
106
+ return DataLoader(
107
+ self.train,
108
+ batch_size=self.batch_size,
109
+ shuffle=True,
110
+ num_workers=self.num_workers,
111
+ drop_last=False,
112
+ pin_memory=True,
113
+ collate_fn=collate_tuple
114
+ )
115
+
116
+ def val_dataloader(self):
117
+ return DataLoader(
118
+ self.val,
119
+ batch_size=self.batch_size,
120
+ shuffle=False,
121
+ num_workers=self.num_workers,
122
+ drop_last=False,
123
+ pin_memory=True,
124
+ collate_fn=collate_tuple
125
+ )
126
+
127
+ def test_dataloader(self):
128
+ return DataLoader(
129
+ self.test,
130
+ batch_size=self.batch_size,
131
+ shuffle=False,
132
+ num_workers=self.num_workers,
133
+ drop_last=False,
134
+ pin_memory=True,
135
+ collate_fn=collate_tuple
136
+ )
137
+
138
+ def test_ood_dataloader(self, shuffle=False):
139
+ return DataLoader(
140
+ self.test_ood,
141
+ batch_size=self.batch_size,
142
+ shuffle=shuffle,
143
+ num_workers=self.num_workers,
144
+ drop_last=False,
145
+ pin_memory=True,
146
+ collate_fn=collate_tuple
147
+ )
148
+
149
+ @classmethod
150
+ def add_dataset_specific_args(cls, parent):
151
+ import argparse
152
+ parser = argparse.ArgumentParser(parents=[parent], add_help=False)
153
+ parser.add_argument('--data_path', type=str,
154
+ default="mnist/alldigits/")
155
+ parser.add_argument('--fold', type=int,
156
+ default=0)
157
+ return parser
app.py CHANGED
@@ -7,7 +7,7 @@ import numpy as np
7
  #import pathlib
8
  #from AtomLenz import *
9
  #from utils_graph import *
10
- #from Object_Smiles import Objects_Smiles
11
 
12
  #from robust_detection import wandb_config
13
  from robust_detection import utils
@@ -33,8 +33,9 @@ dir_list.sort(key=os.path.getctime, reverse=True)
33
  checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
34
  model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
35
  model_atom.model.roi_heads.score_thresh = 0.65
36
-
37
-
 
38
  st.title("Atom Level Entity Detector")
39
 
40
  image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png','jpeg','jpg'])
@@ -45,7 +46,10 @@ if image_file is not None:
45
  image = Image.open(image_file)
46
  col1.image(image, use_column_width=True)
47
  with open(os.path.join("uploads",image_file.name),"wb") as f:
48
- f.write(image_file.getbuffer())
49
  st.success("Saved File")
 
 
 
50
  x = st.slider('Select a value')
51
  st.write(x, 'squared is', x * x)
 
7
  #import pathlib
8
  #from AtomLenz import *
9
  #from utils_graph import *
10
+ from Object_Smiles import Objects_Smiles
11
 
12
  #from robust_detection import wandb_config
13
  from robust_detection import utils
 
33
  checkpoint_file_atoms = [f for f in dir_list if "ckpt" in f][0]
34
  model_atom = model_cls.load_from_checkpoint(checkpoint_file_atoms)
35
  model_atom.model.roi_heads.score_thresh = 0.65
36
+ data_cls = Objects_Smiles
37
+ dataset = data_cls(datapath="./uploads/")
38
+ # dataset.prepare_data()
39
  st.title("Atom Level Entity Detector")
40
 
41
  image_file = st.file_uploader("Upload a chemical structure candidate image",type=['png','jpeg','jpg'])
 
46
  image = Image.open(image_file)
47
  col1.image(image, use_column_width=True)
48
  with open(os.path.join("uploads",image_file.name),"wb") as f:
49
+ f.write(image_file.getbuffer())
50
  st.success("Saved File")
51
+ dataset.prepare_data()
52
+ trainer = pl.Trainer(logger=False)
53
+ atom_preds = trainer.predict(model_atom, dataset.test_dataloader())
54
  x = st.slider('Select a value')
55
  st.write(x, 'squared is', x * x)