geowizard / utils /dataset_configuration.py
lemonaddie's picture
Upload 11 files
2e23827 verified
raw
history blame
2.72 kB
# A reimplemented version in public environments by Xiao Fu and Mu Hu
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
sys.path.append("..")
from dataloader.mix_loader import MixDataset
from torch.utils.data import DataLoader
from dataloader import transforms
import os
# Get Dataset Here
def prepare_dataset(data_dir=None,
batch_size=1,
test_batch=1,
datathread=4,
logger=None):
# set the config parameters
dataset_config_dict = dict()
train_dataset = MixDataset(data_dir=data_dir)
img_height, img_width = train_dataset.get_img_size()
datathread = datathread
if os.environ.get('datathread') is not None:
datathread = int(os.environ.get('datathread'))
if logger is not None:
logger.info("Use %d processes to load data..." % datathread)
train_loader = DataLoader(train_dataset, batch_size = batch_size, \
shuffle = True, num_workers = datathread, \
pin_memory = True)
num_batches_per_epoch = len(train_loader)
dataset_config_dict['num_batches_per_epoch'] = num_batches_per_epoch
dataset_config_dict['img_size'] = (img_height,img_width)
return train_loader, dataset_config_dict
def depth_scale_shift_normalization(depth):
bsz = depth.shape[0]
depth_ = depth[:,0,:,:].reshape(bsz,-1).cpu().numpy()
min_value = torch.from_numpy(np.percentile(a=depth_,q=2,axis=1)).to(depth)[...,None,None,None]
max_value = torch.from_numpy(np.percentile(a=depth_,q=98,axis=1)).to(depth)[...,None,None,None]
normalized_depth = ((depth - min_value)/(max_value-min_value+1e-5) - 0.5) * 2
normalized_depth = torch.clip(normalized_depth, -1., 1.)
return normalized_depth
def resize_max_res_tensor(input_tensor, mode, recom_resolution=768):
assert input_tensor.shape[1]==3
original_H, original_W = input_tensor.shape[2:]
downscale_factor = min(recom_resolution/original_H, recom_resolution/original_W)
if mode == 'normal':
resized_input_tensor = F.interpolate(input_tensor,
scale_factor=downscale_factor,
mode='nearest')
else:
resized_input_tensor = F.interpolate(input_tensor,
scale_factor=downscale_factor,
mode='bilinear',
align_corners=False)
if mode == 'depth':
return resized_input_tensor / downscale_factor
else:
return resized_input_tensor