edisonlee55 commited on
Commit
906e212
1 Parent(s): 8886be5
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ .vscode/
anime_face_detector/__init__.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+
3
+ import torch
4
+
5
+ from .detector import LandmarkDetector
6
+
7
+
8
+ def get_config_path(model_name: str) -> pathlib.Path:
9
+ assert model_name in ['faster-rcnn', 'yolov3', 'hrnetv2']
10
+
11
+ package_path = pathlib.Path(__file__).parent.resolve()
12
+ if model_name in ['faster-rcnn', 'yolov3']:
13
+ config_dir = package_path / 'configs' / 'mmdet'
14
+ else:
15
+ config_dir = package_path / 'configs' / 'mmpose'
16
+ return config_dir / f'{model_name}.py'
17
+
18
+
19
+ def get_checkpoint_path(model_name: str) -> pathlib.Path:
20
+ assert model_name in ['faster-rcnn', 'yolov3', 'hrnetv2']
21
+ if model_name in ['faster-rcnn', 'yolov3']:
22
+ file_name = f'mmdet_anime-face_{model_name}.pth'
23
+ else:
24
+ file_name = f'mmpose_anime-face_{model_name}.pth'
25
+
26
+ model_dir = pathlib.Path(torch.hub.get_dir()) / 'checkpoints'
27
+ model_dir.mkdir(exist_ok=True, parents=True)
28
+ model_path = model_dir / file_name
29
+ if not model_path.exists():
30
+ url = f'https://github.com/hysts/anime-face-detector/releases/download/v0.0.1/{file_name}'
31
+ torch.hub.download_url_to_file(url, model_path.as_posix())
32
+
33
+ return model_path
34
+
35
+
36
+ def create_detector(face_detector_name: str = 'yolov3',
37
+ landmark_model_name='hrnetv2',
38
+ device: str = 'cuda:0',
39
+ flip_test: bool = True,
40
+ box_scale_factor: float = 1.1) -> LandmarkDetector:
41
+ assert face_detector_name in ['yolov3', 'faster-rcnn']
42
+ assert landmark_model_name in ['hrnetv2']
43
+ detector_config_path = get_config_path(face_detector_name)
44
+ landmark_config_path = get_config_path(landmark_model_name)
45
+ detector_checkpoint_path = get_checkpoint_path(face_detector_name)
46
+ landmark_checkpoint_path = get_checkpoint_path(landmark_model_name)
47
+ model = LandmarkDetector(landmark_config_path,
48
+ landmark_checkpoint_path,
49
+ detector_config_path,
50
+ detector_checkpoint_path,
51
+ device=device,
52
+ flip_test=flip_test,
53
+ box_scale_factor=box_scale_factor)
54
+ return model
anime_face_detector/configs/mmdet/faster-rcnn.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model = dict(type='FasterRCNN',
2
+ backbone=dict(type='ResNet',
3
+ depth=50,
4
+ num_stages=4,
5
+ out_indices=(0, 1, 2, 3),
6
+ frozen_stages=1,
7
+ norm_cfg=dict(type='BN', requires_grad=True),
8
+ norm_eval=True,
9
+ style='pytorch'),
10
+ neck=dict(type='FPN',
11
+ in_channels=[256, 512, 1024, 2048],
12
+ out_channels=256,
13
+ num_outs=5),
14
+ rpn_head=dict(type='RPNHead',
15
+ in_channels=256,
16
+ feat_channels=256,
17
+ anchor_generator=dict(type='AnchorGenerator',
18
+ scales=[8],
19
+ ratios=[0.5, 1.0, 2.0],
20
+ strides=[4, 8, 16, 32, 64]),
21
+ bbox_coder=dict(type='DeltaXYWHBBoxCoder',
22
+ target_means=[0.0, 0.0, 0.0, 0.0],
23
+ target_stds=[1.0, 1.0, 1.0, 1.0])),
24
+ roi_head=dict(
25
+ type='StandardRoIHead',
26
+ bbox_roi_extractor=dict(type='SingleRoIExtractor',
27
+ roi_layer=dict(type='RoIAlign',
28
+ output_size=7,
29
+ sampling_ratio=0),
30
+ out_channels=256,
31
+ featmap_strides=[4, 8, 16, 32]),
32
+ bbox_head=dict(type='Shared2FCBBoxHead',
33
+ in_channels=256,
34
+ fc_out_channels=1024,
35
+ roi_feat_size=7,
36
+ num_classes=1,
37
+ bbox_coder=dict(
38
+ type='DeltaXYWHBBoxCoder',
39
+ target_means=[0.0, 0.0, 0.0, 0.0],
40
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
41
+ reg_class_agnostic=False)),
42
+ test_cfg=dict(rpn=dict(nms_pre=1000,
43
+ max_per_img=1000,
44
+ nms=dict(type='nms', iou_threshold=0.7),
45
+ min_bbox_size=0),
46
+ rcnn=dict(score_thr=0.05,
47
+ nms=dict(type='nms', iou_threshold=0.5),
48
+ max_per_img=100)))
49
+ test_pipeline = [
50
+ dict(type='LoadImageFromFile'),
51
+ dict(type='MultiScaleFlipAug',
52
+ img_scale=(1333, 800),
53
+ flip=False,
54
+ transforms=[
55
+ dict(type='Resize', keep_ratio=True),
56
+ dict(type='RandomFlip'),
57
+ dict(type='Normalize',
58
+ mean=[123.675, 116.28, 103.53],
59
+ std=[58.395, 57.12, 57.375],
60
+ to_rgb=True),
61
+ dict(type='Pad', size_divisor=32),
62
+ dict(type='DefaultFormatBundle'),
63
+ dict(type='Collect', keys=['img'])
64
+ ])
65
+ ]
66
+ data = dict(test=dict(pipeline=test_pipeline))
anime_face_detector/configs/mmdet/yolov3.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model = dict(type='YOLOV3',
2
+ backbone=dict(type='Darknet', depth=53, out_indices=(3, 4, 5)),
3
+ neck=dict(type='YOLOV3Neck',
4
+ num_scales=3,
5
+ in_channels=[1024, 512, 256],
6
+ out_channels=[512, 256, 128]),
7
+ bbox_head=dict(type='YOLOV3Head',
8
+ num_classes=1,
9
+ in_channels=[512, 256, 128],
10
+ out_channels=[1024, 512, 256],
11
+ anchor_generator=dict(type='YOLOAnchorGenerator',
12
+ base_sizes=[[(116, 90),
13
+ (156, 198),
14
+ (373, 326)],
15
+ [(30, 61),
16
+ (62, 45),
17
+ (59, 119)],
18
+ [(10, 13),
19
+ (16, 30),
20
+ (33, 23)]],
21
+ strides=[32, 16, 8]),
22
+ bbox_coder=dict(type='YOLOBBoxCoder'),
23
+ featmap_strides=[32, 16, 8]),
24
+ test_cfg=dict(nms_pre=1000,
25
+ min_bbox_size=0,
26
+ score_thr=0.05,
27
+ conf_thr=0.005,
28
+ nms=dict(type='nms', iou_threshold=0.45),
29
+ max_per_img=100))
30
+ test_pipeline = [
31
+ dict(type='LoadImageFromFile'),
32
+ dict(type='MultiScaleFlipAug',
33
+ img_scale=(608, 608),
34
+ flip=False,
35
+ transforms=[
36
+ dict(type='Resize', keep_ratio=True),
37
+ dict(type='RandomFlip'),
38
+ dict(type='Normalize',
39
+ mean=[0, 0, 0],
40
+ std=[255.0, 255.0, 255.0],
41
+ to_rgb=True),
42
+ dict(type='Pad', size_divisor=32),
43
+ dict(type='DefaultFormatBundle'),
44
+ dict(type='Collect', keys=['img'])
45
+ ])
46
+ ]
47
+ data = dict(test=dict(pipeline=test_pipeline))
anime_face_detector/configs/mmpose/hrnetv2.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ channel_cfg = dict(num_output_channels=28,
2
+ dataset_joints=28,
3
+ dataset_channel=[
4
+ list(range(28)),
5
+ ],
6
+ inference_channel=list(range(28)))
7
+
8
+ model = dict(
9
+ type='TopDown',
10
+ backbone=dict(type='HRNet',
11
+ in_channels=3,
12
+ extra=dict(stage1=dict(num_modules=1,
13
+ num_branches=1,
14
+ block='BOTTLENECK',
15
+ num_blocks=(4, ),
16
+ num_channels=(64, )),
17
+ stage2=dict(num_modules=1,
18
+ num_branches=2,
19
+ block='BASIC',
20
+ num_blocks=(4, 4),
21
+ num_channels=(18, 36)),
22
+ stage3=dict(num_modules=4,
23
+ num_branches=3,
24
+ block='BASIC',
25
+ num_blocks=(4, 4, 4),
26
+ num_channels=(18, 36, 72)),
27
+ stage4=dict(num_modules=3,
28
+ num_branches=4,
29
+ block='BASIC',
30
+ num_blocks=(4, 4, 4, 4),
31
+ num_channels=(18, 36, 72, 144),
32
+ multiscale_output=True),
33
+ upsample=dict(mode='bilinear',
34
+ align_corners=False))),
35
+ keypoint_head=dict(type='TopdownHeatmapSimpleHead',
36
+ in_channels=[18, 36, 72, 144],
37
+ in_index=(0, 1, 2, 3),
38
+ input_transform='resize_concat',
39
+ out_channels=channel_cfg['num_output_channels'],
40
+ num_deconv_layers=0,
41
+ extra=dict(final_conv_kernel=1,
42
+ num_conv_layers=1,
43
+ num_conv_kernels=(1, )),
44
+ loss_keypoint=dict(type='JointsMSELoss',
45
+ use_target_weight=True)),
46
+ test_cfg=dict(flip_test=True,
47
+ post_process='unbiased',
48
+ shift_heatmap=True,
49
+ modulate_kernel=11))
50
+
51
+ data_cfg = dict(image_size=[256, 256],
52
+ heatmap_size=[64, 64],
53
+ num_output_channels=channel_cfg['num_output_channels'],
54
+ num_joints=channel_cfg['dataset_joints'],
55
+ dataset_channel=channel_cfg['dataset_channel'],
56
+ inference_channel=channel_cfg['inference_channel'])
57
+
58
+ test_pipeline = [
59
+ dict(type='LoadImageFromFile'),
60
+ dict(type='TopDownAffine'),
61
+ dict(type='ToTensor'),
62
+ dict(type='NormalizeTensor',
63
+ mean=[0.485, 0.456, 0.406],
64
+ std=[0.229, 0.224, 0.225]),
65
+ dict(type='Collect',
66
+ keys=['img'],
67
+ meta_keys=['image_file', 'center', 'scale', 'rotation',
68
+ 'flip_pairs']),
69
+ ]
70
+
71
+ dataset_info = dict(dataset_name='anime_face',
72
+ paper_info=dict(),
73
+ keypoint_info={
74
+ 0:
75
+ dict(name='kpt-0',
76
+ id=0,
77
+ color=[255, 255, 255],
78
+ type='',
79
+ swap='kpt-4'),
80
+ 1:
81
+ dict(name='kpt-1',
82
+ id=1,
83
+ color=[255, 255, 255],
84
+ type='',
85
+ swap='kpt-3'),
86
+ 2:
87
+ dict(name='kpt-2',
88
+ id=2,
89
+ color=[255, 255, 255],
90
+ type='',
91
+ swap=''),
92
+ 3:
93
+ dict(name='kpt-3',
94
+ id=3,
95
+ color=[255, 255, 255],
96
+ type='',
97
+ swap='kpt-1'),
98
+ 4:
99
+ dict(name='kpt-4',
100
+ id=4,
101
+ color=[255, 255, 255],
102
+ type='',
103
+ swap='kpt-0'),
104
+ 5:
105
+ dict(name='kpt-5',
106
+ id=5,
107
+ color=[255, 255, 255],
108
+ type='',
109
+ swap='kpt-10'),
110
+ 6:
111
+ dict(name='kpt-6',
112
+ id=6,
113
+ color=[255, 255, 255],
114
+ type='',
115
+ swap='kpt-9'),
116
+ 7:
117
+ dict(name='kpt-7',
118
+ id=7,
119
+ color=[255, 255, 255],
120
+ type='',
121
+ swap='kpt-8'),
122
+ 8:
123
+ dict(name='kpt-8',
124
+ id=8,
125
+ color=[255, 255, 255],
126
+ type='',
127
+ swap='kpt-7'),
128
+ 9:
129
+ dict(name='kpt-9',
130
+ id=9,
131
+ color=[255, 255, 255],
132
+ type='',
133
+ swap='kpt-6'),
134
+ 10:
135
+ dict(name='kpt-10',
136
+ id=10,
137
+ color=[255, 255, 255],
138
+ type='',
139
+ swap='kpt-5'),
140
+ 11:
141
+ dict(name='kpt-11',
142
+ id=11,
143
+ color=[255, 255, 255],
144
+ type='',
145
+ swap='kpt-19'),
146
+ 12:
147
+ dict(name='kpt-12',
148
+ id=12,
149
+ color=[255, 255, 255],
150
+ type='',
151
+ swap='kpt-18'),
152
+ 13:
153
+ dict(name='kpt-13',
154
+ id=13,
155
+ color=[255, 255, 255],
156
+ type='',
157
+ swap='kpt-17'),
158
+ 14:
159
+ dict(name='kpt-14',
160
+ id=14,
161
+ color=[255, 255, 255],
162
+ type='',
163
+ swap='kpt-22'),
164
+ 15:
165
+ dict(name='kpt-15',
166
+ id=15,
167
+ color=[255, 255, 255],
168
+ type='',
169
+ swap='kpt-21'),
170
+ 16:
171
+ dict(name='kpt-16',
172
+ id=16,
173
+ color=[255, 255, 255],
174
+ type='',
175
+ swap='kpt-20'),
176
+ 17:
177
+ dict(name='kpt-17',
178
+ id=17,
179
+ color=[255, 255, 255],
180
+ type='',
181
+ swap='kpt-13'),
182
+ 18:
183
+ dict(name='kpt-18',
184
+ id=18,
185
+ color=[255, 255, 255],
186
+ type='',
187
+ swap='kpt-12'),
188
+ 19:
189
+ dict(name='kpt-19',
190
+ id=19,
191
+ color=[255, 255, 255],
192
+ type='',
193
+ swap='kpt-11'),
194
+ 20:
195
+ dict(name='kpt-20',
196
+ id=20,
197
+ color=[255, 255, 255],
198
+ type='',
199
+ swap='kpt-16'),
200
+ 21:
201
+ dict(name='kpt-21',
202
+ id=21,
203
+ color=[255, 255, 255],
204
+ type='',
205
+ swap='kpt-15'),
206
+ 22:
207
+ dict(name='kpt-22',
208
+ id=22,
209
+ color=[255, 255, 255],
210
+ type='',
211
+ swap='kpt-14'),
212
+ 23:
213
+ dict(name='kpt-23',
214
+ id=23,
215
+ color=[255, 255, 255],
216
+ type='',
217
+ swap=''),
218
+ 24:
219
+ dict(name='kpt-24',
220
+ id=24,
221
+ color=[255, 255, 255],
222
+ type='',
223
+ swap='kpt-26'),
224
+ 25:
225
+ dict(name='kpt-25',
226
+ id=25,
227
+ color=[255, 255, 255],
228
+ type='',
229
+ swap=''),
230
+ 26:
231
+ dict(name='kpt-26',
232
+ id=26,
233
+ color=[255, 255, 255],
234
+ type='',
235
+ swap='kpt-24'),
236
+ 27:
237
+ dict(name='kpt-27',
238
+ id=27,
239
+ color=[255, 255, 255],
240
+ type='',
241
+ swap='')
242
+ },
243
+ skeleton_info={},
244
+ joint_weights=[1.] * 28,
245
+ sigmas=[])
246
+
247
+ data = dict(test=dict(type='',
248
+ data_cfg=data_cfg,
249
+ pipeline=test_pipeline,
250
+ dataset_info=dataset_info), )
anime_face_detector/detector.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+ import warnings
5
+ from typing import Optional, Union
6
+
7
+ import cv2
8
+ import mmcv
9
+ import numpy as np
10
+ import torch.nn as nn
11
+ from mmdet.apis import inference_detector, init_detector
12
+ from mmpose.apis import inference_top_down_pose_model, init_pose_model
13
+ from mmpose.datasets import DatasetInfo
14
+
15
+
16
+ class LandmarkDetector:
17
+ def __init__(
18
+ self,
19
+ landmark_detector_config_or_path: Union[mmcv.Config, str,
20
+ pathlib.Path],
21
+ landmark_detector_checkpoint_path: Union[str, pathlib.Path],
22
+ face_detector_config_or_path: Optional[Union[mmcv.Config, str,
23
+ pathlib.Path]] = None,
24
+ face_detector_checkpoint_path: Optional[Union[
25
+ str, pathlib.Path]] = None,
26
+ device: str = 'cuda:0',
27
+ flip_test: bool = True,
28
+ box_scale_factor: float = 1.1):
29
+ landmark_config = self._load_config(landmark_detector_config_or_path)
30
+ self.dataset_info = DatasetInfo(
31
+ landmark_config.dataset_info) # type: ignore
32
+ face_detector_config = self._load_config(face_detector_config_or_path)
33
+
34
+ self.landmark_detector = self._init_pose_model(
35
+ landmark_config, landmark_detector_checkpoint_path, device,
36
+ flip_test)
37
+ self.face_detector = self._init_face_detector(
38
+ face_detector_config, face_detector_checkpoint_path, device)
39
+
40
+ self.box_scale_factor = box_scale_factor
41
+
42
+ @staticmethod
43
+ def _load_config(
44
+ config_or_path: Optional[Union[mmcv.Config, str, pathlib.Path]]
45
+ ) -> Optional[mmcv.Config]:
46
+ if config_or_path is None or isinstance(config_or_path, mmcv.Config):
47
+ return config_or_path
48
+ return mmcv.Config.fromfile(config_or_path)
49
+
50
+ @staticmethod
51
+ def _init_pose_model(config: mmcv.Config,
52
+ checkpoint_path: Union[str, pathlib.Path],
53
+ device: str, flip_test: bool) -> nn.Module:
54
+ if isinstance(checkpoint_path, pathlib.Path):
55
+ checkpoint_path = checkpoint_path.as_posix()
56
+ model = init_pose_model(config, checkpoint_path, device=device)
57
+ model.cfg.model.test_cfg.flip_test = flip_test
58
+ return model
59
+
60
+ @staticmethod
61
+ def _init_face_detector(config: Optional[mmcv.Config],
62
+ checkpoint_path: Optional[Union[str,
63
+ pathlib.Path]],
64
+ device: str) -> Optional[nn.Module]:
65
+ if config is not None:
66
+ if isinstance(checkpoint_path, pathlib.Path):
67
+ checkpoint_path = checkpoint_path.as_posix()
68
+ model = init_detector(config, checkpoint_path, device=device)
69
+ else:
70
+ model = None
71
+ return model
72
+
73
+ def _detect_faces(self, image: np.ndarray) -> list[np.ndarray]:
74
+ # predicted boxes using mmdet model have the format of
75
+ # [x0, y0, x1, y1, score]
76
+ boxes = inference_detector(self.face_detector, image)[0]
77
+ # scale boxes by `self.box_scale_factor`
78
+ boxes = self._update_pred_box(boxes)
79
+ return boxes
80
+
81
+ def _update_pred_box(self, pred_boxes: np.ndarray) -> list[np.ndarray]:
82
+ boxes = []
83
+ for pred_box in pred_boxes:
84
+ box = pred_box[:4]
85
+ size = box[2:] - box[:2] + 1
86
+ new_size = size * self.box_scale_factor
87
+ center = (box[:2] + box[2:]) / 2
88
+ tl = center - new_size / 2
89
+ br = tl + new_size
90
+ pred_box[:4] = np.concatenate([tl, br])
91
+ boxes.append(pred_box)
92
+ return boxes
93
+
94
+ def _detect_landmarks(
95
+ self, image: np.ndarray,
96
+ boxes: list[dict[str, np.ndarray]]) -> list[dict[str, np.ndarray]]:
97
+ preds, _ = inference_top_down_pose_model(
98
+ self.landmark_detector,
99
+ image,
100
+ boxes,
101
+ format='xyxy',
102
+ dataset_info=self.dataset_info,
103
+ return_heatmap=False)
104
+ return preds
105
+
106
+ @staticmethod
107
+ def _load_image(
108
+ image_or_path: Union[np.ndarray, str, pathlib.Path]) -> np.ndarray:
109
+ if isinstance(image_or_path, np.ndarray):
110
+ image = image_or_path
111
+ elif isinstance(image_or_path, str):
112
+ image = cv2.imread(image_or_path)
113
+ elif isinstance(image_or_path, pathlib.Path):
114
+ image = cv2.imread(image_or_path.as_posix())
115
+ else:
116
+ raise ValueError
117
+ return image
118
+
119
+ def __call__(
120
+ self,
121
+ image_or_path: Union[np.ndarray, str, pathlib.Path],
122
+ boxes: Optional[list[np.ndarray]] = None
123
+ ) -> list[dict[str, np.ndarray]]:
124
+ """Detect face landmarks.
125
+
126
+ Args:
127
+ image_or_path: An image with BGR channel order or an image path.
128
+ boxes: A list of bounding boxes for faces. Each bounding box
129
+ should be of the form [x0, y0, x1, y1, [score]].
130
+
131
+ Returns: A list of detection results. Each detection result has
132
+ bounding box of the form [x0, y0, x1, y1, [score]], and landmarks
133
+ of the form [x, y, score].
134
+ """
135
+ image = self._load_image(image_or_path)
136
+ if boxes is None:
137
+ if self.face_detector is not None:
138
+ boxes = self._detect_faces(image)
139
+ else:
140
+ warnings.warn(
141
+ 'Neither the face detector nor the bounding box is '
142
+ 'specified. So the entire image is treated as the face '
143
+ 'region.')
144
+ h, w = image.shape[:2]
145
+ boxes = [np.array([0, 0, w - 1, h - 1, 1])]
146
+ box_list = [{'bbox': box} for box in boxes]
147
+ return self._detect_landmarks(image, box_list)
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import pathlib
4
+
5
+ import cv2
6
+ import gradio as gr
7
+ import numpy as np
8
+ import PIL.Image
9
+ import torch
10
+
11
+ import anime_face_detector
12
+
13
+
14
+ def detect(
15
+ img,
16
+ face_score_threshold: float,
17
+ landmark_score_threshold: float,
18
+ detector: anime_face_detector.LandmarkDetector,
19
+ ) -> PIL.Image.Image:
20
+ if not img:
21
+ return None
22
+
23
+ image = cv2.imread(img)
24
+ preds = detector(image)
25
+
26
+ res = image.copy()
27
+ for pred in preds:
28
+ box = pred["bbox"]
29
+ box, score = box[:4], box[4]
30
+ if score < face_score_threshold:
31
+ continue
32
+ box = np.round(box).astype(int)
33
+
34
+ lt = max(2, int(3 * (box[2:] - box[:2]).max() / 256))
35
+
36
+ cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), lt)
37
+
38
+ pred_pts = pred["keypoints"]
39
+ for *pt, score in pred_pts:
40
+ if score < landmark_score_threshold:
41
+ color = (0, 255, 255)
42
+ else:
43
+ color = (0, 0, 255)
44
+ pt = np.round(pt).astype(int)
45
+ cv2.circle(res, tuple(pt), lt, color, cv2.FILLED)
46
+ res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)
47
+
48
+ image_pil = PIL.Image.fromarray(res)
49
+ return image_pil
50
+
51
+
52
+ def main():
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument(
55
+ "--detector", type=str, default="yolov3", choices=["yolov3", "faster-rcnn"]
56
+ )
57
+ parser.add_argument(
58
+ "--device", type=str, default="cuda:0", choices=["cuda:0", "cpu"]
59
+ )
60
+ parser.add_argument("--face-score-threshold", type=float, default=0.5)
61
+ parser.add_argument("--landmark-score-threshold", type=float, default=0.3)
62
+ parser.add_argument("--score-slider-step", type=float, default=0.05)
63
+ parser.add_argument("--port", type=int)
64
+ parser.add_argument("--debug", action="store_true")
65
+ parser.add_argument("--share", action="store_true")
66
+ parser.add_argument("--live", action="store_true")
67
+ args = parser.parse_args()
68
+
69
+ sample_path = pathlib.Path("assets/input.jpg")
70
+ if not sample_path.exists():
71
+ torch.hub.download_url_to_file(
72
+ "https://raw.githubusercontent.com/edisonlee55/hysts-anime-face-detector/main/assets/input.jpg",
73
+ sample_path.as_posix(),
74
+ )
75
+
76
+ detector = anime_face_detector.create_detector(args.detector, device=args.device)
77
+ func = functools.partial(detect, detector=detector)
78
+
79
+ title = "edisonlee55/hysts-anime-face-detector"
80
+ description = "Demo for edisonlee55/hysts-anime-face-detector. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
81
+ article = "<a href='https://github.com/edisonlee55/hysts-anime-face-detector'>GitHub Repo</a>"
82
+
83
+ gr.Interface(
84
+ func,
85
+ [
86
+ gr.Image(type="filepath", label="Input"),
87
+ gr.Slider(
88
+ 0,
89
+ 1,
90
+ step=args.score_slider_step,
91
+ value=args.face_score_threshold,
92
+ label="Face Score Threshold",
93
+ ),
94
+ gr.Slider(
95
+ 0,
96
+ 1,
97
+ step=args.score_slider_step,
98
+ value=args.landmark_score_threshold,
99
+ label="Landmark Score Threshold",
100
+ ),
101
+ ],
102
+ gr.Image(type="pil", label="Output"),
103
+ title=title,
104
+ description=description,
105
+ article=article,
106
+ examples=[
107
+ [
108
+ sample_path.as_posix(),
109
+ args.face_score_threshold,
110
+ args.landmark_score_threshold,
111
+ ],
112
+ ],
113
+ live=args.live,
114
+ ).launch(
115
+ debug=args.debug, share=args.share, server_port=args.port, enable_queue=True
116
+ )
117
+
118
+
119
+ if __name__ == "__main__":
120
+ main()
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python=3.10
2
+ openmim==0.3.7
3
+ mmcv-full==1.6.2
4
+ mmdet==2.28.2
5
+ mmpose==0.29.0
6
+
7
+ numpy==1.24.3
8
+ scipy==1.10.1
9
+
10
+ opencv-python-headless==4.7.0.72
11
+
12
+ torch==2.0.1
13
+ torchvision==0.15.2
14
+
15
+ # for gradio
16
+ # gradio==3.32.0