Johannes commited on
Commit
4efaea0
1 Parent(s): 86cae26
Files changed (1) hide show
  1. app.py +70 -29
app.py CHANGED
@@ -12,13 +12,48 @@ import gradio as gr
12
  import face_detection
13
 
14
 
15
- def detect_faces(img: np.ndarray):
16
- frame = np.array(img)
 
 
 
 
 
 
17
 
18
- kornia_detections = kornia_detect(frame)
19
- retina_detections = retina_detect(frame)
20
- retina_mobile_detections = retina_mobilenet_detect(frame)
21
- dsfd_detections = dsfd_detect(frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  return kornia_detections, retina_detections, retina_mobile_detections, dsfd_detections
24
 
@@ -43,27 +78,27 @@ def base_detect(detector, img):
43
  return img_vis
44
 
45
 
46
- def retina_detect(img):
47
  detector = face_detection.build_detector(
48
- "RetinaNetResNet50", confidence_threshold=.5, nms_iou_threshold=.3)
49
 
50
  img_vis = base_detect(detector, img)
51
 
52
  return img_vis
53
 
54
 
55
- def retina_mobilenet_detect(img):
56
  detector = face_detection.build_detector(
57
- "RetinaNetMobileNetV1", confidence_threshold=.5, nms_iou_threshold=.3)
58
 
59
  img_vis = base_detect(detector, img)
60
 
61
  return img_vis
62
 
63
 
64
- def dsfd_detect(img):
65
  detector = face_detection.build_detector(
66
- "DSFDDetector", confidence_threshold=.5, nms_iou_threshold=.3)
67
 
68
  img_vis = base_detect(detector, img)
69
 
@@ -71,10 +106,9 @@ def dsfd_detect(img):
71
 
72
 
73
 
74
- def kornia_detect(img):
75
  # select the device
76
  device = torch.device('cpu')
77
- vis_threshold = 0.6
78
 
79
  # load the image and scale
80
  img_raw = scale_image(img, 400)
@@ -84,7 +118,8 @@ def kornia_detect(img):
84
  img = K.color.bgr_to_rgb(img.float())
85
 
86
  # create the detector and find the faces !
87
- face_detection = FaceDetector().to(device)
 
88
 
89
  with torch.no_grad():
90
  dets = face_detection(img)
@@ -95,8 +130,6 @@ def kornia_detect(img):
95
  img_vis = img_raw.copy()
96
 
97
  for b in dets:
98
- if b.score < vis_threshold:
99
- continue
100
 
101
  # draw face bounding box
102
  img_vis = cv2.rectangle(img_vis,
@@ -107,29 +140,37 @@ def kornia_detect(img):
107
 
108
  return img_vis
109
 
110
-
111
  input_image = gr.components.Image()
112
-
113
  image_kornia = gr.components.Image(label="Kornia YuNet")
114
  image_retina = gr.components.Image(label="RetinaFace")
115
  image_retina_mobile = gr.components.Image(label="Retina Mobilenet")
116
  image_dsfd = gr.components.Image(label="DSFD")
117
 
118
-
119
- confidence_slider = gr.components.Slider(minimum=0.1, maximum=0.9, value=0.5, label="Confidence Threshold")
120
- nms_slider = gr.components.Slider(minimum=0.1, maximum=0.9, value=0.5, label="Min Number of Neighbours")
121
- # scale_slider = gr.components.Slider(minimum=1.1, maximum=2.0, value=1.3, step=0.1, label="Scale Factor")
122
- # classifier_radio = gr.components.Radio(s)
 
 
 
123
 
124
  #methods_dropdown = gr.components.Dropdown(["Kornia YuNet", "RetinaFace", "RetinaMobile", "DSFD"], value="Kornia YuNet", label="Choose a method")
125
 
126
- description = """Face Detection"""
127
 
 
 
 
 
 
 
128
 
129
- Iface = gr.Interface(
130
- fn=detect_faces,
131
- inputs=[input_image],#, size_slider, neighbour_slider, scale_slider],
132
  outputs=[image_kornia, image_retina, image_retina_mobile, image_dsfd],
133
  examples=[["data/9_Press_Conference_Press_Conference_9_86.jpg"], ["data/12_Group_Group_12_Group_Group_12_39.jpg"], ["data/31_Waiter_Waitress_Waiter_Waitress_31_55.jpg"]],
134
- title="Face Detection",
 
135
  ).launch()
 
12
  import face_detection
13
 
14
 
15
+ def compare_detect_faces(img: np.ndarray,
16
+ confidence_threshold,
17
+ nms_threshold,
18
+ kornia_toggle,
19
+ retina_toggle,
20
+ retina_mobile_toggle,
21
+ dsfd_toggle
22
+ ):
23
 
24
+ detections = []
25
+
26
+ if kornia_toggle=="On":
27
+ kornia_detections = kornia_detect(img,
28
+ confidence_threshold=confidence_threshold,
29
+ nms_threshold=nms_threshold)
30
+ else:
31
+ kornia_detections = None
32
+
33
+ if retina_toggle=="On":
34
+ retina_detections = retina_detect(img,
35
+ confidence_threshold=confidence_threshold,
36
+ nms_threshold=nms_threshold)
37
+ detections.append(retina_detections)
38
+ else:
39
+ retina_detections = None
40
+
41
+ if retina_mobile_toggle=="On":
42
+ retina_mobile_detections = retina_mobilenet_detect(img,
43
+ confidence_threshold=confidence_threshold,
44
+ nms_threshold=nms_threshold)
45
+ detections.append(retina_mobile_detections)
46
+ else:
47
+ retina_mobile_detections = None
48
+
49
+ if dsfd_toggle=="On":
50
+ dsfd_detections = dsfd_detect(img,
51
+ confidence_threshold=confidence_threshold,
52
+ nms_threshold=nms_threshold)
53
+ detections.append(dsfd_detections)
54
+ else:
55
+ dsfd_detections = None
56
+
57
 
58
  return kornia_detections, retina_detections, retina_mobile_detections, dsfd_detections
59
 
 
78
  return img_vis
79
 
80
 
81
+ def retina_detect(img, confidence_threshold, nms_threshold):
82
  detector = face_detection.build_detector(
83
+ "RetinaNetResNet50", confidence_threshold=confidence_threshold, nms_iou_threshold=nms_threshold)
84
 
85
  img_vis = base_detect(detector, img)
86
 
87
  return img_vis
88
 
89
 
90
+ def retina_mobilenet_detect(img, confidence_threshold, nms_threshold):
91
  detector = face_detection.build_detector(
92
+ "RetinaNetMobileNetV1", confidence_threshold=confidence_threshold, nms_iou_threshold=nms_threshold)
93
 
94
  img_vis = base_detect(detector, img)
95
 
96
  return img_vis
97
 
98
 
99
+ def dsfd_detect(img, confidence_threshold, nms_threshold):
100
  detector = face_detection.build_detector(
101
+ "DSFDDetector", confidence_threshold=confidence_threshold, nms_iou_threshold=nms_threshold)
102
 
103
  img_vis = base_detect(detector, img)
104
 
 
106
 
107
 
108
 
109
+ def kornia_detect(img, confidence_threshold, nms_threshold):
110
  # select the device
111
  device = torch.device('cpu')
 
112
 
113
  # load the image and scale
114
  img_raw = scale_image(img, 400)
 
118
  img = K.color.bgr_to_rgb(img.float())
119
 
120
  # create the detector and find the faces !
121
+ face_detection = FaceDetector(confidence_threshold=confidence_threshold,
122
+ nms_threshold=nms_threshold).to(device)
123
 
124
  with torch.no_grad():
125
  dets = face_detection(img)
 
130
  img_vis = img_raw.copy()
131
 
132
  for b in dets:
 
 
133
 
134
  # draw face bounding box
135
  img_vis = cv2.rectangle(img_vis,
 
140
 
141
  return img_vis
142
 
 
143
  input_image = gr.components.Image()
 
144
  image_kornia = gr.components.Image(label="Kornia YuNet")
145
  image_retina = gr.components.Image(label="RetinaFace")
146
  image_retina_mobile = gr.components.Image(label="Retina Mobilenet")
147
  image_dsfd = gr.components.Image(label="DSFD")
148
 
149
+ confidence_slider = gr.components.Slider(minimum=0.1, maximum=0.95, value=0.5, step=0.05, label="Confidence Threshold")
150
+ nms_slider = gr.components.Slider(minimum=0.1, maximum=0.95, value=0.3, step=0.05, label="Non Maximum Supression (NMS) Threshold")
151
+
152
+
153
+ kornia_radio = gr.Radio(["On", "Off"], value="On", label="Kornia YuNet")
154
+ retinanet_radio = gr.Radio(["On", "Off"], value="On", label="RetinaFace")
155
+ retina_mobile_radio = gr.Radio(["On", "Off"], value="On", label="Retina Mobilenets")
156
+ dsfd_radio = gr.Radio(["On", "Off"], value="On", label="DSFD")
157
 
158
  #methods_dropdown = gr.components.Dropdown(["Kornia YuNet", "RetinaFace", "RetinaMobile", "DSFD"], value="Kornia YuNet", label="Choose a method")
159
 
160
+ description = """This space let's you compare different face detection algorithms, based on Convolutional Neural Networks (CNNs).
161
 
162
+ The models used here are:
163
+ * Kornia YuNet:
164
+ * RetinaFace:
165
+ * RetinaMobileNet:
166
+ * DSFD:
167
+ """
168
 
169
+ compare_iface = gr.Interface(
170
+ fn=compare_detect_faces,
171
+ inputs=[input_image, confidence_slider, nms_slider, kornia_radio, retinanet_radio, retina_mobile_radio, dsfd_radio],#, size_slider, neighbour_slider, scale_slider],
172
  outputs=[image_kornia, image_retina, image_retina_mobile, image_dsfd],
173
  examples=[["data/9_Press_Conference_Press_Conference_9_86.jpg"], ["data/12_Group_Group_12_Group_Group_12_39.jpg"], ["data/31_Waiter_Waitress_Waiter_Waitress_31_55.jpg"]],
174
+ title="Face Detections",
175
+ description=description
176
  ).launch()