define-hf-demo / vidar /arch /models /depth /SelfSupervisedModel.py
Jiading Fang
add define
fc16538
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
from abc import ABC
from vidar.arch.blocks.image.ViewSynthesis import ViewSynthesis
from vidar.arch.models.BaseModel import BaseModel
from vidar.arch.models.utils import make_rgb_scales, create_cameras
from vidar.utils.data import get_from_dict
from vidar.utils.config import cfg_has
class SelfSupervisedModel(BaseModel, ABC):
"""
Self-supervised depth estimation model
Parameters
----------
cfg : Config
Configuration with parameters
"""
def __init__(self, cfg):
super().__init__(cfg)
self.view_synthesis = ViewSynthesis()
self.set_attr(cfg.model, 'use_gt_pose', False)
self.set_attr(cfg.model, 'use_gt_intrinsics', True)
if not self.use_gt_intrinsics:
self.camera_model = cfg_has(cfg.networks.intrinsics, 'camera_model', 'UCM')
if self.camera_model == 'UCM':
from vidar.geometry.camera_ucm import UCMCamera
self.camera_class = UCMCamera
elif self.camera_model == 'EUCM':
from vidar.geometry.camera_eucm import EUCMCamera
self.camera_class = EUCMCamera
elif self.camera_model == 'DS':
from vidar.geometry.camera_ds import DSCamera
self.camera_class = DSCamera
else:
raise NotImplementedError('Invalid camera type')
def forward(self, batch, epoch=0):
"""Model forward pass"""
rgb = batch['rgb']
if self.use_gt_intrinsics:
intrinsics = get_from_dict(batch, 'intrinsics')
else:
intrinsics = self.networks['intrinsics'](rgb=rgb[0])
valid_mask = get_from_dict(batch, 'mask')
if self.use_gt_intrinsics:
depth_output = self.networks['depth'](rgb=rgb[0], intrinsics=intrinsics[0])
else:
depth_output = self.networks['depth'](rgb=rgb[0])
pred_depth = depth_output['depths']
predictions = {
'depth': {0: pred_depth},
}
pred_logvar = get_from_dict(depth_output, 'logvar')
if pred_logvar is not None:
predictions['logvar'] = {0: pred_logvar}
if not self.training:
return {
'predictions': predictions,
}
if self.use_gt_pose:
assert 'pose' in batch, 'You need input pose'
pose = batch['pose']
elif 'pose' in self.networks:
pose = self.compute_pose(rgb, self.networks['pose'], tgt=0, invert=True)
predictions['pose'] = pose
else:
pose = None
if not self.use_gt_intrinsics:
cams = {0: self.camera_class(I=intrinsics)}
for key in pose.keys():
cams[key] = self.camera_class(I=intrinsics, Tcw=pose[key])
else:
cams = create_cameras(rgb[0], intrinsics[0], pose)
gt_depth = None if 'depth' not in batch else batch['depth'][0]
loss, metrics = self.compute_loss_and_metrics(
rgb, pred_depth, cams, gt_depth=gt_depth,
logvar=pred_logvar, valid_mask=valid_mask
)
if not self.use_gt_intrinsics:
if self.camera_model == 'UCM':
fx, fy, cx, cy, alpha = intrinsics[0].squeeze()
intrinsics_metrics = {'fx': fx, 'fy':fy, 'cx':cx, 'cy':cy, 'alpha':alpha}
metrics.update(intrinsics_metrics)
elif self.camera_model == 'EUCM':
fx, fy, cx, cy, alpha, beta = intrinsics[0].squeeze()
intrinsics_metrics = {'fx': fx, 'fy':fy, 'cx':cx, 'cy':cy, 'alpha':alpha, 'beta':beta}
metrics.update(intrinsics_metrics)
elif self.camera_model == 'DS':
fx, fy, cx, cy, xi, alpha = intrinsics[0].squeeze()
intrinsics_metrics = {'fx': fx, 'fy':fy, 'cx':cx, 'cy':cy, 'xi':xi, 'alpha':alpha}
metrics.update(intrinsics_metrics)
else:
raise NotImplementedError('Invalid camera type')
return {
'loss': loss,
'metrics': metrics,
'predictions': predictions,
}
def compute_loss_and_metrics(self, rgb, depths, cams, gt_depth=None,
logvar=None, valid_mask=None):
"""
Compute loss and metrics for training
Parameters
----------
rgb : Dict
Dictionary with input images [B,3,H,W]
depths : list[torch.Tensor]
List with target depth maps in different scales [B,1,H,W]
cams : Dict
Dictionary with cameras for each input image
gt_depth : torch.Tensor
Ground-truth depth map for supervised training
logvar : list[torch.Tensor]
Log-variance maps for uncertainty training
valid_mask : list[torch.Tensor]
Binary mask for masking out invalid pixels [B,1,H,W]
Returns
-------
loss : torch.Tensor
Training loss
metrics : Dict
Dictionary with training metrics
"""
tgt = 0
ctx = [key for key in rgb.keys() if key != tgt]
num_scales = self.get_num_scales(depths)
rgbs = make_rgb_scales(rgb, pyramid=depths)
rgb_tgt = [rgbs[tgt][i] for i in range(num_scales)]
rgb_ctx = [[rgbs[j][i] for j in ctx] for i in range(num_scales)]
loss, metrics = [], {}
if 'reprojection' in self.losses:
synth = self.view_synthesis(
rgbs, depths=depths, cams=cams, return_masks=True)
reprojection_output = self.losses['reprojection'](
rgb_tgt, rgb_ctx, synth['warps'], logvar=logvar,
valid_mask=valid_mask, overlap_mask=synth['masks'])
loss.append(reprojection_output['loss'])
metrics.update(**reprojection_output['metrics'])
if 'smoothness' in self.losses:
smoothness_output = self.losses['smoothness'](rgb_tgt, depths)
loss.append(smoothness_output['loss'])
metrics.update(**smoothness_output['metrics'])
if 'supervision' in self.losses and gt_depth is not None:
supervision_output = self.losses['supervision'](depths, gt_depth)
loss.append(supervision_output['loss'])
metrics.update(**supervision_output['metrics'])
if 'normals' in self.losses and gt_depth is not None:
normals_output = self.losses['normals'](depths, gt_depth, cams[0])
loss.append(normals_output['loss'])
metrics.update(**normals_output['metrics'])
loss = sum(loss)
return loss, metrics