vijul.shah commited on
Commit
f0adec0
·
1 Parent(s): 0f2d9f6

Added video upload support. Need to optimize and add new features

Browse files
Files changed (1) hide show
  1. app.py +273 -207
app.py CHANGED
@@ -9,6 +9,7 @@ import matplotlib.pyplot as plt
9
  import numpy as np
10
  import streamlit as st
11
  import torch
 
12
  from PIL import Image
13
  from torchvision import models
14
  from torchvision.transforms.functional import normalize, resize, to_pil_image, to_tensor
@@ -62,6 +63,53 @@ def _load_model(model_configs, device="cpu"):
62
  return model
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def main():
66
  # Wide mode
67
  st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide")
@@ -84,37 +132,69 @@ def main():
84
  st.set_option("deprecation.showfileUploaderEncoding", False)
85
  # Choose your own image
86
  uploaded_file = st.sidebar.file_uploader(
87
- "Upload Image", type=["png", "jpeg", "jpg"]
88
  )
89
  if uploaded_file is not None:
90
- input_img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
91
- # print("input_img before = ", input_img.size)
92
- max_size = [input_img.size[0], input_img.size[1]]
93
- cols[0].text(f"Input Image: {max_size[0]} x {max_size[1]}")
94
- if input_img.size[0] == input_img.size[1] and input_img.size[0] >= 256:
95
- max_size[0] = 256
96
- max_size[1] = 256
97
- else:
98
- if input_img.size[0] >= 640:
99
- max_size[0] = 640
100
- elif input_img.size[0] < 64:
101
- max_size[0] = 64
102
- if input_img.size[1] >= 480:
103
- max_size[1] = 480
104
- elif input_img.size[1] < 32:
105
- max_size[1] = 32
106
- input_img.thumbnail((max_size[0], max_size[1])) # Bicubic resampling
107
- # print("input_img after = ", input_img.size)
108
- # cols[0].image(input_img)
109
- fig0, axs0 = plt.subplots(1, 1, figsize=(10, 10))
110
- # Display the input image
111
- axs0.imshow(input_img)
112
- axs0.axis("off")
113
- axs0.set_title("Input Image")
114
-
115
- # Display the plot
116
- cols[0].pyplot(fig0)
117
- cols[0].text(f"Input Image Resized: {max_size[0]} x {max_size[1]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  st.sidebar.title("Setup")
120
 
@@ -170,196 +250,182 @@ def main():
170
 
171
  else:
172
  with st.spinner("Analyzing..."):
173
- if upscale == "-":
174
- sr_configs = None
175
- else:
176
- sr_configs = {
177
- "method": upscale_method_or_model,
178
- "params": {"upscale": upscale},
179
- }
180
- config_file = {
181
- "sr_configs": sr_configs,
182
- "feature_extraction_configs": {
183
- "blink_detection": False,
184
- "upscale": upscale,
185
- "extraction_library": "mediapipe",
186
- },
187
- }
188
-
189
- img = np.array(input_img)
190
- # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
191
- # if img.shape[0] > max_size or img.shape[1] > max_size:
192
- # img = cv2.resize(img, (max_size, max_size))
193
-
194
- ds_results = EyeDentityDatasetCreation(
195
- feature_extraction_configs=config_file[
196
- "feature_extraction_configs"
197
- ],
198
- sr_configs=config_file["sr_configs"],
199
- )(img)
200
- # if ds_results is not None:
201
- # print("ds_results = ", ds_results.keys())
202
-
203
- preprocess_steps = [
204
- transforms.ToTensor(),
205
- transforms.Resize(
206
- [32, 64],
207
- # interpolation=transforms.InterpolationMode.BILINEAR,
208
- interpolation=transforms.InterpolationMode.BICUBIC,
209
- antialias=True,
210
- ),
211
- ]
212
- preprocess_function = transforms.Compose(preprocess_steps)
213
-
214
- left_eye = None
215
- right_eye = None
216
-
217
- if ds_results is None:
218
- # print("type of input_img = ", type(input_img))
219
- input_img = preprocess_function(input_img)
220
- input_img = input_img.unsqueeze(0)
221
- if pupil_selection == "left_pupil":
222
- left_eye = input_img
223
- elif pupil_selection == "right_pupil":
224
- right_eye = input_img
225
  else:
226
- left_eye = input_img
227
- right_eye = input_img
228
- # print("type of left_eye = ", type(left_eye))
229
- # print("type of right_eye = ", type(right_eye))
230
- elif "eyes" in ds_results.keys():
231
- if (
232
- "left_eye" in ds_results["eyes"].keys()
233
- and ds_results["eyes"]["left_eye"] is not None
234
- ):
235
- left_eye = ds_results["eyes"]["left_eye"]
236
- # print("type of left_eye = ", type(left_eye))
237
- left_eye = to_pil_image(left_eye).convert("RGB")
238
- # print("type of left_eye = ", type(left_eye))
239
 
240
- left_eye = preprocess_function(left_eye)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  # print("type of left_eye = ", type(left_eye))
242
-
243
- left_eye = left_eye.unsqueeze(0)
244
- if (
245
- "right_eye" in ds_results["eyes"].keys()
246
- and ds_results["eyes"]["right_eye"] is not None
247
- ):
248
- right_eye = ds_results["eyes"]["right_eye"]
249
  # print("type of right_eye = ", type(right_eye))
250
- right_eye = to_pil_image(right_eye).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  # print("type of right_eye = ", type(right_eye))
252
 
253
- right_eye = preprocess_function(right_eye)
254
- # print("type of right_eye = ", type(right_eye))
255
 
256
- right_eye = right_eye.unsqueeze(0)
257
- else:
258
- # print("type of input_img = ", type(input_img))
259
- input_img = preprocess_function(input_img)
260
- input_img = input_img.unsqueeze(0)
261
- if pupil_selection == "left_pupil":
262
- left_eye = input_img
263
  elif pupil_selection == "right_pupil":
264
- right_eye = input_img
265
- else:
266
- left_eye = input_img
267
- right_eye = input_img
268
- # print("type of left_eye = ", type(left_eye))
269
- # print("type of right_eye = ", type(right_eye))
270
-
271
- # print("left_eye = ", left_eye.shape)
272
- # print("right_eye = ", right_eye.shape)
273
-
274
- if pupil_selection == "-":
275
- selected_eyes = ["left_eye", "right_eye"]
276
- elif pupil_selection == "left_pupil":
277
- selected_eyes = ["left_eye"]
278
- elif pupil_selection == "right_pupil":
279
- selected_eyes = ["right_eye"]
280
-
281
- for eye_type in selected_eyes:
282
-
283
- model_configs = {
284
- "model_path": root_path
285
- + f"/pre_trained_models/{tv_model}/{eye_type}.pt",
286
- "registered_model_name": tv_model,
287
- "num_classes": 1,
288
- }
289
- registered_model_name = model_configs["registered_model_name"]
290
- model = _load_model(model_configs)
291
-
292
- if registered_model_name == "ResNet18":
293
- target_layer = model.resnet.layer4[-1].conv2
294
- elif registered_model_name == "ResNet50":
295
- target_layer = model.resnet.layer4[-1].conv3
296
- else:
297
- raise Exception(
298
- f"No target layer available for selected model: {registered_model_name}"
 
 
 
 
 
299
  )
 
300
 
301
- if left_eye is not None and eye_type == "left_eye":
302
- input_img = left_eye
303
- elif right_eye is not None and eye_type == "right_eye":
304
- input_img = right_eye
305
- else:
306
- raise Exception("Wrong Data")
307
-
308
- if cam_method is not None:
309
- cam_extractor = torchcam_methods.__dict__[cam_method](
310
- model,
311
- target_layer=target_layer,
312
- fc_layer=model.resnet.fc,
313
- input_shape=input_img.shape,
 
 
314
  )
315
 
316
- # with torch.no_grad():
317
- out = model(input_img)
318
- cols[-1].markdown(
319
- f"<h3>Predicted Pupil Diameter: {out[0].item():.2f} mm</h3>",
320
- unsafe_allow_html=True,
321
- )
322
- # cols[-1].text(f"Predicted Pupil Diameter: {out[0].item():.2f}")
323
-
324
- # Retrieve the CAM
325
- act_maps = cam_extractor(0, out)
326
-
327
- # Fuse the CAMs if there are several
328
- activation_map = (
329
- act_maps[0]
330
- if len(act_maps) == 1
331
- else cam_extractor.fuse_cams(act_maps)
332
- )
333
-
334
- # Convert input image and activation map to PIL images
335
- input_image_pil = to_pil_image(input_img.squeeze(0))
336
- activation_map_pil = to_pil_image(activation_map, mode="F")
337
-
338
- # Create the overlayed CAM result
339
- result = overlay_mask(
340
- input_image_pil,
341
- activation_map_pil,
342
- alpha=0.5,
343
- )
344
-
345
- # Create a subplot with 1 row and 2 columns
346
- fig, axs = plt.subplots(1, 2, figsize=(10, 5))
347
-
348
- # Display the input image
349
- axs[0].imshow(input_image_pil)
350
- axs[0].axis("off")
351
- axs[0].set_title("Input Image")
352
-
353
- # Display the overlayed CAM result
354
- axs[1].imshow(result)
355
- axs[1].axis("off")
356
- axs[1].set_title("Overlayed CAM")
357
-
358
- # Display the plot
359
- cols[-1].pyplot(fig)
360
- cols[-1].text(
361
- f"eye image size: {input_img.shape[-1]} x {input_img.shape[-2]}"
362
- )
363
 
364
 
365
  if __name__ == "__main__":
 
9
  import numpy as np
10
  import streamlit as st
11
  import torch
12
+ import tempfile
13
  from PIL import Image
14
  from torchvision import models
15
  from torchvision.transforms.functional import normalize, resize, to_pil_image, to_tensor
 
63
  return model
64
 
65
 
66
+ def extract_frames(video_path):
67
+ vidcap = cv2.VideoCapture(video_path)
68
+ frames = []
69
+ success, image = vidcap.read()
70
+ count = 0
71
+ while success:
72
+ # Convert the frame to RGB (cv2 uses BGR by default)
73
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
74
+ frames.append(image_rgb)
75
+ success, image = vidcap.read()
76
+ count += 1
77
+ vidcap.release()
78
+ return frames
79
+
80
+
81
+ # Function to check if a file is an image
82
+ def is_image(file_extension):
83
+ return file_extension.lower() in ["png", "jpeg", "jpg"]
84
+
85
+
86
+ # Function to check if a file is a video
87
+ def is_video(file_extension):
88
+ return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm"]
89
+
90
+
91
+ def resize_frame(frame, max_width, max_height):
92
+ image = Image.fromarray(frame)
93
+ original_size = image.size
94
+
95
+ # Resize the frame similarly to the image resizing logic
96
+ if original_size[0] == original_size[1] and original_size[0] >= 256:
97
+ max_size = (256, 256)
98
+ else:
99
+ max_size = list(original_size)
100
+ if original_size[0] >= 640:
101
+ max_size[0] = 640
102
+ elif original_size[0] < 64:
103
+ max_size[0] = 64
104
+ if original_size[1] >= 480:
105
+ max_size[1] = 480
106
+ elif original_size[1] < 32:
107
+ max_size[1] = 32
108
+
109
+ image.thumbnail(max_size)
110
+ return image
111
+
112
+
113
  def main():
114
  # Wide mode
115
  st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide")
 
132
  st.set_option("deprecation.showfileUploaderEncoding", False)
133
  # Choose your own image
134
  uploaded_file = st.sidebar.file_uploader(
135
+ "Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"]
136
  )
