Spaces:
Runtime error
Runtime error
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
import os | |
import random | |
from abc import ABC | |
from collections import OrderedDict | |
import torch | |
from vidar.utils.config import cfg_has, read_config | |
from vidar.utils.data import set_random_seed | |
from vidar.utils.distributed import rank, world_size | |
from vidar.utils.flip import flip_batch, flip_output | |
from vidar.utils.logging import pcolor, set_debug | |
from vidar.utils.networks import load_checkpoint, save_checkpoint, freeze_layers_and_norms | |
from vidar.utils.setup import setup_arch, setup_datasets, setup_metrics | |
from vidar.utils.types import is_str | |
class Wrapper(torch.nn.Module, ABC): | |
""" | |
Trainer class for model optimization and inference | |
Parameters | |
---------- | |
cfg : Config | |
Configuration with parameters | |
ckpt : String | |
Name of the model checkpoint to start from | |
verbose : Bool | |
Print information on screen if enabled | |
""" | |
def __init__(self, cfg, ckpt=None, verbose=False): | |
super().__init__() | |
if verbose and rank() == 0: | |
font = {'color': 'cyan', 'attrs': ('bold', 'dark')} | |
print(pcolor('#' * 100, **font)) | |
print(pcolor('#' * 42 + ' VIDAR WRAPPER ' + '#' * 43, **font)) | |
print(pcolor('#' * 100, **font)) | |
# Get configuration | |
cfg = read_config(cfg) if is_str(cfg) else cfg | |
self.cfg = cfg | |
# Data augmentations | |
self.flip_lr_prob = cfg_has(cfg.wrapper, 'flip_lr_prob', 0.0) | |
self.validate_flipped = cfg_has(cfg.wrapper, 'validate_flipped', False) | |
# Set random seed | |
set_random_seed(cfg.wrapper.seed + rank()) | |
set_debug(cfg_has(cfg.wrapper, 'debug', False)) | |
# Setup architecture, datasets and tasks | |
self.arch = setup_arch(cfg.arch, checkpoint=ckpt, verbose=verbose) if cfg_has(cfg, 'arch') else None | |
self.datasets, self.datasets_cfg = setup_datasets( | |
cfg.datasets, verbose=verbose) if cfg_has(cfg, 'datasets') else (None, None) | |
self.metrics = setup_metrics(cfg.evaluation) if cfg_has(cfg, 'evaluation') else {} | |
sync_batch_norm = cfg_has(cfg.wrapper, 'sync_batch_norm', False) | |
if sync_batch_norm and os.environ['DIST_MODE'] == 'ddp': | |
self.arch = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.arch) | |
self.mixed_precision = cfg_has(cfg.wrapper, 'mixed_precision', False) | |
self.update_schedulers = None | |
def save(self, filename, epoch=None): | |
"""Save checkpoint""" | |
save_checkpoint(filename, self, epoch=epoch) | |
def load(self, checkpoint, strict=True, verbose=False): | |
"""Load checkpoint""" | |
load_checkpoint(self, checkpoint, strict=strict, verbose=verbose) | |
def train_custom(self, in_optimizers, out_optimizers): | |
"""Customized training flag for the model""" | |
self.arch.train() | |
for key in in_optimizers.keys(): | |
arch = self.arch.module if hasattr(self.arch, 'module') else self.arch | |
freeze_layers_and_norms(arch.networks[key], ['ALL'], flag_freeze=False) | |
for key in out_optimizers.keys(): | |
arch = self.arch.module if hasattr(self.arch, 'module') else self.arch | |
freeze_layers_and_norms(arch.networks[key], ['ALL'], flag_freeze=True) | |
def eval_custom(self): | |
"""Customized evaluation flag for the model""" | |
self.arch.eval() | |
def configure_optimizers_and_schedulers(self): | |
"""Configure depth and pose optimizers and the corresponding scheduler""" | |
if not cfg_has(self.cfg, 'optimizers'): | |
return None, None | |
optimizers = OrderedDict() | |
schedulers = OrderedDict() | |
for key, val in self.cfg.optimizers.__dict__.items(): | |
assert key in self.arch.networks, f'There is no network for optimizer {key}' | |
optimizers[key] = { | |
'optimizer': getattr(torch.optim, val.name)(**{ | |
'lr': val.lr, | |
'weight_decay': cfg_has(val, 'weight_decay', 0.0), | |
'params': self.arch.networks[key].parameters(), | |
}), | |
'settings': {} if not cfg_has(val, 'settings') else val.settings.__dict__ | |
} | |
if cfg_has(val, 'scheduler'): | |
if val.scheduler.name == 'CosineAnnealingWarmUpRestarts': | |
from cosine_annealing_warmup import CosineAnnealingWarmupRestarts | |
epoch = float(len(self.datasets['train']) / ( | |
world_size() * self.datasets_cfg['train'].dataloader.batch_size * self.datasets_cfg['train'].repeat[0])) | |
schedulers[key] = CosineAnnealingWarmupRestarts(**{ | |
'optimizer': optimizers[key]['optimizer'], | |
'first_cycle_steps': int(val.scheduler.first_cycle_steps * epoch), | |
'cycle_mult': val.scheduler.cycle_mult, | |
'min_lr': val.scheduler.min_lr, | |
'max_lr': val.scheduler.max_lr, | |
'warmup_steps': int(val.scheduler.warmup_steps * epoch), | |
'gamma': val.scheduler.gamma, | |
}) | |
self.update_schedulers = 'step' | |
elif val.scheduler.name == 'LinearWarmUp': | |
from externals.huggingface.transformers.src.transformers.optimization import get_linear_schedule_with_warmup | |
schedulers[key] = get_linear_schedule_with_warmup(**{ | |
'optimizer': optimizers[key]['optimizer'], | |
'num_warmup_steps': val.scheduler.num_warmup_steps, | |
'num_training_steps': val.scheduler.num_training_steps, | |
}) | |
self.update_schedulers = 'step' | |
else: | |
schedulers[key] = getattr(torch.optim.lr_scheduler, val.scheduler.name)(**{ | |
'optimizer': optimizers[key]['optimizer'], | |
'step_size': val.scheduler.step_size, | |
'gamma': val.scheduler.gamma, | |
}) | |
self.update_schedulers = 'epoch' | |
# Return optimizer and scheduler | |
return optimizers, schedulers | |
def run_arch(self, batch, epoch, flip, unflip): | |
""" | |
Run model on a batch | |
Parameters | |
---------- | |
batch : Dict | |
Dictionary with batch information | |
epoch : Int | |
Current epoch | |
flip : Bool | |
Batch should be flipped | |
unflip : Bool | |
Output should be unflipped | |
Returns | |
------- | |
output : Dict | |
Dictionary with model outputs | |
""" | |
batch = flip_batch(batch) if flip else batch | |
output = self.arch(batch, epoch=epoch) | |
return flip_output(output) if flip and unflip else output | |
def training_step(self, batch, epoch): | |
"""Processes a training batch""" | |
flip_lr = False if self.flip_lr_prob == 0 else \ | |
random.random() < self.flip_lr_prob | |
if self.mixed_precision: | |
with torch.cuda.amp.autocast(): | |
output = self.run_arch(batch, epoch=epoch, flip=flip_lr, unflip=False) | |
else: | |
output = self.run_arch(batch, epoch=epoch, flip=flip_lr, unflip=False) | |
losses = {key: val for key, val in output.items() if key.startswith('loss')} | |
return { | |
**losses, | |
'metrics': output['metrics'] | |
} | |
def validation_step(self, batch, epoch): | |
"""Processes a validation batch""" | |
# from vidar.utils.data import break_batch | |
# batch = break_batch(batch) | |
if self.mixed_precision: | |
with torch.cuda.amp.autocast(): | |
output = self.run_arch(batch, epoch=epoch, flip=False, unflip=False) | |
flipped_output = None if not self.validate_flipped else \ | |
self.run_arch(batch, epoch=epoch, flip=True, unflip=True) | |
else: | |
output = self.run_arch(batch, epoch=epoch, flip=False, unflip=False) | |
flipped_output = None if not self.validate_flipped else \ | |
self.run_arch(batch, epoch=epoch, flip=True, unflip=True) | |
if 'batch' in output: | |
batch = output['batch'] | |
results = self.evaluate(batch, output, flipped_output) | |
results = [{ | |
'idx': batch['idx'][i], | |
**{key: val[i] for key, val in results['metrics'].items()} | |
} for i in range(len(batch['idx']))] | |
return output, results | |
def training_epoch_end(): | |
"""Finishes a training epoch (do nothing for now)""" | |
return {} | |
def validation_epoch_end(self, output, prefixes): | |
"""Finishes a validation epoch""" | |
if isinstance(output[0], dict): | |
output = [output] | |
metrics_dict = {} | |
for task in self.metrics: | |
metrics_dict.update( | |
self.metrics[task].reduce( | |
output, self.datasets['validation'], prefixes)) | |
return metrics_dict | |
def evaluate(self, batch, output, flipped_output=None): | |
""" | |
Evaluate batch to produce predictions and metrics for different tasks | |
Parameters | |
---------- | |
batch : Dict | |
Dictionary with batch information | |
output : Dict | |
Dictionary with output information | |
flipped_output : Dict | |
Dictionary with flipped output information | |
Returns | |
------- | |
results: Dict | |
Dictionary with evaluation results | |
""" | |
# Evaluate different tasks | |
metrics, predictions = OrderedDict(), OrderedDict() | |
for task in self.metrics: | |
task_metrics, task_predictions = \ | |
self.metrics[task].evaluate(batch, output['predictions'], | |
flipped_output['predictions'] if flipped_output else None) | |
metrics.update(task_metrics) | |
predictions.update(task_predictions) | |
# Crate results dictionary with metrics and predictions | |
results = {'metrics': metrics, 'predictions': predictions} | |
# Return final results | |
return results | |