josedolot commited on
Commit
4b573e4
·
1 Parent(s): 2537604

Upload encoders/vgg.py

Browse files
Files changed (1) hide show
  1. encoders/vgg.py +157 -0
encoders/vgg.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
2
+
3
+ Attributes:
4
+
5
+ _out_channels (list of int): specify number of channels for each encoder feature tensor
6
+ _depth (int): specify number of stages in decoder (in other words number of downsampling operations)
7
+ _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
8
+
9
+ Methods:
10
+
11
+ forward(self, x: torch.Tensor)
12
+ produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
13
+ shape NCHW (features should be sorted in descending order according to spatial resolution, starting
14
+ with resolution same as input `x` tensor).
15
+
16
+ Input: `x` with shape (1, 3, 64, 64)
17
+ Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
18
+ [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
19
+ (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
20
+
21
+ also should support number of features according to specified depth, e.g. if depth = 5,
22
+ number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23
+ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24
+ """
25
+
26
+ import torch.nn as nn
27
+ from torchvision.models.vgg import VGG
28
+ from torchvision.models.vgg import make_layers
29
+ from pretrainedmodels.models.torchvision_models import pretrained_settings
30
+
31
+ from ._base import EncoderMixin
32
+
33
+ # fmt: off
34
+ cfg = {
35
+ 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
36
+ 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
37
+ 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
38
+ 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
39
+ }
40
+ # fmt: on
41
+
42
+
43
+ class VGGEncoder(VGG, EncoderMixin):
44
+ def __init__(self, out_channels, config, batch_norm=False, depth=5, **kwargs):
45
+ super().__init__(make_layers(config, batch_norm=batch_norm), **kwargs)
46
+ self._out_channels = out_channels
47
+ self._depth = depth
48
+ self._in_channels = 3
49
+ del self.classifier
50
+
51
+ def make_dilated(self, stage_list, dilation_list):
52
+ raise ValueError("'VGG' models do not support dilated mode due to Max Pooling"
53
+ " operations for downsampling!")
54
+
55
+ def get_stages(self):
56
+ stages = []
57
+ stage_modules = []
58
+ for module in self.features:
59
+ if isinstance(module, nn.MaxPool2d):
60
+ stages.append(nn.Sequential(*stage_modules))
61
+ stage_modules = []
62
+ stage_modules.append(module)
63
+ stages.append(nn.Sequential(*stage_modules))
64
+ return stages
65
+
66
+ def forward(self, x):
67
+ stages = self.get_stages()
68
+
69
+ features = []
70
+ for i in range(self._depth + 1):
71
+ x = stages[i](x)
72
+ features.append(x)
73
+
74
+ return features
75
+
76
+ def load_state_dict(self, state_dict, **kwargs):
77
+ keys = list(state_dict.keys())
78
+ for k in keys:
79
+ if k.startswith("classifier"):
80
+ state_dict.pop(k, None)
81
+ super().load_state_dict(state_dict, **kwargs)
82
+
83
+
84
+ vgg_encoders = {
85
+ "vgg11": {
86
+ "encoder": VGGEncoder,
87
+ "pretrained_settings": pretrained_settings["vgg11"],
88
+ "params": {
89
+ "out_channels": (64, 128, 256, 512, 512, 512),
90
+ "config": cfg["A"],
91
+ "batch_norm": False,
92
+ },
93
+ },
94
+ "vgg11_bn": {
95
+ "encoder": VGGEncoder,
96
+ "pretrained_settings": pretrained_settings["vgg11_bn"],
97
+ "params": {
98
+ "out_channels": (64, 128, 256, 512, 512, 512),
99
+ "config": cfg["A"],
100
+ "batch_norm": True,
101
+ },
102
+ },
103
+ "vgg13": {
104
+ "encoder": VGGEncoder,
105
+ "pretrained_settings": pretrained_settings["vgg13"],
106
+ "params": {
107
+ "out_channels": (64, 128, 256, 512, 512, 512),
108
+ "config": cfg["B"],
109
+ "batch_norm": False,
110
+ },
111
+ },
112
+ "vgg13_bn": {
113
+ "encoder": VGGEncoder,
114
+ "pretrained_settings": pretrained_settings["vgg13_bn"],
115
+ "params": {
116
+ "out_channels": (64, 128, 256, 512, 512, 512),
117
+ "config": cfg["B"],
118
+ "batch_norm": True,
119
+ },
120
+ },
121
+ "vgg16": {
122
+ "encoder": VGGEncoder,
123
+ "pretrained_settings": pretrained_settings["vgg16"],
124
+ "params": {
125
+ "out_channels": (64, 128, 256, 512, 512, 512),
126
+ "config": cfg["D"],
127
+ "batch_norm": False,
128
+ },
129
+ },
130
+ "vgg16_bn": {
131
+ "encoder": VGGEncoder,
132
+ "pretrained_settings": pretrained_settings["vgg16_bn"],
133
+ "params": {
134
+ "out_channels": (64, 128, 256, 512, 512, 512),
135
+ "config": cfg["D"],
136
+ "batch_norm": True,
137
+ },
138
+ },
139
+ "vgg19": {
140
+ "encoder": VGGEncoder,
141
+ "pretrained_settings": pretrained_settings["vgg19"],
142
+ "params": {
143
+ "out_channels": (64, 128, 256, 512, 512, 512),
144
+ "config": cfg["E"],
145
+ "batch_norm": False,
146
+ },
147
+ },
148
+ "vgg19_bn": {
149
+ "encoder": VGGEncoder,
150
+ "pretrained_settings": pretrained_settings["vgg19_bn"],
151
+ "params": {
152
+ "out_channels": (64, 128, 256, 512, 512, 512),
153
+ "config": cfg["E"],
154
+ "batch_norm": True,
155
+ },
156
+ },
157
+ }