|
import os |
|
import io |
|
import gradio as gr |
|
from PIL import Image |
|
import tempfile |
|
from pathlib import Path |
|
import argparse |
|
import shutil |
|
import cv2 |
|
import glob |
|
import torch |
|
from collections import OrderedDict |
|
import numpy as np |
|
from main_test_swinir import define_model, setup, get_image_pair |
|
|
|
|
|
os.system('wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth -P model_zoo/swinir') |
|
os.system('wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/004_grayDN_DFWB_s128w8_SwinIR-M_noise15.pth -P model_zoo/swinir') |
|
os.system('wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/004_grayDN_DFWB_s128w8_SwinIR-M_noise25.pth -P model_zoo/swinir') |
|
os.system('wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/004_grayDN_DFWB_s128w8_SwinIR-M_noise50.pth -P model_zoo/swinir') |
|
os.system('wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/005_colorDN_DFWB_s128w8_SwinIR-M_noise15.pth -P model_zoo/swinir') |
|
os.system('wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/005_colorDN_DFWB_s128w8_SwinIR-M_noise25.pth -P model_zoo/swinir') |
|
os.system('wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/005_colorDN_DFWB_s128w8_SwinIR-M_noise50.pth -P model_zoo/swinir') |
|
os.system('wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/006_CAR_DFWB_s126w7_SwinIR-M_jpeg10.pth -P model_zoo/swinir') |
|
os.system('wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/006_CAR_DFWB_s126w7_SwinIR-M_jpeg20.pth -P model_zoo/swinir') |
|
os.system('wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/006_CAR_DFWB_s126w7_SwinIR-M_jpeg30.pth -P model_zoo/swinir') |
|
os.system('wget https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/006_CAR_DFWB_s126w7_SwinIR-M_jpeg40.pth -P model_zoo/swinir') |
|
|
|
def sentence_builder(image, task_type, noise, jpeg): |
|
model_dir = 'model_zoo/swinir' |
|
|
|
model_zoo = { |
|
'real_sr': { |
|
4: os.path.join(model_dir, '003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth') |
|
}, |
|
'gray_dn': { |
|
15: os.path.join(model_dir, '004_grayDN_DFWB_s128w8_SwinIR-M_noise15.pth'), |
|
25: os.path.join(model_dir, '004_grayDN_DFWB_s128w8_SwinIR-M_noise25.pth'), |
|
50: os.path.join(model_dir, '004_grayDN_DFWB_s128w8_SwinIR-M_noise50.pth') |
|
}, |
|
'color_dn': { |
|
15: os.path.join(model_dir, '005_colorDN_DFWB_s128w8_SwinIR-M_noise15.pth'), |
|
25: os.path.join(model_dir, '005_colorDN_DFWB_s128w8_SwinIR-M_noise25.pth'), |
|
50: os.path.join(model_dir, '005_colorDN_DFWB_s128w8_SwinIR-M_noise50.pth') |
|
}, |
|
'jpeg_car': { |
|
10: os.path.join(model_dir, '006_CAR_DFWB_s126w7_SwinIR-M_jpeg10.pth'), |
|
20: os.path.join(model_dir, '006_CAR_DFWB_s126w7_SwinIR-M_jpeg20.pth'), |
|
30: os.path.join(model_dir, '006_CAR_DFWB_s126w7_SwinIR-M_jpeg30.pth'), |
|
40: os.path.join(model_dir, '006_CAR_DFWB_s126w7_SwinIR-M_jpeg40.pth') |
|
} |
|
} |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--task', type=str, default='real_sr', help='classical_sr, lightweight_sr, real_sr, ' |
|
'gray_dn, color_dn, jpeg_car') |
|
parser.add_argument('--scale', type=int, default=1, help='scale factor: 1, 2, 3, 4, 8') |
|
parser.add_argument('--noise', type=int, default=15, help='noise level: 15, 25, 50') |
|
parser.add_argument('--jpeg', type=int, default=40, help='scale factor: 10, 20, 30, 40') |
|
parser.add_argument('--training_patch_size', type=int, default=128, help='patch size used in training SwinIR. ' |
|
'Just used to differentiate two different settings in Table 2 of the paper. ' |
|
'Images are NOT tested patch by patch.') |
|
parser.add_argument('--large_model', action='store_true', |
|
help='use large model, only provided for real image sr') |
|
parser.add_argument('--model_path', type=str, |
|
default=model_zoo['real_sr'][4]) |
|
parser.add_argument('--folder_lq', type=str, default=None, help='input low-quality test image folder') |
|
parser.add_argument('--folder_gt', type=str, default=None, help='input ground-truth test image folder') |
|
|
|
args = parser.parse_args('') |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
tasks = { |
|
'Real-World Image Super-Resolution': 'real_sr', |
|
'Grayscale Image Denoising': 'gray_dn', |
|
'Color Image Denoising': 'color_dn', |
|
'JPEG Compression Artifact Reduction': 'jpeg_car' |
|
} |
|
|
|
args.task = tasks[task_type] |
|
args.noise = noise |
|
args.jpeg = jpeg |
|
|
|
|
|
if args.task == 'real_sr': |
|
args.scale = 4 |
|
args.model_path = model_zoo[args.task][4] |
|
elif args.task in ['gray_dn', 'color_dn']: |
|
args.model_path = model_zoo[args.task][noise] |
|
else: |
|
args.model_path = model_zoo[args.task][jpeg] |
|
|
|
try: |
|
|
|
input_dir = 'input_cog_temp' |
|
os.makedirs(input_dir, exist_ok=True) |
|
input_path = os.path.join(input_dir, guess_filename(image)) |
|
|
|
image.save(input_path, "JPEG") |
|
if args.task == 'real_sr': |
|
args.folder_lq = input_dir |
|
else: |
|
args.folder_gt = input_dir |
|
|
|
model = define_model(args) |
|
model.eval() |
|
model = model.to(device) |
|
|
|
|
|
folder, save_dir, border, window_size = setup(args) |
|
os.makedirs(save_dir, exist_ok=True) |
|
test_results = OrderedDict() |
|
test_results['psnr'] = [] |
|
test_results['ssim'] = [] |
|
test_results['psnr_y'] = [] |
|
test_results['ssim_y'] = [] |
|
test_results['psnr_b'] = [] |
|
|
|
out_path = Path(tempfile.mkdtemp()) / "out.png" |
|
|
|
for idx, path in enumerate(sorted(glob.glob(os.path.join(folder, '*')))): |
|
|
|
imgname, img_lq, img_gt = get_image_pair(args, path) |
|
img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], |
|
(2, 0, 1)) |
|
img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
_, _, h_old, w_old = img_lq.size() |
|
h_pad = (h_old // window_size + 1) * window_size - h_old |
|
w_pad = (w_old // window_size + 1) * window_size - w_old |
|
img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :] |
|
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad] |
|
output = model(img_lq) |
|
output = output[..., :h_old * args.scale, :w_old * args.scale] |
|
|
|
|
|
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() |
|
if output.ndim == 3: |
|
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) |
|
output = (output * 255.0).round().astype(np.uint8) |
|
cv2.imwrite(str(out_path), output) |
|
finally: |
|
clean_folder(input_dir) |
|
return out_path |
|
|
|
def guess_filename(obj: io.IOBase) -> str: |
|
"""Tries to guess the filename of the given object.""" |
|
name = getattr(obj, "name", "input") |
|
return os.path.basename(name) |
|
|
|
def clean_folder(folder): |
|
for filename in os.listdir(folder): |
|
file_path = os.path.join(folder, filename) |
|
try: |
|
if os.path.isfile(file_path) or os.path.islink(file_path): |
|
os.unlink(file_path) |
|
elif os.path.isdir(file_path): |
|
shutil.rmtree(file_path) |
|
except Exception as e: |
|
print('Failed to delete %s. Reason: %s' % (file_path, e)) |
|
|
|
|
|
title = "Dmonin-SwinIR" |
|
description = "Gradio for SwinIR. SwinIR achieves state-of-the-art performance on six tasks: image super-resolution (including classical, lightweight and real-world image super-resolution), image denoising (including grayscale and color image denoising) and JPEG compression artifact reduction. See the paper and project page for detailed results below. Here, we provide a demo for real-world image SR.To use it, simply upload your image, or click one of the examples to load them." |
|
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.10257' target='_blank'>SwinIR: Image Restoration Using Swin Transformer</a> | <a href='https://github.com/JingyunLiang/SwinIR' target='_blank'>Github Repo</a></p>" |
|
|
|
examples=[['ETH_LR.png']] |
|
gr.Interface( |
|
sentence_builder, |
|
[ |
|
gr.inputs.Image(type="pil", label="Input"), |
|
gr.Dropdown(choices=["Real-World Image Super-Resolution", "Grayscale Image Denoising", "Color Image Denoising", "JPEG Compression Artifact Reduction"], default= "Real-World Image Super-Resolution", value= "Real-World Image Super-Resolution"), |
|
gr.Dropdown(choices=["15", "25", "50"], default = "15", value="15"), |
|
gr.Dropdown(choices=["10", "20", "30", "40"], default="40", value="40") |
|
|
|
], |
|
gr.outputs.Image(type="filepath", label="Output"), |
|
title=title, |
|
description=description, |
|
article=article, |
|
enable_queue=True, |
|
examples=examples |
|
).launch() |