Spaces:
Runtime error
Runtime error
""" | |
UNet Network in PyTorch, modified from https://github.com/milesial/Pytorch-UNet | |
with architecture referenced from https://keras.io/examples/vision/depth_estimation | |
for monocular depth estimation from RGB images, i.e. one output channel. | |
""" | |
import torch | |
from torch import nn | |
class UNet(nn.Module): | |
""" | |
The overall UNet architecture. | |
""" | |
def __init__(self): | |
super().__init__() | |
self.downscale_blocks = nn.ModuleList( | |
[ | |
DownBlock(16, 32), | |
DownBlock(32, 64), | |
DownBlock(64, 128), | |
DownBlock(128, 256), | |
] | |
) | |
self.upscale_blocks = nn.ModuleList( | |
[ | |
UpBlock(256, 128), | |
UpBlock(128, 64), | |
UpBlock(64, 32), | |
UpBlock(32, 16), | |
] | |
) | |
self.input_conv = nn.Conv2d(3, 16, kernel_size=3, padding="same") | |
self.output_conv = nn.Conv2d(16, 1, kernel_size=1) | |
self.bridge = BottleNeckBlock(256) | |
self.activation = nn.Sigmoid() | |
def forward(self, x): | |
x = self.input_conv(x) | |
skip_features = [] | |
for block in self.downscale_blocks: | |
c, x = block(x) | |
skip_features.append(c) | |
x = self.bridge(x) | |
skip_features.reverse() | |
for block, skip in zip(self.upscale_blocks, skip_features): | |
x = block(x, skip) | |
x = self.output_conv(x) | |
x = self.activation(x) | |
return x | |
class DownBlock(nn.Module): | |
""" | |
Module that performs downscaling with residual connections. | |
""" | |
def __init__(self, in_channels, out_channels, padding="same", stride=1): | |
super().__init__() | |
self.conv1 = nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=stride, | |
padding=padding, | |
bias=False, | |
) | |
self.conv2 = nn.Conv2d( | |
out_channels, | |
out_channels, | |
kernel_size=3, | |
stride=stride, | |
padding=padding, | |
bias=False, | |
) | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
self.bn2 = nn.BatchNorm2d(out_channels) | |
self.relu = nn.LeakyReLU(0.2) | |
self.maxpool = nn.MaxPool2d(2) | |
def forward(self, x): | |
d = self.conv1(x) | |
x = self.bn1(d) | |
x = self.relu(x) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
x = self.relu(x) | |
x = x + d | |
p = self.maxpool(x) | |
return x, p | |
class UpBlock(nn.Module): | |
""" | |
Module that performs upscaling after concatenation with skip connections. | |
""" | |
def __init__(self, in_channels, out_channels, padding="same", stride=1): | |
super().__init__() | |
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) | |
self.conv1 = nn.Conv2d( | |
in_channels * 2, | |
in_channels, | |
kernel_size=3, | |
stride=stride, | |
padding=padding, | |
bias=False, | |
) | |
self.conv2 = nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=stride, | |
padding=padding, | |
bias=False, | |
) | |
self.bn1 = nn.BatchNorm2d(in_channels) | |
self.bn2 = nn.BatchNorm2d(out_channels) | |
self.relu = nn.LeakyReLU(0.2) | |
def forward(self, x, skip): | |
x = self.up(x) | |
x = torch.cat([x, skip], dim=1) | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.relu(x) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
x = self.relu(x) | |
return x | |
class BottleNeckBlock(nn.Module): | |
""" | |
BottleNeckBlock that serves as the UNet bridge. | |
""" | |
def __init__(self, channels, padding="same", strides=1): | |
super().__init__() | |
self.conv1 = nn.Conv2d(channels, channels, 3, 1, "same") | |
self.conv2 = nn.Conv2d(channels, channels, 3, 1, "same") | |
self.relu = nn.LeakyReLU(0.2) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = self.relu(x) | |
x = self.conv2(x) | |
x = self.relu(x) | |
return x |