Files changed (1) hide show
  1. app.py +51 -36
app.py CHANGED
@@ -12,17 +12,26 @@ model_url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolv
12
  model_path = "FER_static_ResNet50_AffectNet.pth"
13
 
14
  response = requests.get(model_url, stream=True)
15
- with open(model_path, 'wb') as file:
16
  for chunk in response.iter_content(chunk_size=8192):
17
  file.write(chunk)
18
 
19
  pth_model = torch.jit.load(model_path)
20
  pth_model.eval()
21
 
22
- DICT_EMO = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'}
 
 
 
 
 
 
 
 
23
 
24
  mp_face_mesh = mp.solutions.face_mesh
25
 
 
26
  def pth_processing(fp):
27
  class PreprocessInput(torch.nn.Module):
28
  def init(self):
@@ -37,24 +46,22 @@ def pth_processing(fp):
37
  return x
38
 
39
  def get_img_torch(img):
40
-
41
- ttransform = transforms.Compose([
42
- transforms.PILToTensor(),
43
- PreprocessInput()
44
- ])
45
  img = img.resize((224, 224), Image.Resampling.NEAREST)
46
  img = ttransform(img)
47
  img = torch.unsqueeze(img, 0)
48
  return img
 
49
  return get_img_torch(fp)
50
 
 
51
  def norm_coordinates(normalized_x, normalized_y, image_width, image_height):
52
-
53
  x_px = min(math.floor(normalized_x * image_width), image_width - 1)
54
  y_px = min(math.floor(normalized_y * image_height), image_height - 1)
55
-
56
  return x_px, y_px
57
 
 
58
  def get_box(fl, w, h):
59
  idx_to_coors = {}
60
  for idx, landmark in enumerate(fl.landmark):
@@ -63,44 +70,51 @@ def get_box(fl, w, h):
63
  if landmark_px:
64
  idx_to_coors[idx] = landmark_px
65
 
66
- x_min = np.min(np.asarray(list(idx_to_coors.values()))[:,0])
67
- y_min = np.min(np.asarray(list(idx_to_coors.values()))[:,1])
68
- endX = np.max(np.asarray(list(idx_to_coors.values()))[:,0])
69
- endY = np.max(np.asarray(list(idx_to_coors.values()))[:,1])
70
 
71
  (startX, startY) = (max(0, x_min), max(0, y_min))
72
  (endX, endY) = (min(w - 1, endX), min(h - 1, endY))
73
-
74
  return startX, startY, endX, endY
75
 
76
- def predict(inp):
77
 
 
78
  inp = np.array(inp)
79
  h, w = inp.shape[:2]
80
 
81
  with mp_face_mesh.FaceMesh(
82
- max_num_faces=1,
83
- refine_landmarks=False,
84
- min_detection_confidence=0.5,
85
- min_tracking_confidence=0.5) as face_mesh:
 
86
  results = face_mesh.process(inp)
87
  if results.multi_face_landmarks:
88
  for fl in results.multi_face_landmarks:
89
- startX, startY, endX, endY = get_box(fl, w, h)
90
- cur_face = inp[startY:endY, startX: endX]
91
  cur_face_n = pth_processing(Image.fromarray(cur_face))
92
- prediction = torch.nn.functional.softmax(pth_model(cur_face_n), dim=1).detach().numpy()[0]
 
 
 
 
93
  confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
94
-
95
  return cur_face, confidences
96
 
 
97
  def clear():
98
  return (
99
  gr.Image(value=None, type="pil"),
100
- gr.Image(value=None,scale=1, elem_classes="dl2"),
101
- gr.Label(value=None,num_top_classes=3, scale=1, elem_classes="dl3")
102
  )
