Jiading Fang
add define
fc16538
raw
history blame
10.3 kB
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import os
import random
from abc import ABC
from collections import OrderedDict
import torch
from vidar.utils.config import cfg_has, read_config
from vidar.utils.data import set_random_seed
from vidar.utils.distributed import rank, world_size
from vidar.utils.flip import flip_batch, flip_output
from vidar.utils.logging import pcolor, set_debug
from vidar.utils.networks import load_checkpoint, save_checkpoint, freeze_layers_and_norms
from vidar.utils.setup import setup_arch, setup_datasets, setup_metrics
from vidar.utils.types import is_str
class Wrapper(torch.nn.Module, ABC):
"""
Trainer class for model optimization and inference
Parameters
----------
cfg : Config
Configuration with parameters
ckpt : String
Name of the model checkpoint to start from
verbose : Bool
Print information on screen if enabled
"""
def __init__(self, cfg, ckpt=None, verbose=False):
super().__init__()
if verbose and rank() == 0:
font = {'color': 'cyan', 'attrs': ('bold', 'dark')}
print(pcolor('#' * 100, **font))
print(pcolor('#' * 42 + ' VIDAR WRAPPER ' + '#' * 43, **font))
print(pcolor('#' * 100, **font))
# Get configuration
cfg = read_config(cfg) if is_str(cfg) else cfg
self.cfg = cfg
# Data augmentations
self.flip_lr_prob = cfg_has(cfg.wrapper, 'flip_lr_prob', 0.0)
self.validate_flipped = cfg_has(cfg.wrapper, 'validate_flipped', False)
# Set random seed
set_random_seed(cfg.wrapper.seed + rank())
set_debug(cfg_has(cfg.wrapper, 'debug', False))
# Setup architecture, datasets and tasks
self.arch = setup_arch(cfg.arch, checkpoint=ckpt, verbose=verbose) if cfg_has(cfg, 'arch') else None
self.datasets, self.datasets_cfg = setup_datasets(
cfg.datasets, verbose=verbose) if cfg_has(cfg, 'datasets') else (None, None)
self.metrics = setup_metrics(cfg.evaluation) if cfg_has(cfg, 'evaluation') else {}
sync_batch_norm = cfg_has(cfg.wrapper, 'sync_batch_norm', False)
if sync_batch_norm and os.environ['DIST_MODE'] == 'ddp':
self.arch = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.arch)
self.mixed_precision = cfg_has(cfg.wrapper, 'mixed_precision', False)
self.update_schedulers = None
def save(self, filename, epoch=None):
"""Save checkpoint"""
save_checkpoint(filename, self, epoch=epoch)
def load(self, checkpoint, strict=True, verbose=False):
"""Load checkpoint"""
load_checkpoint(self, checkpoint, strict=strict, verbose=verbose)
def train_custom(self, in_optimizers, out_optimizers):
"""Customized training flag for the model"""
self.arch.train()
for key in in_optimizers.keys():
arch = self.arch.module if hasattr(self.arch, 'module') else self.arch
freeze_layers_and_norms(arch.networks[key], ['ALL'], flag_freeze=False)
for key in out_optimizers.keys():
arch = self.arch.module if hasattr(self.arch, 'module') else self.arch
freeze_layers_and_norms(arch.networks[key], ['ALL'], flag_freeze=True)
def eval_custom(self):
"""Customized evaluation flag for the model"""
self.arch.eval()
def configure_optimizers_and_schedulers(self):
"""Configure depth and pose optimizers and the corresponding scheduler"""
if not cfg_has(self.cfg, 'optimizers'):
return None, None
optimizers = OrderedDict()
schedulers = OrderedDict()
for key, val in self.cfg.optimizers.__dict__.items():
assert key in self.arch.networks, f'There is no network for optimizer {key}'
optimizers[key] = {
'optimizer': getattr(torch.optim, val.name)(**{
'lr': val.lr,
'weight_decay': cfg_has(val, 'weight_decay', 0.0),
'params': self.arch.networks[key].parameters(),
}),
'settings': {} if not cfg_has(val, 'settings') else val.settings.__dict__
}
if cfg_has(val, 'scheduler'):
if val.scheduler.name == 'CosineAnnealingWarmUpRestarts':
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts
epoch = float(len(self.datasets['train']) / (
world_size() * self.datasets_cfg['train'].dataloader.batch_size * self.datasets_cfg['train'].repeat[0]))
schedulers[key] = CosineAnnealingWarmupRestarts(**{
'optimizer': optimizers[key]['optimizer'],
'first_cycle_steps': int(val.scheduler.first_cycle_steps * epoch),
'cycle_mult': val.scheduler.cycle_mult,
'min_lr': val.scheduler.min_lr,
'max_lr': val.scheduler.max_lr,
'warmup_steps': int(val.scheduler.warmup_steps * epoch),
'gamma': val.scheduler.gamma,
})
self.update_schedulers = 'step'
elif val.scheduler.name == 'LinearWarmUp':
from externals.huggingface.transformers.src.transformers.optimization import get_linear_schedule_with_warmup
schedulers[key] = get_linear_schedule_with_warmup(**{
'optimizer': optimizers[key]['optimizer'],
'num_warmup_steps': val.scheduler.num_warmup_steps,
'num_training_steps': val.scheduler.num_training_steps,
})
self.update_schedulers = 'step'
else:
schedulers[key] = getattr(torch.optim.lr_scheduler, val.scheduler.name)(**{
'optimizer': optimizers[key]['optimizer'],
'step_size': val.scheduler.step_size,
'gamma': val.scheduler.gamma,
})
self.update_schedulers = 'epoch'
# Return optimizer and scheduler
return optimizers, schedulers
def run_arch(self, batch, epoch, flip, unflip):
"""
Run model on a batch
Parameters
----------
batch : Dict
Dictionary with batch information
epoch : Int
Current epoch
flip : Bool
Batch should be flipped
unflip : Bool
Output should be unflipped
Returns
-------
output : Dict
Dictionary with model outputs
"""
batch = flip_batch(batch) if flip else batch
output = self.arch(batch, epoch=epoch)
return flip_output(output) if flip and unflip else output
def training_step(self, batch, epoch):
"""Processes a training batch"""
flip_lr = False if self.flip_lr_prob == 0 else \
random.random() < self.flip_lr_prob
if self.mixed_precision:
with torch.cuda.amp.autocast():
output = self.run_arch(batch, epoch=epoch, flip=flip_lr, unflip=False)
else:
output = self.run_arch(batch, epoch=epoch, flip=flip_lr, unflip=False)
losses = {key: val for key, val in output.items() if key.startswith('loss')}
return {
**losses,
'metrics': output['metrics']
}
def validation_step(self, batch, epoch):
"""Processes a validation batch"""
# from vidar.utils.data import break_batch
# batch = break_batch(batch)
if self.mixed_precision:
with torch.cuda.amp.autocast():
output = self.run_arch(batch, epoch=epoch, flip=False, unflip=False)
flipped_output = None if not self.validate_flipped else \
self.run_arch(batch, epoch=epoch, flip=True, unflip=True)
else:
output = self.run_arch(batch, epoch=epoch, flip=False, unflip=False)
flipped_output = None if not self.validate_flipped else \
self.run_arch(batch, epoch=epoch, flip=True, unflip=True)
if 'batch' in output:
batch = output['batch']
results = self.evaluate(batch, output, flipped_output)
results = [{
'idx': batch['idx'][i],
**{key: val[i] for key, val in results['metrics'].items()}
} for i in range(len(batch['idx']))]
return output, results
@staticmethod
def training_epoch_end():
"""Finishes a training epoch (do nothing for now)"""
return {}
def validation_epoch_end(self, output, prefixes):
"""Finishes a validation epoch"""
if isinstance(output[0], dict):
output = [output]
metrics_dict = {}
for task in self.metrics:
metrics_dict.update(
self.metrics[task].reduce(
output, self.datasets['validation'], prefixes))
return metrics_dict
def evaluate(self, batch, output, flipped_output=None):
"""
Evaluate batch to produce predictions and metrics for different tasks
Parameters
----------
batch : Dict
Dictionary with batch information
output : Dict
Dictionary with output information
flipped_output : Dict
Dictionary with flipped output information
Returns
-------
results: Dict
Dictionary with evaluation results
"""
# Evaluate different tasks
metrics, predictions = OrderedDict(), OrderedDict()
for task in self.metrics:
task_metrics, task_predictions = \
self.metrics[task].evaluate(batch, output['predictions'],
flipped_output['predictions'] if flipped_output else None)
metrics.update(task_metrics)
predictions.update(task_predictions)
# Crate results dictionary with metrics and predictions
results = {'metrics': metrics, 'predictions': predictions}
# Return final results
return results