Spaces:
Running
Running
import numpy as np | |
import torch | |
import sys | |
sys.path.append('models') | |
from SRFlow.code import imread, impad, load_model, t, rgb | |
from PIL import Image | |
from torchvision.transforms import PILToTensor | |
def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.7): | |
""" | |
Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image. | |
Args: | |
- lr: PIL Image | |
- conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml. | |
- heat (float): Heat parameter for the SRFlow model. Default is 0.6. | |
Returns: | |
- sr: PIL Image | |
""" | |
model, opt = load_model(conf_path) | |
lr = PILToTensor()(lr).permute(1, 2, 0).numpy() | |
scale = opt['scale'] | |
pad_factor = 2 | |
h, w, c = lr.shape | |
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h), | |
right=int(np.ceil(w / pad_factor) * pad_factor - w)) | |
lr_t = t(lr) | |
heat = opt['heat'] | |
sr_t = model.get_sr(lq=lr_t, heat=heat) | |
sr = rgb(torch.clamp(sr_t, 0, 1)) | |
sr = sr[:h * scale, :w * scale] | |
sr = Image.fromarray((sr).astype('uint8')) | |
return sr | |
def test_srflow(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml', heat=0.7): | |
""" | |
Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image. | |
Args: | |
- lr: tensor | |
- conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml. | |
- heat (float): Heat parameter for the SRFlow model. Default is 0.6. | |
Returns: | |
- sr: tensor | |
""" | |
model, opt = load_model(conf_path) | |
scale = opt['scale'] | |
pad_factor = 2 | |
lr = lr.squeeze(0).permute(0, 1, 2).numpy() | |
h, w, c = lr.shape | |
lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h), | |
right=int(np.ceil(w / pad_factor) * pad_factor - w)) | |
lr_t = t(lr) | |
heat = opt['heat'] | |
sr_t = model.get_sr(lq=lr_t, heat=heat) | |
sr = rgb(torch.clamp(sr_t, 0, 1)) | |
sr = sr[:h * scale, :w * scale] | |
sr = sr.unsqueeze(0).permute(0, 3, 1, 2) | |
return sr | |
if __name__ == '__main__': | |
ip = Image.open('images/demo.png') | |
lr = PILToTensor()(ip).permute(1, 2, 0).numpy() | |
print(lr.shape) | |
sr = test_srflow(lr) | |
# sr = return_SRFlow_result(ip) | |
print(sr.size) |