137
  if uploaded_file is not None:
138
+ # Get file extension
139
+ file_extension = uploaded_file.name.split(".")[-1]
140
+ input_imgs = []
141
+
142
+ if is_image(file_extension):
143
+ input_img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
144
+ # print("input_img before = ", input_img.size)
145
+ max_size = [input_img.size[0], input_img.size[1]]
146
+ cols[0].text(f"Input Image: {max_size[0]} x {max_size[1]}")
147
+ if input_img.size[0] == input_img.size[1] and input_img.size[0] >= 256:
148
+ max_size[0] = 256
149
+ max_size[1] = 256
150
+ else:
151
+ if input_img.size[0] >= 640:
152
+ max_size[0] = 640
153
+ elif input_img.size[0] < 64:
154
+ max_size[0] = 64
155
+ if input_img.size[1] >= 480:
156
+ max_size[1] = 480
157
+ elif input_img.size[1] < 32:
158
+ max_size[1] = 32
159
+ input_img.thumbnail((max_size[0], max_size[1])) # Bicubic resampling
160
+ input_imgs.append(input_img)
161
+ # print("input_img after = ", input_img.size)
162
+ # cols[0].image(input_img)
163
+ fig0, axs0 = plt.subplots(1, 1, figsize=(10, 10))
164
+ # Display the input image
165
+ axs0.imshow(input_imgs[0])
166
+ axs0.axis("off")
167
+ axs0.set_title("Input Image")
168
+
169
+ # Display the plot
170
+ cols[0].pyplot(fig0)
171
+ cols[0].text(f"Input Image Resized: {max_size[0]} x {max_size[1]}")
172
+
173
+ # TODO: show the face features extracted from the image under 'input image' column
174
+ elif is_video(file_extension):
175
+ tfile = tempfile.NamedTemporaryFile(delete=False)
176
+ tfile.write(uploaded_file.read())
177
+ video_path = tfile.name
178
+
179
+ # Extract frames from the video
180
+ frames = extract_frames(video_path)
181
+ print(f"Extracted {len(frames)} frames from the video")
182
+
183
+ # Process the frames
184
+ for i, frame in enumerate(frames):
185
+ input_imgs.append(resize_frame(frame, 640, 480))
186
+
187
+ os.remove(video_path)
188
+
189
+ fig0, axs0 = plt.subplots(1, 1, figsize=(10, 10))
190
+ # Display the input image
191
+ axs0.imshow(input_imgs[0])
192
+ axs0.axis("off")
193
+ axs0.set_title("Input Image")
194
+
195
+ # Display the plot
196
+ cols[0].pyplot(fig0)
197
+ # cols[0].text(f"Input Image Resized: {max_size[0]} x {max_size[1]}")
198
 
