|
import torch.nn as nn |
|
|
|
from .common import ResBlock, default_conv |
|
|
|
def encoder(in_channels, n_feats): |
|
"""RGB / IR feature encoder |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
return nn.Sequential( |
|
nn.Conv2d(in_channels, 1 * n_feats, 5, stride=1, padding=2), |
|
nn.Conv2d(1 * n_feats, 2 * n_feats, 5, stride=2, padding=2), |
|
nn.Conv2d(2 * n_feats, 3 * n_feats, 5, stride=2, padding=2), |
|
) |
|
|
|
def decoder(out_channels, n_feats): |
|
"""RGB / IR / Depth decoder |
|
""" |
|
|
|
|
|
|
|
deconv_kargs = {'stride': 2, 'padding': 1, 'output_padding': 1} |
|
|
|
return nn.Sequential( |
|
nn.ConvTranspose2d(3 * n_feats, 2 * n_feats, 3, **deconv_kargs), |
|
nn.ConvTranspose2d(2 * n_feats, 1 * n_feats, 3, **deconv_kargs), |
|
nn.Conv2d(n_feats, out_channels, 5, stride=1, padding=2), |
|
) |
|
|
|
|
|
def ResNet(n_feats, kernel_size, n_blocks, in_channels=None, out_channels=None): |
|
"""sequential ResNet |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
m = [] |
|
|
|
if in_channels is not None: |
|
m += [default_conv(in_channels, n_feats, kernel_size)] |
|
|
|
m += [ResBlock(n_feats, 3)] * n_blocks |
|
|
|
if out_channels is not None: |
|
m += [default_conv(n_feats, out_channels, kernel_size)] |
|
|
|
|
|
return nn.Sequential(*m) |
|
|
|
|