Nicolas Burrus
Fix image padding.
dcf98e8
#!/usr/bin/env python
import gradio as gr
import torch
from torch import Tensor
from torchvision import transforms
import numpy as np
import sys
model = None
def load_model():
global model
model = torch.jit.load("v4_gated_unet-rn18-rn18_mse_bn32_5e-3_1e-5.pt")
def denormalize_and_clip_as_tensor (im: Tensor) -> Tensor:
return torch.clip(im * 0.5 + 0.5, 0.0, 1.0)
def denormalize_and_clip_as_numpy (im: Tensor) -> np.ndarray:
im = im.squeeze(0)
return np.ascontiguousarray(denormalize_and_clip_as_tensor(im).permute(1,2,0).detach().cpu().numpy())
def pad_width (size: int, multiple: int):
return 0 if size % multiple == 0 else multiple - (size%multiple)
# pad image to a multiple of 64
def pad_image(im, multiple=64):
# B,C,H,W
rows = im.shape[2]
cols = im.shape[3]
rows_to_pad = pad_width(rows, multiple)
cols_to_pad = pad_width(cols, multiple)
if rows_to_pad == 0 and cols_to_pad == 0:
return im
return transforms.Pad(padding=(0, 0, cols_to_pad, rows_to_pad), padding_mode='reflect')(im)
def undo_antialiasing(im):
im_torch = torch.from_numpy (im).permute(2,0,1).unsqueeze(0).float() / 255.0
im_torch = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(im_torch)
initial_rows = im.shape[0]
initial_cols = im.shape[1]
im_torch = pad_image(im_torch)
with torch.no_grad():
output_torch = model(im_torch)
output = denormalize_and_clip_as_numpy(output_torch.rgb)
output = (output*255.99).astype(np.uint8)
output = output[:initial_rows, :initial_cols, :]
return output
load_model()
iface = gr.Interface(fn=undo_antialiasing,
inputs=gr.inputs.Image(),
outputs=gr.outputs.Image(),
examples=[
['examples/opencv.png'],
['examples/matplotlib.png'],
['examples/coco_beach.png'],
['examples/plot-bowling.png']])
# FIXME: add OpenCV-background
# FIXME: add some arXiv test image
iface.launch()