199
  st.sidebar.title("Setup")
200
 
 
250
 
251
  else:
252
  with st.spinner("Analyzing..."):
253
+ model = None
254
+ for input_img in input_imgs:
255
+ if upscale == "-":
256
+ sr_configs = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  else:
258
+ sr_configs = {
259
+ "method": upscale_method_or_model,
260
+ "params": {"upscale": upscale},
261
+ }
262
+ config_file = {
263
+ "sr_configs": sr_configs,
264
+ "feature_extraction_configs": {
265
+ "blink_detection": False,
266
+ "upscale": upscale,
267
+ "extraction_library": "mediapipe",
268
+ },
269
+ }
 
270
 
271
+ img = np.array(input_img)
272
+ # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
273
+ # if img.shape[0] > max_size or img.shape[1] > max_size:
274
+ # img = cv2.resize(img, (max_size, max_size))
275
+
276
+ ds_results = EyeDentityDatasetCreation(
277
+ feature_extraction_configs=config_file["feature_extraction_configs"],
278
+ sr_configs=config_file["sr_configs"],
279
+ )(img)
280
+ # if ds_results is not None:
281
+ # print("ds_results = ", ds_results.keys())
282
+
283
+ preprocess_steps = [
284
+ transforms.ToTensor(),
285
+ transforms.Resize(
286
+ [32, 64],
287
+ # interpolation=transforms.InterpolationMode.BILINEAR,
288
+ interpolation=transforms.InterpolationMode.BICUBIC,
289
+ antialias=True,
290
+ ),
291
+ ]
292
+ preprocess_function = transforms.Compose(preprocess_steps)
293
+
294
+ left_eye = None
295
+ right_eye = None
296
+
297
+ if ds_results is None:
298
+ # print("type of input_img = ", type(input_img))
299
+ input_img = preprocess_function(input_img)
300
+ input_img = input_img.unsqueeze(0)
301
+ if pupil_selection == "left_pupil":
302
+ left_eye = input_img
303
+ elif pupil_selection == "right_pupil":
304
+ right_eye = input_img
305
+ else:
306
+ left_eye = input_img
307
+ right_eye = input_img
308
  # print("type of left_eye = ", type(left_eye))
 
 
 
 
 
 
 
