GCycleGAN / cyclegan.py
Egrt's picture
修复
fcb7410
raw
history blame
3.72 kB
import numpy as np
import torch
from PIL import Image
from torch import nn
from nets.cyclegan import Generator
from utils.utils import (cvtColor, postprocess_output, preprocess_input,
resize_image, show_config)
class CYCLEGAN(object):
_defaults = {
#-----------------------------------------------#
# model_path指向logs文件夹下的权值文件
#-----------------------------------------------#
"model_path" : 'model_data/G_model_B2A_last_epoch_weights.pth',
#-----------------------------------------------#
# 输入图像大小的设置
#-----------------------------------------------#
"input_shape" : [112, 112],
#-------------------------------#
# 是否使用Cuda
# 没有GPU可以设置成False
#-------------------------------#
"cuda" : False,
}
#---------------------------------------------------#
# 初始化CYCLEGAN
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
self._defaults[name] = value
self.generate()
show_config(**self._defaults)
def generate(self):
#----------------------------------------#
# 创建GAN模型
#----------------------------------------#
self.net = Generator(upscale=1, img_size=tuple(self.input_shape),
window_size=7, img_range=1., depths=[3, 3, 3, 3],
embed_dim=60, num_heads=[3, 3, 3, 3], mlp_ratio=1, upsampler='1conv').eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
self.net = self.net.eval()
print('{} model loaded.'.format(self.model_path))
if self.cuda:
self.net = nn.DataParallel(self.net)
self.net = self.net.cuda()
#---------------------------------------------------#
# 生成1x1的图片
#---------------------------------------------------#
def detect_image(self, image):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image, dtype='float32')), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------#
# 图片传入网络进行预测
#---------------------------------------------------#
pr = self.net(images)[0]
#---------------------------------------------------#
# 转为numpy
#---------------------------------------------------#
pr = pr.permute(1, 2, 0).cpu().numpy()
image = postprocess_output(pr)
image = np.clip(image, 0, 255)
image = Image.fromarray(np.uint8(image))
return image