PelosiFilippo's picture
Removed super resolution model load
7d1fb4b
raw
history blame contribute delete
No virus
3.49 kB
import gradio as gr
from PIL import Image
import torch
import numpy as np
from models.network_swinir import SwinIR as net
# model load
param_key_g = 'params_ema'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fisheye_correction_model = net(upscale=4, in_chans=3, img_size=64, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv')
fisheye_correction_pretrained_model = torch.load("model_zoo/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth")
fisheye_correction_model.load_state_dict(fisheye_correction_pretrained_model[param_key_g] if param_key_g in fisheye_correction_pretrained_model.keys() else fisheye_correction_pretrained_model, strict=True)
fisheye_correction_model.eval()
def predict(input_img):
out = None
# preprocess input
if(input_img is not None):
# model predict
img_lq = input_img.astype(np.float32) / 255
img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1)) # HCW-BGR to CHW-RGB
img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device) # CHW-RGB to NCHW-RGB
# inference
window_size = 8
model = fisheye_correction_model.to(device)
with torch.no_grad():
# pad input image to be a multiple of window_size
_, _, h_old, w_old = img_lq.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
output = test(model, img_lq)
output = output[..., :h_old * 4, :w_old * 4]
# process image
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
# convert to pil image
out = Image.fromarray(output)
return out
def test(model, img_lq):
# test the image tile by tile
b, c, h, w = img_lq.size()
tile = min(800, h, w)
tile_overlap = 32
sf = 4
stride = tile - tile_overlap
h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
W = torch.zeros_like(E)
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
output = E.div_(W)
return output
gr.Interface(
fn=predict,
inputs=[
gr.inputs.Image()
],
outputs=[
gr.inputs.Image()
],
title="SwinIR moon distortion",
description="Description of the app",
examples=[
"render0001_DC.png", "render1546_DC.png", "render1682_DC.png"
]
).launch()