import torch import torch.nn as nn from mono.utils.comm import get_func class BaseDepthModel(nn.Module): def __init__(self, cfg, **kwargs) -> None: super(BaseDepthModel, self).__init__() model_type = cfg.model.type self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg) def forward(self, data): output = self.depth_model(**data) return output['prediction'], output['confidence'], output def inference(self, data): with torch.no_grad(): pred_depth, confidence, _ = self.forward(data) return pred_depth, confidence