sunwaee commited on
Commit
a35c8ad
1 Parent(s): 1ba39f5

added streamlit app

Browse files
Files changed (1) hide show
  1. app.py +230 -0
app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os.path
3
+ import time
4
+
5
+ import cv2
6
+ import gdown
7
+ import numpy as np
8
+ import streamlit as st
9
+ import torch
10
+
11
+
12
+ def load_classes(csv_reader):
13
+ """
14
+ Load classes from csv.
15
+
16
+ :param csv_reader: csv
17
+ :return:
18
+ """
19
+ result = {}
20
+
21
+ for line, row in enumerate(csv_reader):
22
+ line += 1
23
+
24
+ try:
25
+ class_name, class_id = row
26
+ except ValueError:
27
+ raise (ValueError('line {}: format should be \'class_name,class_id\''.format(line)))
28
+ class_id = int(class_id)
29
+
30
+ if class_name in result:
31
+ raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name))
32
+ result[class_name] = class_id
33
+ return result
34
+
35
+
36
+ @st.cache
37
+ def draw_caption(image, box, caption):
38
+ """
39
+ Draw caption and bbox on image.
40
+
41
+ :param image: image
42
+ :param box: bounding box
43
+ :param caption: caption
44
+ :return:
45
+ """
46
+
47
+ b = np.array(box).astype(int)
48
+ cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 0), 2)
49
+ cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1)
50
+
51
+
52
+ @st.cache
53
+ def load_labels():
54
+ """
55
+ Loads labels.
56
+
57
+ :return:
58
+ """
59
+
60
+ with open("dataset/labels.csv", 'r') as f:
61
+ classes = load_classes(csv.reader(f, delimiter=','))
62
+
63
+ labels = {}
64
+ for key, value in classes.items():
65
+ labels[value] = key
66
+
67
+ return labels
68
+
69
+
70
+ def download_models(ids):
71
+ """
72
+ Download all models.
73
+
74
+ :param ids: name and links of models
75
+ :return:
76
+ """
77
+
78
+ # Download model from drive if not stored locally
79
+ with st.spinner('Downloading models, this may take a minute...'):
80
+ for key in ids:
81
+ if not os.path.isfile(f"model/{key}.pt"):
82
+ url = f"https://drive.google.com/uc?id={ids[key]}"
83
+ gdown.download(url=url, output=f"model/{key}.pt")
84
+
85
+
86
+ @st.cache(suppress_st_warning=True)
87
+ def load_model(model_path, prefix: str = 'model/'):
88
+ """
89
+ Load model.
90
+
91
+ :param model_path: path to inference model
92
+ :param prefix: model prefix if needed
93
+ :return:
94
+ """
95
+
96
+ # Load model
97
+ if torch.cuda.is_available():
98
+ model = torch.load(f"{prefix}{model_path}.pt").to('cuda')
99
+ else:
100
+ model = torch.load(f"{prefix}{model_path}.pt", map_location=torch.device('cpu'))
101
+ model = model.module.cpu()
102
+ model.training = False
103
+ model.eval()
104
+
105
+ return model
106
+
107
+
108
+ def process_img(model, image, labels, caption: bool = True):
109
+ """
110
+ Process img given a model.
111
+
112
+ :param caption: whether to use captions or not
113
+ :param image: image to process
114
+ :param model: inference model
115
+ :param labels: given labels
116
+ :return:
117
+ """
118
+
119
+ image_orig = image.copy()
120
+ rows, cols, cns = image.shape
121
+
122
+ smallest_side = min(rows, cols)
123
+
124
+ # Rescale the image
125
+ min_side = 608
126
+ max_side = 1024
127
+ scale = min_side / smallest_side
128
+
129
+ # Check if the largest side is now greater than max_side
130
+ largest_side = max(rows, cols)
131
+
132
+ if largest_side * scale > max_side:
133
+ scale = max_side / largest_side
134
+
135
+ # Resize the image with the computed scale
136
+ image = cv2.resize(image, (int(round(cols * scale)), int(round((rows * scale)))))
137
+ rows, cols, cns = image.shape
138
+
139
+ pad_w = 32 - rows % 32
140
+ pad_h = 32 - cols % 32
141
+
142
+ new_image = np.zeros((rows + pad_w, cols + pad_h, cns)).astype(np.float32)
143
+ new_image[:rows, :cols, :] = image.astype(np.float32)
144
+ image = new_image.astype(np.float32)
145
+ image /= 255
146
+ image -= [0.485, 0.456, 0.406]
147
+ image /= [0.229, 0.224, 0.225]
148
+ image = np.expand_dims(image, 0)
149
+ image = np.transpose(image, (0, 3, 1, 2))
150
+
151
+ with torch.no_grad():
152
+
153
+ image = torch.from_numpy(image)
154
+ if torch.cuda.is_available():
155
+ image = image.cuda()
156
+
157
+ st = time.time()
158
+ scores, classification, transformed_anchors = model(image.float())
159
+ elapsed_time = time.time() - st
160
+ idxs = np.where(scores.cpu() > 0.5)
161
+
162
+ for j in range(idxs[0].shape[0]):
163
+ bbox = transformed_anchors[idxs[0][j], :]
164
+
165
+ x1 = int(bbox[0] / scale)
166
+ y1 = int(bbox[1] / scale)
167
+ x2 = int(bbox[2] / scale)
168
+ y2 = int(bbox[3] / scale)
169
+ label_name = labels[int(classification[idxs[0][j]])]
170
+ colors = {
171
+ 'with_mask': (0, 255, 0),
172
+ 'without_mask': (255, 0, 0),
173
+ 'mask_weared_incorrect': (190, 100, 20)
174
+ }
175
+ cap = '{}'.format(label_name) if caption else ''
176
+ draw_caption(image_orig, (x1, y1, x2, y2), cap)
177
+ cv2.rectangle(image_orig, (x1, y1), (x2, y2), color=colors[label_name], thickness=2)
178
+ cv2.putText(image_orig,
179
+ f"{'{:.1f}'.format(1 / float(elapsed_time))}{' cuda:' + str(torch.cuda.is_available()).lower()}",
180
+ fontScale=1, fontFace=cv2.FONT_HERSHEY_PLAIN, org=(10, 20), color=(0, 255, 0))
181
+ return image_orig
182
+
183
+
184
+ # Page config
185
+ st.set_page_config(layout="centered")
186
+ st.sidebar.title("Face Mask Detection")
187
+
188
+ # Models drive ids
189
+ ids = {
190
+ 'resnet50_20': st.secrets['resnet50'],
191
+ # 'resnet50_29': '1E_IOIuE5OpO4tQgTbXjdAmXR-9BCxxmT',
192
+ 'resnet152_20': st.secrets['resnet152'],
193
+ }
194
+
195
+ # Download all models from drive
196
+ download_models(ids)
197
+
198
+ # Model selection
199
+ labels = load_labels()
200
+ model_path = st.selectbox('Choose a model', options=[k for k in ids], index=0)
201
+ model = load_model(model_path=model_path) if model_path != '' else None
202
+
203
+ # Content
204
+ st.title('Face Mask Detection')
205
+ st.write('ResNet[18~152] trained for Face Mask Detection. ')
206
+ st.markdown(f"__Labels:__ with_mask, without_mask, mask_weared_incorrect")
207
+
208
+ # Display example selection
209
+ index = st.number_input('', min_value=0, max_value=852, value=495, help='Choose an image. ')
210
+
211
+ left, right = st.columns([3, 1])
212
+
213
+ # Get corresponding image and transform it
214
+ image = cv2.imread(f'dataset/validation/image/maksssksksss{str(index)}.jpg')
215
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
216
+
217
+ # Process img
218
+ with st.spinner('Please wait while the image is being processed... This may take a while. '):
219
+ image = process_img(model, image, labels, caption=False)
220
+
221
+ left.image(image)
222
+
223
+ # Write labels dict and device on right
224
+ right.write({
225
+ 'green': 'with_mask',
226
+ 'orange': 'mask_weared_incorrect',
227
+ 'red': 'without_mask'
228
+ })
229
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
230
+ right.write(device)