File size: 5,729 Bytes
22d8ab7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from torchvision import models
from collections import namedtuple
import torch
import torch.nn as nn
def vgg_preprocess(tensor):
# input is RGB tensor which ranges in [0,1]
# output is RGB tensor which ranges
mean_val = torch.Tensor([0.485, 0.456, 0.406]).type_as(tensor).view(-1, 1, 1)
std_val = torch.Tensor([0.229, 0.224, 0.225]).type_as(tensor).view(-1, 1, 1)
tensor_norm = (tensor - mean_val) / std_val
return tensor_norm
class vgg19(nn.Module):
def __init__(self, pretrained_path = './experiments/VGG19/vgg19-dcbb9e9d.pth', require_grad = False):
super(vgg19, self).__init__()
self.vgg_model = models.vgg19()
if pretrained_path != None:
print('----load pretrained vgg19----')
self.vgg_model.load_state_dict(torch.load(pretrained_path))
print('----load done!----')
self.vgg_feature = self.vgg_model.features
self.seq_list = [nn.Sequential(ele) for ele in self.vgg_feature]
# self.vgg_layer = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
# 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
# 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
# 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
# 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5']
# self.vgg_layer = ['relu1_2', 'relu2_2', 'relu3_2', 'relu4_2', 'relu5_2']
if not require_grad:
for parameter in self.parameters():
parameter.requires_grad = False
def forward(self, x, layer_name='relu5_2'):
### x: RGB [0, 1], input should be [0, 1]
x = vgg_preprocess(x)
conv1_1 = self.seq_list[0](x)
relu1_1 = self.seq_list[1](conv1_1)
conv1_2 = self.seq_list[2](relu1_1)
relu1_2 = self.seq_list[3](conv1_2)
pool1 = self.seq_list[4](relu1_2)
conv2_1 = self.seq_list[5](pool1)
relu2_1 = self.seq_list[6](conv2_1)
conv2_2 = self.seq_list[7](relu2_1)
relu2_2 = self.seq_list[8](conv2_2)
pool2 = self.seq_list[9](relu2_2)
conv3_1 = self.seq_list[10](pool2)
relu3_1 = self.seq_list[11](conv3_1)
conv3_2 = self.seq_list[12](relu3_1)
relu3_2 = self.seq_list[13](conv3_2)
conv3_3 = self.seq_list[14](relu3_2)
relu3_3 = self.seq_list[15](conv3_3)
conv3_4 = self.seq_list[16](relu3_3)
relu3_4 = self.seq_list[17](conv3_4)
pool3 = self.seq_list[18](relu3_4)
conv4_1 = self.seq_list[19](pool3)
relu4_1 = self.seq_list[20](conv4_1)
conv4_2 = self.seq_list[21](relu4_1)
relu4_2 = self.seq_list[22](conv4_2)
conv4_3 = self.seq_list[23](relu4_2)
relu4_3 = self.seq_list[24](conv4_3)
conv4_4 = self.seq_list[25](relu4_3)
relu4_4 = self.seq_list[26](conv4_4)
pool4 = self.seq_list[27](relu4_4)
conv5_1 = self.seq_list[28](pool4)
relu5_1 = self.seq_list[29](conv5_1)
conv5_2 = self.seq_list[30](relu5_1)
relu5_2 = self.seq_list[31](conv5_2) # [B, 512, 16, 16]
conv5_3 = self.seq_list[32](relu5_2)
relu5_3 = self.seq_list[33](conv5_3)
conv5_4 = self.seq_list[34](relu5_3)
relu5_4 = self.seq_list[35](conv5_4)
pool5 = self.seq_list[36](relu5_4) # [B, 512, 8, 8]
# vgg_output = namedtuple("vgg_output", self.vgg_layer)
# vgg_list = [conv1_1, relu1_1, conv1_2, relu1_2, pool1,
# conv2_1, relu2_1, conv2_2, relu2_2, pool2,
# conv3_1, relu3_1, conv3_2, relu3_2, conv3_3, relu3_3, conv3_4, relu3_4, pool3,
# conv4_1, relu4_1, conv4_2, relu4_2, conv4_3, relu4_3, conv4_4, relu4_4, pool4,
# conv5_1, relu5_1, conv5_2, relu5_2, conv5_3, relu5_3, conv5_4, relu5_4, pool5]
if layer_name == 'relu5_2':
vgg_list = [relu5_2]
elif layer_name == 'conv5_2':
vgg_list = [conv5_2]
elif layer_name == 'relu5_4':
vgg_list = [relu5_4]
elif layer_name == 'pool5':
# print('pool5')
vgg_list = [pool5]
elif layer_name == 'all':
vgg_list = [relu1_2, relu2_2, relu3_2, relu4_2, relu5_2]
# out = vgg_output(*vgg_list)
return vgg_list
class vgg19_class_fea(nn.Module):
def __init__(self, pretrained_path = './experiments/vgg19-dcbb9e9d.pth', require_grad = False):
super(vgg19_class_fea, self).__init__()
self.vgg_model = models.vgg19()
print('----load pretrained vgg19----')
self.vgg_model.load_state_dict(torch.load(pretrained_path))
print('----load done!----')
self.vgg_feature = self.vgg_model.features
self.avgpool = self.vgg_model.avgpool
self.classifier = self.vgg_model.classifier
self.seq_list = [nn.Sequential(ele) for ele in self.vgg_feature] # 37层
if not require_grad:
for parameter in self.parameters():
parameter.requires_grad = False
def forward(self, x):
### x: RGB [0, 1], input should be [0, 1]
x = vgg_preprocess(x)
for i in range(len(self.seq_list)):
x = self.seq_list[i](x)
if i == 31:
relu5_2 = x
x = self.avgpool(x)
x = torch.flatten(x, 1)
x_class = self.classifier(x)
return x_class, relu5_2 |