Spaces:
Runtime error
Runtime error
Upload encoders/_base.py
Browse files- encoders/_base.py +53 -0
encoders/_base.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import List
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
|
| 6 |
+
from . import _utils as utils
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class EncoderMixin:
|
| 10 |
+
"""Add encoder functionality such as:
|
| 11 |
+
- output channels specification of feature tensors (produced by encoder)
|
| 12 |
+
- patching first convolution for arbitrary input channels
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
@property
|
| 16 |
+
def out_channels(self):
|
| 17 |
+
"""Return channels dimensions for each tensor of forward output of encoder"""
|
| 18 |
+
return self._out_channels[: self._depth + 1]
|
| 19 |
+
|
| 20 |
+
def set_in_channels(self, in_channels, pretrained=True):
|
| 21 |
+
"""Change first convolution channels"""
|
| 22 |
+
if in_channels == 3:
|
| 23 |
+
return
|
| 24 |
+
|
| 25 |
+
self._in_channels = in_channels
|
| 26 |
+
if self._out_channels[0] == 3:
|
| 27 |
+
self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])
|
| 28 |
+
|
| 29 |
+
utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained)
|
| 30 |
+
|
| 31 |
+
def get_stages(self):
|
| 32 |
+
"""Method should be overridden in encoder"""
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
|
| 35 |
+
def make_dilated(self, output_stride):
|
| 36 |
+
|
| 37 |
+
if output_stride == 16:
|
| 38 |
+
stage_list=[5,]
|
| 39 |
+
dilation_list=[2,]
|
| 40 |
+
|
| 41 |
+
elif output_stride == 8:
|
| 42 |
+
stage_list=[4, 5]
|
| 43 |
+
dilation_list=[2, 4]
|
| 44 |
+
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError("Output stride should be 16 or 8, got {}.".format(output_stride))
|
| 47 |
+
|
| 48 |
+
stages = self.get_stages()
|
| 49 |
+
for stage_indx, dilation_rate in zip(stage_list, dilation_list):
|
| 50 |
+
utils.replace_strides_with_dilation(
|
| 51 |
+
module=stages[stage_indx],
|
| 52 |
+
dilation_rate=dilation_rate,
|
| 53 |
+
)
|