| | """ |
| | Implementation of YOLOv3 architecture |
| | """ |
| |
|
| | from typing import Any, Dict |
| | from lightning.pytorch.utilities.types import STEP_OUTPUT |
| | import torch |
| | import torch.nn as nn |
| | import lightning as L |
| |
|
| | import config as config_ |
| | from utils.common import one_cycle_lr |
| | from utils.data import PascalDataModule |
| | from utils.loss import YoloLoss |
| | from utils.utils import ( |
| | mean_average_precision, |
| | cells_to_bboxes, |
| | get_evaluation_bboxes, |
| | save_checkpoint, |
| | load_checkpoint, |
| | check_class_accuracy, |
| | get_loaders, |
| | plot_couple_examples, |
| | ) |
| |
|
| |
|
| | """ |
| | Information about architecture config: |
| | Tuple is structured by (filters, kernel_size, stride) |
| | Every conv is a same convolution. |
| | List is structured by "B" indicating a residual block followed by the number of repeats |
| | "S" is for scale prediction block and computing the yolo loss |
| | "U" is for upsampling the feature map and concatenating with a previous layer |
| | """ |
| | config = [ |
| | (32, 3, 1), |
| | (64, 3, 2), |
| | ["B", 1], |
| | (128, 3, 2), |
| | ["B", 2], |
| | (256, 3, 2), |
| | ["B", 8], |
| | (512, 3, 2), |
| | ["B", 8], |
| | (1024, 3, 2), |
| | ["B", 4], |
| | (512, 1, 1), |
| | (1024, 3, 1), |
| | "S", |
| | (256, 1, 1), |
| | "U", |
| | (256, 1, 1), |
| | (512, 3, 1), |
| | "S", |
| | (128, 1, 1), |
| | "U", |
| | (128, 1, 1), |
| | (256, 3, 1), |
| | "S", |
| | ] |
| |
|
| |
|
| | class CNNBlock(L.LightningModule): |
| | def __init__(self, in_channels, out_channels, bn_act=True, **kwargs): |
| | super().__init__() |
| | self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs) |
| | self.bn = nn.BatchNorm2d(out_channels) |
| | self.leaky = nn.LeakyReLU(0.1) |
| | self.use_bn_act = bn_act |
| |
|
| | def forward(self, x): |
| | if self.use_bn_act: |
| | return self.leaky(self.bn(self.conv(x))) |
| | else: |
| | return self.conv(x) |
| |
|
| |
|
| | class ResidualBlock(L.LightningModule): |
| | def __init__(self, channels, use_residual=True, num_repeats=1): |
| | super().__init__() |
| | self.layers = nn.ModuleList() |
| | for repeat in range(num_repeats): |
| | self.layers += [ |
| | nn.Sequential( |
| | CNNBlock(channels, channels // 2, kernel_size=1), |
| | CNNBlock(channels // 2, channels, kernel_size=3, padding=1), |
| | ) |
| | ] |
| |
|
| | self.use_residual = use_residual |
| | self.num_repeats = num_repeats |
| |
|
| | def forward(self, x): |
| | for layer in self.layers: |
| | if self.use_residual: |
| | x = x + layer(x) |
| | else: |
| | x = layer(x) |
| |
|
| | return x |
| |
|
| |
|
| | class ScalePrediction(L.LightningModule): |
| | def __init__(self, in_channels, num_classes): |
| | super().__init__() |
| | self.pred = nn.Sequential( |
| | CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1), |
| | CNNBlock( |
| | 2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1 |
| | ), |
| | ) |
| | self.num_classes = num_classes |
| |
|
| | def forward(self, x): |
| | return ( |
| | self.pred(x) |
| | .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]) |
| | .permute(0, 1, 3, 4, 2) |
| | ) |
| |
|
| |
|
| | class YOLOv3(L.LightningModule): |
| | def __init__( |
| | self, |
| | in_channels=3, |
| | num_classes=80, |
| | epochs=40, |
| | loss_fn=YoloLoss, |
| | datamodule=PascalDataModule(), |
| | learning_rate=None, |
| | maxlr=None, |
| | scheduler_steps=None, |
| | device_count=2, |
| | ): |
| | super().__init__() |
| | self.num_classes = num_classes |
| | self.in_channels = in_channels |
| | self.epochs = epochs |
| | self.loss_fn = loss_fn() |
| | self.layers = self._create_conv_layers() |
| | self.scaled_anchors = torch.tensor(config_.ANCHORS) * torch.tensor( |
| | config_.S |
| | ).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2).to(self.device) |
| | self.datamodule = datamodule |
| | self.learning_rate = learning_rate |
| | self.maxlr = maxlr |
| | self.scheduler_steps = scheduler_steps |
| | self.device_count = device_count |
| |
|
| | def forward(self, x): |
| | outputs = [] |
| | route_connections = [] |
| | for layer in self.layers: |
| | if isinstance(layer, ScalePrediction): |
| | outputs.append(layer(x)) |
| | continue |
| |
|
| | x = layer(x) |
| |
|
| | if isinstance(layer, ResidualBlock) and layer.num_repeats == 8: |
| | route_connections.append(x) |
| |
|
| | elif isinstance(layer, nn.Upsample): |
| | x = torch.cat([x, route_connections[-1]], dim=1) |
| | route_connections.pop() |
| |
|
| | return outputs |
| |
|
| | def _create_conv_layers(self): |
| | layers = nn.ModuleList() |
| | in_channels = self.in_channels |
| |
|
| | for module in config: |
| | if isinstance(module, tuple): |
| | out_channels, kernel_size, stride = module |
| | layers.append( |
| | CNNBlock( |
| | in_channels, |
| | out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | padding=1 if kernel_size == 3 else 0, |
| | ) |
| | ) |
| | in_channels = out_channels |
| |
|
| | elif isinstance(module, list): |
| | num_repeats = module[1] |
| | layers.append( |
| | ResidualBlock( |
| | in_channels, |
| | num_repeats=num_repeats, |
| | ) |
| | ) |
| |
|
| | elif isinstance(module, str): |
| | if module == "S": |
| | layers += [ |
| | ResidualBlock(in_channels, use_residual=False, num_repeats=1), |
| | CNNBlock(in_channels, in_channels // 2, kernel_size=1), |
| | ScalePrediction(in_channels // 2, num_classes=self.num_classes), |
| | ] |
| | in_channels = in_channels // 2 |
| |
|
| | elif module == "U": |
| | layers.append( |
| | nn.Upsample(scale_factor=2), |
| | ) |
| | in_channels = in_channels * 3 |
| |
|
| | return layers |
| |
|
| | def configure_optimizers(self) -> Dict: |
| | |
| | optimizer = torch.optim.Adam( |
| | self.parameters(), lr=self.learning_rate, weight_decay=config_.WEIGHT_DECAY |
| | ) |
| | scheduler = one_cycle_lr( |
| | optimizer=optimizer, |
| | maxlr=self.maxlr, |
| | steps=self.scheduler_steps, |
| | epochs=self.epochs, |
| | ) |
| | return { |
| | "optimizer": optimizer, |
| | "lr_scheduler": {"scheduler": scheduler, "interval": "step"}, |
| | } |
| | |
| | def _common_step(self, batch, batch_idx): |
| | self.scaled_anchors = self.scaled_anchors.to(self.device) |
| | x, y = batch |
| | y0, y1, y2 = y[0], y[1], y[2] |
| | out = self(x) |
| | loss = ( |
| | self.loss_fn(out[0], y0, self.scaled_anchors[0]) |
| | + self.loss_fn(out[1], y1, self.scaled_anchors[1]) |
| | + self.loss_fn(out[2], y2, self.scaled_anchors[2]) |
| | ) |
| | return loss |
| | |
| | def training_step(self, batch, batch_idx): |
| | loss = self._common_step(batch, batch_idx) |
| | self.log(name="train_loss", value=loss, on_step=True, on_epoch=True, prog_bar=True) |
| | return loss |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | loss = self._common_step(batch, batch_idx) |
| | self.log(name="val_loss", value=loss, on_step=True, on_epoch=True, prog_bar=True) |
| | return loss |
| |
|
| | def test_step(self, batch, batch_idx): |
| | class_acc, noobj_acc, obj_acc = check_class_accuracy( |
| | model=self, |
| | loader=self.datamodule.test_dataloader(), |
| | threshold=config_.CONF_THRESHOLD, |
| | ) |
| |
|
| | self.log_dict( |
| | { |
| | "class_acc": class_acc, |
| | "noobj_acc": noobj_acc, |
| | "obj_acc": obj_acc, |
| | }, |
| | prog_bar=True, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | num_classes = 20 |
| | IMAGE_SIZE = 416 |
| | model = YOLOv3(num_classes=num_classes) |
| | x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE)) |
| | out = model(x) |
| | assert model(x)[0].shape == ( |
| | 2, |
| | 3, |
| | IMAGE_SIZE // 32, |
| | IMAGE_SIZE // 32, |
| | num_classes + 5, |
| | ) |
| | assert model(x)[1].shape == ( |
| | 2, |
| | 3, |
| | IMAGE_SIZE // 16, |
| | IMAGE_SIZE // 16, |
| | num_classes + 5, |
| | ) |
| | assert model(x)[2].shape == ( |
| | 2, |
| | 3, |
| | IMAGE_SIZE // 8, |
| | IMAGE_SIZE // 8, |
| | num_classes + 5, |
| | ) |
| | print("Success!") |
| |
|