|
import torch |
|
import torch.nn.functional as F |
|
import os |
|
from skimage import img_as_ubyte |
|
import cv2 |
|
import argparse |
|
import shutil |
|
import gradio as gr |
|
from PIL import Image |
|
from runpy import run_path |
|
import numpy as np |
|
|
|
|
|
examples = [['sample1.png'], |
|
['sample2.png']] |
|
|
|
|
|
title = "Restormer" |
|
description = """ |
|
Gradio demo for <b>Restormer: Efficient Transformer for High-Resolution Image Restoration</b>, CVPR 2022--ORAL. <a href='https://arxiv.org/abs/2111.09881'>[Paper]</a><a href='https://github.com/swz30/Restormer'>[Github Code]</a>\n |
|
<b> Note:</b> Since this demo uses CPU, by default it will run on the downsampled version of the input image (for speedup). But if you want to perform inference on the original input, then choose the option "Full Resolution Image" from the dropdown menu. |
|
""" |
|
|
|
|
|
|
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.09881'>Restormer: Efficient Transformer for High-Resolution Image Restoration </a> | <a href='https://github.com/swz30/Restormer'>Github Repo</a></p>" |
|
|
|
|
|
def inference(img): |
|
if not os.path.exists('temp'): |
|
os.system('mkdir temp') |
|
|
|
|
|
|
|
max_res = 200 |
|
width, height = img.size |
|
if max(width,height) > max_res: |
|
scale = max_res /max(width,height) |
|
width = int(scale*width) |
|
height = int(scale*height) |
|
img = img.resize((width,height), Image.ANTIALIAS) |
|
|
|
|
|
parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False} |
|
load_arch = run_path('restormer_arch.py') |
|
model = load_arch['Restormer'](**parameters) |
|
|
|
checkpoint = torch.load('net_g_92000.pth') |
|
model.load_state_dict(checkpoint['params']) |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = model.to(device) |
|
model.eval() |
|
|
|
img_multiple_of = 8 |
|
|
|
with torch.inference_mode(): |
|
if torch.cuda.is_available(): |
|
torch.cuda.ipc_collect() |
|
torch.cuda.empty_cache() |
|
|
|
open_cv_image = np.array(img) |
|
img = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB) |
|
|
|
input_ = torch.from_numpy(img).float().div(255.).permute(2,0,1).unsqueeze(0).to(device) |
|
|
|
|
|
h,w = input_.shape[2], input_.shape[3] |
|
H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of |
|
padh = H-h if h%img_multiple_of!=0 else 0 |
|
padw = W-w if w%img_multiple_of!=0 else 0 |
|
input_ = F.pad(input_, (0,padw,0,padh), 'reflect') |
|
|
|
restored = torch.clamp(model(input_),0,1) |
|
|
|
|
|
restored = img_as_ubyte(restored[:,:,:h,:w].permute(0, 2, 3, 1).cpu().detach().numpy()[0]) |
|
|
|
|
|
return Image.fromarray(cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)) |
|
|
|
gr.Interface( |
|
inference, |
|
[ |
|
gr.Image(type="pil", label="Input"), |
|
], |
|
gr.Image(type="pil", label="cleaned and restored"), |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=examples, |
|
).launch(debug=False,enable_queue=True) |
|
|