103
 
 
104
  style = """
105
  div.dl1 div.upload-container {
106
  height: 350px;
@@ -154,26 +168,27 @@ with gr.Blocks(css=style) as demo:
154
  submit = gr.Button(
155
  value="Submit", interactive=True, scale=1, elem_classes="submit"
156
  )
157
- clear_btn = gr.Button(
158
- value="Clear", interactive=True, scale=1
159
- )
160
  with gr.Column(scale=1, elem_classes="dl4"):
161
  output_image = gr.Image(scale=1, elem_classes="dl2")
162
  output_label = gr.Label(num_top_classes=3, scale=1, elem_classes="dl3")
163
  gr.Examples(
164
- ["images/fig7.jpg", "images/fig1.jpg", "images/fig2.jpg","images/fig3.jpg",
165
- "images/fig4.jpg", "images/fig5.jpg", "images/fig6.jpg"],
 
 
 
 
 
 
 
166
  [input_image],
167
  )
168
-
169
 
170
  submit.click(
171
  fn=predict,
172
  inputs=[input_image],
173
- outputs=[
174
- output_image,
175
- output_label
176
- ],
177
  queue=True,
178
  )
179
  clear_btn.click(
@@ -188,4 +203,4 @@ with gr.Blocks(css=style) as demo:
188
  )
189
 
190
  if __name__ == "__main__":
191
- demo.queue(api_open=False).launch(share=False)
 
12
  model_path = "FER_static_ResNet50_AffectNet.pth"
13
 
14
  response = requests.get(model_url, stream=True)
15
+ with open(model_path, "wb") as file:
16
  for chunk in response.iter_content(chunk_size=8192):
17
  file.write(chunk)
18
 
19
  pth_model = torch.jit.load(model_path)
20
  pth_model.eval()
21
 
22
+ DICT_EMO = {
23
+ 0: "Neutral",
24
+ 1: "Happiness",
25
+ 2: "Sadness",
26
+ 3: "Surprise",
27
+ 4: "Fear",
28
+ 5: "Disgust",
29
+ 6: "Anger",
30
+ }
31
 
32
  mp_face_mesh = mp.solutions.face_mesh
33
 
34
+
35
  def pth_processing(fp):
36
  class PreprocessInput(torch.nn.Module):
37
  def init(self):
 
46
  return x
47
 
48
  def get_img_torch(img):
49
+ ttransform = transforms.Compose([transforms.PILToTensor(), PreprocessInput()])
 
 
 
 
50
  img = img.resize((224, 224), Image.Resampling.NEAREST)
51
  img = ttransform(img)
52
  img = torch.unsqueeze(img, 0)
53
  return img
54
+
55
  return get_img_torch(fp)
56
 
57
+
58
  def norm_coordinates(normalized_x, normalized_y, image_width, image_height):
 
59
  x_px = min(math.floor(normalized_x * image_width), image_width - 1)
60
  y_px = min(math.floor(normalized_y * image_height), image_height - 1)
61
+
62
  return x_px, y_px
63
 
64
+
65
  def get_box(fl, w, h):
66
  idx_to_coors = {}
67
  for idx, landmark in enumerate(fl.landmark):
 
70
  if landmark_px:
71
  idx_to_coors[idx] = landmark_px
72
 
73
+ x_min = np.min(np.asarray(list(idx_to_coors.values()))[:, 0])
74
+ y_min = np.min(np.asarray(list(idx_to_coors.values()))[:, 1])
75
+ endX = np.max(np.asarray(list(idx_to_coors.values()))[:, 0])
76
+ endY = np.max(np.asarray(list(idx_to_coors.values()))[:, 1])
77
 
78
  (startX, startY) = (max(0, x_min), max(0, y_min))
79
  (endX, endY) = (min(w - 1, endX), min(h - 1, endY))
80
+
81
  return startX, startY, endX, endY
82
 
 
83
 
84
+ def predict(inp):
85
  inp = np.array(inp)
86
  h, w = inp.shape[:2]
87
 
88
  with mp_face_mesh.FaceMesh(
89
+ max_num_faces=1,
90
+ refine_landmarks=False,
91
+ min_detection_confidence=0.5,
92
+ min_tracking_confidence=0.5,
93
+ ) as face_mesh:
94
  results = face_mesh.process(inp)
95
  if results.multi_face_landmarks:
96
  for fl in results.multi_face_landmarks:
97
+ startX, startY, endX, endY = get_box(fl, w, h)
98
+ cur_face = inp[startY:endY, startX:endX]
99
  cur_face_n = pth_processing(Image.fromarray(cur_face))
100
+ prediction = (
101
+ torch.nn.functional.softmax(pth_model(cur_face_n), dim=1)
102
+ .detach()
103
+ .numpy()[0]
104
+ )
105
  confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
106
+
107
  return cur_face, confidences
108
 
109
+
110
  def clear():
111
  return (
112
  gr.Image(value=None, type="pil"),
113
+ gr.Image(value=None, scale=1, elem_classes="dl2"),
114
+ gr.Label(value=None, num_top_classes=3, scale=1, elem_classes="dl3"),
115
  )
116
 
117
+
118
  style = """
119
  div.dl1 div.upload-container {
120
  height: 350px;
 
168
  submit = gr.Button(
169
  value="Submit", interactive=True, scale=1, elem_classes="submit"
170
  )
171
+ clear_btn = gr.Button(value="Clear", interactive=True, scale=1)
 
 
172
  with gr.Column(scale=1, elem_classes="dl4"):
173
  output_image = gr.Image(scale=1, elem_classes="dl2")
174
  output_label = gr.Label(num_top_classes=3, scale=1, elem_classes="dl3")
175
  gr.Examples(
176
+ [
177
+ "images/fig7.jpg",
178
+ "images/fig1.jpg",
179
+ "images/fig2.jpg",
180
+ "images/fig3.jpg",
181
+ "images/fig4.jpg",
182
+ "images/fig5.jpg",
183
+ "images/fig6.jpg",
184
+ ],
185
  [input_image],
186
  )
 
187
 
188
  submit.click(
189
  fn=predict,
190
  inputs=[input_image],
191
+ outputs=[output_image, output_label],
 
 
 
192
  queue=True,
193
  )
194
  clear_btn.click(
 
203
  )
204
 
205
  if __name__ == "__main__":
206
+ demo.queue(api_open=False).launch(share=False)