|
|
|
""" |
|
Lightning Module for Yolo v3 |
|
""" |
|
|
|
|
|
import torch |
|
import torch.optim as optim |
|
import torch.nn as nn |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.utilities.memory import garbage_collection_cuda |
|
from torch.utils.data import DataLoader |
|
from torchmetrics import MeanMetric |
|
|
|
|
|
import config |
|
from utils import check_class_accuracy |
|
from model import ScalePrediction, ResidualBlock, CNNBlock, model_config |
|
from loss import YoloLoss |
|
from dataset import YOLODataset |
|
|
|
|
|
class YOLOv3(pl.LightningModule): |
|
""" |
|
PyTorch Lightning Code for YOLOv3 |
|
""" |
|
|
|
def __init__(self, in_channels=3, num_classes=20): |
|
""" |
|
Constructor |
|
""" |
|
super().__init__() |
|
self.num_classes = num_classes |
|
self.in_channels = in_channels |
|
self.layers = self._create_conv_layers() |
|
|
|
self._learning_rate = 0.03 |
|
self.loss_fn = YoloLoss() |
|
self.scaled_anchors = config.SCALED_ANCHORS |
|
self.combined_data = None |
|
self.train_dataset = None |
|
self.test_dataset = None |
|
self.val_dataset = None |
|
self._data_directory = None |
|
self.epochs = config.NUM_EPOCHS |
|
self.batch_size = config.BATCH_SIZE |
|
self.enable_gc = "batch" |
|
self.my_train_loss = MeanMetric() |
|
self.my_val_loss = MeanMetric() |
|
|
|
def forward(self, x): |
|
outputs = [] |
|
route_connections = [] |
|
for layer in self.layers: |
|
if isinstance(layer, ScalePrediction): |
|
outputs.append(layer(x)) |
|
continue |
|
|
|
x = layer(x) |
|
|
|
if isinstance(layer, ResidualBlock) and layer.num_repeats == 8: |
|
route_connections.append(x) |
|
|
|
elif isinstance(layer, nn.Upsample): |
|
x = torch.cat([x, route_connections[-1]], dim=1) |
|
route_connections.pop() |
|
|
|
return outputs |
|
|
|
def _create_conv_layers(self): |
|
layers = nn.ModuleList() |
|
in_channels = self.in_channels |
|
|
|
for module in model_config: |
|
if isinstance(module, tuple): |
|
out_channels, kernel_size, stride = module |
|
layers.append( |
|
CNNBlock( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=1 if kernel_size == 3 else 0, |
|
) |
|
) |
|
in_channels = out_channels |
|
|
|
elif isinstance(module, list): |
|
num_repeats = module[1] |
|
layers.append(ResidualBlock(in_channels, num_repeats=num_repeats, )) |
|
|
|
elif isinstance(module, str): |
|
if module == "S": |
|
layers += [ |
|
ResidualBlock(in_channels, use_residual=False, num_repeats=1), |
|
CNNBlock(in_channels, in_channels // 2, kernel_size=1), |
|
ScalePrediction(in_channels // 2, num_classes=self.num_classes), |
|
] |
|
in_channels = in_channels // 2 |
|
|
|
elif module == "U": |
|
layers.append(nn.Upsample(scale_factor=2), ) |
|
in_channels = in_channels * 3 |
|
|
|
return layers |
|
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
""" |
|
Method to configure the optimizer and learning rate scheduler |
|
""" |
|
optimizer = optim.Adam(self.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY) |
|
|
|
|
|
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, |
|
max_lr=self._learning_rate, |
|
steps_per_epoch=len(self.train_dataloader()), |
|
epochs=self.epochs, |
|
pct_start=0.3, |
|
div_factor=100, |
|
three_phase=False, |
|
final_div_factor=100, |
|
anneal_strategy="linear" |
|
) |
|
return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}] |
|
|
|
@property |
|
def learning_rate(self) -> float: |
|
""" |
|
Method to get the learning rate value |
|
""" |
|
return self._learning_rate |
|
|
|
@learning_rate.setter |
|
def learning_rate(self, value: float): |
|
""" |
|
Method to set the learning rate value |
|
:param value: Updated value of learning rate |
|
""" |
|
self._learning_rate = value |
|
|
|
@property |
|
def data_directory(self) -> str: |
|
""" |
|
Method to return data directory |
|
""" |
|
return self._data_directory |
|
|
|
@data_directory.setter |
|
def data_directory(self, address: str): |
|
""" |
|
Method to set the data directory path |
|
""" |
|
self._data_directory = address |
|
|
|
def set_training_config(self, *, epochs, batch_size): |
|
""" |
|
Method to set parameters required for model training |
|
:param epochs: Number of epochs for which model is to be trained |
|
:param batch_size: Batch Size |
|
""" |
|
self.epochs = epochs |
|
self.batch_size = batch_size |
|
|
|
|
|
|
|
|
|
def training_step(self, train_batch, batch_index): |
|
""" |
|
Method called on training dataset to train the model |
|
:param train_batch: Batch containing images and labels |
|
:param batch_index: Index of the batch |
|
""" |
|
x, y = train_batch |
|
logits = self.forward(x) |
|
loss = ( |
|
self.loss_fn(logits[0], y[0], self.scaled_anchors[0]) |
|
+ self.loss_fn(logits[1], y[1], self.scaled_anchors[1]) |
|
+ self.loss_fn(logits[2], y[2], self.scaled_anchors[2]) |
|
) |
|
self.my_train_loss.update(loss, x.shape[0]) |
|
self.log("train_loss", loss, prog_bar=True, on_epoch=True) |
|
|
|
del x, y, logits |
|
return loss |
|
|
|
def validation_step(self, train_batch, batch_index): |
|
""" |
|
""" |
|
x, y = train_batch |
|
logits = self.forward(x) |
|
loss = ( |
|
self.loss_fn(logits[0], y[0], self.scaled_anchors[0]) |
|
+ self.loss_fn(logits[1], y[1], self.scaled_anchors[1]) |
|
+ self.loss_fn(logits[2], y[2], self.scaled_anchors[2]) |
|
) |
|
self.my_val_loss.update(loss, x.shape[0]) |
|
self.log("val_loss", loss, prog_bar=True, on_epoch=True) |
|
|
|
del x, y, logits |
|
return loss |
|
|
|
def test_step(self, batch, batch_idx): |
|
""" |
|
""" |
|
return self.validation_step(batch, batch_idx) |
|
|
|
|
|
|
|
|
|
|
|
def prepare_data(self): |
|
""" |
|
Method to download the dataset |
|
""" |
|
|
|
pass |
|
|
|
def setup(self, stage=None): |
|
""" |
|
Method to create Split the dataset into train, test and val |
|
""" |
|
train_csv_path = self._data_directory + config.DATASET + "/train.csv" |
|
test_csv_path = self._data_directory + config.DATASET + "/test.csv" |
|
IMG_DIR = self._data_directory + config.IMG_DIR |
|
LABEL_DIR = self._data_directory + config.LABEL_DIR |
|
IMAGE_SIZE = config.IMAGE_SIZE |
|
|
|
|
|
self.train_dataset = YOLODataset( |
|
train_csv_path, |
|
transform=config.train_transforms, |
|
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8], |
|
img_dir=IMG_DIR, |
|
label_dir=LABEL_DIR, |
|
anchors=config.ANCHORS, |
|
mosaic=0.75 |
|
) |
|
self.val_dataset = YOLODataset( |
|
train_csv_path, |
|
transform=config.test_transforms, |
|
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8], |
|
img_dir=IMG_DIR, |
|
label_dir=LABEL_DIR, |
|
anchors=config.ANCHORS, |
|
) |
|
|
|
|
|
self.test_dataset = YOLODataset( |
|
test_csv_path, |
|
transform=config.test_transforms, |
|
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8], |
|
img_dir=IMG_DIR, |
|
label_dir=LABEL_DIR, |
|
anchors=config.ANCHORS, |
|
) |
|
|
|
def train_dataloader(self): |
|
""" |
|
Method to return the DataLoader for Training set |
|
""" |
|
if not self.train_dataset: |
|
self.setup() |
|
return DataLoader( |
|
dataset=self.train_dataset, |
|
batch_size=self.batch_size, |
|
num_workers=config.NUM_WORKERS, |
|
pin_memory=config.PIN_MEMORY, |
|
shuffle=True, |
|
drop_last=False, |
|
) |
|
|
|
def val_dataloader(self): |
|
""" |
|
Method to return the DataLoader for the Validation set |
|
""" |
|
return DataLoader( |
|
dataset=self.val_dataset, |
|
batch_size=self.batch_size, |
|
num_workers=config.NUM_WORKERS, |
|
pin_memory=config.PIN_MEMORY, |
|
shuffle=False, |
|
drop_last=False, |
|
) |
|
|
|
def test_dataloader(self): |
|
""" |
|
Method to return the DataLoader for the Test set |
|
""" |
|
return DataLoader( |
|
dataset=self.test_dataset, |
|
batch_size=self.batch_size, |
|
num_workers=config.NUM_WORKERS, |
|
pin_memory=config.PIN_MEMORY, |
|
shuffle=False, |
|
drop_last=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx): |
|
""" |
|
Garbage Collection for memory optimization |
|
batch: |
|
batch_idx: |
|
|
|
Returns: |
|
|
|
""" |
|
if self.enable_gc == 'batch': |
|
garbage_collection_cuda() |
|
|
|
def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): |
|
""" |
|
Garbage Collection for memory optimization |
|
batch: |
|
batch_idx: |
|
dataloader_idx: |
|
|
|
Returns: |
|
|
|
""" |
|
if self.enable_gc == 'batch': |
|
garbage_collection_cuda() |
|
|
|
def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0): |
|
""" |
|
Garbage Collection for memory optimization |
|
batch: |
|
batch_idx: |
|
dataloader_idx: |
|
|
|
Returns: |
|
|
|
""" |
|
if self.enable_gc == 'batch': |
|
garbage_collection_cuda() |
|
|
|
def on_train_epoch_end(self): |
|
""" |
|
Garbage Collection for memory optimization |
|
""" |
|
if self.enable_gc == 'epoch': |
|
garbage_collection_cuda() |
|
print( |
|
f"Epoch: {self.current_epoch}, Global Steps: {self.global_step}, Train Loss: {self.my_train_loss.compute()}") |
|
self.my_train_loss.reset() |
|
|
|
def on_validation_epoch_end(self): |
|
""" |
|
Garbage Collection for memory optimization |
|
""" |
|
if self.enable_gc == 'epoch': |
|
garbage_collection_cuda() |
|
print(f"Epoch: {self.current_epoch}, Global Steps: {self.global_step}, Val Loss: {self.my_val_loss.compute()}") |
|
self.my_val_loss.reset() |
|
|
|
def on_predict_epoch_end(self): |
|
""" |
|
Garbage Collection for memory optimization |
|
""" |
|
if self.enable_gc == 'epoch': |
|
garbage_collection_cuda() |
|
|