kdhht2334's picture
Upload 8 files (#1)
134050f
raw
history blame
No virus
4.81 kB
from pathlib import Path
import gradio as gr
import torch
from transformers import AutoModelForImageClassification
import shutil
from optimum.pipelines import pipeline
import os, cv2, numpy as np
from PIL import Image
import itertools
device = 1 if torch.cuda.is_available() else "cpu"
chk_point = "kdhht2334/autotrain-diffusion-emotion-facial-expression-recognition-40429105176"
model = AutoModelForImageClassification.from_pretrained(chk_point)
##Add face detector
protoPath = os.path.sep.join(["face_detector", "deploy.prototxt"])
modelPath = os.path.sep.join(["face_detector",
"res10_300x300_ssd_iter_140000.caffemodel"])
net = cv2.dnn.readNetFromCaffe(protoPath, modelPath)
def draw_results_ssd(detected, input_img, ad):
img_w, img_h, _ = np.shape(input_img)
# loop over the detections
cropped_output = []
if detected.shape[2] > 0:
print("detected face: {}".format(detected.shape[2]))
for i in range(0, detected.shape[2]):
# extract the confidence (i.e., probability) associated with the
# prediction
confidence = detected[0, 0, i, 2]
# filter out weak detections
if confidence > 0.5:
# compute the (x, y)-coordinates of the bounding box for
# the face and extract the face ROI
(h0, w0) = input_img.shape[:2]
box = detected[0, 0, i, 3:7] * np.array([w0, h0, w0, h0])
(startX, startY, endX, endY) = box.astype("int")
# print((startX, startY, endX, endY))
x1 = startX
y1 = startY
w = endX - startX
h = endY - startY
x2 = x1 + w
y2 = y1 + h
xw1 = max(int(x1 - ad * w), 0)
yw1 = max(int(y1 - ad * h), 0)
xw2 = min(int(x2 + ad * w), img_w - 1)
yw2 = min(int(y2 + ad * h), img_h - 1)
face_crop = input_img[yw1:yw2, xw1:xw2]
if face_crop.size != 0:
cropped_output.append(face_crop)
return cropped_output
emotion_dict = {
'neutral': '0',
'happy': '1',
'sad' :'2',
'surprise': '3',
'fear': '4',
'disgust': '5',
'angry': '6',
'uncertain': '7',
'nonface': '8',
}
try:
pipe = pipeline(
"image-classification",
chk_point,
accelerator="bettertransformer",
device=device,
)
except NotImplementedError:
from transformers import pipeline
pipe = pipeline("image-classification", chk_point, device=device)
# def make_label_folders():
# folders = model.config.label2id.keys()
# for folder in folders:
# folder = Path(folder)
# if not folder.exists():
# folder.mkdir()
# return folders
# def predictions_into_folders(files):
# files = [file.name for file in files]
# files = [
# file for file in files if not file.startswith(".") and "DS_Store" not in file
# ]
# folders = make_label_folders()
# predictions = pipe(files)
# for file, prediction in zip(files, predictions):
# label = prediction[0]["label"]
# file_name = Path(file).name
# shutil.copy(file, f"{label}/{file_name}")
# for folder in folders:
# shutil.make_archive(folder, "zip", ".", folder)
# return [f"{folder}.zip" for folder in folders]
# demo = gr.Interface(
# predictions_into_folders,
# gr.Files(file_count="directory", file_types=["image"]),
# gr.Files(),
# cache_examples=True,
# )
# demo.launch(enable_queue=True)
def face_detector(input_img):
img_h, img_w, _ = np.shape(input_img)
blob = cv2.dnn.blobFromImage(cv2.resize(input_img, (300, 300)), 1.0,
(300, 300), (104.0, 177.0, 123.0))
net.setInput(blob)
detected = net.forward() # For example, (1, 1, 200, 7)
cropped_list = draw_results_ssd(detected, input_img, 0.1 ) # 128
return cropped_list
def predict(image):
print(type(image))
cropped_list = face_detector(image)
print(cropped_list)
output = [None for _ in range(20)]
for i in range(len(cropped_list)):
predictions = pipe(Image.fromarray(cropped_list[i]))
output[int(i*2)] = cropped_list[i]
output[int(i*2+1)] = {p["label"]: p["score"] for p in predictions}
return output
gr.Interface(
predict,
inputs=gr.inputs.Image(label="Upload image"),
outputs=list(itertools.chain(*[[gr.outputs.Image(label="image {}".format(ind)), gr.outputs.Label(num_top_classes=7, label="Result {}".format(ind))] for ind in range(10)])),
examples=[["examples/9_peoples.jpg"], ["examples/sad.jpg"], ["examples/angry.jpg"], ["examples/surprise.jpg"]],
title="FER trained on DiffusionFER dataset",
).launch()