File size: 5,745 Bytes
b2aaa70
 
 
 
 
 
 
 
e03131a
b2aaa70
 
 
 
 
 
 
 
 
 
 
 
 
d6fbbca
 
 
 
f857ecf
d6fbbca
 
 
b2aaa70
d6fbbca
2f8be9c
 
 
b2aaa70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692826c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2aaa70
 
 
 
 
 
d6fbbca
b2aaa70
 
 
f857ecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-12-16 16:17:14

import os
import torch
import argparse
import numpy as np
import gradio as gr
from pathlib import Path
from einops import rearrange
from omegaconf import OmegaConf
from skimage import img_as_ubyte

from utils import util_opts
from utils import util_image
from utils import util_common

from sampler import DifIRSampler
from ResizeRight.resize_right import resize
from basicsr.utils.download_util import load_file_from_url

# setting configurations
cfg_path = 'configs/sample/iddpm_ffhq512_swinir.yaml'
configs = OmegaConf.load(cfg_path)
configs.aligned = False
configs.diffusion.timestep_respacing = '250'

# build the sampler for diffusion
sampler_dist = DifIRSampler(configs)

def predict(im_path, background_enhance, face_upsample, upscale, started_timesteps):
    assert isinstance(im_path, str)
    print(f'Processing image: {im_path}...')

    configs.background_enhance = background_enhance
    configs.face_upsample = face_upsample
    started_timesteps = int(started_timesteps)
    assert started_timesteps < int(configs.diffusion.params.timestep_respacing)

    # prepare the checkpoint
    if not Path(configs.model.ckpt_path).exists():
        load_file_from_url(
            url="https://github.com/zsyOAOA/DifFace/releases/download/V1.0/iddpm_ffhq512_ema500000.pth",
            model_dir=str(Path(configs.model.ckpt_path).parent),
            progress=True,
            file_name=Path(configs.model.ckpt_path).name,
            )
    if not Path(configs.model_ir.ckpt_path).exists():
        load_file_from_url(
            url="https://github.com/zsyOAOA/DifFace/releases/download/V1.0/General_Face_ffhq512.pth",
            model_dir=str(Path(configs.model_ir.ckpt_path).parent),
            progress=True,
            file_name=Path(configs.model_ir.ckpt_path).name,
            )

    # Load image
    im_lq = util_image.imread(im_path, chn='bgr', dtype='uint8')
    if upscale > 4:
        upscale = 4  # avoid momory exceeded due to too large upscale
    if upscale > 2 and min(im_lq.shape[:2])>1280:
        upscale = 2  # avoid momory exceeded due to too large img resolution
    configs.detection.upscale = int(upscale)

    if background_enhance:
        image_restored, face_restored, face_cropped = sampler_dist.sample_func_bfr_unaligned(
                y0=im_lq,
                start_timesteps=started_timesteps,
                need_restoration=True,
                draw_box=False,
                )   # h x w x c, numpy array, [0, 255], uint8, BGR
        image_restored = util_image.bgr2rgb(image_restored)
    else:
        image_restored = sampler_dist.sample_func_ir_aligned(
                y0=im_lq,
                start_timesteps=started_timesteps,
                need_restoration=True,
                )[0]  # b x c x h x w, [0, 1], torch tensor, RGB
        image_restored = util_image.tensor2img(
                image_restored.cpu(),
                rgb2bgr=False,
                out_type=np.uint8,
                min_max=(0, 1),
                )     # h x w x c, [0, 255], uint8, RGB, numpy array

    restored_image_dir = Path('restored_output')
    if not restored_image_dir.exists():
        restored_image_dir.mkdir()
    # save the whole image
    save_path = restored_image_dir / Path(im_path).name
    util_image.imwrite(image_restored, save_path, chn='rgb', dtype_in='uint8')

    return image_restored, str(save_path)

title = "DifFace: Blind Face Restoration with Diffused Error Contraction"
description = r"""
<b>Official Gradio demo</b> for <a href='https://github.com/zsyOAOA/DifFace' target='_blank'><b>DifFace: Blind Face Restoration with Diffused Error Contraction</b></a>.<br>
πŸ”₯ DifFace is a robust face restoration algorithm for old or corrupted photos.<br>
"""
article = r"""
If DifFace is helpful for your work, please help to ⭐ the <a href='https://github.com/zsyOAOA/DifFace' target='_blank'>Github Repo</a>. Thanks!
[![GitHub Stars](https://img.shields.io/github/stars/zsyOAOA/DifFace?affiliations=OWNER&color=green&style=social)](https://github.com/zsyOAOA/DifFace)

---

πŸ“ **Citation**

If our work is useful for your research, please consider citing:
```bibtex
@article{yue2022difface,
  title={DifFace: Blind Face Restoration with Diffused Error Contraction},
  author={Yue, Zongsheng and Loy, Chen Change},
  journal={arXiv preprint arXiv:2212.06512},
  year={2022}
}
```

πŸ“‹ **License**

This project is licensed under <a rel="license" href="https://github.com/zsyOAOA/DifFace/blob/master/LICENSE">S-Lab License 1.0</a>.
Redistribution and use for non-commercial purposes should follow this license.

πŸ“§ **Contact**
If you have any questions, please feel free to contact me via <b>zsyzam@gmail.com</b>.
![visitors](https://visitor-badge.laobi.icu/badge?page_id=zsyOAOA/DifFace)
"""

demo = gr.Interface(
    predict,
    inputs=[
        gr.Image(type="filepath", label="Input"),
        gr.Checkbox(value=True, label="Background_Enhance"),
        gr.Checkbox(value=True, label="Face_Upsample"),
        gr.Number(value=2, label="Rescaling_Factor (up to 4)"),
        gr.Slider(1, 160, value=100, step=10, label='Realism-Fidelity Trade-off')
    ],
    outputs=[
        gr.Image(type="numpy", label="Output"),
        gr.outputs.File(label="Download the output")
    ],
    title=title,
    description=description,
    article=article,
    examples=[
        ['./testdata/whole_imgs/00.jpg', True, True, 2, 100],
        ['./testdata/whole_imgs/01.jpg', True, True, 2, 100],
        ['./testdata/whole_imgs/04.jpg', True, True, 2, 100],
        ['./testdata/whole_imgs/05.jpg', True, True, 2, 100],
      ]
    )

demo.queue(concurrency_count=4)
demo.launch()