GCycleGAN / cyclegan.py
Egrt's picture
修复
1db4bf6
import numpy as np
import torch
from PIL import Image
from torch import nn
from nets.cyclegan import Generator
from utils.utils import (cvtColor, postprocess_output, preprocess_input,
resize_image, show_config)
class CYCLEGAN(object):
_defaults = {
#-----------------------------------------------#
# model_path指向logs文件夹下的权值文件
#-----------------------------------------------#
"model_path" : 'model_data/G_model_B2A_last_epoch_weights.pth',
#-----------------------------------------------#
# 输入图像大小的设置
#-----------------------------------------------#
"input_shape" : [112, 112],
#-------------------------------#
# 是否进行不失真的resize
#-------------------------------#
"letterbox_image" : True,
#-------------------------------#
# 是否使用Cuda
# 没有GPU可以设置成False
#-------------------------------#
"cuda" : False,
}
#---------------------------------------------------#
# 初始化CYCLEGAN
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
self._defaults[name] = value
self.generate()
show_config(**self._defaults)
def generate(self):
#----------------------------------------#
# 创建GAN模型
#----------------------------------------#
self.net = Generator(upscale=1, img_size=tuple(self.input_shape),
window_size=7, img_range=1., depths=[3, 3, 3, 3],
embed_dim=60, num_heads=[3, 3, 3, 3], mlp_ratio=1, upsampler='1conv').eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
self.net = self.net.eval()
print('{} model loaded.'.format(self.model_path))
if self.cuda:
self.net = nn.DataParallel(self.net)
self.net = self.net.cuda()
#---------------------------------------------------#
# 生成1x1的图片
#---------------------------------------------------#
def detect_image(self, image):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------#
# 图片传入网络进行预测
#---------------------------------------------------#
pr = self.net(images)[0]
#---------------------------------------------------#
# 转为numpy
#---------------------------------------------------#
pr = pr.permute(1, 2, 0).cpu().numpy()
#--------------------------------------#
# 将灰条部分截取掉
#--------------------------------------#
if nw is not None:
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
image = postprocess_output(pr)
image = np.clip(image, 0, 255)
image = Image.fromarray(np.uint8(image))
return image