|
import torch |
|
import torch.nn.functional as F |
|
import os |
|
from skimage import img_as_ubyte |
|
import cv2 |
|
import argparse |
|
import shutil |
|
import gradio as gr |
|
from PIL import Image |
|
from runpy import run_path |
|
import numpy as np |
|
|
|
|
|
examples = [['./sample1.png'], |
|
['./sample2.png'], |
|
['./Sample3.png'], |
|
['./Sample4.png'], |
|
['./Sample5.png'], |
|
['./Sample6.png']] |
|
|
|
|
|
title = "Restormer" |
|
description = """ |
|
Gradio demo for reconstruction of noisy scanned, photocopied documents\n |
|
using <b>Restormer: Efficient Transformer for High-Resolution Image Restoration</b>, CVPR 2022--ORAL. <a href='https://arxiv.org/abs/2111.09881'>[Paper]</a><a href='https://github.com/swz30/Restormer'>[Github Code]</a>\n |
|
<a href='https://medium.com/towards-data-science/effective-data-augmentation-for-ocr-8013080aa9fa'>[See my article for more details]</a>\n |
|
<b> Note:</b> Since this demo uses CPU, by default it will run on the downsampled version of the input image (for speedup). |
|
""" |
|
|
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.09881'>Restormer: Efficient Transformer for High-Resolution Image Restoration </a> | <a href='https://github.com/swz30/Restormer'>Github Repo</a></p>" |
|
|
|
|
|
def inference(img): |
|
if not os.path.exists('temp'): |
|
os.system('mkdir temp') |
|
|
|
|
|
|
|
max_res = 200 |
|
width, height = img.size |
|
if max(width,height) > max_res: |
|
scale = max_res /max(width,height) |
|
width = int(scale*width) |
|
height = int(scale*height) |
|
img = img.resize((width,height)) |
|
|
|
|
|
parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False} |
|
load_arch = run_path('restormer_arch.py') |
|
model = load_arch['Restormer'](**parameters) |
|
|
|
checkpoint = torch.load('net_g_92000.pth') |
|
model.load_state_dict(checkpoint['params']) |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = model.to(device) |
|
model.eval() |
|
|
|
img_multiple_of = 8 |
|
|
|
with torch.inference_mode(): |
|
if torch.cuda.is_available(): |
|
torch.cuda.ipc_collect() |
|
torch.cuda.empty_cache() |
|
|
|
open_cv_image = np.array(img) |
|
img = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB) |
|
|
|
input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device) |
|
|
|
|
|
h,w = input_.shape[2], input_.shape[3] |
|
H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of |
|
padh = H-h if h%img_multiple_of!=0 else 0 |
|
padw = W-w if w%img_multiple_of!=0 else 0 |
|
input_ = F.pad(input_, (0,padw,0,padh), 'reflect') |
|
|
|
restored = torch.clamp(model(input_),0,1) |
|
|
|
|
|
restored = img_as_ubyte(restored[:,:,:h,:w].permute(0, 2, 3, 1).cpu().detach().numpy()[0]) |
|
|
|
|
|
return Image.fromarray(cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)) |
|
|
|
gr.Interface( |
|
inference, |
|
[ |
|
gr.Image(type="pil", label="Input"), |
|
], |
|
gr.Image(type="pil", label="cleaned and restored"), |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=examples, |
|
).launch(debug=False,enable_queue=True) |
|
|