import numpy as np import torch from PIL import Image from torch import nn from nets.cyclegan import Generator from utils.utils import (cvtColor, postprocess_output, preprocess_input, resize_image, show_config) class CYCLEGAN(object): _defaults = { #-----------------------------------------------# # model_path指向logs文件夹下的权值文件 #-----------------------------------------------# "model_path" : 'model_data/G_model_B2A_last_epoch_weights.pth', #-------------------------------# # 是否使用Cuda # 没有GPU可以设置成False #-------------------------------# "cuda" : True, } #---------------------------------------------------# # 初始化CYCLEGAN #---------------------------------------------------# def __init__(self, **kwargs): self.__dict__.update(self._defaults) for name, value in kwargs.items(): setattr(self, name, value) self._defaults[name] = value self.generate() show_config(**self._defaults) def generate(self): #----------------------------------------# # 创建GAN模型 #----------------------------------------# self.net = Generator(upscale=1, img_size=tuple(self.input_shape), window_size=7, img_range=1., depths=[3, 3, 3, 3], embed_dim=60, num_heads=[3, 3, 3, 3], mlp_ratio=1, upsampler='1conv').eval() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.net.load_state_dict(torch.load(self.model_path, map_location=device)) self.net = self.net.eval() print('{} model loaded.'.format(self.model_path)) if self.cuda: self.net = nn.DataParallel(self.net) self.net = self.net.cuda() #---------------------------------------------------# # 生成1x1的图片 #---------------------------------------------------# def detect_image(self, image): #---------------------------------------------------------# # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB #---------------------------------------------------------# image = cvtColor(image) #---------------------------------------------------------# # 添加上batch_size维度 #---------------------------------------------------------# image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0) with torch.no_grad(): images = torch.from_numpy(image_data) if self.cuda: images = images.cuda() #---------------------------------------------------# # 图片传入网络进行预测 #---------------------------------------------------# pr = self.net(images)[0] #---------------------------------------------------# # 转为numpy #---------------------------------------------------# pr = pr.permute(1, 2, 0).cpu().numpy() image = postprocess_output(pr) image = np.clip(image, 0, 255) image = Image.fromarray(np.uint8(image)) return image