hysts HF staff commited on
Commit
18eea93
1 Parent(s): c7fd838
.gitattributes CHANGED
@@ -26,3 +26,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
26
  *.zip filter=lfs diff=lfs merge=lfs -text
27
  *.zstandard filter=lfs diff=lfs merge=lfs -text
28
  *tfevents* filter=lfs diff=lfs merge=lfs -text
29
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore DELETED
@@ -1 +0,0 @@
1
- images
 
 
.pre-commit-config.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.5.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.13.2
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.8.0
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ [
33
+ "types-python-slugify",
34
+ "types-requests",
35
+ "types-PyYAML",
36
+ "types-pytz",
37
+ ]
38
+ - repo: https://github.com/psf/black
39
+ rev: 24.2.0
40
+ hooks:
41
+ - id: black
42
+ language_version: python3.10
43
+ args: ["--line-length", "119"]
44
+ - repo: https://github.com/kynan/nbstripout
45
+ rev: 0.7.1
46
+ hooks:
47
+ - id: nbstripout
48
+ args:
49
+ [
50
+ "--extra-keys",
51
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
52
+ ]
53
+ - repo: https://github.com/nbQA-dev/nbQA
54
+ rev: 1.7.1
55
+ hooks:
56
+ - id: nbqa-black
57
+ - id: nbqa-pyupgrade
58
+ args: ["--py37-plus"]
59
+ - id: nbqa-isort
60
+ args: ["--float-to-top"]
.vscode/settings.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "editor.formatOnSave": true,
3
+ "files.insertFinalNewline": false,
4
+ "[python]": {
5
+ "editor.defaultFormatter": "ms-python.black-formatter",
6
+ "editor.formatOnType": true,
7
+ "editor.codeActionsOnSave": {
8
+ "source.organizeImports": "explicit"
9
+ }
10
+ },
11
+ "[jupyter]": {
12
+ "files.insertFinalNewline": false
13
+ },
14
+ "black-formatter.args": [
15
+ "--line-length=119"
16
+ ],
17
+ "isort.args": ["--profile", "black"],
18
+ "flake8.args": [
19
+ "--max-line-length=119"
20
+ ],
21
+ "ruff.lint.args": [
22
+ "--line-length=119"
23
+ ],
24
+ "notebook.output.scrolling": true,
25
+ "notebook.formatOnCellExecution": true,
26
+ "notebook.formatOnSave.enabled": true,
27
+ "notebook.codeActionsOnSave": {
28
+ "source.organizeImports": "explicit"
29
+ }
30
+ }
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🔥
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.19.2
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -2,90 +2,50 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
- import functools
7
  import os
8
  import pathlib
9
  import sys
10
- import tarfile
11
 
12
  import cv2
13
  import gradio as gr
14
- import huggingface_hub
15
  import numpy as np
16
  import torch
17
 
18
- sys.path.insert(0, 'face_detection')
19
- sys.path.insert(0, 'face_alignment')
20
- sys.path.insert(0, 'emotion_recognition')
21
 
22
  from ibug.emotion_recognition import EmoNetPredictor
23
  from ibug.face_alignment import FANPredictor
24
  from ibug.face_detection import RetinaFacePredictor
25
 
26
- TITLE = 'ibug-group/emotion_recognition'
27
- DESCRIPTION = 'This is an unofficial demo for https://github.com/ibug-group/emotion_recognition.'
28
- ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.ibug-emotion_recognition" alt="visitor badge"/></center>'
29
-
30
- TOKEN = os.environ['TOKEN']
31
-
32
-
33
- def parse_args() -> argparse.Namespace:
34
- parser = argparse.ArgumentParser()
35
- parser.add_argument('--device', type=str, default='cpu')
36
- parser.add_argument('--theme', type=str)
37
- parser.add_argument('--live', action='store_true')
38
- parser.add_argument('--share', action='store_true')
39
- parser.add_argument('--port', type=int)
40
- parser.add_argument('--disable-queue',
41
- dest='enable_queue',
42
- action='store_false')
43
- parser.add_argument('--allow-flagging', type=str, default='never')
44
- return parser.parse_args()
45
-
46
-
47
- def load_sample_images() -> list[pathlib.Path]:
48
- image_dir = pathlib.Path('images')
49
- if not image_dir.exists():
50
- image_dir.mkdir()
51
- dataset_repo = 'hysts/input-images'
52
- filenames = ['004.tar']
53
- for name in filenames:
54
- path = huggingface_hub.hf_hub_download(dataset_repo,
55
- name,
56
- repo_type='dataset',
57
- use_auth_token=TOKEN)
58
- with tarfile.open(path) as f:
59
- f.extractall(image_dir.as_posix())
60
- return sorted(image_dir.rglob('*.jpg'))
61
-
62
-
63
- def load_face_detector(device: torch.device) -> RetinaFacePredictor:
64
- model = RetinaFacePredictor(
65
- threshold=0.8,
66
- device=device,
67
- model=RetinaFacePredictor.get_model('mobilenet0.25'))
68
- return model
69
 
