|
import logging |
|
import os |
|
import shutil |
|
import math |
|
import sys |
|
import numpy as np |
|
from tensorboardX import SummaryWriter |
|
from tqdm import tqdm |
|
from typing import Iterable |
|
from pathlib import Path |
|
from time import time |
|
import datetime |
|
import wandb |
|
import cv2 |
|
|
|
import torch |
|
from torchvision.utils import save_image |
|
from torchvision.transforms import functional as TF |
|
|
|
import utils.misc |
|
from modules.components import make_components |
|
import utils.misc as misc |
|
from utils.plot import plot_samples_per_epoch, plot_val_samples |
|
from utils.metrics import calculate_batch_psnr, calculate_batch_ssim |
|
from utils.flowvis import flow2img |
|
from utils.padder import InputPadder |
|
from modules.loss import make_loss_dict |
|
from modules.lr_scheduler import make_lr_scheduler |
|
from modules.optimizer import make_optimizer |
|
from modules.models import make, register |
|
from modules.models.inference_video import inference_demo |
|
from modules.models.unimatch.unimatch import UniMatch |
|
|
|
|
|
@register('base_model') |
|
class BaseModel: |
|
def __init__(self, cfgs): |
|
self.cfgs = cfgs |
|
self.device = torch.cuda.current_device() |
|
|
|
self.current_iteration = 0 |
|
self.current_epoch = 0 |
|
self.model = make_components(self.cfgs['model']) |
|
self.loss_dict = make_loss_dict(cfgs['loss']) |
|
|
|
self.logger = logging.getLogger(self.cfgs['model']['name']) |
|
self.move_components_to_device(cfgs['mode']) |
|
self.model_without_ddp = self.model |
|
if cfgs['distributed']: |
|
self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[cfgs['gpu']]) |
|
self.model_without_ddp = self.model.module |
|
|
|
self.optimizer = make_optimizer(self.model_without_ddp.parameters(), self.cfgs['optimizer']) |
|
self.lr_scheduler = make_lr_scheduler(self.optimizer, cfgs['lr_scheduler']) |
|
|
|
|
|
print(f'Total params: {self.count_parameters()}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(self, file_path): |
|
""" |
|
Load checkpoint |
|
""" |
|
checkpoint = torch.load(file_path, map_location="cpu") |
|
|
|
self.current_epoch = checkpoint['epoch'] |
|
self.current_iteration = checkpoint['iteration'] |
|
self.model_without_ddp.load_state_dict(checkpoint['model']) |
|
self.optimizer.load_state_dict(checkpoint['optimizer']) |
|
self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) |
|
self.logger.info('Chekpoint loaded successfully from {} at epoch: {} and iteration: {}'.format( |
|
file_path, checkpoint['epoch'], checkpoint['iteration'])) |
|
self.move_components_to_device(self.cfgs['mode']) |
|
return self.current_epoch |
|
|
|
def load_pretrained(self, file_path): |
|
""" |
|
Load checkpoint |
|
""" |
|
checkpoint = torch.load(file_path, map_location="cpu") |
|
|
|
|
|
for key in list(checkpoint.keys()): |
|
checkpoint['module.'+key] = checkpoint.pop(key) |
|
if 'state_dict' in checkpoint.keys(): |
|
self.model.load_state_dict(checkpoint['state_dict']) |
|
else: |
|
self.model.load_state_dict(checkpoint) |
|
self.logger.info('Pretrained model loaded successfully from {} '.format( |
|
file_path)) |
|
self.move_components_to_device(self.cfgs['mode']) |
|
return self.current_epoch |
|
|
|
def save_checkpoint(self, file_name, is_best=0): |
|
""" |
|
Save checkpoint |
|
""" |
|
state = { |
|
'epoch': self.current_epoch, |
|
'iteration': self.current_iteration, |
|
'model': self.model_without_ddp.state_dict(), |
|
'optimizer': self.optimizer.state_dict(), |
|
'lr_scheduler': self.lr_scheduler.state_dict() |
|
} |
|
|
|
misc.save_on_master(state, os.path.join(self.cfgs['checkpoint_dir'], file_name)) |
|
|
|
if is_best and misc.is_main_process(): |
|
shutil.copyfile(os.path.join(self.cfgs['checkpoint_dir'], file_name), |
|
os.path.join(self.cfgs['checkpoint_dir'], 'model_best.pth')) |
|
|
|
def adjust_learning_rate(self, epoch): |
|
""" |
|
Adjust learning rate every epoch |
|
""" |
|
self.lr_scheduler.step() |
|
|
|
def train_one_epoch(self, train_loader: Iterable, epoch: int, max_norm: float = 0): |
|
""" |
|
Training step for each mini-batch |
|
""" |
|
self.current_epoch = epoch |
|
self._reset_metric() |
|
|
|
self.model.train() |
|
|
|
header = 'Epoch: [{}]'.format(epoch) |
|
print_freq = 100 |
|
for input_dict in self.metric_logger.log_every(train_loader, print_freq, header): |
|
input_dict = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in input_dict.items()} |
|
result_dict, extra_dict = self.model(**input_dict) |
|
imgt_pred = result_dict['imgt_pred'] |
|
loss = torch.Tensor([0]).to(self.device) |
|
losses = dict() |
|
for k, v in self.loss_dict.items(): |
|
losses[k] = v(**result_dict, **input_dict) |
|
loss += losses[k] |
|
|
|
imgt_pred = torch.clamp(imgt_pred, 0, 1) |
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
if 'gradient_clip' in self.cfgs: |
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfgs['gradient_clip']) |
|
self.optimizer.step() |
|
self.lr_scheduler.step() |
|
|
|
self.metric_logger.update(loss=loss, **losses) |
|
self.metric_logger.update(lr=self.optimizer.param_groups[0]["lr"]) |
|
if misc.is_main_process() and self.current_iteration % print_freq == 0: |
|
nsample = 4 |
|
img0_p, img1_p, gt_p, imgt_pred_p = input_dict['img0'][:nsample].detach(), input_dict['img1'][:nsample].detach(), \ |
|
input_dict['imgt'][:nsample].detach(), imgt_pred[:nsample].detach() |
|
overlapped_img = img0_p * 0.5 + img1_p * 0.5 |
|
|
|
flowfwd = flow2img(result_dict['flowfwd'][:nsample].detach()) |
|
if self.cfgs['train_dataset']['args']['flow'] != 'none': |
|
flowfwd_gt = flow2img(input_dict['flowt0'][:nsample]) |
|
|
|
figure = torch.stack([overlapped_img, imgt_pred_p, flowfwd]) |
|
|
|
|
|
else: |
|
figure = torch.stack( |
|
[overlapped_img, imgt_pred_p, flowfwd, gt_p]) |
|
figure = torch.transpose(figure, 0, 1).reshape(-1, 3, self.cfgs['train_dataset']['args']['patch_size'], |
|
self.cfgs['train_dataset']['args']['patch_size']) |
|
image = plot_samples_per_epoch(figure, os.path.join(self.cfgs['output_dir'], "imgs_train"), |
|
self.current_epoch, self.current_iteration, nsample) |
|
self.summary_writer.add_scalar("Train/loss", loss, self.current_iteration) |
|
for k, v in losses.items(): |
|
self.summary_writer.add_scalar(f'Train/loss_{k}', v, self.current_iteration) |
|
self.summary_writer.add_scalar("Train/LR", self.lr_scheduler.get_last_lr(), self.current_iteration) |
|
|
|
if self.cfgs['enable_wandb']: |
|
wandb.log({"loss": loss}, step=self.current_iteration) |
|
for k, v in losses.items(): |
|
wandb.log({f'loss_{k}': v}, step=self.current_iteration) |
|
wandb.log({"lr": torch.Tensor(self.lr_scheduler.get_last_lr())}, |
|
step=self.current_iteration) |
|
if self.current_iteration % (print_freq * 10) == 0: |
|
wandb.log({"Image": wandb.Image(image)}, step=self.current_iteration) |
|
|
|
self.current_iteration += 1 |
|
|
|
|
|
self.metric_logger.synchronize_between_processes() |
|
self.current_epoch += 1 |
|
if utils.misc.is_main_process(): |
|
self.logger.info(f"Averaged training stats: {self.metric_logger}") |
|
|
|
@torch.no_grad() |
|
def validate(self, val_loader): |
|
""" |
|
Validation step for each mini-batch |
|
""" |
|
self.model.eval() |
|
|
|
self.metric_logger = misc.MetricLogger(delimiter=" ") |
|
self.metric_logger.add_meter('psnr', misc.SmoothedValue(window_size=1, fmt='{value:.2f}')) |
|
self.metric_logger.add_meter('ssim', misc.SmoothedValue(window_size=1, fmt='{value:.2f}')) |
|
header = 'Test:' |
|
psnr_dict = {} |
|
|
|
print_freq = 10 |
|
|
|
for input_dict in self.metric_logger.log_every(val_loader, print_freq, header): |
|
input_dict = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in input_dict.items()} |
|
img0 = input_dict['img0'] |
|
imgt = input_dict['imgt'] |
|
img1 = input_dict['img1'] |
|
result_dict, extra_dict = self.model(**input_dict) |
|
|
|
scene_names = input_dict['scene_name'] |
|
|
|
imgt_pred = result_dict['imgt_pred'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
psnr, psnr_list = calculate_batch_psnr(imgt, imgt_pred) |
|
ssim, bs = calculate_batch_ssim(imgt, imgt_pred) |
|
self.metric_logger.update(psnr={'value': psnr, 'n': len(psnr_list)}, |
|
ssim={'value': ssim, 'n': len(psnr_list)}) |
|
if (self.current_epoch!=0) and ((self.current_epoch % self.cfgs['vis_every'] == 0) or (self.cfgs['mode'] != 'train' and self.cfgs['test_dataset']['save_imgs'])): |
|
for i in range(len(scene_names)): |
|
psnr_dict[scene_names[i]] = float(psnr_list[i]) |
|
if self.cfgs['mode'] == "test": |
|
scene_path = os.path.join(self.cfgs['output_dir'], "imgs_test", |
|
f"{self.cfgs['test_dataset']['name']}_{self.cfgs['test_dataset']['args']['split']}", |
|
scene_names[i]) |
|
else: |
|
scene_path = os.path.join(self.cfgs['output_dir'], "imgs_val", |
|
f"{self.cfgs['test_dataset']['name']}_{self.cfgs['test_dataset']['args']['split']}", |
|
scene_names[i]) |
|
Path(scene_path).mkdir(exist_ok=True, parents=True) |
|
save_image(img0[i], os.path.join(scene_path, "img0.png")) |
|
save_image(imgt_pred[i], os.path.join(scene_path, "imgt_pred.png")) |
|
save_image(imgt[i], os.path.join(scene_path, "imgt.png")) |
|
save_image(img1[i], os.path.join(scene_path, "img1.png")) |
|
save_image((img1[i] + img0[i]) / 2, os.path.join(scene_path, "overlayedd.png")) |
|
save_image(flow2img(result_dict['flowfwd'])[i], os.path.join(scene_path, "flow_fwd.png")) |
|
save_image(flow2img(result_dict['flowbwd'])[i], os.path.join(scene_path, "flow_bwd.png")) |
|
|
|
|
|
|
|
|
|
|
|
self.logger.info(f"Averaged validate stats:{self.metric_logger.print_avg()}") |
|
if (self.current_epoch!=0) and ((self.current_epoch % self.cfgs['vis_every'] == 0) or (self.cfgs['mode'] != 'train' and self.cfgs['test_dataset']['save_imgs'])): |
|
psnr_str = [] |
|
psnr_dict = sorted(psnr_dict.items(), key=lambda item: item[1]) |
|
for key, val in psnr_dict: |
|
psnr_str.append("{}: {}".format(key, val)) |
|
psnr_str = "\n".join(psnr_str) |
|
if self.cfgs['mode'] == "test": |
|
outdir = os.path.join(self.cfgs['output_dir'], "imgs_test", |
|
f"{self.cfgs['test_dataset']['name']}_{self.cfgs['test_dataset']['args']['split']}") |
|
else: |
|
outdir = os.path.join(self.cfgs['output_dir'], "imgs_val", |
|
f"{self.cfgs['test_dataset']['name']}_{self.cfgs['test_dataset']['args']['split']}") |
|
with open(os.path.join(outdir, "results.txt"), "w") as f: |
|
f.write(psnr_str) |
|
if misc.is_main_process() and self.cfgs['mode'] == 'train': |
|
self.summary_writer.add_scalar("Val/psnr", self.metric_logger.psnr.global_avg, self.current_epoch) |
|
self.summary_writer.add_scalar("Val/ssim", self.metric_logger.ssim.global_avg, self.current_epoch) |
|
if self.cfgs['enable_wandb']: |
|
wandb.log({'val_psnr': self.metric_logger.psnr.global_avg, 'val_ssim': self.metric_logger.ssim.global_avg}, |
|
step=self.current_iteration) |
|
return self.metric_logger.psnr.global_avg |
|
|
|
@torch.no_grad() |
|
def demo(self, video_dir): |
|
start_time = time() |
|
for video_name in os.listdir(video_dir): |
|
|
|
video_path = os.path.join(video_dir, video_name) |
|
out_path = os.path.join(self.cfgs['output_dir'], 'demo', video_name.split(".")[0]) |
|
inference_demo(self.model, 2, video_path, out_path) |
|
total_time_str = str(datetime.timedelta(seconds=int(time() - start_time))) |
|
print("Total time: {}".format(total_time_str)) |
|
|
|
def init_training_logger(self): |
|
""" |
|
Initialize training logger specific for each model |
|
""" |
|
if misc.is_main_process(): |
|
self.summary_writer = SummaryWriter(log_dir=self.cfgs['summary_dir'], comment='m2mpwc') |
|
Path(os.path.join(self.cfgs['output_dir'], 'imgs_train')).mkdir(parents=True, exist_ok=True) |
|
Path(os.path.join(self.cfgs['output_dir'], 'imgs_val')).mkdir(parents=True, exist_ok=True) |
|
self._reset_metric() |
|
|
|
def init_validation_logger(self): |
|
""" |
|
Initialize validation logger specific for each model |
|
""" |
|
if misc.is_main_process(): |
|
self.summary_writer = SummaryWriter(log_dir=self.cfgs['summary_dir'], comment='m2mpwc') |
|
Path(os.path.join(self.cfgs['output_dir'], 'imgs_val')).mkdir(parents=True, exist_ok=True) |
|
self._reset_metric() |
|
|
|
def init_testing_logger(self): |
|
""" |
|
Initialize testing logger specific for each model |
|
""" |
|
if misc.is_main_process(): |
|
self.summary_writer = SummaryWriter(log_dir=self.cfgs['summary_dir'], comment='m2mpwc') |
|
Path(os.path.join(self.cfgs['output_dir'], 'imgs_test')).mkdir(parents=True, exist_ok=True) |
|
self._reset_metric() |
|
|
|
def init_demo_logger(self): |
|
""" |
|
Initialize testing logger specific for each model |
|
""" |
|
if misc.is_main_process(): |
|
self.summary_writer = SummaryWriter(log_dir=self.cfgs['summary_dir'], comment='m2mpwc') |
|
Path(os.path.join(self.cfgs['output_dir'], 'demo')).mkdir(parents=True, exist_ok=True) |
|
self._reset_metric() |
|
|
|
def finalize_training(self): |
|
if misc.is_main_process(): |
|
self.summary_writer.close() |
|
|
|
def move_components_to_device(self, mode): |
|
""" |
|
Move components to device |
|
""" |
|
self.model.to(self.device) |
|
for _, v in self.loss_dict.items(): |
|
v.to(self.device) |
|
self.logger.info('Model: {}'.format(self.model)) |
|
|
|
def _reset_metric(self): |
|
""" |
|
Metric related to average meter |
|
""" |
|
self.metric_logger = misc.MetricLogger(delimiter=" ") |
|
self.metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
self.metric_logger.add_meter('loss', misc.SmoothedValue(window_size=20)) |
|
|
|
def count_parameters(self): |
|
""" |
|
Return the number of parameters for the model |
|
""" |
|
model_number = sum(p.numel() for p in self.model_without_ddp.parameters() if p.requires_grad) |
|
return model_number |
|
|