#!/usr/bin/env python from __future__ import annotations import functools import os import pathlib import sys import tarfile import urllib.request from typing import Callable import cv2 import gradio as gr import huggingface_hub import numpy as np import PIL.Image import torch import torchvision.transforms as T sys.path.insert(0, 'anime_face_landmark_detection') from CFA import CFA DESCRIPTION = '# [kanosawa/anime_face_landmark_detection](https://github.com/kanosawa/anime_face_landmark_detection)' NUM_LANDMARK = 24 CROP_SIZE = 128 def load_sample_image_paths() -> list[pathlib.Path]: image_dir = pathlib.Path('images') if not image_dir.exists(): dataset_repo = 'hysts/sample-images-TADNE' path = huggingface_hub.hf_hub_download(dataset_repo, 'images.tar.gz', repo_type='dataset') with tarfile.open(path) as f: f.extractall() return sorted(image_dir.glob('*')) def load_face_detector() -> cv2.CascadeClassifier: url = 'https://raw.githubusercontent.com/nagadomi/lbpcascade_animeface/master/lbpcascade_animeface.xml' path = pathlib.Path('lbpcascade_animeface.xml') if not path.exists(): urllib.request.urlretrieve(url, path.as_posix()) return cv2.CascadeClassifier(path.as_posix()) def load_landmark_detector(device: torch.device) -> torch.nn.Module: path = huggingface_hub.hf_hub_download( 'public-data/anime_face_landmark_detection', 'checkpoint_landmark_191116.pth') model = CFA(output_channel_num=NUM_LANDMARK + 1, checkpoint_name=path) model.to(device) model.eval() return model @torch.inference_mode() def detect(image_path: str, face_detector: cv2.CascadeClassifier, device: torch.device, transform: Callable, landmark_detector: torch.nn.Module) -> np.ndarray: image = cv2.imread(image_path) gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) preds = face_detector.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(24, 24)) image_h, image_w = image.shape[:2] pil_image = PIL.Image.fromarray(image[:, :, ::-1].copy()) res = image.copy() for x_orig, y_orig, w_orig, h_orig in preds: x0 = round(max(x_orig - w_orig / 8, 0)) x1 = round(min(x_orig + w_orig * 9 / 8, image_w)) y0 = round(max(y_orig - h_orig / 4, 0)) y1 = y_orig + h_orig w = x1 - x0 h = y1 - y0 temp = pil_image.crop((x0, y0, x1, y1)) temp = temp.resize((CROP_SIZE, CROP_SIZE), PIL.Image.BICUBIC) data = transform(temp) data = data.to(device).unsqueeze(0) heatmaps = landmark_detector(data) heatmaps = heatmaps[-1].cpu().numpy()[0] cv2.rectangle(res, (x0, y0), (x1, y1), (0, 255, 0), 2) for i in range(NUM_LANDMARK): heatmap = cv2.resize(heatmaps[i], (CROP_SIZE, CROP_SIZE), interpolation=cv2.INTER_CUBIC) pty, ptx = np.unravel_index(np.argmax(heatmap), heatmap.shape) pt_crop = np.round(np.array([ptx * w, pty * h]) / CROP_SIZE).astype(int) pt = np.array([x0, y0]) + pt_crop cv2.circle(res, tuple(pt), 2, (0, 0, 255), cv2.FILLED) return res[:, :, ::-1] device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') image_paths = load_sample_image_paths() examples = [[path.as_posix()] for path in image_paths] face_detector = load_face_detector() landmark_detector = load_landmark_detector(device) transform = T.Compose([ T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) fn = functools.partial(detect, face_detector=face_detector, device=device, transform=transform, landmark_detector=landmark_detector) with gr.Blocks(css='style.css') as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): image = gr.Image(label='Input', type='filepath') run_button = gr.Button('Run') with gr.Column(): result = gr.Image(label='Result') gr.Examples(examples=examples, inputs=image, outputs=result, fn=fn, cache_examples=os.getenv('CACHE_EXAMPLES') == '1') run_button.click(fn=fn, inputs=image, outputs=result, api_name='predict') demo.queue(max_size=15).launch()