import numpy as np import torch import torch.backends.cudnn as cudnn from PIL import Image from nets.SwinIR import Generator from utils.utils import cvtColor, preprocess_input class ESRGAN(object): #-----------------------------------------# # 注意修改model_path #-----------------------------------------# _defaults = { #-----------------------------------------------# # model_path指向logs文件夹下的权值文件 #-----------------------------------------------# "model_path" : 'model_data/Generator_SwinIR.pth', #-----------------------------------------------# # 上采样的倍数,和训练时一样 #-----------------------------------------------# "scale_factor" : 4, #-----------------------------------------------# # hr_shape #-----------------------------------------------# "hr_shape" : [128, 224], #-------------------------------# # 是否使用Cuda # 没有GPU可以设置成False #-------------------------------# "cuda" : False, } #---------------------------------------------------# # 初始化SRGAN #---------------------------------------------------# def __init__(self, **kwargs): self.__dict__.update(self._defaults) for name, value in kwargs.items(): setattr(self, name, value) self.generate() def generate(self): # self.net = Generator(self.scale_factor) self.net = Generator(upscale=self.scale_factor, img_size=tuple(self.hr_shape), window_size=8, img_range=1., depths=[3, 3, 3, 3], embed_dim=60, num_heads=[3, 3, 3, 3], mlp_ratio=2, upsampler='pixelshuffledirect') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.net.load_state_dict(torch.load(self.model_path, map_location=device)) self.net = self.net.eval() print('{} model, and classes loaded.'.format(self.model_path)) if self.cuda: self.net = torch.nn.DataParallel(self.net) cudnn.benchmark = True self.net = self.net.cuda() def generate_1x1_image(self, image): #---------------------------------------------------------# # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB #---------------------------------------------------------# image = cvtColor(image) #---------------------------------------------------------# # 添加上batch_size维度,并进行归一化 #---------------------------------------------------------# image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1]), 0) with torch.no_grad(): image_data = torch.from_numpy(image_data).type(torch.FloatTensor) if self.cuda: image_data = image_data.cuda() #---------------------------------------------------------# # 将图像输入网络当中进行预测! #---------------------------------------------------------# hr_image = self.net(image_data)[0] #---------------------------------------------------------# # 将归一化的结果再转成rgb格式 #---------------------------------------------------------# hr_image = (hr_image.cpu().data.numpy().transpose(1, 2, 0) * 0.5 + 0.5) hr_image = np.clip(hr_image * 255, 0, 255) hr_image = Image.fromarray(np.uint8(hr_image)) return hr_image