|
import pathlib |
|
|
|
import torch |
|
|
|
from .detector import LandmarkDetector |
|
|
|
|
|
def get_config_path(model_name: str) -> pathlib.Path: |
|
assert model_name in ['faster-rcnn', 'yolov3', 'hrnetv2'] |
|
|
|
package_path = pathlib.Path(__file__).parent.resolve() |
|
if model_name in ['faster-rcnn', 'yolov3']: |
|
config_dir = package_path / 'configs' / 'mmdet' |
|
else: |
|
config_dir = package_path / 'configs' / 'mmpose' |
|
return config_dir / f'{model_name}.py' |
|
|
|
|
|
def get_checkpoint_path(model_name: str) -> pathlib.Path: |
|
assert model_name in ['faster-rcnn', 'yolov3', 'hrnetv2'] |
|
if model_name in ['faster-rcnn', 'yolov3']: |
|
file_name = f'mmdet_anime-face_{model_name}.pth' |
|
else: |
|
file_name = f'mmpose_anime-face_{model_name}.pth' |
|
|
|
model_dir = pathlib.Path(torch.hub.get_dir()) / 'checkpoints' |
|
model_dir.mkdir(exist_ok=True, parents=True) |
|
model_path = model_dir / file_name |
|
if not model_path.exists(): |
|
url = f'https://github.com/hysts/anime-face-detector/releases/download/v0.0.1/{file_name}' |
|
torch.hub.download_url_to_file(url, model_path.as_posix()) |
|
|
|
return model_path |
|
|
|
|
|
def create_detector(face_detector_name: str = 'yolov3', |
|
landmark_model_name='hrnetv2', |
|
device: str = 'cuda:0', |
|
flip_test: bool = True, |
|
box_scale_factor: float = 1.1) -> LandmarkDetector: |
|
print("loading model...") |
|
assert face_detector_name in ['yolov3', 'faster-rcnn'] |
|
assert landmark_model_name in ['hrnetv2'] |
|
detector_config_path = get_config_path(face_detector_name) |
|
landmark_config_path = get_config_path(landmark_model_name) |
|
detector_checkpoint_path = get_checkpoint_path(face_detector_name) |
|
landmark_checkpoint_path = get_checkpoint_path(landmark_model_name) |
|
model = LandmarkDetector(landmark_config_path, |
|
landmark_checkpoint_path, |
|
detector_config_path, |
|
detector_checkpoint_path, |
|
device=device, |
|
flip_test=flip_test, |
|
box_scale_factor=box_scale_factor) |
|
return model |
|
|