import numpy as np import torch import sys sys.path.append('models') from SRFlow.code import imread, impad, load_model, t, rgb from PIL import Image import matplotlib.pyplot as plt from torchvision.transforms import PILToTensor, ToPILImage def return_SRFlow_result(lr, conf_path='models/SRFlow/code/confs/SRFlow_DF2K_4X.yml'): """ Apply Super-Resolution using SRFlow model to the input LR (low-resolution) image. Args: - lr: PIL Image - conf_path (str): Configuration file path for the SRFlow model. Default is SRFlow_DF2K_4X.yml. - heat (float): Heat parameter for the SRFlow model. Default is 0.6. Returns: - sr: PIL Image """ model, opt = load_model(conf_path) lr = PILToTensor()(lr).permute(1, 2, 0).numpy() scale = opt['scale'] pad_factor = 2 h, w, c = lr.shape lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h), right=int(np.ceil(w / pad_factor) * pad_factor - w)) lr_t = t(lr) heat = opt['heat'] sr_t = model.get_sr(lq=lr_t, heat=heat) sr = rgb(torch.clamp(sr_t, 0, 1)) sr = sr[:h * scale, :w * scale] sr_img = Image.fromarray((sr).astype('uint8')) return sr_img def return_SRFlow_result_from_tensor(lr_tensor): """ Apply Super-Resolution using SRFlow model to the input batched BCHW tensor. Args: - lr_tensor: Batched BCHW tensor Returns: - sr_tensor: Processed batched BCHW tensor """ batch_size = lr_tensor.shape[0] sr_list = [] for b in range(batch_size): lr_image = ToPILImage()(lr_tensor[b]) sr_image = return_SRFlow_result(lr_image) sr_tensor = PILToTensor()(sr_image).unsqueeze(0) sr_list.append(sr_tensor) sr_tensor = torch.cat(sr_list, dim=0) return sr_tensor if __name__ == '__main__': lr = Image.open('images/demo.png') lr_tensor = PILToTensor()(lr).unsqueeze(0) print(lr_tensor.shape) random_tensor = torch.randn(8, 3, 64, 64) sr = return_SRFlow_result_from_tensor(lr_tensor) print(sr) # Show SR image of the first one in the batch plt.imshow(np.transpose(sr[0].cpu().detach().numpy(), (1, 2, 0))) # plt.axis('off') plt.show()