File size: 3,800 Bytes
dc15506
 
 
 
 
 
fca6e54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e1828f
12f5af2
 
2e1828f
12f5af2
2e1828f
fca6e54
 
 
dc15506
fca6e54
dc15506
fca6e54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59c537d
 
 
fca6e54
 
59c537d
fca6e54
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn.functional as F
import os
from skimage import img_as_ubyte
import cv2
import argparse
import shutil
import gradio as gr
from PIL import Image

examples = [['sample1.png'],
            ['sample2.png']]

inference_on = ['Full Resolution Image', 'Downsampled Image']

title = "Restormer"
description = """
Gradio demo for <b>Restormer: Efficient Transformer for High-Resolution Image Restoration</b>, CVPR 2022--ORAL. <a href='https://arxiv.org/abs/2111.09881'>[Paper]</a><a href='https://github.com/swz30/Restormer'>[Github Code]</a>\n 
<b> Note:</b> Since this demo uses CPU, by default it will run on the downsampled version of the input image (for speedup). But if you want to perform inference on the original input, then choose the option "Full Resolution Image" from the dropdown menu. 
"""
##With Restormer, you can perform: (1) Image Denoising, (2) Defocus Deblurring, (3)  Motion Deblurring, and (4) Image Deraining. 
##To use it, simply upload your own image, or click one of the examples provided below.

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.09881'>Restormer: Efficient Transformer for High-Resolution Image Restoration </a> | <a href='https://github.com/swz30/Restormer'>Github Repo</a></p>"


def inference(img, task, run_on):
    if not os.path.exists('temp'):
      os.system('mkdir temp')
      
    if run_on == 'Full Resolution Image':
      img = img
    else: # 'Downsampled Image'
    ####  Resize the longer edge of the input image
      max_res = 512
      width, height = img.size
      if max(width,height) > max_res:
        scale = max_res /max(width,height)
        width = int(scale*width)
        height = int(scale*height)
        img = img.resize((width,height), Image.ANTIALIAS)
      
    #parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
    #checkpoint = torch.load('deshabby.pt')
    #model.load_state_dict(checkpoint['params'])
    #model.eval()
    model = torch.jit.load('deshabby.pt') 
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
    model = model.to(device)
    model.eval()
    
    img_multiple_of = 8
    
    with torch.inference_mode():
        if torch.cuda.is_available():
            torch.cuda.ipc_collect()
            torch.cuda.empty_cache()
        
        img = cv2.cvtColor(cv2.imread(args.input_path), cv2.COLOR_BGR2RGB)
    
        input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device)
    
        # Pad the input if not_multiple_of 8
        h,w = input_.shape[2], input_.shape[3]
        H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
        padh = H-h if h%img_multiple_of!=0 else 0
        padw = W-w if w%img_multiple_of!=0 else 0
        input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
    
        restored = torch.clamp(model(input_),0,1)
    
        # Unpad the output
        restored = img_as_ubyte(restored[:,:,:h,:w].permute(0, 2, 3, 1).cpu().detach().numpy()[0])
        #convert to pil when returning
  
    return Image.fromarray(cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
    
gr.Interface(
    inference,
    [
        gr.Image(type="pil", label="Input"),
        gr.Radio(["Deraining"], default="Denoising", label='task'),
        gr.Dropdown(choices=inference_on, type="value", default='Downsampled Image', label='Inference on')

    ],
    gr.Image(type="pil", label="cleaned and restored"),
    title=title,
    description=description,
    article=article,
    examples=examples,
    ).launch(debug=False,enable_queue=True)