yolov3 / model.py
catchlui's picture
Upload 4 files
5cbcc4c
"""
Implementation of YOLOv3 architecture
"""
import torch
import torch.nn as nn
import config
import torch.optim as optim
import pytorch_lightning as pl
#from model import YOLOv3
from tqdm import tqdm
from utils import (
mean_average_precision,
cells_to_bboxes,
get_evaluation_bboxes,
save_checkpoint,
load_checkpoint,
check_class_accuracy,
get_loaders,
get_loaders_new,
plot_couple_examples
)
from loss import YoloLoss
import warnings
warnings.filterwarnings("ignore")
"""
Information about architecture config:
Tuple is structured by (filters, kernel_size, stride)
Every conv is a same convolution.
List is structured by "B" indicating a residual block followed by the number of repeats
"S" is for scale prediction block and computing the yolo loss
"U" is for upsampling the feature map and concatenating with a previous layer
"""
config_layers = [
(32, 3, 1),
(64, 3, 2),
["B", 1],
(128, 3, 2),
["B", 2],
(256, 3, 2),
["B", 8],
(512, 3, 2),
["B", 8],
(1024, 3, 2),
["B", 4], # To this point is Darknet-53
(512, 1, 1),
(1024, 3, 1),
"S",
(256, 1, 1),
"U",
(256, 1, 1),
(512, 3, 1),
"S",
(128, 1, 1),
"U",
(128, 1, 1),
(256, 3, 1),
"S",
]
class CNNBlock(nn.Module):
def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
self.bn = nn.BatchNorm2d(out_channels)
self.leaky = nn.LeakyReLU(0.1)
self.use_bn_act = bn_act
def forward(self, x):
if self.use_bn_act:
return self.leaky(self.bn(self.conv(x)))
else:
return self.conv(x)
class ResidualBlock(nn.Module):
def __init__(self, channels, use_residual=True, num_repeats=1):
super().__init__()
self.layers = nn.ModuleList()
for repeat in range(num_repeats):
self.layers += [
nn.Sequential(
CNNBlock(channels, channels // 2, kernel_size=1),
CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
)
]
self.use_residual = use_residual
self.num_repeats = num_repeats
def forward(self, x):
for layer in self.layers:
if self.use_residual:
x = x + layer(x)
else:
x = layer(x)
return x
class ScalePrediction(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.pred = nn.Sequential(
CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
CNNBlock(
2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1
),
)
self.num_classes = num_classes
def forward(self, x):
return (
self.pred(x)
.reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
.permute(0, 1, 3, 4, 2)
)
class YOLOv3(nn.Module):
def __init__(self, in_channels=3, num_classes=80):
super().__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.layers = self._create_conv_layers()
def forward(self, x):
outputs = [] # for each scale
route_connections = []
for layer in self.layers:
if isinstance(layer, ScalePrediction):
outputs.append(layer(x))
continue
x = layer(x)
if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
route_connections.append(x)
elif isinstance(layer, nn.Upsample):
x = torch.cat([x, route_connections[-1]], dim=1)
route_connections.pop()
return outputs
def _create_conv_layers(self):
layers = nn.ModuleList()
in_channels = self.in_channels
for module in config_layers:
if isinstance(module, tuple):
out_channels, kernel_size, stride = module
layers.append(
CNNBlock(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=1 if kernel_size == 3 else 0,
)
)
in_channels = out_channels
elif isinstance(module, list):
num_repeats = module[1]
layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,))
elif isinstance(module, str):
if module == "S":
layers += [
ResidualBlock(in_channels, use_residual=False, num_repeats=1),
CNNBlock(in_channels, in_channels // 2, kernel_size=1),
ScalePrediction(in_channels // 2, num_classes=self.num_classes),
]
in_channels = in_channels // 2
elif module == "U":
layers.append(nn.Upsample(scale_factor=2),)
in_channels = in_channels * 3
return layers
from typing import Any, Callable, Optional, Union
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import STEP_OUTPUT, TRAIN_DATALOADERS
from torch.optim.optimizer import Optimizer
import torch
import torch.nn as nn
import torch.optim as optim
class YOLOV3LITE(pl.LightningModule):
def __init__(self,train_loader=None,test_loader=None,valid_data=None):
super(YOLOV3LITE,self).__init__()
self.model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE)
self.loss_fn = YoloLoss()
self.scaler=torch.cuda.amp.GradScaler()
self.scaled_anchors =(
torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(config.DEVICE)
self.config = config
self.mse = nn.MSELoss()
self.bce = nn.BCEWithLogitsLoss()
self.cross_entropy = nn.CrossEntropyLoss()
self.sigmoid = nn.Sigmoid()
self.save_hyperparameters()
self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.LEARNING_RATE, weight_decay=self.config.WEIGHT_DECAY)
self.train_loader=train_loader
self.test_loader = test_loader
self.valid_dataloader = valid_data
def forward(self,x):
out = self.model(x)
return out
def train_dataloader(self) -> TRAIN_DATALOADERS:
return self.train_dataloader
def valid_data_dataloader(self) -> TRAIN_DATALOADERS:
return self.valid_dataloader
def training_step(self,batch,batch_idx):
#print("Inside training step")
(x,y)=batch
x = x.to(self.config .DEVICE)
y0, y1, y2 = (
y[0].to(self.config .DEVICE),
y[1].to(self.config .DEVICE),
y[2].to(self.config .DEVICE),
)
with torch.cuda.amp.autocast():
out = self.model(x)
loss = (
self.loss_fn(out[0], y0, self.scaled_anchors[0])
+ self.loss_fn(out[1], y1, self.scaled_anchors[1])
+ self.loss_fn(out[2], y2, self.scaled_anchors[2])
)
tot_class_preds, correct_class = 0, 0
tot_noobj, correct_noobj = 0, 0
tot_obj, correct_obj = 0, 0
for i in range(3):
y[i] = y[i].to(config.DEVICE)
obj = y[i][..., 0] == 1 # in paper this is Iobj_i
noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
correct_class += torch.sum(
torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
)
tot_class_preds += torch.sum(obj)
obj_preds = torch.sigmoid(out[i][..., 0]) > self.config.CONF_THRESHOLD
correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
tot_obj += torch.sum(obj)
correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
tot_noobj += torch.sum(noobj)
class_accuracy = (correct_class/(tot_class_preds+1e-16))*100
no_obj_accuracy = (correct_noobj/(tot_noobj+1e-16))*100
obj_accuracy = (correct_obj/(tot_obj+1e-16))*100
self.log("train_loss", loss.item(),on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("train_class_acc",class_accuracy,on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("train_noobj_acc",no_obj_accuracy,on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("train_obj_acc",obj_accuracy,on_step=False, on_epoch=True, prog_bar=True, logger=True)
return loss
def save_checkpoint(self,filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
checkpoint = {
"state_dict": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
}
torch.save(checkpoint, filename)
def on_train_epoch_end(self):
#print("Hello",self.current_epoch)
if config.SAVE_MODEL:
#print("saving the model")
self.save_checkpoint(filename=f"checkpoint_e3.pth.tar")
if self.current_epoch > 0 and self.current_epoch % 10 == 0:
class_acc,noobj_acc,obj_acc=check_class_accuracy(self.model, self.test_loader,self.config.CONF_THRESHOLD,"Test_data")
#print(f"MAP: {mapval.item()}")
self.log("test_class_acc",class_acc,on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("test_noobj_acc",noobj_acc,on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("test_obj_acc",obj_acc,on_step=False, on_epoch=True, prog_bar=True, logger=True)
if self.current_epoch > 0 and self.current_epoch == (config.NUM_EPOCHS -1):
pred_boxes, true_boxes = get_evaluation_bboxes(
self.test_loader,
self.model,
iou_threshold=self.config.NMS_IOU_THRESH,
anchors=self.config.ANCHORS,
threshold=self.config.CONF_THRESHOLD,
)
mapval = mean_average_precision(
pred_boxes,
true_boxes,
iou_threshold=config.MAP_IOU_THRESH,
box_format="midpoint",
num_classes=self.config.NUM_CLASSES,
)
self.log("test_class_acc",class_acc,on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("test_noobj_acc",noobj_acc,on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("test_obj_acc",obj_acc,on_step=False, on_epoch=True, prog_bar=True, logger=True)
print("mapval.item()",mapval.item())
self.log("test_mapval", mapval.item())
self.model.train()
def validation_step(self,batch,batch_idx):
#print("inside validation step")
(x,y)=batch
x = x.to(self.config .DEVICE)
y0, y1, y2 = (
y[0].to(self.config .DEVICE),
y[1].to(self.config .DEVICE),
y[2].to(self.config .DEVICE),
)
with torch.cuda.amp.autocast():
out = self.model(x)
loss = (
self.loss_fn(out[0], y0, self.scaled_anchors[0])
+ self.loss_fn(out[1], y1, self.scaled_anchors[1])
+ self.loss_fn(out[2], y2, self.scaled_anchors[2])
)
tot_class_preds, correct_class = 0, 0
tot_noobj, correct_noobj = 0, 0
tot_obj, correct_obj = 0, 0
for i in range(3):
y[i] = y[i].to(config.DEVICE)
obj = y[i][..., 0] == 1 # in paper this is Iobj_i
noobj = y[i][..., 0] == 0 # in paper this is Iobj_i
correct_class += torch.sum(
torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
)
tot_class_preds += torch.sum(obj)
obj_preds = torch.sigmoid(out[i][..., 0]) > self.config.CONF_THRESHOLD
correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
tot_obj += torch.sum(obj)
correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
tot_noobj += torch.sum(noobj)
class_accuracy = (correct_class/(tot_class_preds+1e-16))*100
no_obj_accuracy = (correct_noobj/(tot_noobj+1e-16))*100
obj_accuracy = (correct_obj/(tot_obj+1e-16))*100
self.log("val_loss", loss.item(),on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("valid_class_acc",class_accuracy,on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("valid_noobj_acc",no_obj_accuracy,on_step=False, on_epoch=True, prog_bar=True, logger=True)
self.log("valid_obj_acc",obj_accuracy,on_step=False, on_epoch=True, prog_bar=True, logger=True)
return loss
def test_step(self):
return
def configure_optimizers(self):
EPOCHS = self.config.NUM_EPOCHS * 2 // 5
scheduler_dict = {
"scheduler": optim.lr_scheduler.OneCycleLR(
self.optimizer,
#max_lr = self.lr,
max_lr=2.83E-02,
epochs=EPOCHS,
steps_per_epoch=len(self.train_loader),
pct_start=5/EPOCHS,
div_factor=100,
three_phase=False,
final_div_factor=100,
anneal_strategy='linear'
),
"interval": "step",
}
return {"optimizer": self.optimizer, "lr_scheduler": scheduler_dict}
if __name__ == "__main__":
num_classes = 20
IMAGE_SIZE = 416
model = YOLOv3(num_classes=num_classes)
x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
out = model(x)
assert model(x)[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5)
assert model(x)[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5)
assert model(x)[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5)
print("Success!")