test2 / tests /test_models /test_unet.py
mccaly's picture
Upload 660 files
b13b124
raw
history blame
No virus
30.5 kB
import pytest
import torch
from mmcv.cnn import ConvModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from torch import nn
from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule,
InterpConv, UNet, UpConvBlock)
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_unet_basic_conv_block():
with pytest.raises(AssertionError):
# Not implemented yet.
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
BasicConvBlock(64, 64, dcn=dcn)
with pytest.raises(AssertionError):
# Not implemented yet.
plugins = [
dict(
cfg=dict(type='ContextBlock', ratio=1. / 16),
position='after_conv3')
]
BasicConvBlock(64, 64, plugins=plugins)
with pytest.raises(AssertionError):
# Not implemented yet
plugins = [
dict(
cfg=dict(
type='GeneralizedAttention',
spatial_range=-1,
num_heads=8,
attention_type='0010',
kv_stride=2),
position='after_conv2')
]
BasicConvBlock(64, 64, plugins=plugins)
# test BasicConvBlock with checkpoint forward
block = BasicConvBlock(16, 16, with_cp=True)
assert block.with_cp
x = torch.randn(1, 16, 64, 64, requires_grad=True)
x_out = block(x)
assert x_out.shape == torch.Size([1, 16, 64, 64])
block = BasicConvBlock(16, 16, with_cp=False)
assert not block.with_cp
x = torch.randn(1, 16, 64, 64)
x_out = block(x)
assert x_out.shape == torch.Size([1, 16, 64, 64])
# test BasicConvBlock with stride convolution to downsample
block = BasicConvBlock(16, 16, stride=2)
x = torch.randn(1, 16, 64, 64)
x_out = block(x)
assert x_out.shape == torch.Size([1, 16, 32, 32])
# test BasicConvBlock structure and forward
block = BasicConvBlock(16, 64, num_convs=3, dilation=3)
assert block.convs[0].conv.in_channels == 16
assert block.convs[0].conv.out_channels == 64
assert block.convs[0].conv.kernel_size == (3, 3)
assert block.convs[0].conv.dilation == (1, 1)
assert block.convs[0].conv.padding == (1, 1)
assert block.convs[1].conv.in_channels == 64
assert block.convs[1].conv.out_channels == 64
assert block.convs[1].conv.kernel_size == (3, 3)
assert block.convs[1].conv.dilation == (3, 3)
assert block.convs[1].conv.padding == (3, 3)
assert block.convs[2].conv.in_channels == 64
assert block.convs[2].conv.out_channels == 64
assert block.convs[2].conv.kernel_size == (3, 3)
assert block.convs[2].conv.dilation == (3, 3)
assert block.convs[2].conv.padding == (3, 3)
def test_deconv_module():
with pytest.raises(AssertionError):
# kernel_size should be greater than or equal to scale_factor and
# (kernel_size - scale_factor) should be even numbers
DeconvModule(64, 32, kernel_size=1, scale_factor=2)
with pytest.raises(AssertionError):
# kernel_size should be greater than or equal to scale_factor and
# (kernel_size - scale_factor) should be even numbers
DeconvModule(64, 32, kernel_size=3, scale_factor=2)
with pytest.raises(AssertionError):
# kernel_size should be greater than or equal to scale_factor and
# (kernel_size - scale_factor) should be even numbers
DeconvModule(64, 32, kernel_size=5, scale_factor=4)
# test DeconvModule with checkpoint forward and upsample 2X.
block = DeconvModule(64, 32, with_cp=True)
assert block.with_cp
x = torch.randn(1, 64, 128, 128, requires_grad=True)
x_out = block(x)
assert x_out.shape == torch.Size([1, 32, 256, 256])
block = DeconvModule(64, 32, with_cp=False)
assert not block.with_cp
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert x_out.shape == torch.Size([1, 32, 256, 256])
# test DeconvModule with different kernel size for upsample 2X.
x = torch.randn(1, 64, 64, 64)
block = DeconvModule(64, 32, kernel_size=2, scale_factor=2)
x_out = block(x)
assert x_out.shape == torch.Size([1, 32, 128, 128])
block = DeconvModule(64, 32, kernel_size=6, scale_factor=2)
x_out = block(x)
assert x_out.shape == torch.Size([1, 32, 128, 128])
# test DeconvModule with different kernel size for upsample 4X.
x = torch.randn(1, 64, 64, 64)
block = DeconvModule(64, 32, kernel_size=4, scale_factor=4)
x_out = block(x)
assert x_out.shape == torch.Size([1, 32, 256, 256])
block = DeconvModule(64, 32, kernel_size=6, scale_factor=4)
x_out = block(x)
assert x_out.shape == torch.Size([1, 32, 256, 256])
def test_interp_conv():
# test InterpConv with checkpoint forward and upsample 2X.
block = InterpConv(64, 32, with_cp=True)
assert block.with_cp
x = torch.randn(1, 64, 128, 128, requires_grad=True)
x_out = block(x)
assert x_out.shape == torch.Size([1, 32, 256, 256])
block = InterpConv(64, 32, with_cp=False)
assert not block.with_cp
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert x_out.shape == torch.Size([1, 32, 256, 256])
# test InterpConv with conv_first=False for upsample 2X.
block = InterpConv(64, 32, conv_first=False)
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample)
assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256])
# test InterpConv with conv_first=True for upsample 2X.
block = InterpConv(64, 32, conv_first=True)
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], ConvModule)
assert isinstance(block.interp_upsample[1], nn.Upsample)
assert x_out.shape == torch.Size([1, 32, 256, 256])
# test InterpConv with bilinear upsample for upsample 2X.
block = InterpConv(
64,
32,
conv_first=False,
upsampe_cfg=dict(scale_factor=2, mode='bilinear', align_corners=False))
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample)
assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256])
assert block.interp_upsample[0].mode == 'bilinear'
# test InterpConv with nearest upsample for upsample 2X.
block = InterpConv(
64,
32,
conv_first=False,
upsampe_cfg=dict(scale_factor=2, mode='nearest'))
x = torch.randn(1, 64, 128, 128)
x_out = block(x)
assert isinstance(block.interp_upsample[0], nn.Upsample)
assert isinstance(block.interp_upsample[1], ConvModule)
assert x_out.shape == torch.Size([1, 32, 256, 256])
assert block.interp_upsample[0].mode == 'nearest'
def test_up_conv_block():
with pytest.raises(AssertionError):
# Not implemented yet.
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
UpConvBlock(BasicConvBlock, 64, 32, 32, dcn=dcn)
with pytest.raises(AssertionError):
# Not implemented yet.
plugins = [
dict(
cfg=dict(type='ContextBlock', ratio=1. / 16),
position='after_conv3')
]
UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins)
with pytest.raises(AssertionError):
# Not implemented yet
plugins = [
dict(
cfg=dict(
type='GeneralizedAttention',
spatial_range=-1,
num_heads=8,
attention_type='0010',
kv_stride=2),
position='after_conv2')
]
UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins)
# test UpConvBlock with checkpoint forward and upsample 2X.
block = UpConvBlock(BasicConvBlock, 64, 32, 32, with_cp=True)
skip_x = torch.randn(1, 32, 256, 256, requires_grad=True)
x = torch.randn(1, 64, 128, 128, requires_grad=True)
x_out = block(skip_x, x)
assert x_out.shape == torch.Size([1, 32, 256, 256])
# test UpConvBlock with upsample=True for upsample 2X. The spatial size of
# skip_x is 2X larger than x.
block = UpConvBlock(
BasicConvBlock, 64, 32, 32, upsample_cfg=dict(type='InterpConv'))
skip_x = torch.randn(1, 32, 256, 256)
x = torch.randn(1, 64, 128, 128)
x_out = block(skip_x, x)
assert x_out.shape == torch.Size([1, 32, 256, 256])
# test UpConvBlock with upsample=False for upsample 2X. The spatial size of
# skip_x is the same as that of x.
block = UpConvBlock(BasicConvBlock, 64, 32, 32, upsample_cfg=None)
skip_x = torch.randn(1, 32, 256, 256)
x = torch.randn(1, 64, 256, 256)
x_out = block(skip_x, x)
assert x_out.shape == torch.Size([1, 32, 256, 256])
# test UpConvBlock with different upsample method for upsample 2X.
# The upsample method is interpolation upsample (bilinear or nearest).
block = UpConvBlock(
BasicConvBlock,
64,
32,
32,
upsample_cfg=dict(
type='InterpConv',
upsampe_cfg=dict(
scale_factor=2, mode='bilinear', align_corners=False)))
skip_x = torch.randn(1, 32, 256, 256)
x = torch.randn(1, 64, 128, 128)
x_out = block(skip_x, x)
assert x_out.shape == torch.Size([1, 32, 256, 256])
# test UpConvBlock with different upsample method for upsample 2X.
# The upsample method is deconvolution upsample.
block = UpConvBlock(
BasicConvBlock,
64,
32,
32,
upsample_cfg=dict(type='DeconvModule', kernel_size=4, scale_factor=2))
skip_x = torch.randn(1, 32, 256, 256)
x = torch.randn(1, 64, 128, 128)
x_out = block(skip_x, x)
assert x_out.shape == torch.Size([1, 32, 256, 256])
# test BasicConvBlock structure and forward
block = UpConvBlock(
conv_block=BasicConvBlock,
in_channels=64,
skip_channels=32,
out_channels=32,
num_convs=3,
dilation=3,
upsample_cfg=dict(
type='InterpConv',
upsampe_cfg=dict(
scale_factor=2, mode='bilinear', align_corners=False)))
skip_x = torch.randn(1, 32, 256, 256)
x = torch.randn(1, 64, 128, 128)
x_out = block(skip_x, x)
assert x_out.shape == torch.Size([1, 32, 256, 256])
assert block.conv_block.convs[0].conv.in_channels == 64
assert block.conv_block.convs[0].conv.out_channels == 32
assert block.conv_block.convs[0].conv.kernel_size == (3, 3)
assert block.conv_block.convs[0].conv.dilation == (1, 1)
assert block.conv_block.convs[0].conv.padding == (1, 1)
assert block.conv_block.convs[1].conv.in_channels == 32
assert block.conv_block.convs[1].conv.out_channels == 32
assert block.conv_block.convs[1].conv.kernel_size == (3, 3)
assert block.conv_block.convs[1].conv.dilation == (3, 3)
assert block.conv_block.convs[1].conv.padding == (3, 3)
assert block.conv_block.convs[2].conv.in_channels == 32
assert block.conv_block.convs[2].conv.out_channels == 32
assert block.conv_block.convs[2].conv.kernel_size == (3, 3)
assert block.conv_block.convs[2].conv.dilation == (3, 3)
assert block.conv_block.convs[2].conv.padding == (3, 3)
assert block.upsample.interp_upsample[1].conv.in_channels == 64
assert block.upsample.interp_upsample[1].conv.out_channels == 32
assert block.upsample.interp_upsample[1].conv.kernel_size == (1, 1)
assert block.upsample.interp_upsample[1].conv.dilation == (1, 1)
assert block.upsample.interp_upsample[1].conv.padding == (0, 0)
def test_unet():
with pytest.raises(AssertionError):
# Not implemented yet.
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
UNet(3, 64, 5, dcn=dcn)
with pytest.raises(AssertionError):
# Not implemented yet.
plugins = [
dict(
cfg=dict(type='ContextBlock', ratio=1. / 16),
position='after_conv3')
]
UNet(3, 64, 5, plugins=plugins)
with pytest.raises(AssertionError):
# Not implemented yet
plugins = [
dict(
cfg=dict(
type='GeneralizedAttention',
spatial_range=-1,
num_heads=8,
attention_type='0010',
kv_stride=2),
position='after_conv2')
]
UNet(3, 64, 5, plugins=plugins)
with pytest.raises(AssertionError):
# Check whether the input image size can be devisible by the whole
# downsample rate of the encoder. The whole downsample rate of this
# case is 8.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=4,
strides=(1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2),
dec_num_convs=(2, 2, 2),
downsamples=(True, True, True),
enc_dilations=(1, 1, 1, 1),
dec_dilations=(1, 1, 1))
x = torch.randn(2, 3, 65, 65)
unet(x)
with pytest.raises(AssertionError):
# Check whether the input image size can be devisible by the whole
# downsample rate of the encoder. The whole downsample rate of this
# case is 16.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 65, 65)
unet(x)
with pytest.raises(AssertionError):
# Check whether the input image size can be devisible by the whole
# downsample rate of the encoder. The whole downsample rate of this
# case is 8.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 65, 65)
unet(x)
with pytest.raises(AssertionError):
# Check whether the input image size can be devisible by the whole
# downsample rate of the encoder. The whole downsample rate of this
# case is 8.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 2, 2, 2, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 65, 65)
unet(x)
with pytest.raises(AssertionError):
# Check whether the input image size can be devisible by the whole
# downsample rate of the encoder. The whole downsample rate of this
# case is 32.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=6,
strides=(1, 1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2, 2),
downsamples=(True, True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1, 1))
x = torch.randn(2, 3, 65, 65)
unet(x)
with pytest.raises(AssertionError):
# Check if num_stages matchs strides, len(strides)=num_stages
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 64, 64)
unet(x)
with pytest.raises(AssertionError):
# Check if num_stages matchs strides, len(enc_num_convs)=num_stages
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 64, 64)
unet(x)
with pytest.raises(AssertionError):
# Check if num_stages matchs strides, len(dec_num_convs)=num_stages-1
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 64, 64)
unet(x)
with pytest.raises(AssertionError):
# Check if num_stages matchs strides, len(downsamples)=num_stages-1
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 64, 64)
unet(x)
with pytest.raises(AssertionError):
# Check if num_stages matchs strides, len(enc_dilations)=num_stages
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 64, 64)
unet(x)
with pytest.raises(AssertionError):
# Check if num_stages matchs strides, len(dec_dilations)=num_stages-1
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1, 1))
x = torch.randn(2, 3, 64, 64)
unet(x)
# test UNet norm_eval=True
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1),
norm_eval=True)
unet.train()
assert check_norm_state(unet.modules(), False)
# test UNet norm_eval=False
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1),
norm_eval=False)
unet.train()
assert check_norm_state(unet.modules(), True)
# test UNet forward and outputs. The whole downsample rate is 16.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 8, 8])
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
# test UNet forward and outputs. The whole downsample rate is 8.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
# test UNet forward and outputs. The whole downsample rate is 8.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 2, 2, 2, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
# test UNet forward and outputs. The whole downsample rate is 4.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, False, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
assert x_outs[1].shape == torch.Size([2, 512, 32, 32])
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
# test UNet forward and outputs. The whole downsample rate is 4.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 2, 2, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, False, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
assert x_outs[1].shape == torch.Size([2, 512, 32, 32])
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
# test UNet forward and outputs. The whole downsample rate is 8.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
# test UNet forward and outputs. The whole downsample rate is 4.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, False, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
assert x_outs[1].shape == torch.Size([2, 512, 32, 32])
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
# test UNet forward and outputs. The whole downsample rate is 2.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, False, False, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 64, 64])
assert x_outs[1].shape == torch.Size([2, 512, 64, 64])
assert x_outs[2].shape == torch.Size([2, 256, 64, 64])
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
# test UNet forward and outputs. The whole downsample rate is 1.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 1, 1, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(False, False, False, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 128, 128])
assert x_outs[1].shape == torch.Size([2, 512, 128, 128])
assert x_outs[2].shape == torch.Size([2, 256, 128, 128])
assert x_outs[3].shape == torch.Size([2, 128, 128, 128])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
# test UNet forward and outputs. The whole downsample rate is 16.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 2, 2, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, True),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
print(unet)
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 8, 8])
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
# test UNet forward and outputs. The whole downsample rate is 8.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 2, 2, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
print(unet)
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
# test UNet forward and outputs. The whole downsample rate is 8.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 2, 2, 2, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, True, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
print(unet)
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 16, 16])
assert x_outs[1].shape == torch.Size([2, 512, 16, 16])
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
# test UNet forward and outputs. The whole downsample rate is 4.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 2, 2, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, False, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
print(unet)
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
assert x_outs[1].shape == torch.Size([2, 512, 32, 32])
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])
# test UNet init_weights method.
unet = UNet(
in_channels=3,
base_channels=64,
num_stages=5,
strides=(1, 2, 2, 1, 1),
enc_num_convs=(2, 2, 2, 2, 2),
dec_num_convs=(2, 2, 2, 2),
downsamples=(True, True, False, False),
enc_dilations=(1, 1, 1, 1, 1),
dec_dilations=(1, 1, 1, 1))
unet.init_weights(pretrained=None)
print(unet)
x = torch.randn(2, 3, 128, 128)
x_outs = unet(x)
assert x_outs[0].shape == torch.Size([2, 1024, 32, 32])
assert x_outs[1].shape == torch.Size([2, 512, 32, 32])
assert x_outs[2].shape == torch.Size([2, 256, 32, 32])
assert x_outs[3].shape == torch.Size([2, 128, 64, 64])
assert x_outs[4].shape == torch.Size([2, 64, 128, 128])