hysts HF staff commited on
Commit
5664eba
1 Parent(s): df372a2
Files changed (5) hide show
  1. .pre-commit-config.yaml +35 -0
  2. .style.yapf +5 -0
  3. README.md +1 -1
  4. app.py +23 -56
  5. sample_images/README.md +0 -1
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.12.0
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.991
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🦀
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -1,6 +1,5 @@
1
  #!/usr/bin/env python
2
 
3
- import argparse
4
  import functools
5
  import os
6
  import pathlib
@@ -15,29 +14,14 @@ import torch
15
  import torch.nn as nn
16
  import torch.nn.functional as F
17
 
18
- TITLE = 'yu4u/age-estimation-pytorch'
19
  DESCRIPTION = 'This is an unofficial demo for https://github.com/yu4u/age-estimation-pytorch.'
20
- ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.age-estimation-appa-real" alt="visitor badge"/></center>'
21
 
22
- TOKEN = os.environ['TOKEN']
23
  MODEL_REPO = 'hysts/yu4u-age-estimation-pytorch'
24
  MODEL_FILENAME = 'pretrained.pth'
25
 
26
 
27
- def parse_args() -> argparse.Namespace:
28
- parser = argparse.ArgumentParser()
29
- parser.add_argument('--device', type=str, default='cpu')
30
- parser.add_argument('--theme', type=str)
31
- parser.add_argument('--live', action='store_true')
32
- parser.add_argument('--share', action='store_true')
33
- parser.add_argument('--port', type=int)
34
- parser.add_argument('--disable-queue',
35
- dest='enable_queue',
36
- action='store_false')
37
- parser.add_argument('--allow-flagging', type=str, default='never')
38
- return parser.parse_args()
39
-
40
-
41
  def get_model(model_name='se_resnext50_32x4d',
42
  num_classes=101,
43
  pretrained='imagenet'):
@@ -52,7 +36,7 @@ def load_model(device):
52
  model = get_model(model_name='se_resnext50_32x4d', pretrained=None)
53
  path = huggingface_hub.hf_hub_download(MODEL_REPO,
54
  MODEL_FILENAME,
55
- use_auth_token=TOKEN)
56
  model.load_state_dict(torch.load(path))
57
  model = model.to(device)
58
  model.eval()
@@ -90,7 +74,7 @@ def draw_label(image,
90
 
91
  @torch.inference_mode()
92
  def predict(image, model, face_detector, device, margin=0.4, input_size=224):
93
- image = cv2.imread(image.name, cv2.IMREAD_COLOR)[:, :, ::-1].copy()
94
  image_h, image_w = image.shape[:2]
95
 
96
  # detect faces using dlib detector
@@ -124,39 +108,22 @@ def predict(image, model, face_detector, device, margin=0.4, input_size=224):
124
  return image
125
 
126
 
127
- def main():
128
- args = parse_args()
129
- device = torch.device(args.device)
130
-
131
- model = load_model(device)
132
- face_detector = dlib.get_frontal_face_detector()
133
-
134
- func = functools.partial(predict,
135
- model=model,
136
- face_detector=face_detector,
137
- device=device)
138
- func = functools.update_wrapper(func, predict)
139
-
140
- image_dir = pathlib.Path('sample_images')
141
- examples = [path.as_posix() for path in sorted(image_dir.glob('*.jpg'))]
142
-
143
- gr.Interface(
144
- func,
145
- gr.inputs.Image(type='file', label='Input'),
146
- gr.outputs.Image(label='Output'),
147
- examples=examples,
148
- title=TITLE,
149
- description=DESCRIPTION,
150
- article=ARTICLE,
151
- theme=args.theme,
152
- allow_flagging=args.allow_flagging,
153
- live=args.live,
154
- ).launch(
155
- enable_queue=args.enable_queue,
156
- server_port=args.port,
157
- share=args.share,
158
- )
159
-
160
-
161
- if __name__ == '__main__':
162
- main()
 
1
  #!/usr/bin/env python
2
 
 
3
  import functools
4
  import os
5
  import pathlib
 
14
  import torch.nn as nn
15
  import torch.nn.functional as F
16
 
17
+ TITLE = 'Age Estimation'
18
  DESCRIPTION = 'This is an unofficial demo for https://github.com/yu4u/age-estimation-pytorch.'
 
19
 
20
+ HF_TOKEN = os.getenv('HF_TOKEN')
21
  MODEL_REPO = 'hysts/yu4u-age-estimation-pytorch'
22
  MODEL_FILENAME = 'pretrained.pth'
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def get_model(model_name='se_resnext50_32x4d',
26
  num_classes=101,
27
  pretrained='imagenet'):
 
36
  model = get_model(model_name='se_resnext50_32x4d', pretrained=None)
37
  path = huggingface_hub.hf_hub_download(MODEL_REPO,
38
  MODEL_FILENAME,
39
+ use_auth_token=HF_TOKEN)
40
  model.load_state_dict(torch.load(path))
41
  model = model.to(device)
42
  model.eval()
 
74
 
75
  @torch.inference_mode()
76
  def predict(image, model, face_detector, device, margin=0.4, input_size=224):
77
+ image = cv2.imread(image, cv2.IMREAD_COLOR)[:, :, ::-1].copy()
78
  image_h, image_w = image.shape[:2]
79
 
80
  # detect faces using dlib detector
 
108
  return image
109
 
110
 
111
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
112
+ model = load_model(device)
113
+ face_detector = dlib.get_frontal_face_detector()
114
+ func = functools.partial(predict,
115
+ model=model,
116
+ face_detector=face_detector,
117
+ device=device)
118
+
119
+ image_dir = pathlib.Path('sample_images')
120
+ examples = [path.as_posix() for path in sorted(image_dir.glob('*.jpg'))]
121
+
122
+ gr.Interface(
123
+ fn=func,
124
+ inputs=gr.Image(label='Input', type='filepath'),
125
+ outputs=gr.Image(label='Output'),
126
+ examples=examples,
127
+ title=TITLE,
128
+ description=DESCRIPTION,
129
+ ).launch(show_api=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sample_images/README.md CHANGED
@@ -6,4 +6,3 @@ These images are from the following public domain:
6
  - https://www.pexels.com/photo/man-wearing-white-dress-shirt-and-black-blazer-2182970/
7
  - https://www.pexels.com/photo/shallow-focus-photography-of-woman-in-white-shirt-and-blue-denim-shorts-on-street-near-green-trees-937416/
8
  - https://www.pexels.com/photo/woman-in-collared-shirt-774909/
9
-
 
6
  - https://www.pexels.com/photo/man-wearing-white-dress-shirt-and-black-blazer-2182970/
7
  - https://www.pexels.com/photo/shallow-focus-photography-of-woman-in-white-shirt-and-blue-denim-shorts-on-street-near-green-trees-937416/
8
  - https://www.pexels.com/photo/woman-in-collared-shirt-774909/