File size: 3,716 Bytes
95e767b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a54cd3e
 
 
 
95e767b
 
 
 
6b31eaf
95e767b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcb7410
95e767b
 
 
 
 
 
 
 
 
 
 
 
 
7c02a0d
95e767b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import torch
from PIL import Image
from torch import nn

from nets.cyclegan import Generator
from utils.utils import (cvtColor, postprocess_output, preprocess_input,
                         resize_image, show_config)


class CYCLEGAN(object):
    _defaults = {
        #-----------------------------------------------#
        #   model_path指向logs文件夹下的权值文件
        #-----------------------------------------------#
        "model_path"        : 'model_data/G_model_B2A_last_epoch_weights.pth',
        #-----------------------------------------------#
        #   输入图像大小的设置
        #-----------------------------------------------#
        "input_shape"       : [112, 112],
        #-------------------------------#
        #   是否使用Cuda
        #   没有GPU可以设置成False
        #-------------------------------#
        "cuda"              : False,
    }

    #---------------------------------------------------#
    #   初始化CYCLEGAN
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)  
            self._defaults[name] = value 
        self.generate()
        
        show_config(**self._defaults)

    def generate(self):
        #----------------------------------------#
        #   创建GAN模型
        #----------------------------------------#
        self.net    = Generator(upscale=1, img_size=tuple(self.input_shape),
                   window_size=7, img_range=1., depths=[3, 3, 3, 3],
                   embed_dim=60, num_heads=[3, 3, 3, 3], mlp_ratio=1, upsampler='1conv').eval()

        device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net.load_state_dict(torch.load(self.model_path, map_location=device))
        self.net    = self.net.eval()
        print('{} model loaded.'.format(self.model_path))

        if self.cuda:
            self.net = nn.DataParallel(self.net)
            self.net = self.net.cuda()

    #---------------------------------------------------#
    #   生成1x1的图片
    #---------------------------------------------------#
    def detect_image(self, image):
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        #---------------------------------------------------------#
        #   添加上batch_size维度
        #---------------------------------------------------------#
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image, dtype='float32')), (2, 0, 1)), 0)
        
        with torch.no_grad():
            images = torch.from_numpy(image_data)
            if self.cuda:
                images = images.cuda()
                
            #---------------------------------------------------#
            #   图片传入网络进行预测
            #---------------------------------------------------#
            pr = self.net(images)[0]
            #---------------------------------------------------#
            #   转为numpy
            #---------------------------------------------------#
            pr = pr.permute(1, 2, 0).cpu().numpy()      
            
        image = postprocess_output(pr)
        image = np.clip(image, 0, 255)
        image = Image.fromarray(np.uint8(image))

        return image