|
""" |
|
Implementation of YOLOv3 architecture |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import config |
|
|
|
import torch.optim as optim |
|
import pytorch_lightning as pl |
|
|
|
|
|
from tqdm import tqdm |
|
from utils import ( |
|
mean_average_precision, |
|
cells_to_bboxes, |
|
get_evaluation_bboxes, |
|
save_checkpoint, |
|
load_checkpoint, |
|
check_class_accuracy, |
|
get_loaders, |
|
get_loaders_new, |
|
plot_couple_examples |
|
) |
|
from loss import YoloLoss |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
""" |
|
Information about architecture config: |
|
Tuple is structured by (filters, kernel_size, stride) |
|
Every conv is a same convolution. |
|
List is structured by "B" indicating a residual block followed by the number of repeats |
|
"S" is for scale prediction block and computing the yolo loss |
|
"U" is for upsampling the feature map and concatenating with a previous layer |
|
""" |
|
config_layers = [ |
|
(32, 3, 1), |
|
(64, 3, 2), |
|
["B", 1], |
|
(128, 3, 2), |
|
["B", 2], |
|
(256, 3, 2), |
|
["B", 8], |
|
(512, 3, 2), |
|
["B", 8], |
|
(1024, 3, 2), |
|
["B", 4], |
|
(512, 1, 1), |
|
(1024, 3, 1), |
|
"S", |
|
(256, 1, 1), |
|
"U", |
|
(256, 1, 1), |
|
(512, 3, 1), |
|
"S", |
|
(128, 1, 1), |
|
"U", |
|
(128, 1, 1), |
|
(256, 3, 1), |
|
"S", |
|
] |
|
|
|
|
|
class CNNBlock(nn.Module): |
|
def __init__(self, in_channels, out_channels, bn_act=True, **kwargs): |
|
super().__init__() |
|
self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs) |
|
self.bn = nn.BatchNorm2d(out_channels) |
|
self.leaky = nn.LeakyReLU(0.1) |
|
self.use_bn_act = bn_act |
|
|
|
def forward(self, x): |
|
if self.use_bn_act: |
|
return self.leaky(self.bn(self.conv(x))) |
|
else: |
|
return self.conv(x) |
|
|
|
|
|
class ResidualBlock(nn.Module): |
|
def __init__(self, channels, use_residual=True, num_repeats=1): |
|
super().__init__() |
|
self.layers = nn.ModuleList() |
|
for repeat in range(num_repeats): |
|
self.layers += [ |
|
nn.Sequential( |
|
CNNBlock(channels, channels // 2, kernel_size=1), |
|
CNNBlock(channels // 2, channels, kernel_size=3, padding=1), |
|
) |
|
] |
|
|
|
self.use_residual = use_residual |
|
self.num_repeats = num_repeats |
|
|
|
def forward(self, x): |
|
for layer in self.layers: |
|
if self.use_residual: |
|
x = x + layer(x) |
|
else: |
|
x = layer(x) |
|
|
|
return x |
|
|
|
|
|
class ScalePrediction(nn.Module): |
|
def __init__(self, in_channels, num_classes): |
|
super().__init__() |
|
self.pred = nn.Sequential( |
|
CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1), |
|
CNNBlock( |
|
2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1 |
|
), |
|
) |
|
self.num_classes = num_classes |
|
|
|
def forward(self, x): |
|
return ( |
|
self.pred(x) |
|
.reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]) |
|
.permute(0, 1, 3, 4, 2) |
|
) |
|
|
|
|
|
class YOLOv3(nn.Module): |
|
def __init__(self, in_channels=3, num_classes=80): |
|
super().__init__() |
|
self.num_classes = num_classes |
|
self.in_channels = in_channels |
|
self.layers = self._create_conv_layers() |
|
|
|
def forward(self, x): |
|
outputs = [] |
|
route_connections = [] |
|
for layer in self.layers: |
|
if isinstance(layer, ScalePrediction): |
|
outputs.append(layer(x)) |
|
continue |
|
|
|
x = layer(x) |
|
|
|
if isinstance(layer, ResidualBlock) and layer.num_repeats == 8: |
|
route_connections.append(x) |
|
|
|
elif isinstance(layer, nn.Upsample): |
|
x = torch.cat([x, route_connections[-1]], dim=1) |
|
route_connections.pop() |
|
|
|
return outputs |
|
|
|
def _create_conv_layers(self): |
|
layers = nn.ModuleList() |
|
in_channels = self.in_channels |
|
|
|
for module in config_layers: |
|
if isinstance(module, tuple): |
|
out_channels, kernel_size, stride = module |
|
layers.append( |
|
CNNBlock( |
|
in_channels, |
|
out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=1 if kernel_size == 3 else 0, |
|
) |
|
) |
|
in_channels = out_channels |
|
|
|
elif isinstance(module, list): |
|
num_repeats = module[1] |
|
layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,)) |
|
|
|
elif isinstance(module, str): |
|
if module == "S": |
|
layers += [ |
|
ResidualBlock(in_channels, use_residual=False, num_repeats=1), |
|
CNNBlock(in_channels, in_channels // 2, kernel_size=1), |
|
ScalePrediction(in_channels // 2, num_classes=self.num_classes), |
|
] |
|
in_channels = in_channels // 2 |
|
|
|
elif module == "U": |
|
layers.append(nn.Upsample(scale_factor=2),) |
|
in_channels = in_channels * 3 |
|
|
|
return layers |
|
from typing import Any, Callable, Optional, Union |
|
from pytorch_lightning.core.optimizer import LightningOptimizer |
|
|
|
|
|
from pytorch_lightning.utilities.types import STEP_OUTPUT, TRAIN_DATALOADERS |
|
from torch.optim.optimizer import Optimizer |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
|
|
|
|
|
|
class YOLOV3LITE(pl.LightningModule): |
|
def __init__(self,train_loader=None,test_loader=None,valid_data=None): |
|
super(YOLOV3LITE,self).__init__() |
|
self.model = YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE) |
|
self.loss_fn = YoloLoss() |
|
self.scaler=torch.cuda.amp.GradScaler() |
|
self.scaled_anchors =( |
|
torch.tensor(config.ANCHORS) |
|
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) |
|
).to(config.DEVICE) |
|
self.config = config |
|
self.mse = nn.MSELoss() |
|
self.bce = nn.BCEWithLogitsLoss() |
|
self.cross_entropy = nn.CrossEntropyLoss() |
|
self.sigmoid = nn.Sigmoid() |
|
self.save_hyperparameters() |
|
self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.LEARNING_RATE, weight_decay=self.config.WEIGHT_DECAY) |
|
self.train_loader=train_loader |
|
self.test_loader = test_loader |
|
self.valid_dataloader = valid_data |
|
|
|
|
|
|
|
def forward(self,x): |
|
out = self.model(x) |
|
return out |
|
def train_dataloader(self) -> TRAIN_DATALOADERS: |
|
return self.train_dataloader |
|
def valid_data_dataloader(self) -> TRAIN_DATALOADERS: |
|
return self.valid_dataloader |
|
def training_step(self,batch,batch_idx): |
|
|
|
(x,y)=batch |
|
x = x.to(self.config .DEVICE) |
|
y0, y1, y2 = ( |
|
y[0].to(self.config .DEVICE), |
|
y[1].to(self.config .DEVICE), |
|
y[2].to(self.config .DEVICE), |
|
) |
|
|
|
with torch.cuda.amp.autocast(): |
|
out = self.model(x) |
|
loss = ( |
|
self.loss_fn(out[0], y0, self.scaled_anchors[0]) |
|
+ self.loss_fn(out[1], y1, self.scaled_anchors[1]) |
|
+ self.loss_fn(out[2], y2, self.scaled_anchors[2]) |
|
) |
|
tot_class_preds, correct_class = 0, 0 |
|
tot_noobj, correct_noobj = 0, 0 |
|
tot_obj, correct_obj = 0, 0 |
|
for i in range(3): |
|
y[i] = y[i].to(config.DEVICE) |
|
obj = y[i][..., 0] == 1 |
|
noobj = y[i][..., 0] == 0 |
|
|
|
correct_class += torch.sum( |
|
torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj] |
|
) |
|
tot_class_preds += torch.sum(obj) |
|
|
|
obj_preds = torch.sigmoid(out[i][..., 0]) > self.config.CONF_THRESHOLD |
|
correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj]) |
|
tot_obj += torch.sum(obj) |
|
correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj]) |
|
tot_noobj += torch.sum(noobj) |
|
class_accuracy = (correct_class/(tot_class_preds+1e-16))*100 |
|
no_obj_accuracy = (correct_noobj/(tot_noobj+1e-16))*100 |
|
obj_accuracy = (correct_obj/(tot_obj+1e-16))*100 |
|
self.log("train_loss", loss.item(),on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
self.log("train_class_acc",class_accuracy,on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
self.log("train_noobj_acc",no_obj_accuracy,on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
self.log("train_obj_acc",obj_accuracy,on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
|
|
return loss |
|
def save_checkpoint(self,filename="my_checkpoint.pth.tar"): |
|
print("=> Saving checkpoint") |
|
checkpoint = { |
|
"state_dict": self.model.state_dict(), |
|
"optimizer": self.optimizer.state_dict(), |
|
} |
|
torch.save(checkpoint, filename) |
|
def on_train_epoch_end(self): |
|
|
|
if config.SAVE_MODEL: |
|
|
|
self.save_checkpoint(filename=f"checkpoint_e3.pth.tar") |
|
|
|
|
|
|
|
if self.current_epoch > 0 and self.current_epoch % 10 == 0: |
|
class_acc,noobj_acc,obj_acc=check_class_accuracy(self.model, self.test_loader,self.config.CONF_THRESHOLD,"Test_data") |
|
|
|
|
|
self.log("test_class_acc",class_acc,on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
self.log("test_noobj_acc",noobj_acc,on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
self.log("test_obj_acc",obj_acc,on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
|
|
if self.current_epoch > 0 and self.current_epoch == (config.NUM_EPOCHS -1): |
|
pred_boxes, true_boxes = get_evaluation_bboxes( |
|
self.test_loader, |
|
self.model, |
|
iou_threshold=self.config.NMS_IOU_THRESH, |
|
anchors=self.config.ANCHORS, |
|
threshold=self.config.CONF_THRESHOLD, |
|
) |
|
mapval = mean_average_precision( |
|
pred_boxes, |
|
true_boxes, |
|
iou_threshold=config.MAP_IOU_THRESH, |
|
box_format="midpoint", |
|
num_classes=self.config.NUM_CLASSES, |
|
) |
|
self.log("test_class_acc",class_acc,on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
self.log("test_noobj_acc",noobj_acc,on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
self.log("test_obj_acc",obj_acc,on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
print("mapval.item()",mapval.item()) |
|
self.log("test_mapval", mapval.item()) |
|
self.model.train() |
|
|
|
def validation_step(self,batch,batch_idx): |
|
|
|
(x,y)=batch |
|
x = x.to(self.config .DEVICE) |
|
y0, y1, y2 = ( |
|
y[0].to(self.config .DEVICE), |
|
y[1].to(self.config .DEVICE), |
|
y[2].to(self.config .DEVICE), |
|
) |
|
|
|
with torch.cuda.amp.autocast(): |
|
out = self.model(x) |
|
loss = ( |
|
self.loss_fn(out[0], y0, self.scaled_anchors[0]) |
|
+ self.loss_fn(out[1], y1, self.scaled_anchors[1]) |
|
+ self.loss_fn(out[2], y2, self.scaled_anchors[2]) |
|
) |
|
tot_class_preds, correct_class = 0, 0 |
|
tot_noobj, correct_noobj = 0, 0 |
|
tot_obj, correct_obj = 0, 0 |
|
for i in range(3): |
|
y[i] = y[i].to(config.DEVICE) |
|
obj = y[i][..., 0] == 1 |
|
noobj = y[i][..., 0] == 0 |
|
|
|
correct_class += torch.sum( |
|
torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj] |
|
) |
|
tot_class_preds += torch.sum(obj) |
|
|
|
obj_preds = torch.sigmoid(out[i][..., 0]) > self.config.CONF_THRESHOLD |
|
correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj]) |
|
tot_obj += torch.sum(obj) |
|
correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj]) |
|
tot_noobj += torch.sum(noobj) |
|
class_accuracy = (correct_class/(tot_class_preds+1e-16))*100 |
|
no_obj_accuracy = (correct_noobj/(tot_noobj+1e-16))*100 |
|
obj_accuracy = (correct_obj/(tot_obj+1e-16))*100 |
|
self.log("val_loss", loss.item(),on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
self.log("valid_class_acc",class_accuracy,on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
self.log("valid_noobj_acc",no_obj_accuracy,on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
self.log("valid_obj_acc",obj_accuracy,on_step=False, on_epoch=True, prog_bar=True, logger=True) |
|
return loss |
|
def test_step(self): |
|
return |
|
def configure_optimizers(self): |
|
EPOCHS = self.config.NUM_EPOCHS * 2 // 5 |
|
scheduler_dict = { |
|
"scheduler": optim.lr_scheduler.OneCycleLR( |
|
self.optimizer, |
|
|
|
max_lr=2.83E-02, |
|
epochs=EPOCHS, |
|
steps_per_epoch=len(self.train_loader), |
|
pct_start=5/EPOCHS, |
|
div_factor=100, |
|
three_phase=False, |
|
final_div_factor=100, |
|
anneal_strategy='linear' |
|
), |
|
"interval": "step", |
|
} |
|
return {"optimizer": self.optimizer, "lr_scheduler": scheduler_dict} |
|
|
|
if __name__ == "__main__": |
|
num_classes = 20 |
|
IMAGE_SIZE = 416 |
|
model = YOLOv3(num_classes=num_classes) |
|
x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE)) |
|
out = model(x) |
|
assert model(x)[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5) |
|
assert model(x)[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5) |
|
assert model(x)[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5) |
|
print("Success!") |