File size: 3,525 Bytes
dc15506
 
 
 
 
 
fca6e54
 
 
ddad5e5
196a89d
 
fca6e54
8cf25d5
41ae325
8cf25d5
fca6e54
 
 
 
9741eb4
 
a90a2a6
ceafe2f
fca6e54
 
 
 
 
0d48277
fca6e54
 
 
a1caa30
fca6e54
d085511
fca6e54
 
 
 
 
edbecfe
fca6e54
ddad5e5
 
67ea445
ddad5e5
2e1828f
a642a77
ddad5e5
fca6e54
 
 
dc15506
fca6e54
dc15506
fca6e54
 
 
 
196a89d
 
6ac1620
fca6e54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59c537d
fca6e54
59c537d
fca6e54
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn.functional as F
import os
from skimage import img_as_ubyte
import cv2
import argparse
import shutil
import gradio as gr
from PIL import Image
from runpy import run_path
import numpy as np


examples = [['./sample1.png'],['./sample2.png'],['./Sample3.png'],['./Sample4.png'],['./Sample5.png'],['./Sample6.png']
            ]



title = "Restormer"
description = """
Gradio demo for reconstruction of noisy scanned, photocopied documents\n
using <b>Restormer: Efficient Transformer for High-Resolution Image Restoration</b>, CVPR 2022--ORAL. <a href='https://arxiv.org/abs/2111.09881'>[Paper]</a><a href='https://github.com/swz30/Restormer'>[Github Code]</a>\n
<a href='https://toon-beerten.medium.com/denoising-and-reconstructing-dirty-documents-for-optimal-digitalization-ed3a186aa3d6'>[See my article for more details]</a>\n 
<b> Note:</b> Since this demo uses CPU, by default it will run on the downsampled version of the input image (for speedup). 
"""

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.09881'>Restormer: Efficient Transformer for High-Resolution Image Restoration </a> | <a href='https://github.com/swz30/Restormer'>Github Repo</a></p>"


def inference(img):
    if not os.path.exists('temp'):
      os.system('mkdir temp')
      
    # 'Downsampled Image'
    ####  Resize the longer edge of the input image
      max_res = 400
      width, height = img.size
      if max(width,height) > max_res:
        scale = max_res /max(width,height)
        width = int(scale*width)
        height = int(scale*height)
        img = img.resize((width,height))
      

    parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
    load_arch = run_path('restormer_arch.py')
    model = load_arch['Restormer'](**parameters)
    
    checkpoint = torch.load('net_g_92000.pth')
    model.load_state_dict(checkpoint['params'])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
    model = model.to(device)
    model.eval()
    
    img_multiple_of = 8
    
    with torch.inference_mode():
        if torch.cuda.is_available():
            torch.cuda.ipc_collect()
            torch.cuda.empty_cache()

        open_cv_image = np.array(img) 
        img = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)
    
        input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device)
    
        # Pad the input if not_multiple_of 8
        h,w = input_.shape[2], input_.shape[3]
        H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
        padh = H-h if h%img_multiple_of!=0 else 0
        padw = W-w if w%img_multiple_of!=0 else 0
        input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
    
        restored = torch.clamp(model(input_),0,1)
    
        # Unpad the output
        restored = img_as_ubyte(restored[:,:,:h,:w].permute(0, 2, 3, 1).cpu().detach().numpy()[0])
        #convert to pil when returning
  
    return Image.fromarray(cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
    
gr.Interface(
    inference,
    [
        gr.Image(type="pil", label="Input"),
    ],
    gr.Image(type="pil", label="cleaned and restored"),
    title=title,
    description=description,
    article=article,
    examples=examples,
    ).launch(debug=False,enable_queue=True)