import torch from pytorch_lightning import LightningModule from model import YOLOv3 from dataset import YOLODataset from loss import YoloLoss from torch import optim from torch.utils.data import DataLoader import config class YOLOV3_PL(LightningModule): def __init__(self, in_channels=3, num_classes=config.NUM_CLASSES, batch_size=config.BATCH_SIZE, learning_rate=config.LEARNING_RATE , num_epochs=config.NUM_EPOCHS): super(YOLOV3_PL, self).__init__() self.model = YOLOv3(in_channels, num_classes) self.criterion = YoloLoss() self.batch_size = batch_size self.learning_rate = learning_rate self.num_epochs = num_epochs self.scaled_anchors = config.SCALED_ANCHORS self.layers = self.model.layers def train_dataloader(self): self.train_data = YOLODataset( config.DATASET + '/train.csv', transform=config.train_transforms, img_dir=config.IMG_DIR, label_dir=config.LABEL_DIR, anchors=config.ANCHORS ) train_dataloader = DataLoader( dataset=self.train_data, batch_size=self.batch_size, num_workers=config.NUM_WORKERS, pin_memory=config.PIN_MEMORY, shuffle=True ) return train_dataloader def val_dataloader(self): self.valid_data = YOLODataset( config.DATASET + '/test.csv', transform=config.test_transforms, img_dir=config.IMG_DIR, label_dir=config.LABEL_DIR, anchors=config.ANCHORS ) return DataLoader( dataset=self.valid_data, batch_size=self.batch_size, num_workers=config.NUM_WORKERS, pin_memory=config.PIN_MEMORY, shuffle=False ) def test_dataloader(self): return self.val_dataloader() def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): x, y = batch out = self.forward(x) loss = self.criterion(out, y, self.scaled_anchors) self.log(f"train_loss", loss, on_epoch=True, prog_bar=True, logger=True) return loss def validation_step(self, batch, batch_idx): x, y = batch out = self.forward(x) loss = self.criterion(out, y, self.scaled_anchors) self.log(f"val_loss", loss, on_epoch=True, prog_bar=True, logger=True) return loss def test_step(self, batch, batch_idx, dataloader_idx=0): if isinstance(batch, (tuple, list)): x, _ = batch else: x = batch return self.forward(x) def configure_optimizers(self): optimizer = optim.Adam(self.parameters(), lr=self.learning_rate/100, weight_decay=config.WEIGHT_DECAY) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=self.learning_rate, steps_per_epoch=len(self.train_dataloader()), epochs=self.num_epochs, pct_start=0.2, div_factor=10, three_phase=False, final_div_factor=10, anneal_strategy='linear' ) return { 'optimizer': optimizer, 'lr_scheduler': { "scheduler": scheduler, "interval": "step", } } def main(): num_classes = 20 IMAGE_SIZE = 416 INPUT_SIZE = IMAGE_SIZE model = YOLOV3_PL(num_classes=num_classes) from torchinfo import summary print(summary(model, input_size=(2, 3, INPUT_SIZE, INPUT_SIZE))) inp = torch.randn((2, 3, INPUT_SIZE, INPUT_SIZE)) out = model(inp) assert out[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5) assert out[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5) assert out[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5) print("Success!") if __name__ == "__main__": main()