blure_remover / predict.py
itishalogicgo's picture
Initial commit for HF Space
9279e0d
import torch
import numpy as np
import cv2
import tempfile
import matplotlib.pyplot as plt
from cog import BasePredictor, Path, Input, BaseModel
from basicsr.models import create_model
from basicsr.utils import img2tensor as _img2tensor, tensor2img, imwrite
from basicsr.utils.options import parse
class Predictor(BasePredictor):
def setup(self):
opt_path_denoise = "options/test/SIDD/NAFNet-width64.yml"
opt_denoise = parse(opt_path_denoise, is_train=False)
opt_denoise["dist"] = False
opt_path_deblur = "options/test/GoPro/NAFNet-width64.yml"
opt_deblur = parse(opt_path_deblur, is_train=False)
opt_deblur["dist"] = False
opt_path_stereo = "options/test/NAFSSR/NAFSSR-L_4x.yml"
opt_stereo = parse(opt_path_stereo, is_train=False)
opt_stereo["dist"] = False
self.models = {
"Image Denoising": create_model(opt_denoise),
"Image Debluring": create_model(opt_deblur),
"Stereo Image Super-Resolution": create_model(opt_stereo),
}
def predict(
self,
task_type: str = Input(
choices=[
"Image Denoising",
"Image Debluring",
"Stereo Image Super-Resolution",
],
default="Image Debluring",
description="Choose task type.",
),
image: Path = Input(
description="Input image. Stereo Image Super-Resolution, upload the left image here.",
),
image_r: Path = Input(
default=None,
description="Right Input image for Stereo Image Super-Resolution. Optional, only valid for Stereo"
" Image Super-Resolution task.",
),
) -> Path:
out_path = Path(tempfile.mkdtemp()) / "output.png"
model = self.models[task_type]
if task_type == "Stereo Image Super-Resolution":
assert image_r is not None, (
"Please provide both left and right input image for "
"Stereo Image Super-Resolution task."
)
img_l = imread(str(image))
inp_l = img2tensor(img_l)
img_r = imread(str(image_r))
inp_r = img2tensor(img_r)
stereo_image_inference(model, inp_l, inp_r, str(out_path))
else:
img_input = imread(str(image))
inp = img2tensor(img_input)
out_path = Path(tempfile.mkdtemp()) / "output.png"
single_image_inference(model, inp, str(out_path))
return out_path
def imread(img_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def img2tensor(img, bgr2rgb=False, float32=True):
img = img.astype(np.float32) / 255.0
return _img2tensor(img, bgr2rgb=bgr2rgb, float32=float32)
def single_image_inference(model, img, save_path):
model.feed_data(data={"lq": img.unsqueeze(dim=0)})
if model.opt["val"].get("grids", False):
model.grids()
model.test()
if model.opt["val"].get("grids", False):
model.grids_inverse()
visuals = model.get_current_visuals()
sr_img = tensor2img([visuals["result"]])
imwrite(sr_img, save_path)
def stereo_image_inference(model, img_l, img_r, out_path):
img = torch.cat([img_l, img_r], dim=0)
model.feed_data(data={"lq": img.unsqueeze(dim=0)})
if model.opt["val"].get("grids", False):
model.grids()
model.test()
if model.opt["val"].get("grids", False):
model.grids_inverse()
visuals = model.get_current_visuals()
img_L = visuals["result"][:, :3]
img_R = visuals["result"][:, 3:]
img_L, img_R = tensor2img([img_L, img_R], rgb2bgr=False)
# save_stereo_image
h, w = img_L.shape[:2]
fig = plt.figure(figsize=(w // 40, h // 40))
ax1 = fig.add_subplot(2, 1, 1)
plt.title("NAFSSR output (Left)", fontsize=14)
ax1.axis("off")
ax1.imshow(img_L)
ax2 = fig.add_subplot(2, 1, 2)
plt.title("NAFSSR output (Right)", fontsize=14)
ax2.axis("off")
ax2.imshow(img_R)
plt.subplots_adjust(hspace=0.08)
plt.savefig(str(out_path), bbox_inches="tight", dpi=600)