|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import sys |
|
sys.path.append("..") |
|
|
|
from dataloader.mix_loader import MixDataset |
|
from torch.utils.data import DataLoader |
|
from dataloader import transforms |
|
import os |
|
|
|
|
|
|
|
def prepare_dataset(data_dir=None, |
|
batch_size=1, |
|
test_batch=1, |
|
datathread=4, |
|
logger=None): |
|
|
|
|
|
dataset_config_dict = dict() |
|
|
|
train_dataset = MixDataset(data_dir=data_dir) |
|
|
|
img_height, img_width = train_dataset.get_img_size() |
|
|
|
datathread = datathread |
|
if os.environ.get('datathread') is not None: |
|
datathread = int(os.environ.get('datathread')) |
|
|
|
if logger is not None: |
|
logger.info("Use %d processes to load data..." % datathread) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size = batch_size, \ |
|
shuffle = True, num_workers = datathread, \ |
|
pin_memory = True) |
|
|
|
num_batches_per_epoch = len(train_loader) |
|
|
|
dataset_config_dict['num_batches_per_epoch'] = num_batches_per_epoch |
|
dataset_config_dict['img_size'] = (img_height,img_width) |
|
|
|
return train_loader, dataset_config_dict |
|
|
|
def depth_scale_shift_normalization(depth): |
|
|
|
bsz = depth.shape[0] |
|
|
|
depth_ = depth[:,0,:,:].reshape(bsz,-1).cpu().numpy() |
|
min_value = torch.from_numpy(np.percentile(a=depth_,q=2,axis=1)).to(depth)[...,None,None,None] |
|
max_value = torch.from_numpy(np.percentile(a=depth_,q=98,axis=1)).to(depth)[...,None,None,None] |
|
|
|
normalized_depth = ((depth - min_value)/(max_value-min_value+1e-5) - 0.5) * 2 |
|
normalized_depth = torch.clip(normalized_depth, -1., 1.) |
|
|
|
return normalized_depth |
|
|
|
|
|
|
|
def resize_max_res_tensor(input_tensor, mode, recom_resolution=768): |
|
assert input_tensor.shape[1]==3 |
|
original_H, original_W = input_tensor.shape[2:] |
|
downscale_factor = min(recom_resolution/original_H, recom_resolution/original_W) |
|
|
|
if mode == 'normal': |
|
resized_input_tensor = F.interpolate(input_tensor, |
|
scale_factor=downscale_factor, |
|
mode='nearest') |
|
else: |
|
resized_input_tensor = F.interpolate(input_tensor, |
|
scale_factor=downscale_factor, |
|
mode='bilinear', |
|
align_corners=False) |
|
|
|
if mode == 'depth': |
|
return resized_input_tensor / downscale_factor |
|
else: |
|
return resized_input_tensor |
|
|