ElenaRyumina commited on
Commit
031ec86
β€’
1 Parent(s): c601827
Files changed (10) hide show
  1. .gitignore +171 -0
  2. LICENSE +21 -0
  3. app.py +191 -0
  4. images/fig1.jpg +0 -0
  5. images/fig2.jpg +0 -0
  6. images/fig3.jpg +0 -0
  7. images/fig4.jpg +0 -0
  8. images/fig5.jpg +0 -0
  9. images/fig6.jpg +0 -0
  10. images/fig7.jpg +0 -0
.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Compiled source #
2
+ ###################
3
+ *.com
4
+ *.class
5
+ *.dll
6
+ *.exe
7
+ *.o
8
+ *.so
9
+ *.pyc
10
+
11
+ # Packages #
12
+ ############
13
+ # it's better to unpack these files and commit the raw source
14
+ # git has its own built in compression methods
15
+ *.7z
16
+ *.dmg
17
+ *.gz
18
+ *.iso
19
+ *.rar
20
+ #*.tar
21
+ *.zip
22
+
23
+ # Logs and databases #
24
+ ######################
25
+ *.log
26
+ *.sqlite
27
+
28
+ # OS generated files #
29
+ ######################
30
+ .DS_Store
31
+ ehthumbs.db
32
+ Icon
33
+ Thumbs.db
34
+ .tmtags
35
+ .idea
36
+ .vscode
37
+ tags
38
+ vendor.tags
39
+ tmtagsHistory
40
+ *.sublime-project
41
+ *.sublime-workspace
42
+ .bundle
43
+
44
+ # Byte-compiled / optimized / DLL files
45
+ __pycache__/
46
+ *.py[cod]
47
+ *$py.class
48
+
49
+ # C extensions
50
+ *.so
51
+
52
+ # Distribution / packaging
53
+ .Python
54
+ build/
55
+ develop-eggs/
56
+ dist/
57
+ downloads/
58
+ eggs/
59
+ .eggs/
60
+ lib/
61
+ lib64/
62
+ parts/
63
+ sdist/
64
+ var/
65
+ wheels/
66
+ pip-wheel-metadata/
67
+ share/python-wheels/
68
+ *.egg-info/
69
+ .installed.cfg
70
+ *.egg
71
+ MANIFEST
72
+ node_modules/
73
+
74
+ # PyInstaller
75
+ # Usually these files are written by a python script from a template
76
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
77
+ *.manifest
78
+ *.spec
79
+
80
+ # Installer logs
81
+ pip-log.txt
82
+ pip-delete-this-directory.txt
83
+
84
+ # Unit test / coverage reports
85
+ htmlcov/
86
+ .tox/
87
+ .nox/
88
+ .coverage
89
+ .coverage.*
90
+ .cache
91
+ nosetests.xml
92
+ coverage.xml
93
+ *.cover
94
+ .hypothesis/
95
+ .pytest_cache/
96
+
97
+ # Translations
98
+ *.mo
99
+ *.pot
100
+
101
+ # Django stuff:
102
+ *.log
103
+ local_settings.py
104
+ db.sqlite3
105
+ db.sqlite3-journal
106
+
107
+ # Flask stuff:
108
+ instance/
109
+ .webassets-cache
110
+
111
+ # Scrapy stuff:
112
+ .scrapy
113
+
114
+ # Sphinx documentation
115
+ docs/_build/
116
+
117
+ # PyBuilder
118
+ target/
119
+
120
+ # Jupyter Notebook
121
+ .ipynb_checkpoints
122
+
123
+ # IPython
124
+ profile_default/
125
+ ipython_config.py
126
+
127
+ # pyenv
128
+ .python-version
129
+
130
+ # pipenv
131
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
132
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
133
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
134
+ # install all needed dependencies.
135
+ #Pipfile.lock
136
+
137
+ # celery beat schedule file
138
+ celerybeat-schedule
139
+
140
+ # SageMath parsed files
141
+ *.sage.py
142
+
143
+ # Environments
144
+ .env
145
+ .venv
146
+ env/
147
+ venv/
148
+ ENV/
149
+ env.bak/
150
+ venv.bak/
151
+
152
+ # Spyder project settings
153
+ .spyderproject
154
+ .spyproject
155
+
156
+ # Rope project settings
157
+ .ropeproject
158
+
159
+ # mkdocs documentation
160
+ /site
161
+
162
+ # mypy
163
+ .mypy_cache/
164
+ .dmypy.json
165
+ dmypy.json
166
+
167
+ # Pyre type checker
168
+ .pyre/
169
+
170
+ # Custom
171
+ *.pth
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Elena Ryumina
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torchvision import transforms
4
+ import mediapipe as mp
5
+ import numpy as np
6
+ import math
7
+ import requests
8
+
9
+ import gradio as gr
10
+
11
+ model_url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolve/main/FER_static_ResNet50_AffectNet.pth"
12
+ model_path = "FER_static_ResNet50_AffectNet.pth"
13
+
14
+ response = requests.get(model_url, stream=True)
15
+ with open(model_path, 'wb') as file:
16
+ for chunk in response.iter_content(chunk_size=8192):
17
+ file.write(chunk)
18
+
19
+ pth_model = torch.jit.load(model_path).to('cuda')
20
+ pth_model.eval()
21
+
22
+ DICT_EMO = {0: 'Neutral', 1: 'Happiness', 2: 'Sadness', 3: 'Surprise', 4: 'Fear', 5: 'Disgust', 6: 'Anger'}
23
+
24
+ mp_face_mesh = mp.solutions.face_mesh
25
+
26
+ def pth_processing(fp):
27
+ class PreprocessInput(torch.nn.Module):
28
+ def init(self):
29
+ super(PreprocessInput, self).init()
30
+
31
+ def forward(self, x):
32
+ x = x.to(torch.float32)
33
+ x = torch.flip(x, dims=(0,))
34
+ x[0, :, :] -= 91.4953
35
+ x[1, :, :] -= 103.8827
36
+ x[2, :, :] -= 131.0912
37
+ return x
38
+
39
+ def get_img_torch(img):
40
+
41
+ ttransform = transforms.Compose([
42
+ transforms.PILToTensor(),
43
+ PreprocessInput()
44
+ ])
45
+ img = img.resize((224, 224), Image.Resampling.NEAREST)
46
+ img = ttransform(img)
47
+ img = torch.unsqueeze(img, 0).to('cuda')
48
+ return img
49
+ return get_img_torch(fp)
50
+
51
+ def norm_coordinates(normalized_x, normalized_y, image_width, image_height):
52
+
53
+ x_px = min(math.floor(normalized_x * image_width), image_width - 1)
54
+ y_px = min(math.floor(normalized_y * image_height), image_height - 1)
55
+
56
+ return x_px, y_px
57
+
58
+ def get_box(fl, w, h):
59
+ idx_to_coors = {}
60
+ for idx, landmark in enumerate(fl.landmark):
61
+ landmark_px = norm_coordinates(landmark.x, landmark.y, w, h)
62
+
63
+ if landmark_px:
64
+ idx_to_coors[idx] = landmark_px
65
+
66
+ x_min = np.min(np.asarray(list(idx_to_coors.values()))[:,0])
67
+ y_min = np.min(np.asarray(list(idx_to_coors.values()))[:,1])
68
+ endX = np.max(np.asarray(list(idx_to_coors.values()))[:,0])
69
+ endY = np.max(np.asarray(list(idx_to_coors.values()))[:,1])
70
+
71
+ (startX, startY) = (max(0, x_min), max(0, y_min))
72
+ (endX, endY) = (min(w - 1, endX), min(h - 1, endY))
73
+
74
+ return startX, startY, endX, endY
75
+
76
+ def predict(inp):
77
+
78
+ inp = np.array(inp)
79
+ h, w = inp.shape[:2]
80
+
81
+ with mp_face_mesh.FaceMesh(
82
+ max_num_faces=1,
83
+ refine_landmarks=False,
84
+ min_detection_confidence=0.5,
85
+ min_tracking_confidence=0.5) as face_mesh:
86
+ results = face_mesh.process(inp)
87
+ if results.multi_face_landmarks:
88
+ for fl in results.multi_face_landmarks:
89
+ startX, startY, endX, endY = get_box(fl, w, h)
90
+ cur_face = inp[startY:endY, startX: endX]
91
+ cur_face_n = pth_processing(Image.fromarray(cur_face))
92
+ prediction = torch.nn.functional.softmax(pth_model(cur_face_n), dim=1).cpu().detach().numpy()[0]
93
+ confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
94
+
95
+ return cur_face, confidences
96
+
97
+ def clear():
98
+ return (
99
+ gr.Image(value=None, type="pil"),
100
+ gr.Image(value=None,scale=1, elem_classes="dl2"),
101
+ gr.Label(value=None,num_top_classes=3, scale=1, elem_classes="dl3")
102
+ )
103
+
104
+ style = """
105
+ div.dl1 div.upload-container {
106
+ height: 350px;
107
+ max-height: 350px;
108
+ }
109
+
110
+ div.dl2 {
111
+ max-height: 200px;
112
+ }
113
+
114
+ div.dl2 img {
115
+ max-height: 200px;
116
+ }
117
+
118
+ .submit {
119
+ display: inline-block;
120
+ padding: 10px 20px;
121
+ font-size: 16px;
122
+ font-weight: bold;
123
+ text-align: center;
124
+ text-decoration: none;
125
+ cursor: pointer;
126
+ border: var(--button-border-width) solid var(--button-primary-border-color);
127
+ background: var(--button-primary-background-fill);
128
+ color: var(--button-primary-text-color);
129
+ border-radius: 8px;
130
+ transition: all 0.3s ease;
131
+ }
132
+
133
+ .submit[disabled] {
134
+ cursor: not-allowed;
135
+ opacity: 0.6;
136
+ }
137
+
138
+ .submit:hover:not([disabled]) {
139
+ border-color: var(--button-primary-border-color-hover);
140
+ background: var(--button-primary-background-fill-hover);
141
+ color: var(--button-primary-text-color-hover);
142
+ }
143
+
144
+ .submit:active:not([disabled]) {
145
+ transform: scale(0.98);
146
+ }
147
+ """
148
+
149
+ with gr.Blocks(css=style) as demo:
150
+ with gr.Row():
151
+ with gr.Column(scale=2, elem_classes="dl1"):
152
+ input_image = gr.Image(type="pil")
153
+ with gr.Row():
154
+ submit = gr.Button(
155
+ value="Submit", interactive=True, scale=1, elem_classes="submit"
156
+ )
157
+ clear_btn = gr.Button(
158
+ value="Clear", interactive=True, scale=1
159
+ )
160
+ with gr.Column(scale=1, elem_classes="dl4"):
161
+ output_image = gr.Image(scale=1, elem_classes="dl2")
162
+ output_label = gr.Label(num_top_classes=3, scale=1, elem_classes="dl3")
163
+ gr.Examples(
164
+ ["images/fig7.jpg", "images/fig1.jpg", "images/fig2.jpg","images/fig3.jpg",
165
+ "images/fig4.jpg", "images/fig5.jpg", "images/fig6.jpg"],
166
+ [input_image],
167
+ )
168
+
169
+
170
+ submit.click(
171
+ fn=predict,
172
+ inputs=[input_image],
173
+ outputs=[
174
+ output_image,
175
+ output_label
176
+ ],
177
+ queue=True,
178
+ )
179
+ clear_btn.click(
180
+ fn=clear,
181
+ inputs=[],
182
+ outputs=[
183
+ input_image,
184
+ output_image,
185
+ output_label,
186
+ ],
187
+ queue=True,
188
+ )
189
+
190
+ if __name__ == "__main__":
191
+ demo.queue(api_open=False).launch(share=False)
images/fig1.jpg ADDED
images/fig2.jpg ADDED
images/fig3.jpg ADDED
images/fig4.jpg ADDED
images/fig5.jpg ADDED
images/fig6.jpg ADDED
images/fig7.jpg ADDED