Nguyễn Bá Thiêm
Refactor function to improve performance and readability
d09509f
raw
history blame
2.36 kB
import numpy as np
import torch
import sys
sys.path.append('models')
from SRFlow.code import imread, impad, load_model, t, rgb
from PIL import Image
from torchvision.transforms import PILToTensor
def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.7):
"""
Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
Args:
- lr: PIL Image
- conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml.
- heat (float): Heat parameter for the SRFlow model. Default is 0.6.
Returns:
- sr: PIL Image
"""
model, opt = load_model(conf_path)
lr = PILToTensor()(lr).permute(1, 2, 0).numpy()
scale = opt['scale']
pad_factor = 2
h, w, c = lr.shape
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
right=int(np.ceil(w / pad_factor) * pad_factor - w))
lr_t = t(lr)
heat = opt['heat']
sr_t = model.get_sr(lq=lr_t, heat=heat)
sr = rgb(torch.clamp(sr_t, 0, 1))
sr = sr[:h * scale, :w * scale]
sr = Image.fromarray((sr).astype('uint8'))
return sr
def test_srflow(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.7):
"""
Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image.
Args:
- lr: tensor
- conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml.
- heat (float): Heat parameter for the SRFlow model. Default is 0.6.
Returns:
- sr: tensor
"""
model, opt = load_model(conf_path)
scale = opt['scale']
pad_factor = 2
lr = lr.squeeze(0).permute(0, 1, 2).numpy()
h, w, c = lr.shape
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
right=int(np.ceil(w / pad_factor) * pad_factor - w))
lr_t = t(lr)
heat = opt['heat']
sr_t = model.get_sr(lq=lr_t, heat=heat)
sr = rgb(torch.clamp(sr_t, 0, 1))
sr = sr[:h * scale, :w * scale]
sr = sr.unsqueeze(0).permute(0, 3, 1, 2)
return sr
if __name__ == '__main__':
ip = Image.open('images/demo.png')
lr = PILToTensor()(ip).permute(1, 2, 0).numpy()
print(lr.shape)
sr = test_srflow(lr)
# sr = return_SRFlow_result(ip)
print(sr.size)