File size: 5,535 Bytes
e99c825
 
 
 
 
a4a060a
e99c825
 
a4a060a
 
e99c825
 
 
421c177
 
e99c825
 
 
 
 
 
 
 
 
 
 
a4a060a
 
 
 
 
0a74034
a4a060a
44f4bf1
 
e99c825
 
 
 
 
0a74034
e99c825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a74034
e99c825
0a74034
e99c825
0a74034
e99c825
0a74034
e99c825
0a74034
e99c825
0a74034
e99c825
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a74034
bf09c3f
e99c825
0a74034
84474bd
e99c825
 
 
 
 
 
 
 
 
 
0a74034
e99c825
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import time
import spaces
import cv2
import gradio as gr
import torch

from gfpgan.utils import GFPGANer
from basicsr.archs.srvgg_arch import SRVGGNetCompact
from realesrgan.utils import RealESRGANer

os.system("pip freeze")
# download weights
if not os.path.exists('realesr-general-x4v3.pth'):
    os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
if not os.path.exists('GFPGANv1.2.pth'):
    os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .")
if not os.path.exists('GFPGANv1.3.pth'):
    os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .")
if not os.path.exists('GFPGANv1.4.pth'):
    os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
if not os.path.exists('RestoreFormer.pth'):
    os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
if not os.path.exists('CodeFormer.pth'):
    os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P .")

# background enhancer with RealESRGAN
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'realesr-general-x4v3.pth'
half = True if torch.cuda.is_available() else False
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
upsampler = None

os.makedirs('output', exist_ok=True)

@spaces.GPU(duration=10)
def enhance(
    img_path:str,
    version:str='1.4',
    scale:int=2,
    upscale:int=2,
):
    run_task_time = 0
    time_cost_str = ''
    run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
    extension = os.path.splitext(os.path.basename(img_path))[1]
    img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
    if len(img.shape) == 3 and img.shape[2] == 4:
        img_mode = 'RGBA'
    elif len(img.shape) == 2:  # for gray inputs
        img_mode = None
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    else:
        img_mode = None

    h, w = img.shape[0:2]
    if h < 300:
        img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
    
    if version == 'v1.2':
        face_enhancer = GFPGANer(model_path='GFPGANv1.2.pth', upscale=upscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
    elif version == 'v1.3':
        face_enhancer = GFPGANer(model_path='GFPGANv1.3.pth', upscale=upscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
    elif version == 'v1.4':
        face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=upscale, arch='clean', channel_multiplier=2, bg_upsampler=upsampler)
    elif version == 'RestoreFormer':
        face_enhancer = GFPGANer(model_path='RestoreFormer.pth', upscale=upscale, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler)
    elif version == 'CodeFormer':
        face_enhancer = GFPGANer(model_path='CodeFormer.pth', upscale=upscale, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler)
    elif version == 'RealESR-General-x4v3':
        face_enhancer = GFPGANer(model_path='realesr-general-x4v3.pth', upscale=upscale, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler)

    _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=True, paste_back=True)
    if scale != 2:
        interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
        h, w = img.shape[0:2]
        output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
    
    if img_mode == 'RGBA':  # RGBA images should be saved in png format
        extension = 'png'
    else:
        extension = 'jpg'
    save_path = f'output/out.{extension}'
    cv2.imwrite(save_path, output)

    output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
    run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
    return output, save_path, time_cost_str


def get_time_cost(run_task_time, time_cost_str):
    now_time = int(time.time()*1000)
    if run_task_time == 0:
        time_cost_str = 'start'
    else:
        if time_cost_str != '': 
            time_cost_str += f'-->'
        time_cost_str += f'{now_time - run_task_time}'
    run_task_time = now_time
    return run_task_time, time_cost_str

def create_demo() -> gr.Blocks:
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                version = gr.Radio(['v1.2', 'v1.3', 'v1.4'], type="value", value='v1.4', label='version')
                scale = gr.Number(label="Rescaling factor", value=2)
            with gr.Column():
                upscale = gr.Number(label="Upscale factor", value=2)
                g_btn = gr.Button("Enhance")
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Input Image", type="filepath")
            with gr.Column():
                restored_image = gr.Image(label="Restored Image", type="numpy", interactive=False)
                download_path = gr.File(label="Download the output image", interactive=False)
                restored_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
        
        g_btn.click(
            fn=enhance,
            inputs=[input_image, version, scale, upscale],
            outputs=[restored_image, download_path, restored_cost],
        )

    return demo