trying again
Browse files
model_utils/efficientnet_config.py
CHANGED
@@ -14,8 +14,8 @@ from torch import Tensor, nn
|
|
14 |
from torchvision.models._utils import _make_divisible
|
15 |
from torchvision.ops import StochasticDepth
|
16 |
|
17 |
-
sys.path.insert(1, "
|
18 |
-
from .vision_modifications import Conv2dNormActivation, SqueezeExcitation
|
19 |
|
20 |
|
21 |
@dataclass
|
|
|
14 |
from torchvision.models._utils import _make_divisible
|
15 |
from torchvision.ops import StochasticDepth
|
16 |
|
17 |
+
sys.path.insert(1, "../")
|
18 |
+
from utils.vision_modifications import Conv2dNormActivation, SqueezeExcitation
|
19 |
|
20 |
|
21 |
@dataclass
|
utils/vision_modifications.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import Callable, List, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
interpolate = torch.nn.functional.interpolate
|
8 |
+
|
9 |
+
|
10 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
11 |
+
"""
|
12 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
|
16 |
+
eps (float): a value added to the denominator for numerical stability. Default: 1e-5
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
num_features: int,
|
22 |
+
eps: float = 1e-5,
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
# _log_api_usage_once(self)
|
26 |
+
self.eps = eps
|
27 |
+
self.register_buffer("weight", torch.ones(num_features))
|
28 |
+
self.register_buffer("bias", torch.zeros(num_features))
|
29 |
+
self.register_buffer("running_mean", torch.zeros(num_features))
|
30 |
+
self.register_buffer("running_var", torch.ones(num_features))
|
31 |
+
|
32 |
+
def _load_from_state_dict(
|
33 |
+
self,
|
34 |
+
state_dict: dict,
|
35 |
+
prefix: str,
|
36 |
+
local_metadata: dict,
|
37 |
+
strict: bool,
|
38 |
+
missing_keys: List[str],
|
39 |
+
unexpected_keys: List[str],
|
40 |
+
error_msgs: List[str],
|
41 |
+
):
|
42 |
+
num_batches_tracked_key = prefix + "num_batches_tracked"
|
43 |
+
if num_batches_tracked_key in state_dict:
|
44 |
+
del state_dict[num_batches_tracked_key]
|
45 |
+
|
46 |
+
super()._load_from_state_dict(
|
47 |
+
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
48 |
+
)
|
49 |
+
|
50 |
+
def forward(self, x: Tensor) -> Tensor:
|
51 |
+
# move reshapes to the beginning
|
52 |
+
# to make it fuser-friendly
|
53 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
54 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
55 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
56 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
57 |
+
scale = w * (rv + self.eps).rsqrt()
|
58 |
+
bias = b - rm * scale
|
59 |
+
return x * scale + bias
|
60 |
+
|
61 |
+
def __repr__(self) -> str:
|
62 |
+
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
|
63 |
+
|
64 |
+
|
65 |
+
class ConvNormActivation(torch.nn.Sequential):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
in_channels: int,
|
69 |
+
out_channels: int,
|
70 |
+
kernel_size: int = 3,
|
71 |
+
stride: int = 1,
|
72 |
+
padding: Optional[int] = None,
|
73 |
+
groups: int = 1,
|
74 |
+
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
75 |
+
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
76 |
+
dilation: int = 1,
|
77 |
+
inplace: Optional[bool] = True,
|
78 |
+
bias: Optional[bool] = None,
|
79 |
+
conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
|
80 |
+
) -> None:
|
81 |
+
|
82 |
+
if padding is None:
|
83 |
+
padding = (kernel_size - 1) // 2 * dilation
|
84 |
+
if bias is None:
|
85 |
+
bias = norm_layer is None
|
86 |
+
|
87 |
+
layers = [
|
88 |
+
conv_layer(
|
89 |
+
in_channels,
|
90 |
+
out_channels,
|
91 |
+
kernel_size,
|
92 |
+
stride,
|
93 |
+
padding,
|
94 |
+
dilation=dilation,
|
95 |
+
groups=groups,
|
96 |
+
bias=bias,
|
97 |
+
)
|
98 |
+
]
|
99 |
+
|
100 |
+
if norm_layer is not None:
|
101 |
+
layers.append(norm_layer(out_channels))
|
102 |
+
|
103 |
+
if activation_layer is not None:
|
104 |
+
params = {} if inplace is None else {"inplace": inplace}
|
105 |
+
layers.append(activation_layer(**params))
|
106 |
+
super().__init__(*layers)
|
107 |
+
# _log_api_usage_once(self)
|
108 |
+
self.out_channels = out_channels
|
109 |
+
|
110 |
+
if self.__class__ == ConvNormActivation:
|
111 |
+
warnings.warn(
|
112 |
+
"Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
class Conv2dNormActivation(ConvNormActivation):
|
117 |
+
"""
|
118 |
+
Configurable block used for Convolution2d-Normalization-Activation blocks.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
in_channels (int): Number of channels in the input image
|
122 |
+
out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
|
123 |
+
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
|
124 |
+
stride (int, optional): Stride of the convolution. Default: 1
|
125 |
+
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
|
126 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
127 |
+
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm2d``
|
128 |
+
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
|
129 |
+
dilation (int): Spacing between kernel elements. Default: 1
|
130 |
+
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
|
131 |
+
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
|
132 |
+
|
133 |
+
"""
|
134 |
+
|
135 |
+
def __init__(
|
136 |
+
self,
|
137 |
+
in_channels: int,
|
138 |
+
out_channels: int,
|
139 |
+
kernel_size: int = 3,
|
140 |
+
stride: int = 1,
|
141 |
+
padding: Optional[int] = None,
|
142 |
+
groups: int = 1,
|
143 |
+
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
|
144 |
+
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
145 |
+
dilation: int = 1,
|
146 |
+
inplace: Optional[bool] = True,
|
147 |
+
bias: Optional[bool] = None,
|
148 |
+
) -> None:
|
149 |
+
|
150 |
+
super().__init__(
|
151 |
+
in_channels,
|
152 |
+
out_channels,
|
153 |
+
kernel_size,
|
154 |
+
stride,
|
155 |
+
padding,
|
156 |
+
groups,
|
157 |
+
norm_layer,
|
158 |
+
activation_layer,
|
159 |
+
dilation,
|
160 |
+
inplace,
|
161 |
+
bias,
|
162 |
+
torch.nn.Conv2d,
|
163 |
+
)
|
164 |
+
|
165 |
+
|
166 |
+
class Conv3dNormActivation(ConvNormActivation):
|
167 |
+
"""
|
168 |
+
Configurable block used for Convolution3d-Normalization-Activation blocks.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
in_channels (int): Number of channels in the input video.
|
172 |
+
out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
|
173 |
+
kernel_size: (int, optional): Size of the convolving kernel. Default: 3
|
174 |
+
stride (int, optional): Stride of the convolution. Default: 1
|
175 |
+
padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will calculated as ``padding = (kernel_size - 1) // 2 * dilation``
|
176 |
+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
177 |
+
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``torch.nn.BatchNorm3d``
|
178 |
+
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
|
179 |
+
dilation (int): Spacing between kernel elements. Default: 1
|
180 |
+
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
|
181 |
+
bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
|
182 |
+
"""
|
183 |
+
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
in_channels: int,
|
187 |
+
out_channels: int,
|
188 |
+
kernel_size: int = 3,
|
189 |
+
stride: int = 1,
|
190 |
+
padding: Optional[int] = None,
|
191 |
+
groups: int = 1,
|
192 |
+
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
|
193 |
+
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
194 |
+
dilation: int = 1,
|
195 |
+
inplace: Optional[bool] = True,
|
196 |
+
bias: Optional[bool] = None,
|
197 |
+
) -> None:
|
198 |
+
|
199 |
+
super().__init__(
|
200 |
+
in_channels,
|
201 |
+
out_channels,
|
202 |
+
kernel_size,
|
203 |
+
stride,
|
204 |
+
padding,
|
205 |
+
groups,
|
206 |
+
norm_layer,
|
207 |
+
activation_layer,
|
208 |
+
dilation,
|
209 |
+
inplace,
|
210 |
+
bias,
|
211 |
+
torch.nn.Conv3d,
|
212 |
+
)
|
213 |
+
|
214 |
+
|
215 |
+
class SqueezeExcitation(torch.nn.Module):
|
216 |
+
"""
|
217 |
+
This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
|
218 |
+
Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in eq. 3.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
input_channels (int): Number of channels in the input image
|
222 |
+
squeeze_channels (int): Number of squeeze channels
|
223 |
+
activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``
|
224 |
+
scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``
|
225 |
+
"""
|
226 |
+
|
227 |
+
def __init__(
|
228 |
+
self,
|
229 |
+
input_channels: int,
|
230 |
+
squeeze_channels: int,
|
231 |
+
activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
|
232 |
+
scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
|
233 |
+
) -> None:
|
234 |
+
super().__init__()
|
235 |
+
# _log_api_usage_once(self)
|
236 |
+
self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
|
237 |
+
self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)
|
238 |
+
self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)
|
239 |
+
self.activation = activation()
|
240 |
+
self.scale_activation = scale_activation()
|
241 |
+
|
242 |
+
def _scale(self, input: Tensor) -> Tensor:
|
243 |
+
scale = self.avgpool(input)
|
244 |
+
scale = self.fc1(scale)
|
245 |
+
scale = self.activation(scale)
|
246 |
+
scale = self.fc2(scale)
|
247 |
+
return self.scale_activation(scale)
|
248 |
+
|
249 |
+
def forward(self, input: Tensor) -> Tensor:
|
250 |
+
scale = self._scale(input)
|
251 |
+
return scale * input
|
252 |
+
|
253 |
+
|
254 |
+
class MLP(torch.nn.Sequential):
|
255 |
+
"""This block implements the multi-layer perceptron (MLP) module.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
in_channels (int): Number of channels of the input
|
259 |
+
hidden_channels (List[int]): List of the hidden channel dimensions
|
260 |
+
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer wont be used. Default: ``None``
|
261 |
+
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer wont be used. Default: ``torch.nn.ReLU``
|
262 |
+
inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
|
263 |
+
bias (bool): Whether to use bias in the linear layer. Default ``True``
|
264 |
+
dropout (float): The probability for the dropout layer. Default: 0.0
|
265 |
+
"""
|
266 |
+
|
267 |
+
def __init__(
|
268 |
+
self,
|
269 |
+
in_channels: int,
|
270 |
+
hidden_channels: List[int],
|
271 |
+
norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
|
272 |
+
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
|
273 |
+
inplace: Optional[bool] = True,
|
274 |
+
bias: bool = True,
|
275 |
+
dropout: float = 0.0,
|
276 |
+
):
|
277 |
+
# The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
|
278 |
+
# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
|
279 |
+
params = {} if inplace is None else {"inplace": inplace}
|
280 |
+
|
281 |
+
layers = []
|
282 |
+
in_dim = in_channels
|
283 |
+
for hidden_dim in hidden_channels[:-1]:
|
284 |
+
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
|
285 |
+
if norm_layer is not None:
|
286 |
+
layers.append(norm_layer(hidden_dim))
|
287 |
+
layers.append(activation_layer(**params))
|
288 |
+
layers.append(torch.nn.Dropout(dropout, **params))
|
289 |
+
in_dim = hidden_dim
|
290 |
+
|
291 |
+
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
|
292 |
+
layers.append(torch.nn.Dropout(dropout, **params))
|
293 |
+
|
294 |
+
super().__init__(*layers)
|
295 |
+
# _log_api_usage_once(self)
|
296 |
+
|
297 |
+
|
298 |
+
class Permute(torch.nn.Module):
|
299 |
+
"""This module returns a view of the tensor input with its dimensions permuted.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
dims (List[int]): The desired ordering of dimensions
|
303 |
+
"""
|
304 |
+
|
305 |
+
def __init__(self, dims: List[int]):
|
306 |
+
super().__init__()
|
307 |
+
self.dims = dims
|
308 |
+
|
309 |
+
def forward(self, x: Tensor) -> Tensor:
|
310 |
+
return torch.permute(x, self.dims)
|