|
import datetime |
|
import math |
|
import os |
|
from functools import partial |
|
|
|
import albumentations as A |
|
import torch.optim as optim |
|
from termcolor import cprint |
|
from timm.scheduler import create_scheduler |
|
from torch.utils.data import DataLoader |
|
|
|
import utils.misc as misc |
|
from datasets import crop_to_smallest_collate_fn, get_dataset |
|
from engine import bundled_evaluate, train |
|
from losses import get_bundled_loss, get_loss |
|
from models import get_ensemble_model, get_single_modal_model |
|
from opt import get_opt |
|
|
|
|
|
def main(opt): |
|
|
|
writer = misc.setup_env(opt) |
|
|
|
|
|
|
|
train_loaders = {} |
|
if not opt.eval: |
|
train_transform = A.Compose( |
|
[ |
|
A.HorizontalFlip(0.5), |
|
A.SmallestMaxSize(int(opt.input_size * 1.5)) |
|
if opt.resize_aug |
|
else A.NoOp(), |
|
A.RandomSizedCrop( |
|
(opt.input_size, int(opt.input_size * 1.5)), |
|
opt.input_size, |
|
opt.input_size, |
|
) |
|
if opt.resize_aug |
|
else A.NoOp(), |
|
A.NoOp() if opt.no_gaussian_blur else A.GaussianBlur(p=0.5), |
|
A.NoOp() if opt.no_color_jitter else A.ColorJitter(p=0.5), |
|
A.NoOp() if opt.no_jpeg_compression else A.ImageCompression(p=0.5), |
|
] |
|
) |
|
train_sets = get_dataset(opt.train_datalist, "train", train_transform, opt) |
|
for k, dataset in train_sets.items(): |
|
train_loaders[k] = DataLoader( |
|
dataset, |
|
batch_size=opt.batch_size, |
|
shuffle=True, |
|
pin_memory=True, |
|
num_workers=0 if opt.debug else opt.num_workers, |
|
collate_fn=partial( |
|
crop_to_smallest_collate_fn, |
|
max_size=opt.input_size, |
|
uncorrect_label=opt.uncorrect_label, |
|
), |
|
) |
|
|
|
if opt.large_image_strategy == "rescale": |
|
val_transform = A.Compose([A.SmallestMaxSize(opt.tile_size)]) |
|
else: |
|
val_transform = None |
|
val_sets = get_dataset(opt.val_datalist, opt.val_set, val_transform, opt) |
|
val_loaders = {} |
|
for k, dataset in val_sets.items(): |
|
val_loaders[k] = DataLoader( |
|
dataset, |
|
batch_size=1, |
|
shuffle=opt.val_shuffle, |
|
pin_memory=True, |
|
num_workers=0 if opt.debug else opt.num_workers, |
|
) |
|
|
|
|
|
optimizer_dict = {} |
|
scheduler_dict = {} |
|
model = get_ensemble_model(opt).to(opt.device) |
|
n_param = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
print( |
|
f"Number of total params: {n_param}, num params per model: {int(n_param / len(opt.modality))}" |
|
) |
|
|
|
|
|
for modality in opt.modality: |
|
if opt.optimizer.lower() == "adamw": |
|
optimizer = optim.AdamW( |
|
model.sub_models[modality].parameters(), |
|
opt.lr, |
|
weight_decay=opt.weight_decay, |
|
) |
|
elif opt.optimizer.lower() == "sgd": |
|
optimizer = optim.SGD( |
|
model.sub_models[modality].parameters(), |
|
opt.lr, |
|
opt.momentum, |
|
weight_decay=opt.weight_decay, |
|
) |
|
else: |
|
raise RuntimeError(f"Unsupported optimizer {opt.optimizer}.") |
|
|
|
scheduler, num_epoch = create_scheduler(opt, optimizer) |
|
|
|
optimizer_dict[modality] = optimizer |
|
scheduler_dict[modality] = scheduler |
|
opt.epochs = num_epoch |
|
|
|
|
|
|
|
bundled_criterion = get_bundled_loss(opt).to(opt.device) |
|
|
|
single_criterion = get_loss(opt).to(opt.device) |
|
|
|
if opt.resume: |
|
misc.resume_from(model, opt.resume) |
|
|
|
if opt.eval: |
|
bundled_evaluate( |
|
model, val_loaders, single_criterion, 0, writer, suffix="val", opt=opt |
|
) |
|
return |
|
|
|
cprint("The training will last for {} epochs.".format(opt.epochs), "blue") |
|
best_ensemble_image_f1 = -math.inf |
|
for epoch in range(opt.epochs): |
|
for title, dataloader in train_loaders.items(): |
|
train( |
|
model, |
|
dataloader, |
|
title, |
|
optimizer_dict, |
|
bundled_criterion, |
|
epoch, |
|
writer, |
|
suffix="train", |
|
opt=opt, |
|
) |
|
for sched_idx, scheduler in enumerate(scheduler_dict.values()): |
|
if sched_idx == 0 and writer is not None: |
|
writer.add_scalar("lr", scheduler._get_lr(epoch)[0], epoch) |
|
scheduler.step(epoch) |
|
|
|
if (epoch + 1) % opt.eval_freq == 0 or epoch in [opt.epochs - 1]: |
|
result = bundled_evaluate( |
|
model, |
|
val_loaders, |
|
single_criterion, |
|
epoch, |
|
writer, |
|
suffix="val", |
|
opt=opt, |
|
) |
|
misc.save_model( |
|
os.path.join( |
|
opt.save_root_path, opt.dir_name, "checkpoint", f"{epoch}.pt" |
|
), |
|
model, |
|
epoch, |
|
opt, |
|
performance=result, |
|
) |
|
if result["image_f1/AVG_ensemble"] > best_ensemble_image_f1: |
|
best_ensemble_image_f1 = result["image_f1/AVG_ensemble"] |
|
misc.save_model( |
|
os.path.join( |
|
opt.save_root_path, opt.dir_name, "checkpoint", "best.pt" |
|
), |
|
model, |
|
epoch, |
|
opt, |
|
performance=result, |
|
) |
|
misc.update_record(result, epoch, opt, "best_record") |
|
misc.update_record(result, epoch, opt, "latest_record") |
|
|
|
print("best performance:", best_ensemble_image_f1) |
|
|
|
|
|
if __name__ == "__main__": |
|
opt = get_opt() |
|
|
|
|
|
|
|
|
|
|
|
|
|
st = datetime.datetime.now() |
|
main(opt) |
|
total_time = datetime.datetime.now() - st |
|
total_time = str(datetime.timedelta(seconds=total_time.seconds)) |
|
print(f"Total time: {total_time}") |
|
|
|
print("finished") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|