import torch import torch.nn as nn from .model_pipelines.__base_model__ import BaseDepthModel class DepthModel(BaseDepthModel): def __init__(self, cfg, **kwards): super(DepthModel, self).__init__(cfg) model_type = cfg.model.type def inference(self, data): with torch.no_grad(): pred_depth, confidence, output_dict = self.forward(data) return pred_depth, confidence, output_dict def get_monodepth_model( cfg : dict, **kwargs ) -> nn.Module: # config depth model model = DepthModel(cfg, **kwargs) #model.init_weights(load_imagenet_model, imagenet_ckpt_fpath) assert isinstance(model, nn.Module) return model def get_configured_monodepth_model( cfg: dict, ) -> nn.Module: """ Args: @ configs: configures for the network. @ load_imagenet_model: whether to initialize from ImageNet-pretrained model. @ imagenet_ckpt_fpath: string representing path to file with weights to initialize model with. Returns: # model: depth model. """ model = get_monodepth_model(cfg) return model