70
 
71
- def load_landmark_detector(device: torch.device) -> FANPredictor:
72
- model = FANPredictor(device=device, model=FANPredictor.get_model('2dfan2'))
73
- return model
 
74
 
75
 
76
  def load_model(model_name: str, device: torch.device) -> EmoNetPredictor:
77
- model = EmoNetPredictor(device=device,
78
- model=EmoNetPredictor.get_model(model_name))
79
  return model
80
 
81
 
82
- def predict(image: np.ndarray, model_name: str, max_num_faces: int,
83
- face_detector: RetinaFacePredictor,
84
- landmark_detector: FANPredictor,
85
- models: dict[str, EmoNetPredictor]) -> np.ndarray:
 
 
 
 
 
 
86
  model = models[model_name]
87
  if len(model.config.emotion_labels) == 8:
88
- colors = (
89
  (192, 192, 192),
90
  (0, 255, 0),
91
  (255, 0, 0),
@@ -109,13 +69,10 @@ def predict(image: np.ndarray, model_name: str, max_num_faces: int,
109
 
110
  faces = face_detector(image, rgb=False)
111
  if len(faces) == 0:
112
- raise RuntimeError('No face was found.')
113
  faces = sorted(list(faces), key=lambda x: -x[4])[:max_num_faces]
114
  faces = np.asarray(faces)
115
- _, _, features = landmark_detector(image,
116
- faces,
117
- rgb=False,
118
- return_features=True)
119
  emotions = model(features)
120
 
121
  res = image.copy()
@@ -123,71 +80,54 @@ def predict(image: np.ndarray, model_name: str, max_num_faces: int,
123
  box = np.round(face[:4]).astype(int)
124
  cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), 2)
125
 
126
- emotion = emotions['emotion'][index]
127
- valence = emotions['valence'][index]
128
- arousal = emotions['arousal'][index]
129
  emotion_label = model.config.emotion_labels[emotion].title()
130
 
131
- text_content = f'{emotion_label} ({valence: .01f}, {arousal: .01f})'
132
- cv2.putText(res,
133
- text_content, (box[0], box[1] - 10),
134
- cv2.FONT_HERSHEY_DUPLEX,
135
- 1,
136
- colors[emotion],
137
- lineType=cv2.LINE_AA)
138
 
139
  return res[:, :, ::-1]
140
 
141
 
142
- def main():
143
- args = parse_args()
144
- device = torch.device(args.device)
145
-
146
- face_detector = load_face_detector(device)
147
- landmark_detector = load_landmark_detector(device)
148
-
149
- model_names = [
150
- 'emonet248',
151
- 'emonet245',
152
- 'emonet248_alt',
153
- 'emonet245_alt',
154
- ]
155
- models = {name: load_model(name, device=device) for name in model_names}
156
-
157
- func = functools.partial(predict,
158
- face_detector=face_detector,
159
- landmark_detector=landmark_detector,
160
- models=models)
161
- func = functools.update_wrapper(func, predict)
162
-
163
- image_paths = load_sample_images()
164
- examples = [[path.as_posix(), model_names[0], 30] for path in image_paths]
165
-
166
- gr.Interface(
167
- func,
168
- [
169
- gr.inputs.Image(type='numpy', label='Input'),
170
- gr.inputs.Radio(model_names,
171
- type='value',
172
- default=model_names[0],
173
- label='Model'),
174
- gr.inputs.Slider(
175
- 1, 30, step=1, default=30, label='Max Number of Faces'),
176
- ],
177
- gr.outputs.Image(type='numpy', label='Output'),
178
- examples=examples,
179
- title=TITLE,
180
- description=DESCRIPTION,
181
- article=ARTICLE,
182
- theme=args.theme,
183
- allow_flagging=args.allow_flagging,
184
- live=args.live,
185
- ).launch(
186
- enable_queue=args.enable_queue,
187
- server_port=args.port,
188
- share=args.share,
189
  )
190
 
191
 
192
- if __name__ == '__main__':
193
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
 
5
  import os
6
  import pathlib
7
  import sys
 
8
 
9
  import cv2
10
  import gradio as gr
 
11
  import numpy as np
12
  import torch
13
 
14
+ sys.path.insert(0, "face_detection")
15
+ sys.path.insert(0, "face_alignment")
16
+ sys.path.insert(0, "emotion_recognition")
17
 
18
  from ibug.emotion_recognition import EmoNetPredictor
19
  from ibug.face_alignment import FANPredictor
20
  from ibug.face_detection import RetinaFacePredictor
21
 
22
+ DESCRIPTION = "# [ibug-group/emotion_recognition](https://github.com/ibug-group/emotion_recognition)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
+
27
+ face_detector = RetinaFacePredictor(threshold=0.8, device=device, model=RetinaFacePredictor.get_model("mobilenet0.25"))
28
+ landmark_detector = FANPredictor(device=device, model=FANPredictor.get_model("2dfan2"))
29
 
30
 
31
  def load_model(model_name: str, device: torch.device) -> EmoNetPredictor:
32
+ model = EmoNetPredictor(device=device, model=EmoNetPredictor.get_model(model_name))
 
33
  return model
34
 
35
 
36
+ model_names = [
37
+ "emonet248",
38
+ "emonet245",
39
+ "emonet248_alt",
40
+ "emonet245_alt",
41
+ ]
42
+ models = {name: load_model(name, device=device) for name in model_names}
43
+
44
+
45
+ def predict(image: np.ndarray, model_name: str, max_num_faces: int) -> np.ndarray:
46
  model = models[model_name]
47
  if len(model.config.emotion_labels) == 8:
48
+ colors: tuple[tuple[int, int, int], ...] = (
49
  (192, 192, 192),
50
  (0, 255, 0),
51
  (255, 0, 0),
 
69
 
70
  faces = face_detector(image, rgb=False)
71
  if len(faces) == 0:
72
+ raise gr.Error("No face was found.")
73
  faces = sorted(list(faces), key=lambda x: -x[4])[:max_num_faces]
74
  faces = np.asarray(faces)
75
+ _, _, features = landmark_detector(image, faces, rgb=False, return_features=True)
 
 
 
76
  emotions = model(features)
77
 
78
  res = image.copy()
 
80
  box = np.round(face[:4]).astype(int)
81
  cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), 2)
