Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.optim as optim | |
| import lightning.pytorch as pl | |
| from tqdm import tqdm | |
| from .model import YOLOv3 | |
| from .loss import YoloLoss | |
| from .utils import get_loaders, load_checkpoint, check_class_accuracy, intersection_over_union | |
| from . import config | |
| from torch.optim.lr_scheduler import OneCycleLR | |
| class YOLOv3Lightning(pl.LightningModule): | |
| def __init__(self, config, lr_value=0): | |
| super().__init__() | |
| self.automatic_optimization =True | |
| self.config = config | |
| self.model = YOLOv3(num_classes=self.config.NUM_CLASSES) | |
| self.loss_fn = YoloLoss() | |
| if lr_value == 0: | |
| self.learning_rate = self.config.LEARNING_RATE | |
| else: | |
| self.learning_rate = lr_value | |
| def forward(self, x): | |
| return self.model(x) | |
| def configure_optimizers(self): | |
| optimizer = optim.Adam(self.model.parameters(), lr=self.config.LEARNING_RATE, weight_decay=self.config.WEIGHT_DECAY) | |
| EPOCHS = self.config.NUM_EPOCHS * 2 // 5 | |
| scheduler = OneCycleLR(optimizer, max_lr=1E-3, steps_per_epoch=len(self.train_dataloader()), epochs=EPOCHS, pct_start=5/EPOCHS, div_factor=100, three_phase=False, final_div_factor=100, anneal_strategy='linear') | |
| return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}] | |
| def train_dataloader(self): | |
| train_loader, _, _ = get_loaders( | |
| train_csv_path=self.config.DATASET + "/train.csv", | |
| test_csv_path=self.config.DATASET + "/test.csv", | |
| ) | |
| return train_loader | |
| def training_step(self, batch, batch_idx): | |
| x, y = batch | |
| y0, y1, y2 = (y[0].to(self.device),y[1].to(self.device),y[2].to(self.device)) | |
| out = self(x) | |
| loss = (self.loss_fn(out[0], y0, self.scaled_anchors[0]) | |
| + self.loss_fn(out[1], y1, self.scaled_anchors[1]) | |
| + self.loss_fn(out[2], y2, self.scaled_anchors[2])) | |
| self.log('train_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
| return loss | |
| def val_dataloader(self): | |
| _, _, val_loader = get_loaders( | |
| train_csv_path=self.config.DATASET + "/train.csv", | |
| test_csv_path=self.config.DATASET + "/test.csv", | |
| ) | |
| return val_loader | |
| def validation_step(self, batch, batch_idx): | |
| x, y = batch | |
| y0, y1, y2 = ( | |
| y[0].to(self.device), | |
| y[1].to(self.device), | |
| y[2].to(self.device), | |
| ) | |
| out = self(x) | |
| loss = ( | |
| self.loss_fn(out[0], y0, self.scaled_anchors[0]) | |
| + self.loss_fn(out[1], y1, self.scaled_anchors[1]) | |
| + self.loss_fn(out[2], y2, self.scaled_anchors[2]) | |
| ) | |
| self.log('val_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
| def test_dataloader(self): | |
| _, test_loader, _ = get_loaders( | |
| train_csv_path=self.config.DATASET + "/train.csv", | |
| test_csv_path=self.config.DATASET + "/test.csv", | |
| ) | |
| return test_loader | |
| def test_step(self, batch, batch_idx): | |
| x, y = batch | |
| y0, y1, y2 = ( | |
| y[0].to(self.device), | |
| y[1].to(self.device), | |
| y[2].to(self.device), | |
| ) | |
| out = self(x) | |
| loss = ( | |
| self.loss_fn(out[0], y0, self.scaled_anchors[0]) | |
| + self.loss_fn(out[1], y1, self.scaled_anchors[1]) | |
| + self.loss_fn(out[2], y2, self.scaled_anchors[2]) | |
| ) | |
| self.log('test_loss', loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
| def on_train_start(self): | |
| if self.config.LOAD_MODEL: | |
| load_checkpoint(self.config.CHECKPOINT_FILE, self.model, self.optimizers(), self.config.LEARNING_RATE) | |
| self.scaled_anchors = ( | |
| torch.tensor(self.config.ANCHORS) | |
| * torch.tensor(self.config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
| ).to(self.device) | |