thinkin-machine commited on
Commit
edc8afb
1 Parent(s): 146bc14
Files changed (7) hide show
  1. app.py +36 -0
  2. face_module.py +24 -0
  3. meter.py +109 -0
  4. model.py +25 -0
  5. predict_image.py +103 -0
  6. requirements.txt +10 -0
  7. utils.py +31 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ from predict_image import load_model, predict
4
+
5
+ def predict_fn(image, model_name):
6
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
7
+ cv2.imwrite('./myimage.jpg', image)
8
+ # model for emotion classification
9
+ if model_name == 'EfficientNetB0':
10
+ model_name = 'effb0'
11
+ elif model_name == 'ResNet18':
12
+ model_name = 'res18'
13
+ else:
14
+ raise ValueError('Enter correct model_name')
15
+
16
+ model = load_model(model_name)
17
+
18
+ out = predict('./myimage.jpg', './result.jpg', model)
19
+ out = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
20
+ return out
21
+
22
+
23
+ demo = gr.Interface(
24
+ fn=predict_fn,
25
+ inputs=[
26
+ gr.inputs.Image(label="Input Image"),
27
+ gr.Radio(['EfficientNetB0', 'ResNet18'], value='EfficientNetB0', label='Model Name')
28
+ ],
29
+ outputs=[
30
+ gr.inputs.Image(label="Prediction"),
31
+ ],
32
+ title="Emotion Recognition Demo",
33
+ description="Emotion Classification Model trained on FER Dataset"
34
+ )
35
+
36
+ demo.launch(debug=True)
face_module.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mediapipe as mp
2
+ mp_face_detection = mp.solutions.face_detection
3
+
4
+ def get_face_coords(image):
5
+ with mp_face_detection.FaceDetection(
6
+ model_selection=1, min_detection_confidence=0.5) as face_detection:
7
+ #image = cv2.imread(file)
8
+ # Convert the BGR image to RGB and process it with MediaPipe Face Detection.
9
+ results = face_detection.process(image)
10
+ # Draw face detections of each face.
11
+ if not results.detections:
12
+ return False
13
+
14
+ # shape of image
15
+ h, w, _ = image.shape
16
+
17
+ t = results.detections[0].location_data.relative_bounding_box
18
+ height = t.height * h
19
+ ymin = t.ymin * h
20
+ width = t.width * w
21
+ xmin = t.xmin * w
22
+ xmax = xmin + width
23
+ ymax = ymin + height
24
+ return int(xmin), int(ymin), int(xmax), int(ymax)
meter.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ # storing settings for semicircle
5
+ class SemiCircle:
6
+ def __init__(
7
+ self, thickness=10, color=(255, 0, 0), radius=100,
8
+ center=(250, 250), angle=0, start_angle=180, end_angle=360
9
+ ):
10
+ self.thickness = thickness
11
+ self.color = color
12
+ self.radius = (radius, radius)
13
+ self.center = center
14
+ self.angle = angle
15
+ self.start_angle = start_angle
16
+ self.end_angle = end_angle
17
+
18
+ class Line:
19
+ def __init__(self, thickness=2, color=(0, 0, 0)):
20
+ self.thickness = thickness
21
+ self.color = color
22
+
23
+ def generate_points(radius, length, center, num_points):
24
+ # center points
25
+ cx, cy = center
26
+
27
+ # generating points on circle
28
+ outer_circle_points = [(radius * np.cos(i), radius * np.sin(i)) for i in np.linspace(np.pi, 2*np.pi, num_points)]
29
+
30
+ inner_radius = radius - length
31
+ inner_circle_points = [(inner_radius * np.cos(i), inner_radius * np.sin(i)) for i in np.linspace(np.pi, 2*np.pi, num_points)]
32
+
33
+
34
+ # genrating point for drawing line using cv2, start_points and end_points
35
+ start_points = [(int(cx + i), int(cy + j)) for i, j in outer_circle_points]
36
+ end_points = [(int(cx + i), int(cy + j)) for i, j in inner_circle_points]
37
+ return zip(start_points, end_points)
38
+
39
+ class Meter:
40
+ def __init__(self, center, radius, circle_color):
41
+ self.center = center
42
+ self.radius = radius
43
+ self.circle_color = circle_color
44
+
45
+ def draw_meter(self, image, idx):
46
+ # drawing semicircle
47
+ circle = SemiCircle(center=self.center, radius=self.radius, color=self.circle_color)
48
+
49
+ cv2.ellipse(
50
+ image, circle.center, circle.radius,
51
+ circle.angle, circle.start_angle, circle.end_angle,
52
+ circle.color, circle.thickness
53
+ )
54
+
55
+ # drawing smaller fine lines
56
+ line = Line(thickness=circle.thickness//3)
57
+
58
+ for start, end in generate_points(self.radius - 10, self.radius * 0.05, self.center, 50):
59
+ cv2.line(image, start, end, line.color, line.thickness)
60
+
61
+ # drawing bigger lines
62
+ line2 = Line(thickness=circle.thickness//2, color=(238, 222, 23))
63
+
64
+ for start, end in generate_points(self.radius - 10, self.radius * 0.15, self.center, 10):
65
+ cv2.line(image, start, end, line2.color, line2.thickness)
66
+
67
+ # drawing needle anchor point
68
+ cv2.circle(image, circle.center, 15, (0, 0, 0), -1)
69
+
70
+ # emotion classes
71
+ emotions = ['neutral', 'happy', 'surprise', 'sad', 'angry',
72
+ "disgust", 'fear', 'contempt', 'unknown', 'NotFace']
73
+
74
+ # points where text will be written
75
+ pts = generate_points(self.radius*1.55, self.radius * 0.15, self.center, 10)
76
+ pts = list(pts)
77
+
78
+ for i, emot in enumerate(emotions):
79
+ x, y = pts[i][0]
80
+ x -= 20
81
+
82
+ color = (98, 65, 255) if (i == idx) else (144,238,144)
83
+
84
+ cv2.putText(
85
+ image, emot, (x, y),
86
+ cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2,
87
+ cv2.LINE_AA
88
+ )
89
+
90
+ # needle 12, 178, 33
91
+ line3 = Line(thickness=circle.thickness//2, color=(0, 0, 255))
92
+ pts2 = generate_points(self.radius*0.7, self.radius*0.7, self.center, 10)
93
+ pts2 = list(pts2)
94
+ start, end = pts2[idx]
95
+
96
+ cv2.line(image, start, end, line3.color, line3.thickness)
97
+
98
+ if __name__ == '__main__':
99
+ image = np.ones((500, 500, 3))
100
+
101
+ meter = Meter((250, 250), 150, (80, 127, 255))
102
+
103
+ meter.draw_meter(image, 4)
104
+
105
+ cv2.imshow('image', image)
106
+
107
+ if cv2.waitKey(0) & 0xFF == 27:
108
+ pass
109
+ cv2.destroyAllWindows()
model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+
5
+ class ResnetModel(nn.Module):
6
+ def __init__(self, num_classes=10):
7
+ super().__init__()
8
+ model = models.resnet18()
9
+ model.fc = nn.Linear(512, 10)
10
+ self.model = model
11
+
12
+ def forward(self, x):
13
+ out = self.model(x)
14
+ return out
15
+
16
+ class EffnetModel(nn.Module):
17
+ def __init__(self, num_classes=10) -> None:
18
+ super().__init__()
19
+ model = timm.create_model('efficientnet_b0', num_classes=10)
20
+ self.model = model
21
+
22
+ def forward(self, x):
23
+ out = self.model(x)
24
+ return out
25
+
predict_image.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from torchvision import transforms
4
+ from model import ResnetModel, EffnetModel
5
+ from face_module import get_face_coords
6
+ from meter import Meter
7
+ from utils import download_weights
8
+
9
+ # statistics of imagenet dataset
10
+ mean = [0.485, 0.456, 0.406]
11
+ std = [0.229, 0.224, 0.225]
12
+
13
+ # model wieghts url
14
+ effb0_net_url = 'https://github.com/yMayanand/Emotion-Recognition/releases/download/v1.0.0/eff_b0.pt'
15
+ res18_net_url = 'https://github.com/yMayanand/Emotion-Recognition/releases/download/v1.0.0/res18.pt'
16
+
17
+ # transforms for image
18
+ val_transform = transforms.Compose([
19
+ transforms.ToTensor(),
20
+ transforms.Resize((48, 48)),
21
+ transforms.Normalize(mean, std)
22
+ ])
23
+
24
+ def load_model(model_name):
25
+
26
+ # model for emotion classification
27
+ if model_name == 'effb0':
28
+ model = EffnetModel()
29
+ fname = download_weights(effb0_net_url)
30
+ elif model_name == 'res18':
31
+ model = ResnetModel
32
+ fname = download_weights(res18_net_url)
33
+ else:
34
+ raise ValueError('Enter correct model_name')
35
+
36
+ # loading pretrained model
37
+ state_dict = torch.load(fname)
38
+ model.load_state_dict(state_dict['weights'])
39
+ return model
40
+
41
+ # emotion classes
42
+ emotions = ['neutral', 'happy :-)', 'surprise :-O', 'sad', 'angry >:(',
43
+ "disgust D-':", 'fear', 'contempt', 'unknown', 'NF']
44
+
45
+ # colors for text for each emotion classes
46
+ colors = [(0, 128, 255), (255, 0, 255), (0, 255, 255), (255, 191, 0), (0, 0, 255),
47
+ (255, 255, 0), (0, 191, 255), (255, 0, 191), (255, 0, 191), (255, 0, 191)]
48
+
49
+ def predict(image, save_path, model):
50
+ image = cv2.imread(image)
51
+ h, w, c = image.shape
52
+
53
+ # meter
54
+ m = Meter((w//2, h), w//5, (255, 0, 0))
55
+
56
+ # storing orignal image in bgr mode
57
+ orig_image = image
58
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
59
+ coords = get_face_coords(image)
60
+ if coords:
61
+ # getting bounding box coordinates for face
62
+ xmin, ymin, xmax, ymax = coords
63
+ model.eval()
64
+
65
+ image = image[ymin:ymax, xmin:xmax, :]
66
+
67
+ # check if face detected is not on edge of the screen
68
+ h, w, c = image.shape
69
+ if not (h and w):
70
+ idx = 9
71
+
72
+ image_tensor = val_transform(image).unsqueeze(0)
73
+ out = model(image_tensor)
74
+
75
+ # prediction emotion for detected face
76
+ pred = torch.argmax(out, dim=1)
77
+ idx = pred.item()
78
+ pred_emot = emotions[pred.item()]
79
+ color = colors[idx]
80
+
81
+ # drawing annotations on orignal bgr image
82
+ orig_image = cv2.rectangle(orig_image, (xmin, ymin), (xmax, ymax), (255, 0, 0), 1)
83
+ else:
84
+ idx = 9
85
+ pred_emot = 'Face Not Detected'
86
+ color = colors[-1]
87
+
88
+ orig_image = cv2.flip(orig_image, 1)
89
+ m.draw_meter(orig_image, idx)
90
+
91
+ cv2.imwrite(save_path, orig_image)
92
+ return orig_image
93
+
94
+ if __name__ == '__main__':
95
+ import argparse
96
+ parser = argparse.ArgumentParser()
97
+ parser.add_argument('--image_path', type=str, help='path to image location')
98
+ parser.add_argument('--model_name', type=str, default='effb0', help='name of the model')
99
+ parser.add_argument('--save_path', type=str, default='./result.jpg', help='path to save image')
100
+ args = parser.parse_args()
101
+ model = load_model(args.model_name)
102
+
103
+ predict(args.image_path, args.save_path, model)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ###### Requirements without Version Specifiers ######
2
+
3
+ gradio
4
+ opencv-python
5
+ mediapipe
6
+ numpy
7
+ torch
8
+ torchvision
9
+ torchaudio
10
+ timm
utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from urllib.request import urlretrieve
4
+
5
+ def read_image(file):
6
+ """Reads the image file
7
+
8
+ Returns the numpy array.
9
+
10
+ Args:
11
+ file : path to the image
12
+
13
+ Returns:
14
+ (numpy.ndarray): image read as numpy array
15
+ """
16
+ image = cv2.imread(file)
17
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
18
+ return image
19
+
20
+ def accuracy(predictions, ground_truth):
21
+ """Funtion to calculate accuracy of the model.
22
+ """
23
+
24
+ _, preds = torch.max(predictions, dim=1)
25
+ score = (preds == ground_truth).float().mean()
26
+ return score.item()
27
+
28
+ def download_weights(url):
29
+ fname = url.split('/')[-1]
30
+ urlretrieve(url, fname)
31
+ return fname