#!/usr/bin/env python # -*- coding:utf-8 -*- # Power by Zongsheng Yue 2022-12-16 16:17:14 import os import torch import argparse import numpy as np import grdio as gr from pathlib import Path from einops import rearrange from omegaconf import OmegaConf from skimage import img_as_ubyte from utils import util_opts from utils import util_image from utils import util_common from sampler import DifIRSampler from ResizeRight.resize_right import resize from basicsr.utils.download_util import load_file_from_url def predict(im_path, background_enhance, face_upsample, upscale, started_timesteps): cfg_path = 'configs/sample/iddpm_ffhq512_swinir.yaml' # setting configurations configs = OmegaConf.load(cfg_path) configs.aligned = False configs.background_enhance = background_enhance configs.face_upsample = face_upsample started_timesteps = int(started_timesteps) assert started_timesteps < int(configs.diffusion.params.timestep_respacing) # prepare the checkpoint if not Path(configs.model.ckpt_path).exists(): load_file_from_url( url="https://github.com/zsyOAOA/DifFace/releases/download/V1.0/iddpm_ffhq512_ema500000.pth", model_dir=str(Path(configs.model.ckpt_path).parent), progress=True, file_name=Path(configs.model.ckpt_path).name, ) if not Path(configs.model_ir.ckpt_path).exists(): load_file_from_url( url="https://github.com/zsyOAOA/DifFace/releases/download/V1.0/General_Face_ffhq512.pth", model_dir=str(Path(configs.model_ir.ckpt_path).parent), progress=True, file_name=Path(configs.model_ir.ckpt_path).name, ) # Load image im_lq = util_image.imread(im_path, chn='bgr', dtype='uint8') if upscale > 4: upscale = 4 # avoid momory exceeded due to too large upscale if upscale > 2 and min(im_lq.shape[:2])>1280: upscale = 2 # avoid momory exceeded due to too large img resolution configs.detection.upscale = int(upscale) # build the sampler for diffusion sampler_dist = DifIRSampler(configs) image_restored, face_restored, face_cropped = sampler_dist.sample_func_bfr_unaligned( y0=im_lq, start_timesteps=started_timesteps, need_restoration=True, draw_box=False, ) restored_image_dir = Path('restored_output') if not restored_image_dir.exists(): restored_image_dir.mkdir() # save the whole image save_path = restored_image_dir / Path(im_path).name util_image.imwrite(image_restored, save_path, chn='bgr', dtype_in='uint8') return image_restored, str(save_path) # im_path = './testdata/whole_imgs/00.jpg' # predict(im_path, True, True, 3, 100) title = "DifFace: Blind Face Restoration with Diffused Error Contraction" description = r"""
DifFace logo
Official Gradio demo for DifFace: Blind Face Restoration with Diffused Error Contraction.
🔥 DifFace is a robust face restoration algorithm for old or corrupted photos.
""" article = r""" If DifFace is helpful for your work, please help to ⭐ the Github Repo. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/zsyOAOA/DifFace?affiliations=OWNER&color=green&style=social)](https://github.com/zsyOAOA/DifFace) --- 📝 **Citation** If our work is useful for your research, please consider citing: ```bibtex @article{yue2022difface, title={DifFace: Blind Face Restoration with Diffused Error Contraction}, author={Yue, Zongsheng and Loy, Chen Change}, journal={arXiv preprint arXiv:2212.06512}, year={2022} } ``` 📋 **License** This project is licensed under S-Lab License 1.0. Redistribution and use for non-commercial purposes should follow this license. 📧 **Contact** If you have any questions, please feel free to contact me via zsyzam@gmail.com. ![visitors](https://visitor-badge.laobi.icu/badge?page_id=zsyOAOA/DifFace) """ demo = gr.Interface( inference, inputs=[ gr.inputs.Image(type="filepath", label="Input"), gr.inputs.Checkbox(default=True, label="Background_Enhance"), gr.inputs.Checkbox(default=True, label="Face_Upsample"), gr.inputs.Number(default=2, label="Rescaling_Factor (up to 4)"), gr.Slider(1, 200, value=100, step=10, label='Realism-Fidelity Trade-off') ], outputs=[ gr.outputs.Image(type="numpy", label="Output"), gr.outputs.File(label="Download the output") ], title=title, description=description, article=article, examples=[ ['./testdata/whole_imgs/00.jpg', True, True, 2, 100], ['./testdata/whole_imgs/01.jpg', True, True, 2, 100], ['./testdata/whole_imgs/04.jpg', True, True, 2, 100], ['./testdata/whole_imgs/Solvay_conference_1927.png', True, True, 2, 100], ] ) demo.queue(concurrency_count=4) demo.launch()