File size: 3,878 Bytes
905cd18
 
 
 
db5513e
905cd18
 
 
 
 
 
 
 
 
 
 
db5513e
905cd18
 
 
db5513e
 
 
 
 
905cd18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db5513e
 
 
 
905cd18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db5513e
905cd18
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image
from nets.SwinIR import Generator
from utils.utils import cvtColor, preprocess_input


class ESRGAN(object):
    #-----------------------------------------#
    #   注意修改model_path
    #-----------------------------------------#
    _defaults = {
        #-----------------------------------------------#
        #   model_path指向logs文件夹下的权值文件
        #-----------------------------------------------#
        "model_path"        : 'model_data/Generator_SwinIR.pth',
        #-----------------------------------------------#
        #   上采样的倍数,和训练时一样
        #-----------------------------------------------#
        "scale_factor"      : 4, 
        #-----------------------------------------------#
        #   hr_shape
        #-----------------------------------------------#
        "hr_shape"          : [128, 224], 
        #-------------------------------#
        #   是否使用Cuda
        #   没有GPU可以设置成False
        #-------------------------------#
        "cuda"              : False,
    }

    #---------------------------------------------------#
    #   初始化SRGAN
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)  
        self.generate()

    def generate(self):
        # self.net    = Generator(self.scale_factor)
        self.net    = Generator(upscale=self.scale_factor, img_size=tuple(self.hr_shape),
                   window_size=8, img_range=1., depths=[3, 3, 3, 3],
                   embed_dim=60, num_heads=[3, 3, 3, 3], mlp_ratio=2, upsampler='pixelshuffledirect')

        device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net.load_state_dict(torch.load(self.model_path, map_location=device))
        self.net    = self.net.eval()
        print('{} model, and classes loaded.'.format(self.model_path))

        if self.cuda:
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = True
            self.net = self.net.cuda()
        
    def generate_1x1_image(self, image):
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        #---------------------------------------------------------#
        #   添加上batch_size维度,并进行归一化
        #---------------------------------------------------------#
        image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1]), 0)
        
        with torch.no_grad():
            image_data = torch.from_numpy(image_data).type(torch.FloatTensor)
            if self.cuda:
                image_data = image_data.cuda()

            #---------------------------------------------------------#
            #   将图像输入网络当中进行预测!
            #---------------------------------------------------------#
            hr_image = self.net(image_data)[0]
            #---------------------------------------------------------#
            #   将归一化的结果再转成rgb格式
            #---------------------------------------------------------#
            hr_image = (hr_image.cpu().data.numpy().transpose(1, 2, 0) * 0.5 + 0.5)
            hr_image = np.clip(hr_image * 255, 0, 255)

            hr_image = Image.fromarray(np.uint8(hr_image))
            return hr_image