#!/usr/bin/env python from __future__ import annotations import argparse import functools import os import pathlib import sys import tarfile import urllib from typing import Callable sys.path.insert(0, 'anime_face_landmark_detection') import cv2 import gradio as gr import huggingface_hub import numpy as np import PIL.Image import torch import torchvision.transforms as T from CFA import CFA TOKEN = os.environ['TOKEN'] MODEL_REPO = 'hysts/anime_face_landmark_detection' MODEL_FILENAME = 'checkpoint_landmark_191116.pth' NUM_LANDMARK = 24 CROP_SIZE = 128 def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--theme', type=str) parser.add_argument('--live', action='store_true') parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') parser.add_argument('--allow-flagging', type=str, default='never') parser.add_argument('--allow-screenshot', action='store_true') return parser.parse_args() def load_sample_image_paths() -> list[pathlib.Path]: image_dir = pathlib.Path('images') if not image_dir.exists(): dataset_repo = 'hysts/sample-images-TADNE' path = huggingface_hub.hf_hub_download(dataset_repo, 'images.tar.gz', repo_type='dataset', use_auth_token=TOKEN) with tarfile.open(path) as f: f.extractall() return sorted(image_dir.glob('*')) def load_face_detector() -> cv2.CascadeClassifier: url = 'https://raw.githubusercontent.com/nagadomi/lbpcascade_animeface/master/lbpcascade_animeface.xml' path = pathlib.Path('lbpcascade_animeface.xml') if not path.exists(): urllib.request.urlretrieve(url, path.as_posix()) return cv2.CascadeClassifier(path.as_posix()) def load_landmark_detector(device: torch.device) -> torch.nn.Module: path = huggingface_hub.hf_hub_download(MODEL_REPO, MODEL_FILENAME, use_auth_token=TOKEN) model = CFA(output_channel_num=NUM_LANDMARK + 1, checkpoint_name=path) model.to(device) model.eval() return model @torch.inference_mode() def detect(image, face_detector: cv2.CascadeClassifier, device: torch.device, transform: Callable, landmark_detector: torch.nn.Module) -> np.ndarray: image = cv2.imread(image.name) gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) preds = face_detector.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(24, 24)) image_h, image_w = image.shape[:2] pil_image = PIL.Image.fromarray(image[:, :, ::-1].copy()) res = image.copy() for x_orig, y_orig, w_orig, h_orig in preds: x0 = round(max(x_orig - w_orig / 8, 0)) x1 = round(min(x_orig + w_orig * 9 / 8, image_w)) y0 = round(max(y_orig - h_orig / 4, 0)) y1 = y_orig + h_orig w = x1 - x0 h = y1 - y0 temp = pil_image.crop((x0, y0, x1, y1)) temp = temp.resize((CROP_SIZE, CROP_SIZE), PIL.Image.BICUBIC) data = transform(temp) data = data.to(device).unsqueeze(0) heatmaps = landmark_detector(data) heatmaps = heatmaps[-1].cpu().numpy()[0] cv2.rectangle(res, (x0, y0), (x1, y1), (0, 255, 0), 2) for i in range(NUM_LANDMARK): heatmap = cv2.resize(heatmaps[i], (CROP_SIZE, CROP_SIZE), interpolation=cv2.INTER_CUBIC) pty, ptx = np.unravel_index(np.argmax(heatmap), heatmap.shape) pt_crop = np.round(np.array([ptx * w, pty * h]) / CROP_SIZE).astype(int) pt = np.array([x0, y0]) + pt_crop cv2.circle(res, tuple(pt), 2, (0, 0, 255), cv2.FILLED) res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB) return res def main(): gr.close_all() args = parse_args() device = torch.device(args.device) image_paths = load_sample_image_paths() examples = [[path.as_posix()] for path in image_paths] face_detector = load_face_detector() landmark_detector = load_landmark_detector(device) transform = T.Compose([ T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) func = functools.partial(detect, face_detector=face_detector, device=device, transform=transform, landmark_detector=landmark_detector) func = functools.update_wrapper(func, detect) repo_url = 'https://github.com/kanosawa/anime_face_landmark_detection' title = 'kanosawa/anime_face_landmark_detection' description = f'A demo for {repo_url}' article = None gr.Interface( func, gr.inputs.Image(type='file', label='Input'), gr.outputs.Image(label='Output'), theme=args.theme, title=title, description=description, article=article, examples=examples, allow_screenshot=args.allow_screenshot, allow_flagging=args.allow_flagging, live=args.live, ).launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()