#!/usr/bin/env python from __future__ import annotations import argparse import functools import os import pathlib import sys from typing import Callable if os.environ.get('SYSTEM') == 'spaces': os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/fused_act.py") os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/upfirdn2d.py") sys.path.insert(0, 'DualStyleGAN') import dlib import gradio as gr import huggingface_hub import numpy as np import PIL.Image import torch import torch.nn as nn import torchvision.transforms as T from model.dualstylegan import DualStyleGAN from model.encoder.align_all_parallel import align_face from model.encoder.psp import pSp from util import load_image, visualize ORIGINAL_REPO_URL = 'https://github.com/williamyang1991/DualStyleGAN' TITLE = 'williamyang1991/DualStyleGAN' DESCRIPTION = f"""A demo for {ORIGINAL_REPO_URL} You can select style images for cartoon from the table below. The style image index should be in the following range: (cartoon: 0-316, caricature: 0-198, anime: 0-173, arcane: 0-99, comic: 0-100, pixar: 0-121, slamdunk: 0-119) """ ARTICLE = '![cartoon style images](https://user-images.githubusercontent.com/18130694/159848447-96fa5194-32ec-42f0-945a-3b1958bf6e5e.jpg)' TOKEN = os.environ['TOKEN'] MODEL_REPO = 'hysts/DualStyleGAN' def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--theme', type=str) parser.add_argument('--live', action='store_true') parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') parser.add_argument('--allow-flagging', type=str, default='never') parser.add_argument('--allow-screenshot', action='store_true') return parser.parse_args() def load_encoder(device: torch.device) -> nn.Module: ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'models/encoder.pt', use_auth_token=TOKEN) ckpt = torch.load(ckpt_path, map_location='cpu') opts = ckpt['opts'] opts['device'] = device.type opts['checkpoint_path'] = ckpt_path opts = argparse.Namespace(**opts) model = pSp(opts) model.to(device) model.eval() return model def load_generator(style_type: str, device: torch.device) -> nn.Module: model = DualStyleGAN(1024, 512, 8, 2, res_index=6) ckpt_path = huggingface_hub.hf_hub_download( MODEL_REPO, f'models/{style_type}/generator.pt', use_auth_token=TOKEN) ckpt = torch.load(ckpt_path, map_location='cpu') model.load_state_dict(ckpt['g_ema']) model.to(device) model.eval() return model def load_exstylecode(style_type: str) -> dict[str, np.ndarray]: if style_type in ['cartoon', 'caricature', 'anime']: filename = 'refined_exstyle_code.npy' else: filename = 'exstyle_code.npy' path = huggingface_hub.hf_hub_download(MODEL_REPO, f'models/{style_type}/{filename}', use_auth_token=TOKEN) exstyles = np.load(path, allow_pickle=True).item() return exstyles def create_transform() -> Callable: transform = T.Compose([ T.Resize(256), T.CenterCrop(256), T.ToTensor(), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) return transform def create_dlib_landmark_model(): path = huggingface_hub.hf_hub_download( 'hysts/dlib_face_landmark_model', 'shape_predictor_68_face_landmarks.dat', use_auth_token=TOKEN) return dlib.shape_predictor(path) def denormalize(tensor: torch.Tensor) -> torch.Tensor: return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8) def postprocess(tensor: torch.Tensor) -> PIL.Image.Image: tensor = denormalize(tensor) image = tensor.cpu().numpy().transpose(1, 2, 0) return PIL.Image.fromarray(image) @torch.inference_mode() def run( image, style_type: str, style_id: float, dlib_landmark_model, encoder: nn.Module, generator_dict: dict[str, nn.Module], exstyle_dict: dict[str, dict[str, np.ndarray]], transform: Callable, device: torch.device, ) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image, PIL.Image]: generator = generator_dict[style_type] exstyles = exstyle_dict[style_type] style_id = int(style_id) style_id = min(max(0, style_id), len(exstyles) - 1) stylename = list(exstyles.keys())[style_id] image = align_face(filepath=image.name, predictor=dlib_landmark_model) input_data = transform(image).unsqueeze(0).to(device) img_rec, instyle = encoder(input_data, randomize_noise=False, return_latents=True, z_plus_latent=True, return_z_plus_latent=True, resize=False) img_rec = torch.clamp(img_rec.detach(), -1, 1) latent = torch.tensor(exstyles[stylename]).repeat(2, 1, 1).to(device) # latent[0] for both color and structrue transfer and latent[1] for only structrue transfer latent[1, 7:18] = instyle[0, 7:18] exstyle = generator.generator.style( latent.reshape(latent.shape[0] * latent.shape[1], latent.shape[2])).reshape(latent.shape) img_gen, _ = generator([instyle.repeat(2, 1, 1)], exstyle, z_plus_latent=True, truncation=0.7, truncation_latent=0, use_res=True, interp_weights=[0.6] * 7 + [1] * 11) img_gen = torch.clamp(img_gen.detach(), -1, 1) # deactivate color-related layers by setting w_c = 0 img_gen2, _ = generator([instyle], exstyle[0:1], z_plus_latent=True, truncation=0.7, truncation_latent=0, use_res=True, interp_weights=[0.6] * 7 + [0] * 11) img_gen2 = torch.clamp(img_gen2.detach(), -1, 1) img_rec = postprocess(img_rec[0]) img_gen0 = postprocess(img_gen[0]) img_gen1 = postprocess(img_gen[1]) img_gen2 = postprocess(img_gen2[0]) return image, img_rec, img_gen0, img_gen1, img_gen2 def main(): gr.close_all() args = parse_args() device = torch.device(args.device) style_types = [ 'cartoon', 'caricature', 'anime', 'arcane', 'comic', 'pixar', 'slamdunk', ] generator_dict = { style_type: load_generator(style_type, device) for style_type in style_types } exstyle_dict = { style_type: load_exstylecode(style_type) for style_type in style_types } dlib_landmark_model = create_dlib_landmark_model() encoder = load_encoder(device) transform = create_transform() func = functools.partial(run, dlib_landmark_model=dlib_landmark_model, encoder=encoder, generator_dict=generator_dict, exstyle_dict=exstyle_dict, transform=transform, device=device) func = functools.update_wrapper(func, run) image_paths = sorted(pathlib.Path('images').glob('*')) examples = [[path.as_posix(), 'cartoon', 26] for path in image_paths] gr.Interface( func, [ gr.inputs.Image(type='file', label='Input Image'), gr.inputs.Radio( style_types, type='value', default='cartoon', label='Style Type', ), gr.inputs.Number(default=26, label='Style Image Index'), ], [ gr.outputs.Image(type='pil', label='Aligned Face'), gr.outputs.Image(type='pil', label='Reconstructed'), gr.outputs.Image(type='pil', label='Result 1'), gr.outputs.Image(type='pil', label='Result 2'), gr.outputs.Image(type='pil', label='Result 3'), ], examples=examples, theme=args.theme, title=TITLE, description=DESCRIPTION, article=ARTICLE, allow_screenshot=args.allow_screenshot, allow_flagging=args.allow_flagging, live=args.live, ).launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()