File size: 4,673 Bytes
302bb69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import os
import numpy as np
import torch
from models.network_swinir import SwinIR


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: %s" % device)
default_models = {
    "sr": "weights/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth",
    "denoise": "weights/005_colorDN_DFWB_s128w8_SwinIR-M_noise25.pth"
    }
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True


denoise_model = SwinIR(upscale=1, in_chans=3, img_size=128, window_size=8,
        img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
        mlp_ratio=2, upsampler='', resi_connection='1conv').to(device)
param_key_g = 'params'
try:
    pretrained_model = torch.load(default_models["denoise"])
    denoise_model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model, strict=True)
except: print("Loading model failed")
denoise_model.eval()

sr_model = SwinIR(upscale=4, in_chans=3, img_size=64, window_size=8,
            img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
            mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv').to(device)
param_key_g = 'params_ema'
try:
    pretrained_model = torch.load(default_models["sr"])
    sr_model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model, strict=True)
except: print("Loading model failed")
sr_model.eval()


def sr(input_img):

    window_size = 8
        # read image
    img_lq = input_img.astype(np.float32) / 255.
    img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1))  # HCW-BGR to CHW-RGB
    img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)  # CHW-RGB to NCHW-RGB

    # inference
    with torch.no_grad():
        # pad input image to be a multiple of window_size
        _, _, h_old, w_old = img_lq.size()
        h_pad = (h_old // window_size + 1) * window_size - h_old
        w_pad = (w_old // window_size + 1) * window_size - w_old
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
        output = sr_model(img_lq)
        output = output[..., :h_old * 4, :w_old * 4]

    # save image
    output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
    if output.ndim == 3:
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # CHW-RGB to HCW-BGR
    output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8

    return output

def denoise(input_img):

    window_size = 8
    # read image
    img_lq = input_img.astype(np.float32) / 255.
    img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1))  # HCW-BGR to CHW-RGB
    img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)  # CHW-RGB to NCHW-RGB

    # inference
    with torch.no_grad():
        # pad input image to be a multiple of window_size
        _, _, h_old, w_old = img_lq.size()
        h_pad = (h_old // window_size + 1) * window_size - h_old
        w_pad = (w_old // window_size + 1) * window_size - w_old
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
        img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
        output = denoise_model(img_lq)
        output = output[..., :h_old * 4, :w_old * 4]

    # save image
    output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
    if output.ndim == 3:
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))  # CHW-RGB to HCW-BGR
    output = (output * 255.0).round().astype(np.uint8)  # float32 to uint8
    
    return output

title = " AISeed AI Application Demo "
description = "# A Demo of Deep Learning for Image Restoration"
example_list = [["examples/" + example] for example in os.listdir("examples")]

with gr.Blocks() as demo:
    demo.title = title
    gr.Markdown(description)
    with gr.Row():
        with gr.Column():
            im = gr.Image(label="Input Image")
            im_2 = gr.Image(label="Enhanced Image")
            
        with gr.Column():
           
            btn1 = gr.Button(value="Enhance Resolution")
            btn1.click(sr, inputs=[im], outputs=[im_2])
            btn2 = gr.Button(value="Denoise")
            btn2.click(denoise, inputs=[im], outputs=[im_2])
            gr.Examples(examples=example_list,
                inputs=[im],
                outputs=[im_2])
    
if __name__ == "__main__":
    demo.launch()