hysts HF staff commited on
Commit
2071150
·
1 Parent(s): bdf18e3
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import argparse
4
+ import functools
5
+ import os
6
+ import pathlib
7
+
8
+ import cv2
9
+ import dlib
10
+ import gradio as gr
11
+ import huggingface_hub
12
+ import numpy as np
13
+ import pretrainedmodels
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ TOKEN = os.environ['TOKEN']
19
+
20
+ MODEL_REPO = 'hysts/yu4u-age-estimation-pytorch'
21
+ MODEL_FILENAME = 'pretrained.pth'
22
+
23
+
24
+ def parse_args() -> argparse.Namespace:
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument('--device', type=str, default='cpu')
27
+ parser.add_argument('--theme', type=str)
28
+ parser.add_argument('--live', action='store_true')
29
+ parser.add_argument('--share', action='store_true')
30
+ parser.add_argument('--port', type=int)
31
+ parser.add_argument('--disable-queue',
32
+ dest='enable_queue',
33
+ action='store_false')
34
+ parser.add_argument('--allow-flagging', type=str, default='never')
35
+ parser.add_argument('--allow-screenshot', action='store_true')
36
+ return parser.parse_args()
37
+
38
+
39
+ def get_model(model_name='se_resnext50_32x4d',
40
+ num_classes=101,
41
+ pretrained='imagenet'):
42
+ model = pretrainedmodels.__dict__[model_name](pretrained=pretrained)
43
+ dim_feats = model.last_linear.in_features
44
+ model.last_linear = nn.Linear(dim_feats, num_classes)
45
+ model.avg_pool = nn.AdaptiveAvgPool2d(1)
46
+ return model
47
+
48
+
49
+ def load_model(device):
50
+ model = get_model(model_name='se_resnext50_32x4d', pretrained=None)
51
+ path = huggingface_hub.hf_hub_download(MODEL_REPO,
52
+ MODEL_FILENAME,
53
+ use_auth_token=TOKEN)
54
+ model.load_state_dict(torch.load(path))
55
+ model = model.to(device)
56
+ model.eval()
57
+ return model
58
+
59
+
60
+ def load_image(path):
61
+ image = cv2.imread(path)
62
+ h_orig, w_orig = image.shape[:2]
63
+ size = max(h_orig, w_orig)
64
+ scale = 640 / size
65
+ w, h = int(w_orig * scale), int(h_orig * scale)
66
+ image = cv2.resize(image, (w, h))
67
+ return image
68
+
69
+
70
+ def draw_label(image,
71
+ point,
72
+ label,
73
+ font=cv2.FONT_HERSHEY_SIMPLEX,
74
+ font_scale=0.8,
75
+ thickness=1):
76
+ size = cv2.getTextSize(label, font, font_scale, thickness)[0]
77
+ x, y = point
78
+ cv2.rectangle(image, (x, y - size[1]), (x + size[0], y), (255, 0, 0),
79
+ cv2.FILLED)
80
+ cv2.putText(image,
81
+ label,
82
+ point,
83
+ font,
84
+ font_scale, (255, 255, 255),
85
+ thickness,
86
+ lineType=cv2.LINE_AA)
87
+
88
+
89
+ @torch.inference_mode()
90
+ def predict(image, model, face_detector, device, margin=0.4, input_size=224):
91
+ image = cv2.imread(image.name, cv2.IMREAD_COLOR)[:, :, ::-1].copy()
92
+ image_h, image_w = image.shape[:2]
93
+
94
+ # detect faces using dlib detector
95
+ detected = face_detector(image, 1)
96
+ faces = np.empty((len(detected), input_size, input_size, 3))
97
+
98
+ if len(detected) > 0:
99
+ for i, d in enumerate(detected):
100
+ x1, y1, x2, y2, w, h = d.left(), d.top(
101
+ ), d.right() + 1, d.bottom() + 1, d.width(), d.height()
102
+ xw1 = max(int(x1 - margin * w), 0)
103
+ yw1 = max(int(y1 - margin * h), 0)
104
+ xw2 = min(int(x2 + margin * w), image_w - 1)
105
+ yw2 = min(int(y2 + margin * h), image_h - 1)
106
+ faces[i] = cv2.resize(image[yw1:yw2 + 1, xw1:xw2 + 1],
107
+ (input_size, input_size))
108
+
109
+ cv2.rectangle(image, (x1, y1), (x2, y2), (255, 255, 255), 2)
110
+ cv2.rectangle(image, (xw1, yw1), (xw2, yw2), (255, 0, 0), 2)
111
+
112
+ # predict ages
113
+ inputs = torch.from_numpy(
114
+ np.transpose(faces.astype(np.float32), (0, 3, 1, 2))).to(device)
115
+ outputs = F.softmax(model(inputs), dim=-1).cpu().numpy()
116
+ ages = np.arange(0, 101)
117
+ predicted_ages = (outputs * ages).sum(axis=-1)
118
+
119
+ # draw results
120
+ for age, d in zip(predicted_ages, detected):
121
+ draw_label(image, (d.left(), d.top()), f'{int(age)}')
122
+ return image
123
+
124
+
125
+ def main():
126
+ gr.close_all()
127
+
128
+ args = parse_args()
129
+ device = torch.device(args.device)
130
+
131
+ model = load_model(device)
132
+ face_detector = dlib.get_frontal_face_detector()
133
+
134
+ func = functools.partial(predict,
135
+ model=model,
136
+ face_detector=face_detector,
137
+ device=device)
138
+ func = functools.update_wrapper(func, predict)
139
+
140
+ image_dir = pathlib.Path('sample_images')
141
+ examples = [path.as_posix() for path in sorted(image_dir.glob('*.jpg'))]
142
+
143
+ repo_url = 'https://github.com/yu4u/age-estimation-pytorch'
144
+ title = 'yu4u/age-estimation-pytorch'
145
+ description = f'A demo for {repo_url}'
146
+ article = None
147
+
148
+ gr.Interface(
149
+ func,
150
+ gr.inputs.Image(type='file', label='Input'),
151
+ gr.outputs.Image(label='Output'),
152
+ theme=args.theme,
153
+ title=title,
154
+ description=description,
155
+ article=article,
156
+ examples=examples,
157
+ allow_screenshot=args.allow_screenshot,
158
+ allow_flagging=args.allow_flagging,
159
+ live=args.live,
160
+ ).launch(
161
+ enable_queue=args.enable_queue,
162
+ server_port=args.port,
163
+ share=args.share,
164
+ )
165
+
166
+
167
+ if __name__ == '__main__':
168
+ main()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ cmake
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ dlib>=19.23
2
+ numpy>=1.22.2
3
+ opencv-python-headless>=4.5.5.62
4
+ pretrainedmodels>=0.7.4
5
+ torch>=1.10.2
6
+ torchvision>=0.11.3
sample_images/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ These images are from the following public domain:
2
+ - https://www.pexels.com/photo/2-women-sitting-on-rock-during-daytime-214576/
3
+ - https://www.pexels.com/photo/boy-in-yellow-crew-neck-t-shirt-and-gray-bottoms-929436/
4
+ - https://www.pexels.com/photo/group-of-people-standing-beside-body-of-water-2672979/
5
+ - https://www.pexels.com/photo/man-sitting-on-chair-beside-table-834863/
6
+ - https://www.pexels.com/photo/man-wearing-white-dress-shirt-and-black-blazer-2182970/
7
+ - https://www.pexels.com/photo/shallow-focus-photography-of-woman-in-white-shirt-and-blue-denim-shorts-on-street-near-green-trees-937416/
8
+ - https://www.pexels.com/photo/woman-in-collared-shirt-774909/
9
+
sample_images/pexels-alexey-makhinko-929436.jpg ADDED
sample_images/pexels-andrea-piacquadio-2672979.jpg ADDED
sample_images/pexels-andrea-piacquadio-774909.jpg ADDED
sample_images/pexels-andrea-piacquadio-834863.jpg ADDED
sample_images/pexels-linkedin-sales-navigator-2182970.jpg ADDED
sample_images/pexels-mentatdgt-937416.jpg ADDED
sample_images/pexels-sebastian-voortman-214576.jpg ADDED