hysts HF staff commited on
Commit
fb7a74a
1 Parent(s): 77f867c
Files changed (4) hide show
  1. .pre-commit-config.yaml +35 -0
  2. .style.yapf +5 -0
  3. README.md +1 -29
  4. app.py +33 -64
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.12.0
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.991
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
README.md CHANGED
@@ -4,35 +4,7 @@ emoji: 🐢
4
  colorFrom: yellow
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  ---
11
-
12
- # Configuration
13
-
14
- `title`: _string_
15
- Display title for the Space
16
-
17
- `emoji`: _string_
18
- Space emoji (emoji-only character allowed)
19
-
20
- `colorFrom`: _string_
21
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
-
23
- `colorTo`: _string_
24
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
-
26
- `sdk`: _string_
27
- Can be either `gradio`, `streamlit`, or `static`
28
-
29
- `sdk_version` : _string_
30
- Only applicable for `streamlit` SDK.
31
- See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
-
33
- `app_file`: _string_
34
- Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
35
- Path is relative to the root of the repository.
36
-
37
- `pinned`: _boolean_
38
- Whether the Space stays on top of your list.
 
4
  colorFrom: yellow
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.19.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -2,13 +2,12 @@
2
 
3
  from __future__ import annotations
4
 
5
- import argparse
6
  import functools
7
  import os
8
  import pathlib
9
  import sys
10
  import tarfile
11
- import urllib
12
  from typing import Callable
13
 
14
  import cv2
@@ -25,9 +24,8 @@ from CFA import CFA
25
 
26
  TITLE = 'kanosawa/anime_face_landmark_detection'
27
  DESCRIPTION = 'This is an unofficial demo for https://github.com/kanosawa/anime_face_landmark_detection.'
28
- ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.anime_face_landmark_detection" alt="visitor badge"/></center>'
29
 
30
- TOKEN = os.environ['TOKEN']
31
  MODEL_REPO = 'hysts/anime_face_landmark_detection'
32
  MODEL_FILENAME = 'checkpoint_landmark_191116.pth'
33
 
@@ -35,20 +33,6 @@ NUM_LANDMARK = 24
35
  CROP_SIZE = 128
36
 
37
 
38
- def parse_args() -> argparse.Namespace:
39
- parser = argparse.ArgumentParser()
40
- parser.add_argument('--device', type=str, default='cpu')
41
- parser.add_argument('--theme', type=str)
42
- parser.add_argument('--live', action='store_true')
43
- parser.add_argument('--share', action='store_true')
44
- parser.add_argument('--port', type=int)
45
- parser.add_argument('--disable-queue',
46
- dest='enable_queue',
47
- action='store_false')
48
- parser.add_argument('--allow-flagging', type=str, default='never')
49
- return parser.parse_args()
50
-
51
-
52
  def load_sample_image_paths() -> list[pathlib.Path]:
53
  image_dir = pathlib.Path('images')
54
  if not image_dir.exists():
@@ -56,7 +40,7 @@ def load_sample_image_paths() -> list[pathlib.Path]:
56
  path = huggingface_hub.hf_hub_download(dataset_repo,
57
  'images.tar.gz',
58
  repo_type='dataset',
59
- use_auth_token=TOKEN)
60
  with tarfile.open(path) as f:
61
  f.extractall()
62
  return sorted(image_dir.glob('*'))
@@ -73,7 +57,7 @@ def load_face_detector() -> cv2.CascadeClassifier:
73
  def load_landmark_detector(device: torch.device) -> torch.nn.Module:
74
  path = huggingface_hub.hf_hub_download(MODEL_REPO,
75
  MODEL_FILENAME,
76
- use_auth_token=TOKEN)
77
  model = CFA(output_channel_num=NUM_LANDMARK + 1, checkpoint_name=path)
78
  model.to(device)
79
  model.eval()
@@ -81,10 +65,10 @@ def load_landmark_detector(device: torch.device) -> torch.nn.Module:
81
 
82
 
83
  @torch.inference_mode()
84
- def detect(image, face_detector: cv2.CascadeClassifier, device: torch.device,
85
- transform: Callable,
86
  landmark_detector: torch.nn.Module) -> np.ndarray:
