|
""" |
|
Main file for training Yolo model on Pascal VOC and COCO dataset |
|
""" |
|
|
|
import config |
|
import torch |
|
import torch.optim as optim |
|
import albumentations as A |
|
import cv2 |
|
from albumentations.pytorch import ToTensorV2 |
|
|
|
|
|
from model import YOLOv3,YOLOV3LITE |
|
from tqdm import tqdm |
|
from utils import ( |
|
mean_average_precision, |
|
cells_to_bboxes, |
|
get_evaluation_bboxes, |
|
save_checkpoint, |
|
load_checkpoint, |
|
check_class_accuracy, |
|
get_loaders, |
|
plot_couple_examples |
|
) |
|
from dataset import YOLODatasetOK |
|
from utils import non_max_suppression,plot_image |
|
from loss import YoloLoss |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
from pytorch_lightning.callbacks.progress import TQDMProgressBar |
|
from pytorch_lightning.loggers import CSVLogger,TensorBoardLogger |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
import torch.optim as optim |
|
import pytorch_lightning as pl |
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
|
|
def load_checkpoint(checkpoint_file, model, optimizer, lr): |
|
print("=> Loading checkpoint") |
|
checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE) |
|
model.load_state_dict(checkpoint["state_dict"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return model |
|
|
|
|
|
def main(): |
|
train_data="100examples.csv" |
|
valid_data = "2examples.csv" |
|
test_data="8examples.csv" |
|
train_loader, test_ldr, train_eval_loader = get_loaders( |
|
train_csv_path=config.DATASET + "/"+train_data, test_csv_path=config.DATASET + "/"+test_data,valid_csv_path=config.DATASET + "/"+valid_data) |
|
|
|
|
|
|
|
trainer= pl.Trainer( |
|
max_epochs=8, |
|
accelerator="auto", |
|
check_val_every_n_epoch=5, |
|
devices=1 if torch.cuda.is_available() else None, |
|
logger=TensorBoardLogger(save_dir="logs/"), |
|
precision = 16, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
model_handler = YOLOV3LITE() |
|
|
|
|
|
|
|
|
|
|
|
|
|
loaded_model =load_checkpoint( |
|
config.CHECKPOINT_FILE,model_handler, model_handler.optimizer, config.LEARNING_RATE |
|
) |
|
|
|
|
|
test_transform = A.Compose( |
|
[ |
|
|
|
A.LongestMaxSize(max_size=config.IMAGE_SIZE), |
|
|
|
A.PadIfNeeded( |
|
min_height=config.IMAGE_SIZE, min_width=config.IMAGE_SIZE, border_mode=cv2.BORDER_CONSTANT |
|
), |
|
|
|
A.Normalize( |
|
mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255 |
|
), |
|
|
|
ToTensorV2() |
|
], |
|
|
|
bbox_params=A.BboxParams( |
|
format="yolo", |
|
min_visibility=0.4, |
|
label_fields=[] |
|
)) |
|
|
|
dataset = YOLODatasetOK( |
|
csv_file=config.DATASET + "/"+test_data, |
|
img_dir=config.IMG_DIR, |
|
label_dir=config.LABEL_DIR, |
|
S=[13, 26, 52], |
|
anchors=config.ANCHORS, |
|
transform=test_transform |
|
) |
|
|
|
|
|
loader = torch.utils.data.DataLoader( |
|
dataset=dataset, |
|
batch_size=8, |
|
shuffle=True, |
|
) |
|
scaled_anchors = ( |
|
torch.tensor(config.ANCHORS) |
|
* torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) |
|
).to(config.DEVICE) |
|
|
|
plot_couple_examples(loaded_model, loader, 0.6,0.5,scaled_anchors) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |