DmitryRyumin's picture
Update app.py
09cbd66
raw history blame
No virus
5.78 kB
import torch
from PIL import Image
from torchvision import transforms
import mediapipe as mp
import numpy as np
import math
import requests
import gradio as gr
model_url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolve/main/FER_static_ResNet50_AffectNet.pth"
model_path = "FER_static_ResNet50_AffectNet.pth"
response = requests.get(model_url, stream=True)
with open(model_path, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk)
pth_model = torch.jit.load(model_path)
pth_model.eval()
DICT_EMO = {
0: "Neutral",
1: "Happiness",
2: "Sadness",
3: "Surprise",
4: "Fear",
5: "Disgust",
6: "Anger",
}
mp_face_mesh = mp.solutions.face_mesh
def pth_processing(fp):
class PreprocessInput(torch.nn.Module):
def init(self):
super(PreprocessInput, self).init()
def forward(self, x):
x = x.to(torch.float32)
x = torch.flip(x, dims=(0,))
x[0, :, :] -= 91.4953
x[1, :, :] -= 103.8827
x[2, :, :] -= 131.0912
return x
def get_img_torch(img):
ttransform = transforms.Compose([transforms.PILToTensor(), PreprocessInput()])
img = img.resize((224, 224), Image.Resampling.NEAREST)
img = ttransform(img)
img = torch.unsqueeze(img, 0)
return img
return get_img_torch(fp)
def norm_coordinates(normalized_x, normalized_y, image_width, image_height):
x_px = min(math.floor(normalized_x * image_width), image_width - 1)
y_px = min(math.floor(normalized_y * image_height), image_height - 1)
return x_px, y_px
def get_box(fl, w, h):
idx_to_coors = {}
for idx, landmark in enumerate(fl.landmark):
landmark_px = norm_coordinates(landmark.x, landmark.y, w, h)
if landmark_px:
idx_to_coors[idx] = landmark_px
x_min = np.min(np.asarray(list(idx_to_coors.values()))[:, 0])
y_min = np.min(np.asarray(list(idx_to_coors.values()))[:, 1])
endX = np.max(np.asarray(list(idx_to_coors.values()))[:, 0])
endY = np.max(np.asarray(list(idx_to_coors.values()))[:, 1])
(startX, startY) = (max(0, x_min), max(0, y_min))
(endX, endY) = (min(w - 1, endX), min(h - 1, endY))
return startX, startY, endX, endY
def predict(inp):
inp = np.array(inp)
h, w = inp.shape[:2]
with mp_face_mesh.FaceMesh(
max_num_faces=1,
refine_landmarks=False,
min_detection_confidence=0.5,
min_tracking_confidence=0.5,
) as face_mesh:
results = face_mesh.process(inp)
if results.multi_face_landmarks:
for fl in results.multi_face_landmarks:
startX, startY, endX, endY = get_box(fl, w, h)
cur_face = inp[startY:endY, startX:endX]
cur_face_n = pth_processing(Image.fromarray(cur_face))
prediction = (
torch.nn.functional.softmax(pth_model(cur_face_n), dim=1)
.detach()
.numpy()[0]
)
confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)}
return cur_face, confidences
def clear():
return (
gr.Image(value=None, type="pil"),
gr.Image(value=None, scale=1, elem_classes="dl2"),
gr.Label(value=None, num_top_classes=3, scale=1, elem_classes="dl3"),
)
style = """
div.dl1 div.upload-container {
height: 350px;
max-height: 350px;
}
div.dl2 {
max-height: 200px;
}
div.dl2 img {
max-height: 200px;
}
.submit {
display: inline-block;
padding: 10px 20px;
font-size: 16px;
font-weight: bold;
text-align: center;
text-decoration: none;
cursor: pointer;
border: var(--button-border-width) solid var(--button-primary-border-color);
background: var(--button-primary-background-fill);
color: var(--button-primary-text-color);
border-radius: 8px;
transition: all 0.3s ease;
}
.submit[disabled] {
cursor: not-allowed;
opacity: 0.6;
}
.submit:hover:not([disabled]) {
border-color: var(--button-primary-border-color-hover);
background: var(--button-primary-background-fill-hover);
color: var(--button-primary-text-color-hover);
}
.submit:active:not([disabled]) {
transform: scale(0.98);
}
"""
with gr.Blocks(css=style) as demo:
with gr.Row():
with gr.Column(scale=2, elem_classes="dl1"):
input_image = gr.Image(type="pil")
with gr.Row():
submit = gr.Button(
value="Submit", interactive=True, scale=1, elem_classes="submit"
)
clear_btn = gr.Button(value="Clear", interactive=True, scale=1)
with gr.Column(scale=1, elem_classes="dl4"):
output_image = gr.Image(scale=1, elem_classes="dl2")
output_label = gr.Label(num_top_classes=3, scale=1, elem_classes="dl3")
gr.Examples(
[
"images/fig7.jpg",
"images/fig1.jpg",
"images/fig2.jpg",
"images/fig3.jpg",
"images/fig4.jpg",
"images/fig5.jpg",
"images/fig6.jpg",
],
[input_image],
)
submit.click(
fn=predict,
inputs=[input_image],
outputs=[output_image, output_label],
queue=True,
)
clear_btn.click(
fn=clear,
inputs=[],
outputs=[
input_image,
output_image,
output_label,
],
queue=True,
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(share=False)