82
 
83
+ emotion = emotions["emotion"][index]
84
+ valence = emotions["valence"][index]
85
+ arousal = emotions["arousal"][index]
86
  emotion_label = model.config.emotion_labels[emotion].title()
87
 
88
+ text_content = f"{emotion_label} ({valence: .01f}, {arousal: .01f})"
89
+ cv2.putText(
90
+ res, text_content, (box[0], box[1] - 10), cv2.FONT_HERSHEY_DUPLEX, 1, colors[emotion], lineType=cv2.LINE_AA
91
+ )
 
 
 
92
 
93
  return res[:, :, ::-1]
94
 
95
 
96
+ with gr.Blocks(css="style.css") as demo:
97
+ gr.Markdown(DESCRIPTION)
98
+ with gr.Row():
99
+ with gr.Column():
100
+ image = gr.Image(label="Input", type="numpy")
101
+ model_name = gr.Radio(
102
+ label="Model",
103
+ choices=model_names,
104
+ value=model_names[0],
105
+ type="value",
106
+ )
107
+ max_num_of_faces = gr.Slider(
108
+ label="Max Number of Faces",
109
+ minimum=1,
110
+ maximum=30,
111
+ step=1,
112
+ value=30,
113
+ )
114
+ run_button = gr.Button()
115
+ with gr.Column():
116
+ result = gr.Image(label="Output")
117
+ gr.Examples(
118
+ examples=[[path.as_posix(), model_names[0], 30] for path in sorted(pathlib.Path("images").rglob("*.jpg"))],
119
+ inputs=[image, model_name, max_num_of_faces],
120
+ outputs=result,
121
+ fn=predict,
122
+ cache_examples=os.getenv("CACHE_EXAMPLES") == "1",
123
+ )
124
+ run_button.click(
125
+ fn=predict,
126
+ inputs=[image, model_name, max_num_of_faces],
127
+ outputs=result,
128
+ api_name="predict",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  )
130
 
131
 
132
+ if __name__ == "__main__":
133
+ demo.queue(max_size=20).launch()
images/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ These images are from the following public domain:
2
+ - https://www.pexels.com/photo/collage-photo-of-woman-3812743/
3
+ - https://www.pexels.com/photo/collage-of-portraits-of-cheerful-woman-3807758/
images/pexels-andrea-piacquadio-3807758.jpg ADDED

Git LFS Details

  • SHA256: a2e5b4281b1ab26f11d3908e1a953edce35942e20a8c9a427fbf26e494c78a7a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
images/pexels-andrea-piacquadio-3812743.jpg ADDED

Git LFS Details

  • SHA256: e9ca4821fe880d3d5362b86082e547aad7efb43dab0f27cd4989128f141f9b3f
  • Pointer size: 132 Bytes
  • Size of remote file: 2.15 MB
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- numpy==1.22.3
2
- opencv-python-headless==4.5.5.64
3
- torch==1.11.0
4
- torchvision==0.12.0
 
1
+ numpy==1.26.4
2
+ opencv-python-headless==4.9.0.80
3
+ torch==2.0.1
4
+ torchvision==0.15.2
style.css ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ display: block;
4
+ }
5
+
6
+ #duplicate-button {
7
+ margin: auto;
8
+ color: #fff;
9
+ background: #1565c0;
10
+ border-radius: 100vh;
11
+ }