jpterry commited on
Commit
00e189e
1 Parent(s): 467722d

trying again

Browse files
model_utils/efficientnet_config.py CHANGED
@@ -14,8 +14,8 @@ from torch import Tensor, nn
14
  from torchvision.models._utils import _make_divisible
15
  from torchvision.ops import StochasticDepth
16
 
17
- sys.path.insert(1, "./")
18
- from .vision_modifications import Conv2dNormActivation, SqueezeExcitation
19
 
20
 
21
  @dataclass
 
14
  from torchvision.models._utils import _make_divisible
15
  from torchvision.ops import StochasticDepth
16
 
17
+ sys.path.insert(1, "../")
18
+ from utils.vision_modifications import Conv2dNormActivation, SqueezeExcitation
19
 
20
 
21
  @dataclass
utils/vision_modifications.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Callable, List, Optional
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ interpolate = torch.nn.functional.interpolate
8
+
9
+
10
+ class FrozenBatchNorm2d(torch.nn.Module):
11
+ """
12
+ BatchNorm2d where the batch statistics and the affine parameters are fixed
13
+
14
+ Args:
15
+ num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
16
+ eps (float): a value added to the denominator for numerical stability. Default: 1e-5
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ num_features: int,
22
+ eps: float = 1e-5,
23
+ ):
24
+ super().__init__()
25
+ # _log_api_usage_once(self)
26
+ self.eps = eps
27
+ self.register_buffer("weight", torch.ones(num_features))
28
+ self.register_buffer("bias", torch.zeros(num_features))
29
+ self.register_buffer("running_mean", torch.zeros(num_features))
30
+ self.register_buffer("running_var", torch.ones(num_features))
31
+
32
+ def _load_from_state_dict(
33
+ self,
34
+ state_dict: dict,
35
+ prefix: str,
36
+ local_metadata: dict,
37
+ strict: bool,
38
+ missing_keys: List[str],
39
+ unexpected_keys: List[str],
40
+ error_msgs: List[str],
41
+ ):
42
+ num_batches_tracked_key = prefix + "num_batches_tracked"
43
+ if num_batches_tracked_key in state_dict:
44
+ del state_dict[num_batches_tracked_key]
45
+
46
+ super()._load_from_state_dict(
47
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
48
+ )
49
+
50
+ def forward(self, x: Tensor) -> Tensor:
51
+ # move reshapes to the beginning
52
+ # to make it fuser-friendly
53
+ w = self.weight.reshape(1, -1, 1, 1)
54
+ b = self.bias.reshape(1, -1, 1, 1)
55
+ rv = self.running_var.reshape(1, -1, 1, 1)
56
+ rm = self.running_mean.reshape(1, -1, 1, 1)
57
+ scale = w * (rv + self.eps).rsqrt()
58
+ bias = b - rm * scale
59
+ return x * scale + bias
60
+
61
+ def __repr__(self) -> str:
62
+ return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
63
+
64
+
65
+ class ConvNormActivation(torch.nn.Sequential):
66
+ def __init__(
67
+ self,
68
+ in_channels: int,
69
+ out_channels: int,
70
+ kernel_size: int = 3,
71
+ stride: int = 1,
72
+ padding: Optional[int] = None,
73
+ groups: int = 1,
74
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
75
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
76
+ dilation: int = 1,
77
+ inplace: Optional[bool] = True,
78
+ bias: Optional[bool] = None,
79
+ conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
80
+ ) -> None:
81
+
82
+ if padding is None:
83
+ padding = (kernel_size - 1) // 2 * dilation
84
+ if bias is None:
85
+ bias = norm_layer is None
86
+
87
+ layers = [
88
+ conv_layer(
89
+ in_channels,
90
+ out_channels,
91
+ kernel_size,
92
+ stride,
93
+ padding,
94
+ dilation=dilation,
95
+ groups=groups,
96
+ bias=bias,
97
+ )
98
+ ]
99
+
100
+ if norm_layer is not None:
101
+ layers.append(norm_layer(out_channels))
102
+
103
+ if activation_layer is not None:
104
+ params = {} if inplace is None else {"inplace": inplace}
105
+ layers.append(activation_layer(**params))
106
+ super().__init__(*layers)
107
+ # _log_api_usage_once(self)
108
+ self.out_channels = out_channels
109
+
110
+ if self.__class__ == ConvNormActivation:
111
+ warnings.warn(
112
+ "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
113
+ )
114
+
115
+
116
+ class Conv2dNormActivation(ConvNormActivation):
117
+ """
118
+ Configurable block used for Convolution2d-Normalization-Activation blocks.
119
+
120
+ Args:
121
+ in_channels (int): Number of channels in the input image
122
+ out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
123
+ kernel_size: (int, optional): Size of the convolving kernel. Default: 3
124
+ stride (int, optional): Stride of the convolution. Default: 1
125
+ padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
126
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
127
+ norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
128
+ activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
129
+ dilation (int): Spacing between kernel elements. Default: 1
130
+ inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
131
+ bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
132
+
133
+ """
134
+
135
+ def __init__(
136
+ self,
137
+ in_channels: int,
138
+ out_channels: int,
139
+ kernel_size: int = 3,
140
+ stride: int = 1,
141
+ padding: Optional[int] = None,
142
+ groups: int = 1,
143
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
144
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
145
+ dilation: int = 1,
146
+ inplace: Optional[bool] = True,
147
+ bias: Optional[bool] = None,
148
+ ) -> None:
149
+
150
+ super().__init__(
151
+ in_channels,
152
+ out_channels,
153
+ kernel_size,
154
+ stride,
155
+ padding,
156
+ groups,
157
+ norm_layer,
158
+ activation_layer,
159
+ dilation,
160
+ inplace,
161
+ bias,
162
+ torch.nn.Conv2d,
163
+ )
164
+
165
+
166
+ class Conv3dNormActivation(ConvNormActivation):
167
+ """
168
+ Configurable block used for Convolution3d-Normalization-Activation blocks.
169
+
170
+ Args:
171
+ in_channels (int): Number of channels in the input video.
172
+ out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
173
+ kernel_size: (int, optional): Size of the convolving kernel. Default: 3
174
+ stride (int, optional): Stride of the convolution. Default: 1
175
+ padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
176
+ groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
177
+ norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d``
178
+ activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
179
+ dilation (int): Spacing between kernel elements. Default: 1
180
+ inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
181
+ bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
182
+ """
183
+
184
+ def __init__(
185
+ self,
186
+ in_channels: int,
187
+ out_channels: int,
188
+ kernel_size: int = 3,
189
+ stride: int = 1,
190
+ padding: Optional[int] = None,
191
+ groups: int = 1,
192
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
193
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
194
+ dilation: int = 1,
195
+ inplace: Optional[bool] = True,
196
+ bias: Optional[bool] = None,
197
+ ) -> None:
198
+
199
+ super().__init__(
200
+ in_channels,
201
+ out_channels,
202
+ kernel_size,
203
+ stride,
204
+ padding,
205
+ groups,
206
+ norm_layer,
207
+ activation_layer,
208
+ dilation,
209
+ inplace,
210
+ bias,
211
+ torch.nn.Conv3d,
212
+ )
213
+
214
+
215
+ class SqueezeExcitation(torch.nn.Module):
216
+ """
217
+ This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
218
+ Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in eq. 3.
219
+
220
+ Args:
221
+ input_channels (int): Number of channels in the input image
222
+ squeeze_channels (int): Number of squeeze channels
223
+ activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``
224
+ scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ input_channels: int,
230
+ squeeze_channels: int,
231
+ activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
232
+ scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
233
+ ) -> None:
234
+ super().__init__()
235
+ # _log_api_usage_once(self)
236
+ self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
237
+ self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)
238
+ self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)
239
+ self.activation = activation()
240
+ self.scale_activation = scale_activation()
241
+
242
+ def _scale(self, input: Tensor) -> Tensor:
243
+ scale = self.avgpool(input)
244
+ scale = self.fc1(scale)
245
+ scale = self.activation(scale)
246
+ scale = self.fc2(scale)
247
+ return self.scale_activation(scale)
248
+
249
+ def forward(self, input: Tensor) -> Tensor:
250
+ scale = self._scale(input)
251
+ return scale * input
252
+
253
+
254
+ class MLP(torch.nn.Sequential):
255
+ """This block implements the multi-layer perceptron (MLP) module.
256
+
257
+ Args:
258
+ in_channels (int): Number of channels of the input
259
+ hidden_channels (List[int]): List of the hidden channel dimensions
260
+ norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None``
261
+ activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
262
+ inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
263
+ bias (bool): Whether to use bias in the linear layer. Default ``True``
264
+ dropout (float): The probability for the dropout layer. Default: 0.0
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ in_channels: int,
270
+ hidden_channels: List[int],
271
+ norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
272
+ activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
273
+ inplace: Optional[bool] = True,
274
+ bias: bool = True,
275
+ dropout: float = 0.0,
276
+ ):
277
+ # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
278
+ # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
279
+ params = {} if inplace is None else {"inplace": inplace}
280
+
281
+ layers = []
282
+ in_dim = in_channels
283
+ for hidden_dim in hidden_channels[:-1]:
284
+ layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
285
+ if norm_layer is not None:
286
+ layers.append(norm_layer(hidden_dim))
287
+ layers.append(activation_layer(**params))
288
+ layers.append(torch.nn.Dropout(dropout, **params))
289
+ in_dim = hidden_dim
290
+
291
+ layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
292
+ layers.append(torch.nn.Dropout(dropout, **params))
293
+
294
+ super().__init__(*layers)
295
+ # _log_api_usage_once(self)
296
+
297
+
298
+ class Permute(torch.nn.Module):
299
+ """This module returns a view of the tensor input with its dimensions permuted.
300
+
301
+ Args:
302
+ dims (List[int]): The desired ordering of dimensions
303
+ """
304
+
305
+ def __init__(self, dims: List[int]):
306
+ super().__init__()
307
+ self.dims = dims
308
+
309
+ def forward(self, x: Tensor) -> Tensor:
310
+ return torch.permute(x, self.dims)