asrs777's picture
Update app.py
1b35d13
raw
history blame contribute delete
No virus
2.11 kB
'''from pathlib import Path
import shutil
import itertools
import os, cv2, numpy as np'''
import gradio as gr
import torch
from transformers import AutoModelForImageClassification
from optimum.pipelines import pipeline
from PIL import Image
import numpy as np
device = 1 if torch.cuda.is_available() else "cpu"
# chk_point = "kdhht2334/autotrain-diffusion-emotion-facial-expression-recognition-40429105176"
model = AutoModelForImageClassification.from_pretrained("./autotrain-diffusion-emotion-facial-expression-recognition-40429105176")
##Add face detector
from facenet_pytorch import MTCNN, InceptionResnetV1
mtcnn = MTCNN(image_size=300, margin=0, min_face_size=20,
thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True)
resnet = InceptionResnetV1(pretrained='vggface2').eval()
emotion_dict = {
'neutral': '0',
'happy': '1',
'sad' :'2',
'surprise': '3',
'fear': '4',
'disgust': '5',
'angry': '6',
'uncertain': '7',
'nonface': '8',
}
output_img_size = (2100, 700)
try:
pipe = pipeline(
"image-classification",
model,
accelerator="bettertransformer",
device=device,
)
except NotImplementedError:
from transformers import pipeline
pipe = pipeline("image-classification", model, device=device)
def face_detector(input_img):
img = Image.fromarray(input_img)
bbox, _ = mtcnn.detect(img)
bbox = bbox.squeeze().tolist()
crop = img.crop(bbox)
return crop
def predict(image):
cropped_face = face_detector(image)
face_w, face_h = cropped_face.size
face_re_w = int(face_w * (700 / face_h))
resized_face = cropped_face.resize((face_re_w, 700))
output_img = Image.new("RGBA", output_img_size)
output_img.paste(resized_face, (1050 - int(face_re_w/2), 0))
predictions = pipe(cropped_face)
return output_img, {p["label"]: p["score"] for p in predictions}
gr.Interface(
predict,
inputs=gr.inputs.Image(label="Upload image"),
outputs=["image", "label"],
examples=[["examples/happy.png"], ["examples/angry.png"], ["examples/surprise.png"]],
title="Demo - DiffusionFER",
).launch()