| |
| |
|
|
| import fvcore.nn.weight_init as weight_init |
| from torch import nn |
|
|
| from .batch_norm import FrozenBatchNorm2d, get_norm |
| from .wrappers import Conv2d |
|
|
|
|
| """ |
| CNN building blocks. |
| """ |
|
|
|
|
| class CNNBlockBase(nn.Module): |
| """ |
| A CNN block is assumed to have input channels, output channels and a stride. |
| The input and output of `forward()` method must be NCHW tensors. |
| The method can perform arbitrary computation but must match the given |
| channels and stride specification. |
| |
| Attribute: |
| in_channels (int): |
| out_channels (int): |
| stride (int): |
| """ |
|
|
| def __init__(self, in_channels, out_channels, stride): |
| """ |
| The `__init__` method of any subclass should also contain these arguments. |
| |
| Args: |
| in_channels (int): |
| out_channels (int): |
| stride (int): |
| """ |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.stride = stride |
|
|
| def freeze(self): |
| """ |
| Make this block not trainable. |
| This method sets all parameters to `requires_grad=False`, |
| and convert all BatchNorm layers to FrozenBatchNorm |
| |
| Returns: |
| the block itself |
| """ |
| for p in self.parameters(): |
| p.requires_grad = False |
| FrozenBatchNorm2d.convert_frozen_batchnorm(self) |
| return self |
|
|
|
|
| class DepthwiseSeparableConv2d(nn.Module): |
| """ |
| A kxk depthwise convolution + a 1x1 convolution. |
| |
| In :paper:`xception`, norm & activation are applied on the second conv. |
| :paper:`mobilenet` uses norm & activation on both convs. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| kernel_size=3, |
| padding=1, |
| dilation=1, |
| *, |
| norm1=None, |
| activation1=None, |
| norm2=None, |
| activation2=None, |
| ): |
| """ |
| Args: |
| norm1, norm2 (str or callable): normalization for the two conv layers. |
| activation1, activation2 (callable(Tensor) -> Tensor): activation |
| function for the two conv layers. |
| """ |
| super().__init__() |
| self.depthwise = Conv2d( |
| in_channels, |
| in_channels, |
| kernel_size=kernel_size, |
| padding=padding, |
| dilation=dilation, |
| groups=in_channels, |
| bias=not norm1, |
| norm=get_norm(norm1, in_channels), |
| activation=activation1, |
| ) |
| self.pointwise = Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size=1, |
| bias=not norm2, |
| norm=get_norm(norm2, out_channels), |
| activation=activation2, |
| ) |
|
|
| |
| weight_init.c2_msra_fill(self.depthwise) |
| weight_init.c2_msra_fill(self.pointwise) |
|
|
| def forward(self, x): |
| return self.pointwise(self.depthwise(x)) |
|
|