Spaces:
Runtime error
Runtime error
| # TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. | |
| from collections import OrderedDict | |
| from copy import deepcopy | |
| import torch | |
| import torch.nn as nn | |
| import wandb | |
| from vidar.utils.config import cfg_has | |
| from vidar.utils.distributed import world_size | |
| from vidar.utils.logging import pcolor | |
| from vidar.utils.types import is_dict, is_tensor, is_seq, is_namespace | |
| from vidar.utils.viz import viz_depth, viz_inv_depth, viz_normals, viz_optical_flow, viz_camera | |
| class WandbLogger: | |
| """ | |
| Wandb logger class to monitor training | |
| Parameters | |
| ---------- | |
| cfg : Config | |
| Configuration with parameters | |
| verbose : Bool | |
| Print information on screen if enabled | |
| """ | |
| def __init__(self, cfg, verbose=False): | |
| super().__init__() | |
| self.num_logs = { | |
| 'train': cfg_has(cfg, 'num_train_logs', 0), | |
| 'val': cfg_has(cfg, 'num_validation_logs', 0), | |
| 'test': cfg_has(cfg, 'num_test_logs', 0), | |
| } | |
| self._name = cfg.name if cfg_has(cfg, 'name') else None | |
| self._dir = cfg.folder | |
| self._entity = cfg.entity | |
| self._project = cfg.project | |
| self._tags = cfg_has(cfg, 'tags', '') | |
| self._notes = cfg_has(cfg, 'notes', '') | |
| self._id = None | |
| self._anonymous = None | |
| self._log_model = True | |
| self._experiment = self._create_experiment() | |
| self._metrics = OrderedDict() | |
| self.only_first = cfg_has(cfg, 'only_first', False) | |
| cfg.name = self.run_name | |
| cfg.url = self.run_url | |
| if verbose: | |
| self.print() | |
| def finish(): | |
| """Finish wandb session""" | |
| wandb.finish() | |
| def print(self): | |
| """Print information on screen""" | |
| font_base = {'color': 'red', 'attrs': ('bold', 'dark')} | |
| font_name = {'color': 'red', 'attrs': ('bold',)} | |
| font_underline = {'color': 'red', 'attrs': ('underline',)} | |
| print(pcolor('#' * 60, **font_base)) | |
| print(pcolor('### WandB: ', **font_base) + \ | |
| pcolor('{}'.format(self.run_name), **font_name)) | |
| print(pcolor('### ', **font_base) + \ | |
| pcolor('{}'.format(self.run_url), **font_underline)) | |
| print(pcolor('#' * 60, **font_base)) | |
| def __getstate__(self): | |
| """Get the current logger state""" | |
| state = self.__dict__.copy() | |
| state['_id'] = self._experiment.id if self._experiment is not None else None | |
| state['_experiment'] = None | |
| return state | |
| def _create_experiment(self): | |
| """Creates and returns a new experiment""" | |
| experiment = wandb.init( | |
| name=self._name, dir=self._dir, project=self._project, | |
| anonymous=self._anonymous, reinit=True, id=self._id, notes=self._notes, | |
| resume='allow', tags=self._tags, entity=self._entity | |
| ) | |
| wandb.run.save() | |
| return experiment | |
| def watch(self, model: nn.Module, log='gradients', log_freq=100): | |
| """Watch training parameters""" | |
| self.experiment.watch(model, log=log, log_freq=log_freq) | |
| def experiment(self): | |
| """Returns the experiment (creates a new if it doesn't exist)""" | |
| if self._experiment is None: | |
| self._experiment = self._create_experiment() | |
| return self._experiment | |
| def run_name(self): | |
| """Returns run name""" | |
| return wandb.run.name if self._experiment else None | |
| def run_url(self): | |
| """Returns run URL""" | |
| return f'https://app.wandb.ai/' \ | |
| f'{wandb.run.entity}/' \ | |
| f'{wandb.run.project}/runs/' \ | |
| f'{wandb.run.id}' if self._experiment else None | |
| def log_config(self, cfg): | |
| """Log model configuration""" | |
| cfg = recursive_convert_config(deepcopy(cfg)) | |
| self.experiment.config.update(cfg, allow_val_change=True) | |
| def log_metrics(self, metrics): | |
| """Log training metrics""" | |
| self._metrics.update(metrics) | |
| if 'epochs' in metrics or 'samples' in metrics: | |
| self.experiment.log(self._metrics) | |
| self._metrics.clear() | |
| def log_images(self, batch, output, prefix, ontology=None): | |
| """ | |
| Log images depending on its nature | |
| Parameters | |
| ---------- | |
| batch : Dict | |
| Dictionary containing batch information | |
| output : Dict | |
| Dictionary containing output information | |
| prefix : String | |
| Prefix string for the log name | |
| ontology : Dict | |
| Dictionary with ontology information | |
| """ | |
| for data, suffix in zip([batch, output['predictions']], ['-gt', '-pred']): | |
| for key in data.keys(): | |
| if key.startswith('rgb'): | |
| self._metrics.update(log_rgb( | |
| key, prefix + suffix, data, only_first=self.only_first)) | |
| elif key.startswith('depth'): | |
| self._metrics.update(log_depth( | |
| key, prefix + suffix, data, only_first=self.only_first)) | |
| elif key.startswith('inv_depth'): | |
| self._metrics.update(log_inv_depth( | |
| key, prefix + suffix, data, only_first=self.only_first)) | |
| elif 'normals' in key: | |
| self._metrics.update(log_normals( | |
| key, prefix + suffix, data, only_first=self.only_first)) | |
| elif key.startswith('stddev'): | |
| self._metrics.update(log_stddev( | |
| key, prefix + suffix, data, only_first=self.only_first)) | |
| elif key.startswith('logvar'): | |
| self._metrics.update(log_logvar( | |
| key, prefix + suffix, data, only_first=self.only_first)) | |
| elif 'optical_flow' in key: | |
| self._metrics.update(log_optical_flow( | |
| key, prefix + suffix, data, only_first=self.only_first)) | |
| elif 'mask' in key or 'valid' in key: | |
| self._metrics.update(log_rgb( | |
| key, prefix, data, only_first=self.only_first)) | |
| # elif 'camera' in key: | |
| # self._metrics.update(log_camera( | |
| # key, prefix + suffix, data, only_first=self.only_first)) | |
| # elif 'uncertainty' in key: | |
| # self._metrics.update(log_uncertainty(key, prefix, data)) | |
| # elif 'semantic' in key and ontology is not None: | |
| # self._metrics.update(log_semantic(key, prefix, data, ontology=ontology)) | |
| # if 'scene_flow' in key: | |
| # self._metrics.update(log_scene_flow(key, prefix_idx, data)) | |
| # elif 'score' in key: | |
| # # Log score as image heatmap | |
| # self._metrics.update(log_keypoint_score(key, prefix, data)) | |
| def log_data(self, mode, batch, output, dataset, prefix, ontology=None): | |
| """Helper function used to log images""" | |
| idx = batch['idx'][0] | |
| num_logs = self.num_logs[mode] | |
| if num_logs > 0: | |
| interval = (len(dataset) // world_size() // num_logs) * world_size() | |
| if interval == 0 or (idx % interval == 0 and idx < interval * num_logs): | |
| prefix = '{}-{}-{}'.format(mode, prefix, batch['idx'][0].item()) | |
| # batch, output = prepare_logging(batch, output) | |
| self.log_images(batch, output, prefix, ontology=ontology) | |
| def recursive_convert_config(cfg): | |
| """Convert configuration to dictionary recursively""" | |
| cfg = cfg.__dict__ | |
| for key, val in cfg.items(): | |
| if is_namespace(val): | |
| cfg[key] = recursive_convert_config(val) | |
| return cfg | |
| def prep_image(key, prefix, image): | |
| """Prepare image for logging""" | |
| if is_tensor(image): | |
| if image.dim() == 2: | |
| image = image.unsqueeze(0) | |
| if image.dim() == 4: | |
| image = image[0] | |
| image = image.detach().permute(1, 2, 0).cpu().numpy() | |
| prefix_key = '{}-{}'.format(prefix, key) | |
| return {prefix_key: wandb.Image(image, caption=key)} | |
| def log_sequence(key, prefix, data, i, only_first, fn): | |
| """Logs a sequence of images (list, tuple or dict)""" | |
| log = {} | |
| if is_dict(data): | |
| for ctx, dict_val in data.items(): | |
| if is_seq(dict_val): | |
| if only_first: | |
| dict_val = dict_val[:1] | |
| for idx, list_val in enumerate(dict_val): | |
| if list_val.dim() == 5: | |
| for j in range(list_val.shape[1]): | |
| log.update(fn('%s(%s_%d)_%d' % (key, str(ctx), j, idx), prefix, list_val[:, j], i)) | |
| else: | |
| log.update(fn('%s(%s)_%d' % (key, str(ctx), idx), prefix, list_val, i)) | |
| else: | |
| if dict_val.dim() == 5: | |
| for j in range(dict_val.shape[1]): | |
| log.update(fn('%s(%s_%d)' % (key, str(ctx), j), prefix, dict_val[:, j], i)) | |
| else: | |
| log.update(fn('%s(%s)' % (key, str(ctx)), prefix, dict_val, i)) | |
| elif is_seq(data): | |
| if only_first: | |
| data = data[:1] | |
| for idx, list_val in enumerate(data): | |
| log.update(fn('%s_%d' % (key, idx), prefix, list_val, i)) | |
| else: | |
| log.update(fn('%s' % key, prefix, data, i)) | |
| return log | |
| def log_rgb(key, prefix, batch, i=0, only_first=None): | |
| """Log RGB image""" | |
| rgb = batch[key] if is_dict(batch) else batch | |
| if is_seq(rgb) or is_dict(rgb): | |
| return log_sequence(key, prefix, rgb, i, only_first, log_rgb) | |
| return prep_image(key, prefix, rgb[i].clamp(min=0.0, max=1.0)) | |
| def log_depth(key, prefix, batch, i=0, only_first=None): | |
| """Log depth map""" | |
| depth = batch[key] if is_dict(batch) else batch | |
| if is_seq(depth) or is_dict(depth): | |
| return log_sequence(key, prefix, depth, i, only_first, log_depth) | |
| return prep_image(key, prefix, viz_depth(depth[i], filter_zeros=True)) | |
| def log_inv_depth(key, prefix, batch, i=0, only_first=None): | |
| """Log inverse depth map""" | |
| inv_depth = batch[key] if is_dict(batch) else batch | |
| if is_seq(inv_depth) or is_dict(inv_depth): | |
| return log_sequence(key, prefix, inv_depth, i, only_first, log_inv_depth) | |
| return prep_image(key, prefix, viz_inv_depth(inv_depth[i])) | |
| def log_normals(key, prefix, batch, i=0, only_first=None): | |
| """Log normals""" | |
| normals = batch[key] if is_dict(batch) else batch | |
| if is_seq(normals) or is_dict(normals): | |
| return log_sequence(key, prefix, normals, i, only_first, log_normals) | |
| return prep_image(key, prefix, viz_normals(normals[i])) | |
| def log_optical_flow(key, prefix, batch, i=0, only_first=None): | |
| """Log optical flow""" | |
| optical_flow = batch[key] if is_dict(batch) else batch | |
| if is_seq(optical_flow) or is_dict(optical_flow): | |
| return log_sequence(key, prefix, optical_flow, i, only_first, log_optical_flow) | |
| return prep_image(key, prefix, viz_optical_flow(optical_flow[i])) | |
| def log_stddev(key, prefix, batch, i=0, only_first=None): | |
| """Log standard deviation""" | |
| stddev = batch[key] if is_dict(batch) else batch | |
| if is_seq(stddev) or is_dict(stddev): | |
| return log_sequence(key, prefix, stddev, i, only_first, log_stddev) | |
| return prep_image(key, prefix, viz_inv_depth(stddev[i], colormap='jet')) | |
| def log_logvar(key, prefix, batch, i=0, only_first=None): | |
| """Log standard deviation""" | |
| logvar = batch[key] if is_dict(batch) else batch | |
| if is_seq(logvar) or is_dict(logvar): | |
| return log_sequence(key, prefix, logvar, i, only_first, log_logvar) | |
| return prep_image(key, prefix, viz_inv_depth(torch.exp(logvar[i]), colormap='jet')) | |
| def log_camera(key, prefix, batch, i=0, only_first=None): | |
| """Log camera""" | |
| camera = batch[key] if is_dict(batch) else batch | |
| if is_seq(camera) or is_dict(camera): | |
| return log_sequence(key, prefix, camera, i, only_first, log_camera) | |
| return prep_image(key, prefix, viz_camera(camera[i])) | |