import torch from PIL import Image from torchvision import transforms import mediapipe as mp import numpy as np import math import requests import gradio as gr model_url = "https://huggingface.co/ElenaRyumina/face_emotion_recognition/resolve/main/FER_static_ResNet50_AffectNet.pth" model_path = "FER_static_ResNet50_AffectNet.pth" response = requests.get(model_url, stream=True) with open(model_path, "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) pth_model = torch.jit.load(model_path) pth_model.eval() DICT_EMO = { 0: "Neutral", 1: "Happiness", 2: "Sadness", 3: "Surprise", 4: "Fear", 5: "Disgust", 6: "Anger", } mp_face_mesh = mp.solutions.face_mesh def pth_processing(fp): class PreprocessInput(torch.nn.Module): def init(self): super(PreprocessInput, self).init() def forward(self, x): x = x.to(torch.float32) x = torch.flip(x, dims=(0,)) x[0, :, :] -= 91.4953 x[1, :, :] -= 103.8827 x[2, :, :] -= 131.0912 return x def get_img_torch(img): ttransform = transforms.Compose([transforms.PILToTensor(), PreprocessInput()]) img = img.resize((224, 224), Image.Resampling.NEAREST) img = ttransform(img) img = torch.unsqueeze(img, 0) return img return get_img_torch(fp) def norm_coordinates(normalized_x, normalized_y, image_width, image_height): x_px = min(math.floor(normalized_x * image_width), image_width - 1) y_px = min(math.floor(normalized_y * image_height), image_height - 1) return x_px, y_px def get_box(fl, w, h): idx_to_coors = {} for idx, landmark in enumerate(fl.landmark): landmark_px = norm_coordinates(landmark.x, landmark.y, w, h) if landmark_px: idx_to_coors[idx] = landmark_px x_min = np.min(np.asarray(list(idx_to_coors.values()))[:, 0]) y_min = np.min(np.asarray(list(idx_to_coors.values()))[:, 1]) endX = np.max(np.asarray(list(idx_to_coors.values()))[:, 0]) endY = np.max(np.asarray(list(idx_to_coors.values()))[:, 1]) (startX, startY) = (max(0, x_min), max(0, y_min)) (endX, endY) = (min(w - 1, endX), min(h - 1, endY)) return startX, startY, endX, endY def predict(inp): inp = np.array(inp) h, w = inp.shape[:2] with mp_face_mesh.FaceMesh( max_num_faces=1, refine_landmarks=False, min_detection_confidence=0.5, min_tracking_confidence=0.5, ) as face_mesh: results = face_mesh.process(inp) if results.multi_face_landmarks: for fl in results.multi_face_landmarks: startX, startY, endX, endY = get_box(fl, w, h) cur_face = inp[startY:endY, startX:endX] cur_face_n = pth_processing(Image.fromarray(cur_face)) prediction = ( torch.nn.functional.softmax(pth_model(cur_face_n), dim=1) .detach() .numpy()[0] ) confidences = {DICT_EMO[i]: float(prediction[i]) for i in range(7)} return cur_face, confidences def clear(): return ( gr.Image(value=None, type="pil"), gr.Image(value=None, scale=1, elem_classes="dl2"), gr.Label(value=None, num_top_classes=3, scale=1, elem_classes="dl3"), ) style = """ div.dl1 div.upload-container { height: 350px; max-height: 350px; } div.dl2 { max-height: 200px; } div.dl2 img { max-height: 200px; } .submit { display: inline-block; padding: 10px 20px; font-size: 16px; font-weight: bold; text-align: center; text-decoration: none; cursor: pointer; border: var(--button-border-width) solid var(--button-primary-border-color); background: var(--button-primary-background-fill); color: var(--button-primary-text-color); border-radius: 8px; transition: all 0.3s ease; } .submit[disabled] { cursor: not-allowed; opacity: 0.6; } .submit:hover:not([disabled]) { border-color: var(--button-primary-border-color-hover); background: var(--button-primary-background-fill-hover); color: var(--button-primary-text-color-hover); } .submit:active:not([disabled]) { transform: scale(0.98); } """ with gr.Blocks(css=style) as demo: with gr.Row(): with gr.Column(scale=2, elem_classes="dl1"): input_image = gr.Image(type="pil") with gr.Row(): submit = gr.Button( value="Submit", interactive=True, scale=1, elem_classes="submit" ) clear_btn = gr.Button(value="Clear", interactive=True, scale=1) with gr.Column(scale=1, elem_classes="dl4"): output_image = gr.Image(scale=1, elem_classes="dl2") output_label = gr.Label(num_top_classes=3, scale=1, elem_classes="dl3") gr.Examples( [ "images/fig7.jpg", "images/fig1.jpg", "images/fig2.jpg", "images/fig3.jpg", "images/fig4.jpg", "images/fig5.jpg", "images/fig6.jpg", ], [input_image], ) submit.click( fn=predict, inputs=[input_image], outputs=[output_image, output_label], queue=True, ) clear_btn.click( fn=clear, inputs=[], outputs=[ input_image, output_image, output_label, ], queue=True, ) if __name__ == "__main__": demo.queue(api_open=False).launch(share=False)