309
  # print("type of right_eye = ", type(right_eye))
310
+ elif "eyes" in ds_results.keys():
311
+ if "left_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["left_eye"] is not None:
312
+ left_eye = ds_results["eyes"]["left_eye"]
313
+ # print("type of left_eye = ", type(left_eye))
314
+ left_eye = to_pil_image(left_eye).convert("RGB")
315
+ # print("type of left_eye = ", type(left_eye))
316
+
317
+ left_eye = preprocess_function(left_eye)
318
+ # print("type of left_eye = ", type(left_eye))
319
+
320
+ left_eye = left_eye.unsqueeze(0)
321
+ if "right_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["right_eye"] is not None:
322
+ right_eye = ds_results["eyes"]["right_eye"]
323
+ # print("type of right_eye = ", type(right_eye))
324
+ right_eye = to_pil_image(right_eye).convert("RGB")
325
+ # print("type of right_eye = ", type(right_eye))
326
+
327
+ right_eye = preprocess_function(right_eye)
328
+ # print("type of right_eye = ", type(right_eye))
329
+
330
+ right_eye = right_eye.unsqueeze(0)
331
+ else:
332
+ # print("type of input_img = ", type(input_img))
333
+ input_img = preprocess_function(input_img)
334
+ input_img = input_img.unsqueeze(0)
335
+ if pupil_selection == "left_pupil":
336
+ left_eye = input_img
337
+ elif pupil_selection == "right_pupil":
338
+ right_eye = input_img
339
+ else:
340
+ left_eye = input_img
341
+ right_eye = input_img
342
+ # print("type of left_eye = ", type(left_eye))
343
  # print("type of right_eye = ", type(right_eye))
