DiffIR2VR / GMFlow /main.py
jimmycv07's picture
first commit
1de8821
raw
history blame
26.8 kB
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import argparse
import numpy as np
import os
from data import build_train_dataset
from gmflow.gmflow import GMFlow
from loss import flow_loss_func
from evaluate import (validate_chairs, validate_things, validate_sintel, validate_kitti,
create_sintel_submission, create_kitti_submission, inference_on_dir)
from utils.logger import Logger
from utils import misc
from utils.dist_utils import get_dist_info, init_dist, setup_for_distributed
def get_args_parser():
parser = argparse.ArgumentParser()
# dataset
parser.add_argument('--checkpoint_dir', default='tmp', type=str,
help='where to save the training log and models')
parser.add_argument('--stage', default='chairs', type=str,
help='training stage')
parser.add_argument('--image_size', default=[384, 512], type=int, nargs='+',
help='image size for training')
parser.add_argument('--padding_factor', default=16, type=int,
help='the input should be divisible by padding_factor, otherwise do padding')
parser.add_argument('--max_flow', default=400, type=int,
help='exclude very large motions during training')
parser.add_argument('--val_dataset', default=['chairs'], type=str, nargs='+',
help='validation dataset')
parser.add_argument('--with_speed_metric', action='store_true',
help='with speed metric when evaluation')
# training
parser.add_argument('--lr', default=4e-4, type=float)
parser.add_argument('--batch_size', default=12, type=int)
parser.add_argument('--num_workers', default=4, type=int)
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--grad_clip', default=1.0, type=float)
parser.add_argument('--num_steps', default=100000, type=int)
parser.add_argument('--seed', default=326, type=int)
parser.add_argument('--summary_freq', default=100, type=int)
parser.add_argument('--val_freq', default=10000, type=int)
parser.add_argument('--save_ckpt_freq', default=10000, type=int)
parser.add_argument('--save_latest_ckpt_freq', default=1000, type=int)
# resume pretrained model or resume training
parser.add_argument('--resume', default=None, type=str,
help='resume from pretrain model for finetuing or resume from terminated training')
parser.add_argument('--strict_resume', action='store_true')
parser.add_argument('--no_resume_optimizer', action='store_true')
# GMFlow model
parser.add_argument('--num_scales', default=1, type=int,
help='basic gmflow model uses a single 1/8 feature, the refinement uses 1/4 feature')
parser.add_argument('--feature_channels', default=128, type=int)
parser.add_argument('--upsample_factor', default=8, type=int)
parser.add_argument('--num_transformer_layers', default=6, type=int)
parser.add_argument('--num_head', default=1, type=int)
parser.add_argument('--attention_type', default='swin', type=str)
parser.add_argument('--ffn_dim_expansion', default=4, type=int)
parser.add_argument('--attn_splits_list', default=[2], type=int, nargs='+',
help='number of splits in attention')
parser.add_argument('--corr_radius_list', default=[-1], type=int, nargs='+',
help='correlation radius for matching, -1 indicates global matching')
parser.add_argument('--prop_radius_list', default=[-1], type=int, nargs='+',
help='self-attention radius for flow propagation, -1 indicates global attention')
# loss
parser.add_argument('--gamma', default=0.9, type=float,
help='loss weight')
# evaluation
parser.add_argument('--eval', action='store_true')
parser.add_argument('--save_eval_to_file', action='store_true')
parser.add_argument('--evaluate_matched_unmatched', action='store_true')
# inference on a directory
parser.add_argument('--inference_dir', default=None, type=str)
parser.add_argument('--inference_size', default=None, type=int, nargs='+',
help='can specify the inference size')
parser.add_argument('--dir_paired_data', action='store_true',
help='Paired data in a dir instead of a sequence')
parser.add_argument('--save_flo_flow', action='store_true')
parser.add_argument('--pred_bidir_flow', action='store_true',
help='predict bidirectional flow')
parser.add_argument('--fwd_bwd_consistency_check', action='store_true',
help='forward backward consistency check with bidirection flow')
# predict on sintel and kitti test set for submission
parser.add_argument('--submission', action='store_true',
help='submission to sintel or kitti test sets')
parser.add_argument('--output_path', default='output', type=str,
help='where to save the prediction results')
parser.add_argument('--save_vis_flow', action='store_true',
help='visualize flow prediction as .png image')
parser.add_argument('--no_save_flo', action='store_true',
help='not save flow as .flo')
# distributed training
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--distributed', action='store_true')
parser.add_argument('--launcher', default='none', type=str, choices=['none', 'pytorch'])
parser.add_argument('--gpu_ids', default=0, type=int, nargs='+')
parser.add_argument('--count_time', action='store_true',
help='measure the inference time on sintel')
return parser
def main(args):
if not args.eval and not args.submission and args.inference_dir is None:
if args.local_rank == 0:
print('pytorch version:', torch.__version__)
print(args)
misc.save_args(args)
misc.check_path(args.checkpoint_dir)
misc.save_command(args.checkpoint_dir)
seed = args.seed
torch.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = True
if args.launcher == 'none':
args.distributed = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
args.distributed = True
# adjust batch size for each gpu
assert args.batch_size % torch.cuda.device_count() == 0
args.batch_size = args.batch_size // torch.cuda.device_count()
dist_params = dict(backend='nccl')
init_dist(args.launcher, **dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
args.gpu_ids = range(world_size)
device = torch.device('cuda:{}'.format(args.local_rank))
setup_for_distributed(args.local_rank == 0)
# model
model = GMFlow(feature_channels=args.feature_channels,
num_scales=args.num_scales,
upsample_factor=args.upsample_factor,
num_head=args.num_head,
attention_type=args.attention_type,
ffn_dim_expansion=args.ffn_dim_expansion,
num_transformer_layers=args.num_transformer_layers,
).to(device)
if not args.eval and not args.submission and not args.inference_dir:
print('Model definition:')
print(model)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(
model.to(device),
device_ids=[args.local_rank],
output_device=args.local_rank)
model_without_ddp = model.module
else:
if torch.cuda.device_count() > 1:
print('Use %d GPUs' % torch.cuda.device_count())
model = torch.nn.DataParallel(model)
model_without_ddp = model.module
else:
model_without_ddp = model
num_params = sum(p.numel() for p in model.parameters())
print('Number of params:', num_params)
if not args.eval and not args.submission and args.inference_dir is None:
save_name = '%d_parameters' % num_params
open(os.path.join(args.checkpoint_dir, save_name), 'a').close()
optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
start_epoch = 0
start_step = 0
# resume checkpoints
if args.resume:
print('Load checkpoint: %s' % args.resume)
loc = 'cuda:{}'.format(args.local_rank)
checkpoint = torch.load(args.resume, map_location=loc)
weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
model_without_ddp.load_state_dict(weights, strict=args.strict_resume)
if 'optimizer' in checkpoint and 'step' in checkpoint and 'epoch' in checkpoint and not \
args.no_resume_optimizer:
print('Load optimizer')
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
start_step = checkpoint['step']
print('start_epoch: %d, start_step: %d' % (start_epoch, start_step))
# evaluate
if args.eval:
val_results = {}
if 'chairs' in args.val_dataset:
results_dict = validate_chairs(model_without_ddp,
with_speed_metric=args.with_speed_metric,
attn_splits_list=args.attn_splits_list,
corr_radius_list=args.corr_radius_list,
prop_radius_list=args.prop_radius_list,
)
val_results.update(results_dict)
if 'things' in args.val_dataset:
results_dict = validate_things(model_without_ddp,
padding_factor=args.padding_factor,
with_speed_metric=args.with_speed_metric,
attn_splits_list=args.attn_splits_list,
corr_radius_list=args.corr_radius_list,
prop_radius_list=args.prop_radius_list,
)
val_results.update(results_dict)
if 'sintel' in args.val_dataset:
results_dict = validate_sintel(model_without_ddp,
count_time=args.count_time,
padding_factor=args.padding_factor,
with_speed_metric=args.with_speed_metric,
evaluate_matched_unmatched=args.evaluate_matched_unmatched,
attn_splits_list=args.attn_splits_list,
corr_radius_list=args.corr_radius_list,
prop_radius_list=args.prop_radius_list,
)
val_results.update(results_dict)
if 'kitti' in args.val_dataset:
results_dict = validate_kitti(model_without_ddp,
padding_factor=args.padding_factor,
with_speed_metric=args.with_speed_metric,
attn_splits_list=args.attn_splits_list,
corr_radius_list=args.corr_radius_list,
prop_radius_list=args.prop_radius_list,
)
val_results.update(results_dict)
if args.save_eval_to_file:
misc.check_path(args.checkpoint_dir)
val_file = os.path.join(args.checkpoint_dir, 'val_results.txt')
with open(val_file, 'a') as f:
f.write('\neval results after training done\n\n')
metrics = ['chairs_epe', 'chairs_s0_10', 'chairs_s10_40', 'chairs_s40+',
'things_clean_epe', 'things_clean_s0_10', 'things_clean_s10_40', 'things_clean_s40+',
'things_final_epe', 'things_final_s0_10', 'things_final_s10_40', 'things_final_s40+',
'sintel_clean_epe', 'sintel_clean_s0_10', 'sintel_clean_s10_40', 'sintel_clean_s40+',
'sintel_final_epe', 'sintel_final_s0_10', 'sintel_final_s10_40', 'sintel_final_s40+',
'kitti_epe', 'kitti_f1', 'kitti_s0_10', 'kitti_s10_40', 'kitti_s40+',
]
eval_metrics = []
for metric in metrics:
if metric in val_results.keys():
eval_metrics.append(metric)
metrics_values = [val_results[metric] for metric in eval_metrics]
num_metrics = len(eval_metrics)
# save as markdown format
f.write(("| {:>20} " * num_metrics + '\n').format(*eval_metrics))
f.write(("| {:20.3f} " * num_metrics).format(*metrics_values))
f.write('\n\n')
return
# Sintel and KITTI submission
if args.submission:
# NOTE: args.val_dataset is a list
if args.val_dataset[0] == 'sintel':
create_sintel_submission(model_without_ddp,
output_path=args.output_path,
padding_factor=args.padding_factor,
save_vis_flow=args.save_vis_flow,
no_save_flo=args.no_save_flo,
attn_splits_list=args.attn_splits_list,
corr_radius_list=args.corr_radius_list,
prop_radius_list=args.prop_radius_list,
)
elif args.val_dataset[0] == 'kitti':
create_kitti_submission(model_without_ddp,
output_path=args.output_path,
padding_factor=args.padding_factor,
save_vis_flow=args.save_vis_flow,
attn_splits_list=args.attn_splits_list,
corr_radius_list=args.corr_radius_list,
prop_radius_list=args.prop_radius_list,
)
else:
raise ValueError(f'Not supported dataset for submission')
return
# inferece on a dir
if args.inference_dir is not None:
inference_on_dir(model_without_ddp,
inference_dir=args.inference_dir,
output_path=args.output_path,
padding_factor=args.padding_factor,
inference_size=args.inference_size,
paired_data=args.dir_paired_data,
save_flo_flow=args.save_flo_flow,
attn_splits_list=args.attn_splits_list,
corr_radius_list=args.corr_radius_list,
prop_radius_list=args.prop_radius_list,
pred_bidir_flow=args.pred_bidir_flow,
fwd_bwd_consistency_check=args.fwd_bwd_consistency_check,
)
return
# training datset
train_dataset = build_train_dataset(args)
print('Number of training images:', len(train_dataset))
# Multi-processing
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=torch.cuda.device_count(),
rank=args.local_rank)
else:
train_sampler = None
shuffle = False if args.distributed else True
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
shuffle=shuffle, num_workers=args.num_workers,
pin_memory=True, drop_last=True,
sampler=train_sampler)
last_epoch = start_step if args.resume and start_step > 0 else -1
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, args.lr,
args.num_steps + 10,
pct_start=0.05,
cycle_momentum=False,
anneal_strategy='cos',
last_epoch=last_epoch,
)
if args.local_rank == 0:
summary_writer = SummaryWriter(args.checkpoint_dir)
logger = Logger(lr_scheduler, summary_writer, args.summary_freq,
start_step=start_step)
total_steps = start_step
epoch = start_epoch
print('Start training')
while total_steps < args.num_steps:
model.train()
# mannual change random seed for shuffling every epoch
if args.distributed:
train_sampler.set_epoch(epoch)
for i, sample in enumerate(train_loader):
img1, img2, flow_gt, valid = [x.to(device) for x in sample]
results_dict = model(img1, img2,
attn_splits_list=args.attn_splits_list,
corr_radius_list=args.corr_radius_list,
prop_radius_list=args.prop_radius_list,
)
flow_preds = results_dict['flow_preds']
loss, metrics = flow_loss_func(flow_preds, flow_gt, valid,
gamma=args.gamma,
max_flow=args.max_flow,
)
if isinstance(loss, float):
continue
if torch.isnan(loss):
continue
metrics.update({'total_loss': loss.item()})
# more efficient zero_grad
for param in model_without_ddp.parameters():
param.grad = None
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
lr_scheduler.step()
if args.local_rank == 0:
logger.push(metrics)
logger.add_image_summary(img1, img2, flow_preds, flow_gt)
total_steps += 1
if total_steps % args.save_ckpt_freq == 0 or total_steps == args.num_steps:
if args.local_rank == 0:
checkpoint_path = os.path.join(args.checkpoint_dir, 'step_%06d.pth' % total_steps)
torch.save({
'model': model_without_ddp.state_dict()
}, checkpoint_path)
if total_steps % args.save_latest_ckpt_freq == 0:
checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint_latest.pth')
if args.local_rank == 0:
torch.save({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'step': total_steps,
'epoch': epoch,
}, checkpoint_path)
if total_steps % args.val_freq == 0:
print('Start validation')
val_results = {}
# support validation on multiple datasets
if 'chairs' in args.val_dataset:
results_dict = validate_chairs(model_without_ddp,
with_speed_metric=args.with_speed_metric,
attn_splits_list=args.attn_splits_list,
corr_radius_list=args.corr_radius_list,
prop_radius_list=args.prop_radius_list,
)
if args.local_rank == 0:
val_results.update(results_dict)
if 'things' in args.val_dataset:
results_dict = validate_things(model_without_ddp,
padding_factor=args.padding_factor,
with_speed_metric=args.with_speed_metric,
attn_splits_list=args.attn_splits_list,
corr_radius_list=args.corr_radius_list,
prop_radius_list=args.prop_radius_list,
)
if args.local_rank == 0:
val_results.update(results_dict)
if 'sintel' in args.val_dataset:
results_dict = validate_sintel(model_without_ddp,
count_time=args.count_time,
padding_factor=args.padding_factor,
with_speed_metric=args.with_speed_metric,
evaluate_matched_unmatched=args.evaluate_matched_unmatched,
attn_splits_list=args.attn_splits_list,
corr_radius_list=args.corr_radius_list,
prop_radius_list=args.prop_radius_list,
)
if args.local_rank == 0:
val_results.update(results_dict)
if 'kitti' in args.val_dataset:
results_dict = validate_kitti(model_without_ddp,
padding_factor=args.padding_factor,
with_speed_metric=args.with_speed_metric,
attn_splits_list=args.attn_splits_list,
corr_radius_list=args.corr_radius_list,
prop_radius_list=args.prop_radius_list,
)
if args.local_rank == 0:
val_results.update(results_dict)
if args.local_rank == 0:
logger.write_dict(val_results)
# Save validation results
val_file = os.path.join(args.checkpoint_dir, 'val_results.txt')
with open(val_file, 'a') as f:
f.write('step: %06d\n' % total_steps)
if args.evaluate_matched_unmatched:
metrics = ['chairs_epe',
'chairs_s0_10', 'chairs_s10_40', 'chairs_s40+',
'things_clean_epe', 'things_clean_s0_10', 'things_clean_s10_40',
'things_clean_s40+',
'sintel_clean_epe', 'sintel_clean_matched', 'sintel_clean_unmatched',
'sintel_clean_s0_10', 'sintel_clean_s10_40',
'sintel_clean_s40+',
'sintel_final_epe', 'sintel_final_matched', 'sintel_final_unmatched',
'sintel_final_s0_10', 'sintel_final_s10_40',
'sintel_final_s40+',
'kitti_epe', 'kitti_f1', 'kitti_s0_10', 'kitti_s10_40', 'kitti_s40+',
]
else:
metrics = ['chairs_epe', 'chairs_s0_10', 'chairs_s10_40', 'chairs_s40+',
'things_clean_epe', 'things_clean_s0_10', 'things_clean_s10_40',
'things_clean_s40+',
'sintel_clean_epe', 'sintel_clean_s0_10', 'sintel_clean_s10_40',
'sintel_clean_s40+',
'sintel_final_epe', 'sintel_final_s0_10', 'sintel_final_s10_40',
'sintel_final_s40+',
'kitti_epe', 'kitti_f1', 'kitti_s0_10', 'kitti_s10_40', 'kitti_s40+',
]
eval_metrics = []
for metric in metrics:
if metric in val_results.keys():
eval_metrics.append(metric)
metrics_values = [val_results[metric] for metric in eval_metrics]
num_metrics = len(eval_metrics)
# save as markdown format
if args.evaluate_matched_unmatched:
f.write(("| {:>25} " * num_metrics + '\n').format(*eval_metrics))
f.write(("| {:25.3f} " * num_metrics).format(*metrics_values))
else:
f.write(("| {:>20} " * num_metrics + '\n').format(*eval_metrics))
f.write(("| {:20.3f} " * num_metrics).format(*metrics_values))
f.write('\n\n')
model.train()
if total_steps >= args.num_steps:
print('Training done')
return
epoch += 1
if __name__ == '__main__':
parser = get_args_parser()
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
main(args)