import torch import torch.nn.functional as F import os from skimage import img_as_ubyte import cv2 import argparse import shutil import gradio as gr from PIL import Image from runpy import run_path import numpy as np examples = [['./sample1.png'],['./sample2.png'],['./Sample3.png'],['./Sample4.png'],['./Sample5.png'],['./Sample6.png'] ] title = "Restormer" description = """ Gradio demo for reconstruction of noisy scanned, photocopied documents\n using Restormer: Efficient Transformer for High-Resolution Image Restoration, CVPR 2022--ORAL. [Paper][Github Code]\n [See my article for more details]\n Note: Since this demo uses CPU, by default it will run on the downsampled version of the input image (for speedup). """ article = "

Restormer: Efficient Transformer for High-Resolution Image Restoration | Github Repo

" def inference(img): if not os.path.exists('temp'): os.system('mkdir temp') # 'Downsampled Image' #### Resize the longer edge of the input image max_res = 400 width, height = img.size if max(width,height) > max_res: scale = max_res /max(width,height) width = int(scale*width) height = int(scale*height) img = img.resize((width,height)) parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False} load_arch = run_path('restormer_arch.py') model = load_arch['Restormer'](**parameters) checkpoint = torch.load('net_g_92000.pth') model.load_state_dict(checkpoint['params']) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) model.eval() img_multiple_of = 8 with torch.inference_mode(): if torch.cuda.is_available(): torch.cuda.ipc_collect() torch.cuda.empty_cache() open_cv_image = np.array(img) img = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB) input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device) # Pad the input if not_multiple_of 8 h,w = input_.shape[2], input_.shape[3] H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of padh = H-h if h%img_multiple_of!=0 else 0 padw = W-w if w%img_multiple_of!=0 else 0 input_ = F.pad(input_, (0,padw,0,padh), 'reflect') restored = torch.clamp(model(input_),0,1) # Unpad the output restored = img_as_ubyte(restored[:,:,:h,:w].permute(0, 2, 3, 1).cpu().detach().numpy()[0]) #convert to pil when returning return Image.fromarray(cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)) gr.Interface( inference, [ gr.Image(type="pil", label="Input"), ], gr.Image(type="pil", label="cleaned and restored"), title=title, description=description, article=article, examples=examples, ).launch(debug=False,enable_queue=True)