344
 
345
+ # print("left_eye = ", left_eye.shape)
346
+ # print("right_eye = ", right_eye.shape)
347
 
348
+ if pupil_selection == "-":
349
+ selected_eyes = ["left_eye", "right_eye"]
350
+ elif pupil_selection == "left_pupil":
351
+ selected_eyes = ["left_eye"]
 
 
 
352
  elif pupil_selection == "right_pupil":
353
+ selected_eyes = ["right_eye"]
354
+
355
+ for eye_type in selected_eyes:
356
+
357
+ if model is None:
358
+ model_configs = {
359
+ "model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt",
360
+ "registered_model_name": tv_model,
361
+ "num_classes": 1,
362
+ }
363
+ registered_model_name = model_configs["registered_model_name"]
364
+ model = _load_model(model_configs)
365
+
366
+ if registered_model_name == "ResNet18":
367
+ target_layer = model.resnet.layer4[-1].conv2
368
+ elif registered_model_name == "ResNet50":
369
+ target_layer = model.resnet.layer4[-1].conv3
370
+ else:
371
+ raise Exception(f"No target layer available for selected model: {registered_model_name}")
372
+
373
+ if left_eye is not None and eye_type == "left_eye":
374
+ input_img = left_eye
375
+ elif right_eye is not None and eye_type == "right_eye":
376
+ input_img = right_eye
377
+ else:
378
+ raise Exception("Wrong Data")
379
+
380
+ if cam_method is not None:
381
+ cam_extractor = torchcam_methods.__dict__[cam_method](
382
+ model,
383
+ target_layer=target_layer,
384
+ fc_layer=model.resnet.fc,
385
+ input_shape=input_img.shape,
386
+ )
387
+
388
+ # with torch.no_grad():
389
+ out = model(input_img)
390
+ cols[-1].markdown(
391
+ f"<h3>Predicted Pupil Diameter: {out[0].item():.2f} mm</h3>",
392
+ unsafe_allow_html=True,
393
  )
394
+ # cols[-1].text(f"Predicted Pupil Diameter: {out[0].item():.2f}")
395
 
396
+ # Retrieve the CAM
397
+ act_maps = cam_extractor(0, out)
398
+
399
+ # Fuse the CAMs if there are several
400
+ activation_map = act_maps[0] if len(act_maps) == 1 else cam_extractor.fuse_cams(act_maps)
401
+
402
+ # Convert input image and activation map to PIL images
403
+ input_image_pil = to_pil_image(input_img.squeeze(0))
404
+ activation_map_pil = to_pil_image(activation_map, mode="F")
405
+
406
+ # Create the overlayed CAM result
407
+ result = overlay_mask(
408
+ input_image_pil,
409
+ activation_map_pil,
410
+ alpha=0.5,
411
  )
412
 
413
+ # Create a subplot with 1 row and 2 columns
414
+ fig, axs = plt.subplots(1, 2, figsize=(10, 5))
415
+
416
+ # Display the input image
417
+ axs[0].imshow(input_image_pil)
418
+ axs[0].axis("off")
419
+ axs[0].set_title("Input Image")
420
+
421
+ # Display the overlayed CAM result
422
+ axs[1].imshow(result)
423
+ axs[1].axis("off")
424
+ axs[1].set_title("Overlayed CAM")
425
+
426
+ # Display the plot
427
+ cols[-1].pyplot(fig)
428
+ cols[-1].text(f"eye image size: {input_img.shape[-1]} x {input_img.shape[-2]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
 
431
  if __name__ == "__main__":