digit-draw-detect / src /model_architecture.py
Andrey
Update code following deepsource checks.
99d3d67
raw
history blame
4.64 kB
import torch
import torch.nn as nn
class CNNBlock(nn.Module):
def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
self.bn = nn.BatchNorm2d(out_channels)
self.leaky = nn.LeakyReLU(0.1)
self.use_bn_act = bn_act
def forward(self, x):
if self.use_bn_act:
return self.leaky(self.bn(self.conv(x)))
else:
return self.conv(x)
class ResidualBlock(nn.Module):
def __init__(self, channels, use_residual=True, num_repeats=1):
super().__init__()
self.layers = nn.ModuleList()
for _ in range(num_repeats):
self.layers += [
nn.Sequential(
CNNBlock(channels, channels // 2, kernel_size=1),
CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
)
]
self.use_residual = use_residual
self.num_repeats = num_repeats
def forward(self, x):
for layer in self.layers:
if self.use_residual:
x = x + layer(x)
else:
x = layer(x)
return x
class ScalePrediction(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.pred = nn.Sequential(
CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
CNNBlock(2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1),
)
self.num_classes = num_classes
def forward(self, x):
return self.pred(x).reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]).permute(0, 1, 3, 4, 2)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.num_classes = 12
self.in_channels = 3
self.config = [
(32, 3, 1),
(64, 3, 2),
['B', 1],
(128, 3, 2),
['B', 2],
(256, 3, 2),
['B', 8],
(512, 3, 2),
['B', 8],
(1024, 3, 2),
['B', 4],
(512, 1, 1),
(1024, 3, 1),
'S',
(256, 1, 1),
'U',
(256, 1, 1),
(512, 3, 1),
'S',
(128, 1, 1),
'U',
(128, 1, 1),
(256, 3, 1),
'S',
]
self.layers = self._create_conv_layers()
def forward(self, x):
outputs = [] # for each scale
route_connections = []
for layer in self.layers:
if isinstance(layer, ScalePrediction):
outputs.append(layer(x))
continue
x = layer(x)
if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
route_connections.append(x)
elif isinstance(layer, nn.Upsample):
x = torch.cat([x, route_connections[-1]], dim=1)
route_connections.pop()
return outputs
def _create_conv_layers(self):
layers = nn.ModuleList()
in_channels = self.in_channels
for module in self.config:
if isinstance(module, tuple):
out_channels, kernel_size, stride = module
layers.append(
CNNBlock(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=1 if kernel_size == 3 else 0,
)
)
in_channels = out_channels
elif isinstance(module, list):
num_repeats = module[1]
layers.append(
ResidualBlock(
in_channels,
num_repeats=num_repeats,
)
)
elif isinstance(module, str):
if module == 'S':
layers += [
ResidualBlock(in_channels, use_residual=False, num_repeats=1),
CNNBlock(in_channels, in_channels // 2, kernel_size=1),
ScalePrediction(in_channels // 2, num_classes=self.num_classes),
]
in_channels = in_channels // 2
elif module == 'U':
layers.append(
nn.Upsample(scale_factor=2),
)
in_channels = in_channels * 3
return layers