import torch import torch.nn as nn class CNNBlock(nn.Module): def __init__(self, in_channels, out_channels, bn_act=True, **kwargs): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs) self.bn = nn.BatchNorm2d(out_channels) self.leaky = nn.LeakyReLU(0.1) self.use_bn_act = bn_act def forward(self, x): if self.use_bn_act: return self.leaky(self.bn(self.conv(x))) else: return self.conv(x) class ResidualBlock(nn.Module): def __init__(self, channels, use_residual=True, num_repeats=1): super().__init__() self.layers = nn.ModuleList() for _ in range(num_repeats): self.layers += [ nn.Sequential( CNNBlock(channels, channels // 2, kernel_size=1), CNNBlock(channels // 2, channels, kernel_size=3, padding=1), ) ] self.use_residual = use_residual self.num_repeats = num_repeats def forward(self, x): for layer in self.layers: if self.use_residual: x = x + layer(x) else: x = layer(x) return x class ScalePrediction(nn.Module): def __init__(self, in_channels, num_classes): super().__init__() self.pred = nn.Sequential( CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1), CNNBlock(2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1), ) self.num_classes = num_classes def forward(self, x): return self.pred(x).reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]).permute(0, 1, 3, 4, 2) class Net(nn.Module): def __init__(self): super().__init__() self.num_classes = 12 self.in_channels = 3 self.config = [ (32, 3, 1), (64, 3, 2), ['B', 1], (128, 3, 2), ['B', 2], (256, 3, 2), ['B', 8], (512, 3, 2), ['B', 8], (1024, 3, 2), ['B', 4], (512, 1, 1), (1024, 3, 1), 'S', (256, 1, 1), 'U', (256, 1, 1), (512, 3, 1), 'S', (128, 1, 1), 'U', (128, 1, 1), (256, 3, 1), 'S', ] self.layers = self._create_conv_layers() def forward(self, x): outputs = [] # for each scale route_connections = [] for layer in self.layers: if isinstance(layer, ScalePrediction): outputs.append(layer(x)) continue x = layer(x) if isinstance(layer, ResidualBlock) and layer.num_repeats == 8: route_connections.append(x) elif isinstance(layer, nn.Upsample): x = torch.cat([x, route_connections[-1]], dim=1) route_connections.pop() return outputs def _create_conv_layers(self): layers = nn.ModuleList() in_channels = self.in_channels for module in self.config: if isinstance(module, tuple): out_channels, kernel_size, stride = module layers.append( CNNBlock( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1 if kernel_size == 3 else 0, ) ) in_channels = out_channels elif isinstance(module, list): num_repeats = module[1] layers.append( ResidualBlock( in_channels, num_repeats=num_repeats, ) ) elif isinstance(module, str): if module == 'S': layers += [ ResidualBlock(in_channels, use_residual=False, num_repeats=1), CNNBlock(in_channels, in_channels // 2, kernel_size=1), ScalePrediction(in_channels // 2, num_classes=self.num_classes), ] in_channels = in_channels // 2 elif module == 'U': layers.append( nn.Upsample(scale_factor=2), ) in_channels = in_channels * 3 return layers