yolo_v3 / yolo_model.py
Shilpaj's picture
Feat: Add model and example files
5695992
#!/usr/bin/env python3
"""
Lightning Module for Yolo v3
"""
# Third-Party Imports
import torch
import torch.optim as optim
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.utilities.memory import garbage_collection_cuda
from torch.utils.data import DataLoader
from torchmetrics import MeanMetric
# Local Imports
import config
from utils import check_class_accuracy
from model import ScalePrediction, ResidualBlock, CNNBlock, model_config
from loss import YoloLoss
from dataset import YOLODataset
class YOLOv3(pl.LightningModule):
"""
PyTorch Lightning Code for YOLOv3
"""
def __init__(self, in_channels=3, num_classes=20):
"""
Constructor
"""
super().__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.layers = self._create_conv_layers()
self._learning_rate = 0.03
self.loss_fn = YoloLoss()
self.scaled_anchors = config.SCALED_ANCHORS
self.combined_data = None
self.train_dataset = None
self.test_dataset = None
self.val_dataset = None
self._data_directory = None
self.epochs = config.NUM_EPOCHS
self.batch_size = config.BATCH_SIZE
self.enable_gc = "batch"
self.my_train_loss = MeanMetric()
self.my_val_loss = MeanMetric()
def forward(self, x):
outputs = [] # for each scale
route_connections = []
for layer in self.layers:
if isinstance(layer, ScalePrediction):
outputs.append(layer(x))
continue
x = layer(x)
if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
route_connections.append(x)
elif isinstance(layer, nn.Upsample):
x = torch.cat([x, route_connections[-1]], dim=1)
route_connections.pop()
return outputs
def _create_conv_layers(self):
layers = nn.ModuleList()
in_channels = self.in_channels
for module in model_config:
if isinstance(module, tuple):
out_channels, kernel_size, stride = module
layers.append(
CNNBlock(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=1 if kernel_size == 3 else 0,
)
)
in_channels = out_channels
elif isinstance(module, list):
num_repeats = module[1]
layers.append(ResidualBlock(in_channels, num_repeats=num_repeats, ))
elif isinstance(module, str):
if module == "S":
layers += [
ResidualBlock(in_channels, use_residual=False, num_repeats=1),
CNNBlock(in_channels, in_channels // 2, kernel_size=1),
ScalePrediction(in_channels // 2, num_classes=self.num_classes),
]
in_channels = in_channels // 2
elif module == "U":
layers.append(nn.Upsample(scale_factor=2), )
in_channels = in_channels * 3
return layers
# ##################################################################################################
# ############################## Training Configuration Related Hooks ##############################
# ##################################################################################################
def configure_optimizers(self):
"""
Method to configure the optimizer and learning rate scheduler
"""
optimizer = optim.Adam(self.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)
# Scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
max_lr=self._learning_rate,
steps_per_epoch=len(self.train_dataloader()),
epochs=self.epochs,
pct_start=0.3,
div_factor=100,
three_phase=False,
final_div_factor=100,
anneal_strategy="linear"
)
return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]
@property
def learning_rate(self) -> float:
"""
Method to get the learning rate value
"""
return self._learning_rate
@learning_rate.setter
def learning_rate(self, value: float):
"""
Method to set the learning rate value
:param value: Updated value of learning rate
"""
self._learning_rate = value
@property
def data_directory(self) -> str:
"""
Method to return data directory
"""
return self._data_directory
@data_directory.setter
def data_directory(self, address: str):
"""
Method to set the data directory path
"""
self._data_directory = address
def set_training_config(self, *, epochs, batch_size):
"""
Method to set parameters required for model training
:param epochs: Number of epochs for which model is to be trained
:param batch_size: Batch Size
"""
self.epochs = epochs
self.batch_size = batch_size
# #################################################################################################
# ################################## Training Loop Related Hooks ##################################
# #################################################################################################
def training_step(self, train_batch, batch_index):
"""
Method called on training dataset to train the model
:param train_batch: Batch containing images and labels
:param batch_index: Index of the batch
"""
x, y = train_batch
logits = self.forward(x)
loss = (
self.loss_fn(logits[0], y[0], self.scaled_anchors[0])
+ self.loss_fn(logits[1], y[1], self.scaled_anchors[1])
+ self.loss_fn(logits[2], y[2], self.scaled_anchors[2])
)
self.my_train_loss.update(loss, x.shape[0])
self.log("train_loss", loss, prog_bar=True, on_epoch=True)
del x, y, logits
return loss
def validation_step(self, train_batch, batch_index):
"""
"""
x, y = train_batch
logits = self.forward(x)
loss = (
self.loss_fn(logits[0], y[0], self.scaled_anchors[0])
+ self.loss_fn(logits[1], y[1], self.scaled_anchors[1])
+ self.loss_fn(logits[2], y[2], self.scaled_anchors[2])
)
self.my_val_loss.update(loss, x.shape[0])
self.log("val_loss", loss, prog_bar=True, on_epoch=True)
del x, y, logits
return loss
def test_step(self, batch, batch_idx):
"""
"""
return self.validation_step(batch, batch_idx)
# ###########################################################################################
# ##################################### Data Related Hooks ##################################
# ###########################################################################################
def prepare_data(self):
"""
Method to download the dataset
"""
# Since data is already downloaded
pass
def setup(self, stage=None):
"""
Method to create Split the dataset into train, test and val
"""
train_csv_path = self._data_directory + config.DATASET + "/train.csv"
test_csv_path = self._data_directory + config.DATASET + "/test.csv"
IMG_DIR = self._data_directory + config.IMG_DIR
LABEL_DIR = self._data_directory + config.LABEL_DIR
IMAGE_SIZE = config.IMAGE_SIZE
# Assign train/val datasets for use in dataloaders
self.train_dataset = YOLODataset(
train_csv_path,
transform=config.train_transforms,
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
img_dir=IMG_DIR,
label_dir=LABEL_DIR,
anchors=config.ANCHORS,
mosaic=0.75
)
self.val_dataset = YOLODataset(
train_csv_path,
transform=config.test_transforms,
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
img_dir=IMG_DIR,
label_dir=LABEL_DIR,
anchors=config.ANCHORS,
)
# Assign test dataset for use in dataloader(s)
self.test_dataset = YOLODataset(
test_csv_path,
transform=config.test_transforms,
S=[IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8],
img_dir=IMG_DIR,
label_dir=LABEL_DIR,
anchors=config.ANCHORS,
)
def train_dataloader(self):
"""
Method to return the DataLoader for Training set
"""
if not self.train_dataset:
self.setup()
return DataLoader(
dataset=self.train_dataset,
batch_size=self.batch_size,
num_workers=config.NUM_WORKERS,
pin_memory=config.PIN_MEMORY,
shuffle=True,
drop_last=False,
)
def val_dataloader(self):
"""
Method to return the DataLoader for the Validation set
"""
return DataLoader(
dataset=self.val_dataset,
batch_size=self.batch_size,
num_workers=config.NUM_WORKERS,
pin_memory=config.PIN_MEMORY,
shuffle=False,
drop_last=False,
)
def test_dataloader(self):
"""
Method to return the DataLoader for the Test set
"""
return DataLoader(
dataset=self.test_dataset,
batch_size=self.batch_size,
num_workers=config.NUM_WORKERS,
pin_memory=config.PIN_MEMORY,
shuffle=False,
drop_last=False,
)
# ###########################################################################################
# ##################################### Memory Related Hooks ##################################
# ###########################################################################################
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
Garbage Collection for memory optimization
batch:
batch_idx:
Returns:
"""
if self.enable_gc == 'batch':
garbage_collection_cuda()
def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0):
"""
Garbage Collection for memory optimization
batch:
batch_idx:
dataloader_idx:
Returns:
"""
if self.enable_gc == 'batch':
garbage_collection_cuda()
def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx=0):
"""
Garbage Collection for memory optimization
batch:
batch_idx:
dataloader_idx:
Returns:
"""
if self.enable_gc == 'batch':
garbage_collection_cuda()
def on_train_epoch_end(self):
"""
Garbage Collection for memory optimization
"""
if self.enable_gc == 'epoch':
garbage_collection_cuda()
print(
f"Epoch: {self.current_epoch}, Global Steps: {self.global_step}, Train Loss: {self.my_train_loss.compute()}")
self.my_train_loss.reset()
def on_validation_epoch_end(self):
"""
Garbage Collection for memory optimization
"""
if self.enable_gc == 'epoch':
garbage_collection_cuda()
print(f"Epoch: {self.current_epoch}, Global Steps: {self.global_step}, Val Loss: {self.my_val_loss.compute()}")
self.my_val_loss.reset()
def on_predict_epoch_end(self):
"""
Garbage Collection for memory optimization
"""
if self.enable_gc == 'epoch':
garbage_collection_cuda()