# A reimplemented version in public environments by Xiao Fu and Mu Hu import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import sys sys.path.append("..") from dataloader.mix_loader import MixDataset from torch.utils.data import DataLoader from dataloader import transforms import os # Get Dataset Here def prepare_dataset(data_dir=None, batch_size=1, test_batch=1, datathread=4, logger=None): # set the config parameters dataset_config_dict = dict() train_dataset = MixDataset(data_dir=data_dir) img_height, img_width = train_dataset.get_img_size() datathread = datathread if os.environ.get('datathread') is not None: datathread = int(os.environ.get('datathread')) if logger is not None: logger.info("Use %d processes to load data..." % datathread) train_loader = DataLoader(train_dataset, batch_size = batch_size, \ shuffle = True, num_workers = datathread, \ pin_memory = True) num_batches_per_epoch = len(train_loader) dataset_config_dict['num_batches_per_epoch'] = num_batches_per_epoch dataset_config_dict['img_size'] = (img_height,img_width) return train_loader, dataset_config_dict def depth_scale_shift_normalization(depth): bsz = depth.shape[0] depth_ = depth[:,0,:,:].reshape(bsz,-1).cpu().numpy() min_value = torch.from_numpy(np.percentile(a=depth_,q=2,axis=1)).to(depth)[...,None,None,None] max_value = torch.from_numpy(np.percentile(a=depth_,q=98,axis=1)).to(depth)[...,None,None,None] normalized_depth = ((depth - min_value)/(max_value-min_value+1e-5) - 0.5) * 2 normalized_depth = torch.clip(normalized_depth, -1., 1.) return normalized_depth def resize_max_res_tensor(input_tensor, mode, recom_resolution=768): assert input_tensor.shape[1]==3 original_H, original_W = input_tensor.shape[2:] downscale_factor = min(recom_resolution/original_H, recom_resolution/original_W) if mode == 'normal': resized_input_tensor = F.interpolate(input_tensor, scale_factor=downscale_factor, mode='nearest') else: resized_input_tensor = F.interpolate(input_tensor, scale_factor=downscale_factor, mode='bilinear', align_corners=False) if mode == 'depth': return resized_input_tensor / downscale_factor else: return resized_input_tensor