to-be's picture
Update app.py
a90a2a6
import torch
import torch.nn.functional as F
import os
from skimage import img_as_ubyte
import cv2
import argparse
import shutil
import gradio as gr
from PIL import Image
from runpy import run_path
import numpy as np
examples = [['./sample1.png'],['./sample2.png'],['./Sample3.png'],['./Sample4.png'],['./Sample5.png'],['./Sample6.png']
]
title = "Restormer"
description = """
Gradio demo for reconstruction of noisy scanned, photocopied documents\n
using <b>Restormer: Efficient Transformer for High-Resolution Image Restoration</b>, CVPR 2022--ORAL. <a href='https://arxiv.org/abs/2111.09881'>[Paper]</a><a href='https://github.com/swz30/Restormer'>[Github Code]</a>\n
<a href='https://toon-beerten.medium.com/denoising-and-reconstructing-dirty-documents-for-optimal-digitalization-ed3a186aa3d6'>[See my article for more details]</a>\n
<b> Note:</b> Since this demo uses CPU, by default it will run on the downsampled version of the input image (for speedup).
"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.09881'>Restormer: Efficient Transformer for High-Resolution Image Restoration </a> | <a href='https://github.com/swz30/Restormer'>Github Repo</a></p>"
def inference(img):
if not os.path.exists('temp'):
os.system('mkdir temp')
# 'Downsampled Image'
#### Resize the longer edge of the input image
max_res = 400
width, height = img.size
if max(width,height) > max_res:
scale = max_res /max(width,height)
width = int(scale*width)
height = int(scale*height)
img = img.resize((width,height))
parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}
load_arch = run_path('restormer_arch.py')
model = load_arch['Restormer'](**parameters)
checkpoint = torch.load('net_g_92000.pth')
model.load_state_dict(checkpoint['params'])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
img_multiple_of = 8
with torch.inference_mode():
if torch.cuda.is_available():
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
open_cv_image = np.array(img)
img = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)
input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device)
# Pad the input if not_multiple_of 8
h,w = input_.shape[2], input_.shape[3]
H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of
padh = H-h if h%img_multiple_of!=0 else 0
padw = W-w if w%img_multiple_of!=0 else 0
input_ = F.pad(input_, (0,padw,0,padh), 'reflect')
restored = torch.clamp(model(input_),0,1)
# Unpad the output
restored = img_as_ubyte(restored[:,:,:h,:w].permute(0, 2, 3, 1).cpu().detach().numpy()[0])
#convert to pil when returning
return Image.fromarray(cv2.cvtColor(restored, cv2.COLOR_RGB2BGR))
gr.Interface(
inference,
[
gr.Image(type="pil", label="Input"),
],
gr.Image(type="pil", label="cleaned and restored"),
title=title,
description=description,
article=article,
examples=examples,
).launch(debug=False,enable_queue=True)