File size: 5,729 Bytes
22d8ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from torchvision import models
from collections import namedtuple
import torch
import torch.nn as nn

def vgg_preprocess(tensor):
    # input is RGB tensor which ranges in [0,1]
    # output is RGB tensor which ranges
    mean_val = torch.Tensor([0.485, 0.456, 0.406]).type_as(tensor).view(-1, 1, 1)
    std_val = torch.Tensor([0.229, 0.224, 0.225]).type_as(tensor).view(-1, 1, 1)
    tensor_norm = (tensor - mean_val) / std_val
    return tensor_norm

class vgg19(nn.Module):
    
    def __init__(self, pretrained_path = './experiments/VGG19/vgg19-dcbb9e9d.pth', require_grad = False):
        super(vgg19, self).__init__()
        self.vgg_model = models.vgg19()
        if pretrained_path != None:
            print('----load pretrained vgg19----')
            self.vgg_model.load_state_dict(torch.load(pretrained_path))
            print('----load done!----')
        self.vgg_feature = self.vgg_model.features
        self.seq_list = [nn.Sequential(ele) for ele in self.vgg_feature]
        # self.vgg_layer = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 
        #                  'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
        #                  'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
        #                  'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
        #                  'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5']

        # self.vgg_layer = ['relu1_2', 'relu2_2', 'relu3_2', 'relu4_2', 'relu5_2']
        
        if not require_grad:
            for parameter in self.parameters():
                parameter.requires_grad = False
        
    def forward(self, x, layer_name='relu5_2'):
        ### x: RGB [0, 1], input should be [0, 1]
        x = vgg_preprocess(x)

        conv1_1 = self.seq_list[0](x)
        relu1_1 = self.seq_list[1](conv1_1)
        conv1_2 = self.seq_list[2](relu1_1)
        relu1_2 = self.seq_list[3](conv1_2)
        pool1 = self.seq_list[4](relu1_2)
        
        conv2_1 = self.seq_list[5](pool1)
        relu2_1 = self.seq_list[6](conv2_1)
        conv2_2 = self.seq_list[7](relu2_1)
        relu2_2 = self.seq_list[8](conv2_2)
        pool2 = self.seq_list[9](relu2_2)
        
        conv3_1 = self.seq_list[10](pool2)
        relu3_1 = self.seq_list[11](conv3_1)
        conv3_2 = self.seq_list[12](relu3_1)
        relu3_2 = self.seq_list[13](conv3_2)
        conv3_3 = self.seq_list[14](relu3_2)
        relu3_3 = self.seq_list[15](conv3_3)
        conv3_4 = self.seq_list[16](relu3_3)
        relu3_4 = self.seq_list[17](conv3_4)
        pool3 = self.seq_list[18](relu3_4)
        
        conv4_1 = self.seq_list[19](pool3)
        relu4_1 = self.seq_list[20](conv4_1)
        conv4_2 = self.seq_list[21](relu4_1)
        relu4_2 = self.seq_list[22](conv4_2)
        conv4_3 = self.seq_list[23](relu4_2)
        relu4_3 = self.seq_list[24](conv4_3)
        conv4_4 = self.seq_list[25](relu4_3)
        relu4_4 = self.seq_list[26](conv4_4)
        pool4 = self.seq_list[27](relu4_4)
        
        conv5_1 = self.seq_list[28](pool4)
        relu5_1 = self.seq_list[29](conv5_1)
        conv5_2 = self.seq_list[30](relu5_1)
        relu5_2 = self.seq_list[31](conv5_2) # [B, 512, 16, 16]
        conv5_3 = self.seq_list[32](relu5_2)
        relu5_3 = self.seq_list[33](conv5_3)
        conv5_4 = self.seq_list[34](relu5_3)
        relu5_4 = self.seq_list[35](conv5_4)
        pool5 = self.seq_list[36](relu5_4) # [B, 512, 8, 8]
        
        # vgg_output = namedtuple("vgg_output", self.vgg_layer)
        
        # vgg_list = [conv1_1, relu1_1, conv1_2, relu1_2, pool1, 
        #                  conv2_1, relu2_1, conv2_2, relu2_2, pool2,
        #                  conv3_1, relu3_1, conv3_2, relu3_2, conv3_3, relu3_3, conv3_4, relu3_4, pool3,
        #                  conv4_1, relu4_1, conv4_2, relu4_2, conv4_3, relu4_3, conv4_4, relu4_4, pool4,
        #                  conv5_1, relu5_1, conv5_2, relu5_2, conv5_3, relu5_3, conv5_4, relu5_4, pool5]

        if layer_name == 'relu5_2':
            vgg_list = [relu5_2]
        elif layer_name == 'conv5_2':
            vgg_list = [conv5_2]
        elif layer_name == 'relu5_4':
            vgg_list = [relu5_4]
        elif layer_name == 'pool5':
            # print('pool5')
            vgg_list = [pool5]
        elif layer_name == 'all':
            vgg_list = [relu1_2, relu2_2, relu3_2, relu4_2, relu5_2]
        
        # out = vgg_output(*vgg_list)
        
        return vgg_list

class vgg19_class_fea(nn.Module):
    
    def __init__(self, pretrained_path = './experiments/vgg19-dcbb9e9d.pth', require_grad = False):
        super(vgg19_class_fea, self).__init__()
        self.vgg_model = models.vgg19()
        print('----load pretrained vgg19----')
        self.vgg_model.load_state_dict(torch.load(pretrained_path))
        print('----load done!----')
        self.vgg_feature = self.vgg_model.features
        self.avgpool = self.vgg_model.avgpool
        self.classifier = self.vgg_model.classifier

        self.seq_list = [nn.Sequential(ele) for ele in self.vgg_feature] # 37层
        if not require_grad:
            for parameter in self.parameters():
                parameter.requires_grad = False
        
    def forward(self, x):
        ### x: RGB [0, 1], input should be [0, 1]
        x = vgg_preprocess(x)

        for i in range(len(self.seq_list)):
            x = self.seq_list[i](x)
            if i == 31:
                relu5_2 = x
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x_class = self.classifier(x)
        return x_class, relu5_2