File size: 3,716 Bytes
95e767b a54cd3e 95e767b 6b31eaf 95e767b fcb7410 95e767b 7c02a0d 95e767b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import numpy as np
import torch
from PIL import Image
from torch import nn
from nets.cyclegan import Generator
from utils.utils import (cvtColor, postprocess_output, preprocess_input,
resize_image, show_config)
class CYCLEGAN(object):
_defaults = {
#-----------------------------------------------#
# model_path指向logs文件夹下的权值文件
#-----------------------------------------------#
"model_path" : 'model_data/G_model_B2A_last_epoch_weights.pth',
#-----------------------------------------------#
# 输入图像大小的设置
#-----------------------------------------------#
"input_shape" : [112, 112],
#-------------------------------#
# 是否使用Cuda
# 没有GPU可以设置成False
#-------------------------------#
"cuda" : False,
}
#---------------------------------------------------#
# 初始化CYCLEGAN
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
self._defaults[name] = value
self.generate()
show_config(**self._defaults)
def generate(self):
#----------------------------------------#
# 创建GAN模型
#----------------------------------------#
self.net = Generator(upscale=1, img_size=tuple(self.input_shape),
window_size=7, img_range=1., depths=[3, 3, 3, 3],
embed_dim=60, num_heads=[3, 3, 3, 3], mlp_ratio=1, upsampler='1conv').eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
self.net = self.net.eval()
print('{} model loaded.'.format(self.model_path))
if self.cuda:
self.net = nn.DataParallel(self.net)
self.net = self.net.cuda()
#---------------------------------------------------#
# 生成1x1的图片
#---------------------------------------------------#
def detect_image(self, image):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 添加上batch_size维度
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image, dtype='float32')), (2, 0, 1)), 0)
with torch.no_grad():
images = torch.from_numpy(image_data)
if self.cuda:
images = images.cuda()
#---------------------------------------------------#
# 图片传入网络进行预测
#---------------------------------------------------#
pr = self.net(images)[0]
#---------------------------------------------------#
# 转为numpy
#---------------------------------------------------#
pr = pr.permute(1, 2, 0).cpu().numpy()
image = postprocess_output(pr)
image = np.clip(image, 0, 255)
image = Image.fromarray(np.uint8(image))
return image
|