Spaces:
Build error
Build error
Add apple
Browse files- app.py +36 -0
- face_module.py +24 -0
- meter.py +109 -0
- model.py +25 -0
- predict_image.py +103 -0
- requirements.txt +10 -0
- utils.py +31 -0
app.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import gradio as gr
|
3 |
+
from predict_image import load_model, predict
|
4 |
+
|
5 |
+
def predict_fn(image, model_name):
|
6 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
7 |
+
cv2.imwrite('./myimage.jpg', image)
|
8 |
+
# model for emotion classification
|
9 |
+
if model_name == 'EfficientNetB0':
|
10 |
+
model_name = 'effb0'
|
11 |
+
elif model_name == 'ResNet18':
|
12 |
+
model_name = 'res18'
|
13 |
+
else:
|
14 |
+
raise ValueError('Enter correct model_name')
|
15 |
+
|
16 |
+
model = load_model(model_name)
|
17 |
+
|
18 |
+
out = predict('./myimage.jpg', './result.jpg', model)
|
19 |
+
out = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
20 |
+
return out
|
21 |
+
|
22 |
+
|
23 |
+
demo = gr.Interface(
|
24 |
+
fn=predict_fn,
|
25 |
+
inputs=[
|
26 |
+
gr.inputs.Image(label="Input Image"),
|
27 |
+
gr.Radio(['EfficientNetB0', 'ResNet18'], value='EfficientNetB0', label='Model Name')
|
28 |
+
],
|
29 |
+
outputs=[
|
30 |
+
gr.inputs.Image(label="Prediction"),
|
31 |
+
],
|
32 |
+
title="Emotion Recognition Demo",
|
33 |
+
description="Emotion Classification Model trained on FER Dataset"
|
34 |
+
)
|
35 |
+
|
36 |
+
demo.launch(debug=True)
|
face_module.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mediapipe as mp
|
2 |
+
mp_face_detection = mp.solutions.face_detection
|
3 |
+
|
4 |
+
def get_face_coords(image):
|
5 |
+
with mp_face_detection.FaceDetection(
|
6 |
+
model_selection=1, min_detection_confidence=0.5) as face_detection:
|
7 |
+
#image = cv2.imread(file)
|
8 |
+
# Convert the BGR image to RGB and process it with MediaPipe Face Detection.
|
9 |
+
results = face_detection.process(image)
|
10 |
+
# Draw face detections of each face.
|
11 |
+
if not results.detections:
|
12 |
+
return False
|
13 |
+
|
14 |
+
# shape of image
|
15 |
+
h, w, _ = image.shape
|
16 |
+
|
17 |
+
t = results.detections[0].location_data.relative_bounding_box
|
18 |
+
height = t.height * h
|
19 |
+
ymin = t.ymin * h
|
20 |
+
width = t.width * w
|
21 |
+
xmin = t.xmin * w
|
22 |
+
xmax = xmin + width
|
23 |
+
ymax = ymin + height
|
24 |
+
return int(xmin), int(ymin), int(xmax), int(ymax)
|
meter.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
# storing settings for semicircle
|
5 |
+
class SemiCircle:
|
6 |
+
def __init__(
|
7 |
+
self, thickness=10, color=(255, 0, 0), radius=100,
|
8 |
+
center=(250, 250), angle=0, start_angle=180, end_angle=360
|
9 |
+
):
|
10 |
+
self.thickness = thickness
|
11 |
+
self.color = color
|
12 |
+
self.radius = (radius, radius)
|
13 |
+
self.center = center
|
14 |
+
self.angle = angle
|
15 |
+
self.start_angle = start_angle
|
16 |
+
self.end_angle = end_angle
|
17 |
+
|
18 |
+
class Line:
|
19 |
+
def __init__(self, thickness=2, color=(0, 0, 0)):
|
20 |
+
self.thickness = thickness
|
21 |
+
self.color = color
|
22 |
+
|
23 |
+
def generate_points(radius, length, center, num_points):
|
24 |
+
# center points
|
25 |
+
cx, cy = center
|
26 |
+
|
27 |
+
# generating points on circle
|
28 |
+
outer_circle_points = [(radius * np.cos(i), radius * np.sin(i)) for i in np.linspace(np.pi, 2*np.pi, num_points)]
|
29 |
+
|
30 |
+
inner_radius = radius - length
|
31 |
+
inner_circle_points = [(inner_radius * np.cos(i), inner_radius * np.sin(i)) for i in np.linspace(np.pi, 2*np.pi, num_points)]
|
32 |
+
|
33 |
+
|
34 |
+
# genrating point for drawing line using cv2, start_points and end_points
|
35 |
+
start_points = [(int(cx + i), int(cy + j)) for i, j in outer_circle_points]
|
36 |
+
end_points = [(int(cx + i), int(cy + j)) for i, j in inner_circle_points]
|
37 |
+
return zip(start_points, end_points)
|
38 |
+
|
39 |
+
class Meter:
|
40 |
+
def __init__(self, center, radius, circle_color):
|
41 |
+
self.center = center
|
42 |
+
self.radius = radius
|
43 |
+
self.circle_color = circle_color
|
44 |
+
|
45 |
+
def draw_meter(self, image, idx):
|
46 |
+
# drawing semicircle
|
47 |
+
circle = SemiCircle(center=self.center, radius=self.radius, color=self.circle_color)
|
48 |
+
|
49 |
+
cv2.ellipse(
|
50 |
+
image, circle.center, circle.radius,
|
51 |
+
circle.angle, circle.start_angle, circle.end_angle,
|
52 |
+
circle.color, circle.thickness
|
53 |
+
)
|
54 |
+
|
55 |
+
# drawing smaller fine lines
|
56 |
+
line = Line(thickness=circle.thickness//3)
|
57 |
+
|
58 |
+
for start, end in generate_points(self.radius - 10, self.radius * 0.05, self.center, 50):
|
59 |
+
cv2.line(image, start, end, line.color, line.thickness)
|
60 |
+
|
61 |
+
# drawing bigger lines
|
62 |
+
line2 = Line(thickness=circle.thickness//2, color=(238, 222, 23))
|
63 |
+
|
64 |
+
for start, end in generate_points(self.radius - 10, self.radius * 0.15, self.center, 10):
|
65 |
+
cv2.line(image, start, end, line2.color, line2.thickness)
|
66 |
+
|
67 |
+
# drawing needle anchor point
|
68 |
+
cv2.circle(image, circle.center, 15, (0, 0, 0), -1)
|
69 |
+
|
70 |
+
# emotion classes
|
71 |
+
emotions = ['neutral', 'happy', 'surprise', 'sad', 'angry',
|
72 |
+
"disgust", 'fear', 'contempt', 'unknown', 'NotFace']
|
73 |
+
|
74 |
+
# points where text will be written
|
75 |
+
pts = generate_points(self.radius*1.55, self.radius * 0.15, self.center, 10)
|
76 |
+
pts = list(pts)
|
77 |
+
|
78 |
+
for i, emot in enumerate(emotions):
|
79 |
+
x, y = pts[i][0]
|
80 |
+
x -= 20
|
81 |
+
|
82 |
+
color = (98, 65, 255) if (i == idx) else (144,238,144)
|
83 |
+
|
84 |
+
cv2.putText(
|
85 |
+
image, emot, (x, y),
|
86 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2,
|
87 |
+
cv2.LINE_AA
|
88 |
+
)
|
89 |
+
|
90 |
+
# needle 12, 178, 33
|
91 |
+
line3 = Line(thickness=circle.thickness//2, color=(0, 0, 255))
|
92 |
+
pts2 = generate_points(self.radius*0.7, self.radius*0.7, self.center, 10)
|
93 |
+
pts2 = list(pts2)
|
94 |
+
start, end = pts2[idx]
|
95 |
+
|
96 |
+
cv2.line(image, start, end, line3.color, line3.thickness)
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
image = np.ones((500, 500, 3))
|
100 |
+
|
101 |
+
meter = Meter((250, 250), 150, (80, 127, 255))
|
102 |
+
|
103 |
+
meter.draw_meter(image, 4)
|
104 |
+
|
105 |
+
cv2.imshow('image', image)
|
106 |
+
|
107 |
+
if cv2.waitKey(0) & 0xFF == 27:
|
108 |
+
pass
|
109 |
+
cv2.destroyAllWindows()
|
model.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import timm
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision import models
|
4 |
+
|
5 |
+
class ResnetModel(nn.Module):
|
6 |
+
def __init__(self, num_classes=10):
|
7 |
+
super().__init__()
|
8 |
+
model = models.resnet18()
|
9 |
+
model.fc = nn.Linear(512, 10)
|
10 |
+
self.model = model
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
out = self.model(x)
|
14 |
+
return out
|
15 |
+
|
16 |
+
class EffnetModel(nn.Module):
|
17 |
+
def __init__(self, num_classes=10) -> None:
|
18 |
+
super().__init__()
|
19 |
+
model = timm.create_model('efficientnet_b0', num_classes=10)
|
20 |
+
self.model = model
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
out = self.model(x)
|
24 |
+
return out
|
25 |
+
|
predict_image.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms
|
4 |
+
from model import ResnetModel, EffnetModel
|
5 |
+
from face_module import get_face_coords
|
6 |
+
from meter import Meter
|
7 |
+
from utils import download_weights
|
8 |
+
|
9 |
+
# statistics of imagenet dataset
|
10 |
+
mean = [0.485, 0.456, 0.406]
|
11 |
+
std = [0.229, 0.224, 0.225]
|
12 |
+
|
13 |
+
# model wieghts url
|
14 |
+
effb0_net_url = 'https://github.com/yMayanand/Emotion-Recognition/releases/download/v1.0.0/eff_b0.pt'
|
15 |
+
res18_net_url = 'https://github.com/yMayanand/Emotion-Recognition/releases/download/v1.0.0/res18.pt'
|
16 |
+
|
17 |
+
# transforms for image
|
18 |
+
val_transform = transforms.Compose([
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Resize((48, 48)),
|
21 |
+
transforms.Normalize(mean, std)
|
22 |
+
])
|
23 |
+
|
24 |
+
def load_model(model_name):
|
25 |
+
|
26 |
+
# model for emotion classification
|
27 |
+
if model_name == 'effb0':
|
28 |
+
model = EffnetModel()
|
29 |
+
fname = download_weights(effb0_net_url)
|
30 |
+
elif model_name == 'res18':
|
31 |
+
model = ResnetModel
|
32 |
+
fname = download_weights(res18_net_url)
|
33 |
+
else:
|
34 |
+
raise ValueError('Enter correct model_name')
|
35 |
+
|
36 |
+
# loading pretrained model
|
37 |
+
state_dict = torch.load(fname)
|
38 |
+
model.load_state_dict(state_dict['weights'])
|
39 |
+
return model
|
40 |
+
|
41 |
+
# emotion classes
|
42 |
+
emotions = ['neutral', 'happy :-)', 'surprise :-O', 'sad', 'angry >:(',
|
43 |
+
"disgust D-':", 'fear', 'contempt', 'unknown', 'NF']
|
44 |
+
|
45 |
+
# colors for text for each emotion classes
|
46 |
+
colors = [(0, 128, 255), (255, 0, 255), (0, 255, 255), (255, 191, 0), (0, 0, 255),
|
47 |
+
(255, 255, 0), (0, 191, 255), (255, 0, 191), (255, 0, 191), (255, 0, 191)]
|
48 |
+
|
49 |
+
def predict(image, save_path, model):
|
50 |
+
image = cv2.imread(image)
|
51 |
+
h, w, c = image.shape
|
52 |
+
|
53 |
+
# meter
|
54 |
+
m = Meter((w//2, h), w//5, (255, 0, 0))
|
55 |
+
|
56 |
+
# storing orignal image in bgr mode
|
57 |
+
orig_image = image
|
58 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
59 |
+
coords = get_face_coords(image)
|
60 |
+
if coords:
|
61 |
+
# getting bounding box coordinates for face
|
62 |
+
xmin, ymin, xmax, ymax = coords
|
63 |
+
model.eval()
|
64 |
+
|
65 |
+
image = image[ymin:ymax, xmin:xmax, :]
|
66 |
+
|
67 |
+
# check if face detected is not on edge of the screen
|
68 |
+
h, w, c = image.shape
|
69 |
+
if not (h and w):
|
70 |
+
idx = 9
|
71 |
+
|
72 |
+
image_tensor = val_transform(image).unsqueeze(0)
|
73 |
+
out = model(image_tensor)
|
74 |
+
|
75 |
+
# prediction emotion for detected face
|
76 |
+
pred = torch.argmax(out, dim=1)
|
77 |
+
idx = pred.item()
|
78 |
+
pred_emot = emotions[pred.item()]
|
79 |
+
color = colors[idx]
|
80 |
+
|
81 |
+
# drawing annotations on orignal bgr image
|
82 |
+
orig_image = cv2.rectangle(orig_image, (xmin, ymin), (xmax, ymax), (255, 0, 0), 1)
|
83 |
+
else:
|
84 |
+
idx = 9
|
85 |
+
pred_emot = 'Face Not Detected'
|
86 |
+
color = colors[-1]
|
87 |
+
|
88 |
+
orig_image = cv2.flip(orig_image, 1)
|
89 |
+
m.draw_meter(orig_image, idx)
|
90 |
+
|
91 |
+
cv2.imwrite(save_path, orig_image)
|
92 |
+
return orig_image
|
93 |
+
|
94 |
+
if __name__ == '__main__':
|
95 |
+
import argparse
|
96 |
+
parser = argparse.ArgumentParser()
|
97 |
+
parser.add_argument('--image_path', type=str, help='path to image location')
|
98 |
+
parser.add_argument('--model_name', type=str, default='effb0', help='name of the model')
|
99 |
+
parser.add_argument('--save_path', type=str, default='./result.jpg', help='path to save image')
|
100 |
+
args = parser.parse_args()
|
101 |
+
model = load_model(args.model_name)
|
102 |
+
|
103 |
+
predict(args.image_path, args.save_path, model)
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
###### Requirements without Version Specifiers ######
|
2 |
+
|
3 |
+
gradio
|
4 |
+
opencv-python
|
5 |
+
mediapipe
|
6 |
+
numpy
|
7 |
+
torch
|
8 |
+
torchvision
|
9 |
+
torchaudio
|
10 |
+
timm
|
utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
from urllib.request import urlretrieve
|
4 |
+
|
5 |
+
def read_image(file):
|
6 |
+
"""Reads the image file
|
7 |
+
|
8 |
+
Returns the numpy array.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
file : path to the image
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
(numpy.ndarray): image read as numpy array
|
15 |
+
"""
|
16 |
+
image = cv2.imread(file)
|
17 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
18 |
+
return image
|
19 |
+
|
20 |
+
def accuracy(predictions, ground_truth):
|
21 |
+
"""Funtion to calculate accuracy of the model.
|
22 |
+
"""
|
23 |
+
|
24 |
+
_, preds = torch.max(predictions, dim=1)
|
25 |
+
score = (preds == ground_truth).float().mean()
|
26 |
+
return score.item()
|
27 |
+
|
28 |
+
def download_weights(url):
|
29 |
+
fname = url.split('/')[-1]
|
30 |
+
urlretrieve(url, fname)
|
31 |
+
return fname
|