87
- image = cv2.imread(image.name)
88
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
89
  preds = face_detector.detectMultiScale(gray,
90
  scaleFactor=1.1,
@@ -126,44 +110,29 @@ def detect(image, face_detector: cv2.CascadeClassifier, device: torch.device,
126
  return res[:, :, ::-1]
127
 
128
 
129
- def main():
130
- args = parse_args()
131
- device = torch.device(args.device)
132
-
133
- image_paths = load_sample_image_paths()
134
- examples = [[path.as_posix()] for path in image_paths]
135
-
136
- face_detector = load_face_detector()
137
- landmark_detector = load_landmark_detector(device)
138
- transform = T.Compose([
139
- T.ToTensor(),
140
- T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
141
- ])
142
-
143
- func = functools.partial(detect,
144
- face_detector=face_detector,
145
- device=device,
146
- transform=transform,
147
- landmark_detector=landmark_detector)
148
- func = functools.update_wrapper(func, detect)
149
-
150
- gr.Interface(
151
- func,
152
- gr.inputs.Image(type='file', label='Input'),
153
- gr.outputs.Image(type='numpy', label='Output'),
154
- examples=examples,
155
- title=TITLE,
156
- description=DESCRIPTION,
157
- article=ARTICLE,
158
- theme=args.theme,
159
- allow_flagging=args.allow_flagging,
160
- live=args.live,
161
- ).launch(
162
- enable_queue=args.enable_queue,
163
- server_port=args.port,
164
- share=args.share,
165
- )
166
-
167
-
168
- if __name__ == '__main__':
169
- main()
 
2
 
3
  from __future__ import annotations
4
 
 
5
  import functools
6
  import os
7
  import pathlib
8
  import sys
9
  import tarfile
10
+ import urllib.request
11
  from typing import Callable
12
 
13
  import cv2
 
24
 
25
  TITLE = 'kanosawa/anime_face_landmark_detection'
26
  DESCRIPTION = 'This is an unofficial demo for https://github.com/kanosawa/anime_face_landmark_detection.'
 
27
 
28
+ HF_TOKEN = os.getenv('HF_TOKEN')
29
  MODEL_REPO = 'hysts/anime_face_landmark_detection'
30
  MODEL_FILENAME = 'checkpoint_landmark_191116.pth'
31
 
 
33
  CROP_SIZE = 128
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def load_sample_image_paths() -> list[pathlib.Path]:
37
  image_dir = pathlib.Path('images')
38
  if not image_dir.exists():
 
40
  path = huggingface_hub.hf_hub_download(dataset_repo,
41
  'images.tar.gz',
42
  repo_type='dataset',
43
+ use_auth_token=HF_TOKEN)
44
  with tarfile.open(path) as f:
45
  f.extractall()
46
  return sorted(image_dir.glob('*'))
 
57
  def load_landmark_detector(device: torch.device) -> torch.nn.Module:
58
  path = huggingface_hub.hf_hub_download(MODEL_REPO,
59
  MODEL_FILENAME,
60
+ use_auth_token=HF_TOKEN)
61
  model = CFA(output_channel_num=NUM_LANDMARK + 1, checkpoint_name=path)
62
  model.to(device)
63
  model.eval()
 
65
 
66
 
67
  @torch.inference_mode()
68
+ def detect(image_path: str, face_detector: cv2.CascadeClassifier,
69
+ device: torch.device, transform: Callable,
70
  landmark_detector: torch.nn.Module) -> np.ndarray:
71
+ image = cv2.imread(image_path)
72
  gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
73
  preds = face_detector.detectMultiScale(gray,
74
  scaleFactor=1.1,
 
110
  return res[:, :, ::-1]
111
 
112
 
113
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
114
+
115
+ image_paths = load_sample_image_paths()
116
+ examples = [[path.as_posix()] for path in image_paths]
117
+
118
+ face_detector = load_face_detector()
119
+ landmark_detector = load_landmark_detector(device)
120
+ transform = T.Compose([
121
+ T.ToTensor(),
122
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
123
+ ])
124
+
125
+ func = functools.partial(detect,
126
+ face_detector=face_detector,
127
+ device=device,
128
+ transform=transform,
129
+ landmark_detector=landmark_detector)
130
+
131
+ gr.Interface(
132
+ fn=func,
133
+ inputs=gr.Image(label='Input', type='filepath'),
134
+ outputs=gr.Image(label='Output', type='numpy'),
135
+ examples=examples,
136
+ title=TITLE,
137
+ description=DESCRIPTION,
138
+ ).queue().launch(show_api=False)