yolov3 / train.py
catchlui's picture
Upload 4 files
5cbcc4c
"""
Main file for training Yolo model on Pascal VOC and COCO dataset
"""
import config
import torch
import torch.optim as optim
import albumentations as A
import cv2
from albumentations.pytorch import ToTensorV2
from model import YOLOv3,YOLOV3LITE
from tqdm import tqdm
from utils import (
mean_average_precision,
cells_to_bboxes,
get_evaluation_bboxes,
save_checkpoint,
load_checkpoint,
check_class_accuracy,
get_loaders,
plot_couple_examples
)
from dataset import YOLODatasetOK
from utils import non_max_suppression,plot_image
from loss import YoloLoss
import warnings
warnings.filterwarnings("ignore")
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger,TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.optim as optim
import pytorch_lightning as pl
torch.backends.cudnn.benchmark = True
def load_checkpoint(checkpoint_file, model, optimizer, lr):
print("=> Loading checkpoint")
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
model.load_state_dict(checkpoint["state_dict"])
# optimizer.load_state_dict(checkpoint["optimizer"])
# # If we don't do this then it will just have learning rate of old checkpoint
# # and it will lead to many hours of debugging \:
# for param_group in optimizer.param_groups:
# param_group["lr"] = lr
return model
def main():
train_data="100examples.csv"
valid_data = "2examples.csv"
test_data="8examples.csv"
train_loader, test_ldr, train_eval_loader = get_loaders(
train_csv_path=config.DATASET + "/"+train_data, test_csv_path=config.DATASET + "/"+test_data,valid_csv_path=config.DATASET + "/"+valid_data)
#trainer = pl.Trainer()
#test_ldr_final = get_loaders_new(test_csv_path=config.DATASET + "/"+test_data)
trainer= pl.Trainer(
max_epochs=8,
accelerator="auto",
check_val_every_n_epoch=5,
devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
logger=TensorBoardLogger(save_dir="logs/"),
precision = 16,
#callbacks=[LearningRateMonitor(logging_interval='epoch')]
)
#model = YOLOV3LITE(train_loader,test_ldr,train_eval_loader)
#trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=train_eval_loader)
model_handler = YOLOV3LITE()
# loaded_model =load_checkpoint(
# config.CHECKPOINT_FILE,model_handler.model, model_handler.optimizer, config.LEARNING_RATE
# )
loaded_model =load_checkpoint(
config.CHECKPOINT_FILE,model_handler, model_handler.optimizer, config.LEARNING_RATE
)
test_transform = A.Compose(
[
# Rescale an image so that maximum side is equal to image_size
A.LongestMaxSize(max_size=config.IMAGE_SIZE),
# Pad remaining areas with zeros
A.PadIfNeeded(
min_height=config.IMAGE_SIZE, min_width=config.IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT
),
# Normalize the image
A.Normalize(
mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255
),
# Convert the image to PyTorch tensor
ToTensorV2()
],
# Augmentation for bounding boxes
bbox_params=A.BboxParams(
format="yolo",
min_visibility=0.4,
label_fields=[]
))
dataset = YOLODatasetOK(
csv_file=config.DATASET + "/"+test_data,
img_dir=config.IMG_DIR,
label_dir=config.LABEL_DIR,
S=[13, 26, 52],
anchors=config.ANCHORS,
transform=test_transform
)
# Creating a dataloader object
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=8,
shuffle=True,
)
scaled_anchors = (
torch.tensor(config.ANCHORS)
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(config.DEVICE)
plot_couple_examples(loaded_model, loader, 0.6,0.5,scaled_anchors)
# # Defining the grid size and the scaled anchors
# GRID_SIZE = [13, 26, 52]
# scaled_anchors = torch.tensor(config.ANCHORS) / (
# 1 / torch.tensor(GRID_SIZE).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
# )
# # Getting a batch from the dataloader
# x, y = next(iter(loader))
# # Getting the boxes coordinates from the labels
# # and converting them into bounding boxes without scaling
# boxes = []
# for i in range(y[0].shape[1]):
# anchor = scaled_anchors[i]
# boxes += cells_to_bboxes(
# y[i], is_preds=False, S=y[i].shape[2], anchors=anchor
# )[0]
# # Applying non-maximum suppression
# boxes = non_max_suppression(boxes, iou_threshold=1, threshold=0.7)
# # Plotting the image with the bounding boxes
# plot_image(x[0].permute(1,2,0).to("cpu"), boxes)
if __name__ == "__main__":
main()