File size: 4,657 Bytes
95e767b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a54cd3e
 
 
 
95e767b
253a98e
 
 
 
95e767b
 
 
6b31eaf
95e767b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253a98e
 
 
 
 
95e767b
 
253a98e
95e767b
 
 
 
 
 
 
 
 
 
 
 
 
253a98e
 
 
 
 
 
 
 
 
95e767b
 
1db4bf6
95e767b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
import torch
from PIL import Image
from torch import nn

from nets.cyclegan import Generator
from utils.utils import (cvtColor, postprocess_output, preprocess_input,
                         resize_image, show_config)


class CYCLEGAN(object):
    _defaults = {
        #-----------------------------------------------#
        #   model_path指向logs文件夹下的权值文件
        #-----------------------------------------------#
        "model_path"        : 'model_data/G_model_B2A_last_epoch_weights.pth',
        #-----------------------------------------------#
        #   输入图像大小的设置
        #-----------------------------------------------#
        "input_shape"       : [112, 112],
        #-------------------------------#
        #   是否进行不失真的resize
        #-------------------------------#
        "letterbox_image"   : True,
        #-------------------------------#
        #   是否使用Cuda
        #   没有GPU可以设置成False
        #-------------------------------#
        "cuda"              : False,
    }

    #---------------------------------------------------#
    #   初始化CYCLEGAN
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)  
            self._defaults[name] = value 
        self.generate()
        
        show_config(**self._defaults)

    def generate(self):
        #----------------------------------------#
        #   创建GAN模型
        #----------------------------------------#
        self.net    = Generator(upscale=1, img_size=tuple(self.input_shape),
                   window_size=7, img_range=1., depths=[3, 3, 3, 3],
                   embed_dim=60, num_heads=[3, 3, 3, 3], mlp_ratio=1, upsampler='1conv').eval()

        device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net.load_state_dict(torch.load(self.model_path, map_location=device))
        self.net    = self.net.eval()
        print('{} model loaded.'.format(self.model_path))

        if self.cuda:
            self.net = nn.DataParallel(self.net)
            self.net = self.net.cuda()

    #---------------------------------------------------#
    #   生成1x1的图片
    #---------------------------------------------------#
    def detect_image(self, image):
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        #---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
        #---------------------------------------------------------#
        #   添加上batch_size维度
        #---------------------------------------------------------#
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
        
        with torch.no_grad():
            images = torch.from_numpy(image_data)
            if self.cuda:
                images = images.cuda()
                
            #---------------------------------------------------#
            #   图片传入网络进行预测
            #---------------------------------------------------#
            pr = self.net(images)[0]
            #---------------------------------------------------#
            #   转为numpy
            #---------------------------------------------------#
            pr = pr.permute(1, 2, 0).cpu().numpy()
            
            #--------------------------------------#
            #   将灰条部分截取掉
            #--------------------------------------#
            if nw is not None:
                pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                        int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
                
            
        image = postprocess_output(pr)
        image = np.clip(image, 0, 255)
        image = Image.fromarray(np.uint8(image))

        return image