#!/usr/bin/env python from __future__ import annotations import functools import os import pathlib import sys import tarfile import urllib.request from typing import Callable import cv2 import gradio as gr import huggingface_hub import numpy as np import PIL.Image import torch import torchvision.transforms as T sys.path.insert(0, "anime_face_landmark_detection") from CFA import CFA DESCRIPTION = "# [kanosawa/anime_face_landmark_detection](https://github.com/kanosawa/anime_face_landmark_detection)" NUM_LANDMARK = 24 CROP_SIZE = 128 def load_sample_image_paths() -> list[pathlib.Path]: image_dir = pathlib.Path("images") if not image_dir.exists(): dataset_repo = "hysts/sample-images-TADNE" path = huggingface_hub.hf_hub_download(dataset_repo, "images.tar.gz", repo_type="dataset") with tarfile.open(path) as f: f.extractall() return sorted(image_dir.glob("*")) def load_face_detector() -> cv2.CascadeClassifier: url = "https://raw.githubusercontent.com/nagadomi/lbpcascade_animeface/master/lbpcascade_animeface.xml" path = pathlib.Path("lbpcascade_animeface.xml") if not path.exists(): urllib.request.urlretrieve(url, path.as_posix()) return cv2.CascadeClassifier(path.as_posix()) def load_landmark_detector(device: torch.device) -> torch.nn.Module: path = huggingface_hub.hf_hub_download( "public-data/anime_face_landmark_detection", "checkpoint_landmark_191116.pth" ) model = CFA(output_channel_num=NUM_LANDMARK + 1, checkpoint_name=path) model.to(device) model.eval() return model @torch.inference_mode() def detect( image_path: str, face_detector: cv2.CascadeClassifier, device: torch.device, transform: Callable, landmark_detector: torch.nn.Module, ) -> np.ndarray: image = cv2.imread(image_path) gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) preds = face_detector.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(24, 24)) image_h, image_w = image.shape[:2] pil_image = PIL.Image.fromarray(image[:, :, ::-1].copy()) res = image.copy() for x_orig, y_orig, w_orig, h_orig in preds: x0 = round(max(x_orig - w_orig / 8, 0)) x1 = round(min(x_orig + w_orig * 9 / 8, image_w)) y0 = round(max(y_orig - h_orig / 4, 0)) y1 = y_orig + h_orig w = x1 - x0 h = y1 - y0 temp = pil_image.crop((x0, y0, x1, y1)) temp = temp.resize((CROP_SIZE, CROP_SIZE), PIL.Image.BICUBIC) data = transform(temp) data = data.to(device).unsqueeze(0) heatmaps = landmark_detector(data) heatmaps = heatmaps[-1].cpu().numpy()[0] cv2.rectangle(res, (x0, y0), (x1, y1), (0, 255, 0), 2) for i in range(NUM_LANDMARK): heatmap = cv2.resize(heatmaps[i], (CROP_SIZE, CROP_SIZE), interpolation=cv2.INTER_CUBIC) pty, ptx = np.unravel_index(np.argmax(heatmap), heatmap.shape) pt_crop = np.round(np.array([ptx * w, pty * h]) / CROP_SIZE).astype(int) pt = np.array([x0, y0]) + pt_crop cv2.circle(res, tuple(pt), 2, (0, 0, 255), cv2.FILLED) return res[:, :, ::-1] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") image_paths = load_sample_image_paths() examples = [[path.as_posix()] for path in image_paths] face_detector = load_face_detector() landmark_detector = load_landmark_detector(device) transform = T.Compose( [ T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ] ) fn = functools.partial( detect, face_detector=face_detector, device=device, transform=transform, landmark_detector=landmark_detector, ) with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): image = gr.Image(label="Input", type="filepath") run_button = gr.Button("Run") with gr.Column(): result = gr.Image(label="Result") gr.Examples( examples=examples, inputs=image, outputs=result, fn=fn, cache_examples=os.getenv("CACHE_EXAMPLES") == "1", ) run_button.click( fn=fn, inputs=image, outputs=result, api_name="predict", ) if __name__ == "__main__": demo.queue(max_size=15).launch()