PatchFusion / infer_user.py
Zhenyu Li
update
78ab311
raw
history blame
48.4 kB
# MIT License
# Copyright (c) 2022 Intelligent Systems Lab Org
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# File author: Zhenyu Li
import os
import cv2
import argparse
from zoedepth.utils.config import get_config_user
from zoedepth.models.builder import build_model
from zoedepth.utils.arg_utils import parse_unknown
import numpy as np
from zoedepth.models.base_models.midas import Resize
from torchvision.transforms import Compose
from torchvision.transforms import Normalize
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import copy
from zoedepth.utils.misc import get_boundaries
from zoedepth.utils.misc import compute_metrics, RunningAverageDict
from tqdm import tqdm
import matplotlib
import torch.nn.functional as F
from zoedepth.data.middleburry import readPFM
import random
import imageio
from PIL import Image
def load_state_dict(model, state_dict):
"""Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict.
DataParallel prefixes state_dict keys with 'module.' when saving.
If the model is not a DataParallel model but the state_dict is, then prefixes are removed.
If the model is a DataParallel model but the state_dict is not, then prefixes are added.
"""
state_dict = state_dict.get('model', state_dict)
# if model is a DataParallel model, then state_dict keys are prefixed with 'module.'
do_prefix = isinstance(
model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel))
state = {}
for k, v in state_dict.items():
if k.startswith('module.') and not do_prefix:
k = k[7:]
if not k.startswith('module.') and do_prefix:
k = 'module.' + k
state[k] = v
model.load_state_dict(state, strict=True)
# model.load_state_dict(state, strict=False)
print("Loaded successfully")
return model
def load_wts(model, checkpoint_path):
ckpt = torch.load(checkpoint_path, map_location='cpu')
return load_state_dict(model, ckpt)
def load_ckpt(model, checkpoint):
model = load_wts(model, checkpoint)
print("Loaded weights from {0}".format(checkpoint))
return model
#### def dataset
def read_image(path, dataset_name):
if dataset_name == 'u4k':
img = np.fromfile(open(path, 'rb'), dtype=np.uint8).reshape(2160, 3840, 3) / 255.0
img = img.astype(np.float32)[:, :, ::-1].copy()
elif dataset_name == 'mid':
img = cv2.imread(path)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
img = F.interpolate(torch.tensor(img).unsqueeze(dim=0).permute(0, 3, 1, 2), IMG_RESOLUTION, mode='bicubic', align_corners=True)
img = img.squeeze().permute(1, 2, 0)
elif dataset_name == 'nyu':
img = Image.open(path)
img = np.asarray(img, dtype=np.float32) / 255.0
else:
img = cv2.imread(path)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
print(img.shape)
img = F.interpolate(torch.tensor(img).unsqueeze(dim=0).permute(0, 3, 1, 2), IMG_RESOLUTION, mode='bicubic', align_corners=True)
img = img.squeeze().permute(1, 2, 0)
return img
class Images:
def __init__(self, root_dir, files, index, dataset_name=None):
self.root_dir = root_dir
name = files[index]
self.dataset_name = dataset_name
self.rgb_image = read_image(os.path.join(self.root_dir, name), dataset_name)
name = name.replace(".jpg", "")
name = name.replace(".png", "")
name = name.replace(".jpeg", "")
self.name = name
class DepthMap:
def __init__(self, root_dir, files, index, dataset_name, pred=False):
self.root_dir = root_dir
name = files[index]
gt_path = os.path.join(self.root_dir, name)
if dataset_name == 'u4k':
depth_factor = gt_path.replace('val_gt', 'val_factor')
depth_factor = depth_factor.replace('.npy', '.txt')
with open(depth_factor, 'r') as f:
df = f.readline()
df = float(df)
gt_disp = np.load(gt_path, mmap_mode='c')
gt_disp = gt_disp.astype(np.float32)
edges = get_boundaries(gt_disp, th=1, dilation=0)
gt_depth = df/gt_disp
self.gt = gt_depth
self.edge = edges
elif dataset_name == 'gta':
gt_depth = imageio.imread(gt_path)
gt_depth = np.array(gt_depth).astype(np.float32) / 256
edges = get_boundaries(gt_depth, th=1, dilation=0)
self.gt = gt_depth
self.edge = edges
elif dataset_name == 'mid':
depth_factor = gt_path.replace('gts', 'calibs')
depth_factor = depth_factor.replace('.pfm', '.txt')
with open(depth_factor, 'r') as f:
ext_l = f.readlines()
cam_info = ext_l[0].strip()
cam_info_f = float(cam_info.split(' ')[0].split('[')[1])
base = float(ext_l[3].strip().split('=')[1])
doffs = float(ext_l[2].strip().split('=')[1])
depth_factor = base * cam_info_f
height = 1840
width = 2300
disp_gt, scale = readPFM(gt_path)
disp_gt = disp_gt.astype(np.float32)
disp_gt_copy = disp_gt.copy()
disp_gt = disp_gt
invalid_mask = disp_gt == np.inf
depth_gt = depth_factor / (disp_gt + doffs)
depth_gt = depth_gt / 1000
depth_gt[invalid_mask] = 0 # set to a invalid number
disp_gt_copy[invalid_mask] = 0
edges = get_boundaries(disp_gt_copy, th=1, dilation=0)
self.gt = depth_gt
self.edge = edges
elif dataset_name == 'nyu':
if pred:
depth_gt = np.load(gt_path.replace('png', 'npy'))
depth_gt = nn.functional.interpolate(
torch.tensor(depth_gt).unsqueeze(dim=0).unsqueeze(dim=0), (480, 640), mode='bilinear', align_corners=True).squeeze().numpy()
edges = get_boundaries(depth_gt, th=1, dilation=0)
else:
depth_gt = np.asarray(Image.open(gt_path), dtype=np.float32) / 1000
edges = get_boundaries(depth_gt, th=1, dilation=0)
self.gt = depth_gt
self.edge = edges
else:
raise NotImplementedError
name = name.replace(".npy", "") # u4k
name = name.replace(".exr", "") # gta
self.name = name
class ImageDataset:
def __init__(self, rgb_image_dir, gt_dir=None, dataset_name=''):
self.rgb_image_dir = rgb_image_dir
self.files = sorted(os.listdir(self.rgb_image_dir))
self.gt_dir = gt_dir
self.dataset_name = dataset_name
if gt_dir is not None:
self.gt_dir = gt_dir
self.gt_files = sorted(os.listdir(self.gt_dir))
def __len__(self):
return len(self.files)
def __getitem__(self, index):
if self.dataset_name == 'nyu':
return Images(self.rgb_image_dir, self.files, index, self.dataset_name), DepthMap(self.gt_dir, self.gt_files, index, self.dataset_name), DepthMap('/ibex/ai/home/liz0l/codes/ZoeDepth/nfs/save/nyu', self.gt_files, index, self.dataset_name, pred=True)
if self.gt_dir is not None:
return Images(self.rgb_image_dir, self.files, index, self.dataset_name), DepthMap(self.gt_dir, self.gt_files, index, self.dataset_name)
else:
return Images(self.rgb_image_dir, self.files, index)
def crop(img, crop_bbox):
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
templete = torch.zeros((1, 1, img.shape[-2], img.shape[-1]), dtype=torch.float)
templete[:, :, crop_y1:crop_y2, crop_x1:crop_x2] = 1.0
img = img[:, :, crop_y1:crop_y2, crop_x1:crop_x2]
return img, templete
# def generatemask(size):
# # Generates a Guassian mask
# mask = np.zeros(size, dtype=np.float32)
# sigma = int(size[0]/16)
# k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1)
# mask[int(0.15*size[0]):size[0] - int(0.15*size[0]), int(0.15*size[1]): size[1] - int(0.15*size[1])] = 1
# mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma)
# mask = (mask - mask.min()) / (mask.max() - mask.min())
# mask = mask.astype(np.float32)
# return mask
def generatemask(size):
# Generates a Guassian mask
mask = np.zeros(size, dtype=np.float32)
sigma = int(size[0]/16)
k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1)
mask[int(0.1*size[0]):size[0] - int(0.1*size[0]), int(0.1*size[1]): size[1] - int(0.1*size[1])] = 1
mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma)
mask = (mask - mask.min()) / (mask.max() - mask.min())
mask = mask.astype(np.float32)
return mask
def generatemask_coarse(size):
# Generates a Guassian mask
mask = np.zeros(size, dtype=np.float32)
sigma = int(size[0]/64)
k_size = int(2 * np.ceil(2 * int(size[0]/64)) + 1)
mask[int(0.001*size[0]):size[0] - int(0.001*size[0]), int(0.001*size[1]): size[1] - int(0.001*size[1])] = 1
mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma)
mask = (mask - mask.min()) / (mask.max() - mask.min())
mask = mask.astype(np.float32)
return mask
class RunningAverageMap:
"""A dictionary of running averages."""
def __init__(self, average_map, count_map):
self.average_map = average_map
self.count_map = count_map
self.average_map = self.average_map / self.count_map
def update(self, pred_map, ct_map):
self.average_map = (pred_map + self.count_map * self.average_map) / (self.count_map + ct_map)
self.count_map = self.count_map + ct_map
# default size [540, 960]
# x_start, y_start = [0, 540, 1080, 1620], [0, 960, 1920, 2880]
def regular_tile(model, image, offset_x=0, offset_y=0, img_lr=None, iter_pred=None, boundary=0, update=False, avg_depth_map=None, blr_mask=False):
# crop size
# height = 540
# width = 960
height = CROP_SIZE[0]
width = CROP_SIZE[1]
assert offset_x >= 0 and offset_y >= 0
tile_num_x = (IMG_RESOLUTION[1] - offset_x) // width
tile_num_y = (IMG_RESOLUTION[0] - offset_y) // height
x_start = [width * x + offset_x for x in range(tile_num_x)]
y_start = [height * y + offset_y for y in range(tile_num_y)]
imgs_crop = []
crop_areas = []
bboxs_roi = []
bboxs_raw = []
if iter_pred is not None:
iter_pred = iter_pred.unsqueeze(dim=0).unsqueeze(dim=0)
iter_priors = []
for x in x_start: # w
for y in y_start: # h
bbox = (int(y), int(y+height), int(x), int(x+width))
img_crop, crop_area = crop(image, bbox)
imgs_crop.append(img_crop)
crop_areas.append(crop_area)
crop_y1, crop_y2, crop_x1, crop_x2 = bbox
bbox_roi = torch.tensor([crop_x1 / IMG_RESOLUTION[1] * 512, crop_y1 / IMG_RESOLUTION[0] * 384, crop_x2 / IMG_RESOLUTION[1] * 512, crop_y2 / IMG_RESOLUTION[0] * 384])
bboxs_roi.append(bbox_roi)
bbox_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2])
bboxs_raw.append(bbox_raw)
if iter_pred is not None:
iter_prior, _ = crop(iter_pred, bbox)
iter_priors.append(iter_prior)
crop_areas = torch.cat(crop_areas, dim=0)
imgs_crop = torch.cat(imgs_crop, dim=0)
bboxs_roi = torch.stack(bboxs_roi, dim=0)
bboxs_raw = torch.stack(bboxs_raw, dim=0)
if iter_pred is not None:
iter_priors = torch.cat(iter_priors, dim=0)
iter_priors = TRANSFORM(iter_priors)
iter_priors = iter_priors.cuda().float()
crop_areas = TRANSFORM(crop_areas)
imgs_crop = TRANSFORM(imgs_crop)
imgs_crop = imgs_crop.cuda().float()
bboxs_roi = bboxs_roi.cuda().float()
crop_areas = crop_areas.cuda().float()
img_lr = img_lr.cuda().float()
pred_depth_crops = []
with torch.no_grad():
for i, (img, bbox, crop_area) in enumerate(zip(imgs_crop, bboxs_roi, crop_areas)):
if iter_pred is not None:
iter_prior = iter_priors[i].unsqueeze(dim=0)
else:
iter_prior = None
if i == 0:
out_dict = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)
whole_depth_pred = out_dict['coarse_depth_pred']
# return whole_depth_pred.squeeze()
# pred_depth_crop = out_dict['fine_depth_pred']
pred_depth_crop = out_dict['metric_depth']
else:
pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['metric_depth']
# pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['fine_depth_pred']
pred_depth_crop = nn.functional.interpolate(
pred_depth_crop, (height, width), mode='bilinear', align_corners=True)
# pred_depth_crop = nn.functional.interpolate(
# pred_depth_crop, (height, width), mode='nearest')
pred_depth_crops.append(pred_depth_crop.squeeze())
whole_depth_pred = whole_depth_pred.squeeze()
whole_depth_pred = nn.functional.interpolate(whole_depth_pred.unsqueeze(dim=0).unsqueeze(dim=0), IMG_RESOLUTION, mode='bilinear', align_corners=True).squeeze()
####### stich part
inner_idx = 0
init_flag = False
if offset_x == 0 and offset_y == 0:
init_flag = True
# pred_depth = whole_depth_pred
pred_depth = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
else:
iter_pred = iter_pred.squeeze()
pred_depth = iter_pred
blur_mask = generatemask((height, width)) + 1e-3
count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
for ii, x in enumerate(x_start):
for jj, y in enumerate(y_start):
if init_flag:
# pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx] + (1 - blur_mask) * crop_temp
# pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx] + (1 - blur_mask) * crop_temp
blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device)
count_map[y: y+height, x: x+width] = blur_mask
pred_depth[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask
else:
# ensemble with running mean
if blr_mask:
blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device)
count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
count_map[y: y+height, x: x+width] = blur_mask
pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask
avg_depth_map.update(pred_map, count_map)
else:
if boundary != 0:
count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
count_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = 1
pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
pred_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = pred_depth_crops[inner_idx][boundary:-boundary, boundary:-boundary]
avg_depth_map.update(pred_map, count_map)
else:
count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
count_map[y: y+height, x: x+width] = 1
pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx]
avg_depth_map.update(pred_map, count_map)
inner_idx += 1
if init_flag:
avg_depth_map = RunningAverageMap(pred_depth, count_map)
# blur_mask = generatemask_coarse(IMG_RESOLUTION)
# blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device)
# count_map = (1 - blur_mask)
# pred_map = whole_depth_pred * (1 - blur_mask)
# avg_depth_map.update(pred_map, count_map)
return avg_depth_map
def regular_tile_param(model, image, offset_x=0, offset_y=0, img_lr=None, iter_pred=None, boundary=0, update=False, avg_depth_map=None, blr_mask=False, crop_size=None,
img_resolution=None, transform=None):
# crop size
# height = 540
# width = 960
height = crop_size[0]
width = crop_size[1]
assert offset_x >= 0 and offset_y >= 0
tile_num_x = (img_resolution[1] - offset_x) // width
tile_num_y = (img_resolution[0] - offset_y) // height
x_start = [width * x + offset_x for x in range(tile_num_x)]
y_start = [height * y + offset_y for y in range(tile_num_y)]
imgs_crop = []
crop_areas = []
bboxs_roi = []
bboxs_raw = []
if iter_pred is not None:
iter_pred = iter_pred.unsqueeze(dim=0).unsqueeze(dim=0)
iter_priors = []
for x in x_start: # w
for y in y_start: # h
bbox = (int(y), int(y+height), int(x), int(x+width))
img_crop, crop_area = crop(image, bbox)
imgs_crop.append(img_crop)
crop_areas.append(crop_area)
crop_y1, crop_y2, crop_x1, crop_x2 = bbox
bbox_roi = torch.tensor([crop_x1 / img_resolution[1] * 512, crop_y1 / img_resolution[0] * 384, crop_x2 / img_resolution[1] * 512, crop_y2 / img_resolution[0] * 384])
bboxs_roi.append(bbox_roi)
bbox_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2])
bboxs_raw.append(bbox_raw)
if iter_pred is not None:
iter_prior, _ = crop(iter_pred, bbox)
iter_priors.append(iter_prior)
crop_areas = torch.cat(crop_areas, dim=0)
imgs_crop = torch.cat(imgs_crop, dim=0)
bboxs_roi = torch.stack(bboxs_roi, dim=0)
bboxs_raw = torch.stack(bboxs_raw, dim=0)
if iter_pred is not None:
iter_priors = torch.cat(iter_priors, dim=0)
iter_priors = transform(iter_priors)
iter_priors = iter_priors.to(image.device).float()
crop_areas = transform(crop_areas)
imgs_crop = transform(imgs_crop)
imgs_crop = imgs_crop.to(image.device).float()
bboxs_roi = bboxs_roi.to(image.device).float()
crop_areas = crop_areas.to(image.device).float()
img_lr = img_lr.to(image.device).float()
pred_depth_crops = []
with torch.no_grad():
for i, (img, bbox, crop_area) in enumerate(zip(imgs_crop, bboxs_roi, crop_areas)):
if iter_pred is not None:
iter_prior = iter_priors[i].unsqueeze(dim=0)
else:
iter_prior = None
if i == 0:
out_dict = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)
whole_depth_pred = out_dict['coarse_depth_pred']
# return whole_depth_pred.squeeze()
# pred_depth_crop = out_dict['fine_depth_pred']
pred_depth_crop = out_dict['metric_depth']
else:
pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['metric_depth']
# pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['fine_depth_pred']
pred_depth_crop = nn.functional.interpolate(
pred_depth_crop, (height, width), mode='bilinear', align_corners=True)
# pred_depth_crop = nn.functional.interpolate(
# pred_depth_crop, (height, width), mode='nearest')
pred_depth_crops.append(pred_depth_crop.squeeze())
whole_depth_pred = whole_depth_pred.squeeze()
whole_depth_pred = nn.functional.interpolate(whole_depth_pred.unsqueeze(dim=0).unsqueeze(dim=0), img_resolution, mode='bilinear', align_corners=True).squeeze()
####### stich part
inner_idx = 0
init_flag = False
if offset_x == 0 and offset_y == 0:
init_flag = True
# pred_depth = whole_depth_pred
pred_depth = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
else:
iter_pred = iter_pred.squeeze()
pred_depth = iter_pred
blur_mask = generatemask((height, width)) + 1e-3
count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
for ii, x in enumerate(x_start):
for jj, y in enumerate(y_start):
if init_flag:
# pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx] + (1 - blur_mask) * crop_temp
# pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx] + (1 - blur_mask) * crop_temp
blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device)
count_map[y: y+height, x: x+width] = blur_mask
pred_depth[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask
else:
# ensemble with running mean
if blr_mask:
blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device)
count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
count_map[y: y+height, x: x+width] = blur_mask
pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask
avg_depth_map.update(pred_map, count_map)
else:
if boundary != 0:
count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
count_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = 1
pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
pred_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = pred_depth_crops[inner_idx][boundary:-boundary, boundary:-boundary]
avg_depth_map.update(pred_map, count_map)
else:
count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
count_map[y: y+height, x: x+width] = 1
pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx]
avg_depth_map.update(pred_map, count_map)
inner_idx += 1
if init_flag:
avg_depth_map = RunningAverageMap(pred_depth, count_map)
# blur_mask = generatemask_coarse(img_resolution)
# blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device)
# count_map = (1 - blur_mask)
# pred_map = whole_depth_pred * (1 - blur_mask)
# avg_depth_map.update(pred_map, count_map)
return avg_depth_map
def random_tile(model, image, img_lr=None, iter_pred=None, boundary=0, update=False, avg_depth_map=None, blr_mask=False):
height = CROP_SIZE[0]
width = CROP_SIZE[1]
x_start = [random.randint(0, IMG_RESOLUTION[1] - width - 1)]
y_start = [random.randint(0, IMG_RESOLUTION[0] - height - 1)]
imgs_crop = []
crop_areas = []
bboxs_roi = []
bboxs_raw = []
if iter_pred is not None:
iter_pred = iter_pred.unsqueeze(dim=0).unsqueeze(dim=0)
iter_priors = []
for x in x_start: # w
for y in y_start: # h
bbox = (int(y), int(y+height), int(x), int(x+width))
img_crop, crop_area = crop(image, bbox)
imgs_crop.append(img_crop)
crop_areas.append(crop_area)
crop_y1, crop_y2, crop_x1, crop_x2 = bbox
bbox_roi = torch.tensor([crop_x1 / IMG_RESOLUTION[1] * 512, crop_y1 / IMG_RESOLUTION[0] * 384, crop_x2 / IMG_RESOLUTION[1] * 512, crop_y2 / IMG_RESOLUTION[0] * 384])
bboxs_roi.append(bbox_roi)
bbox_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2])
bboxs_raw.append(bbox_raw)
if iter_pred is not None:
iter_prior, _ = crop(iter_pred, bbox)
iter_priors.append(iter_prior)
crop_areas = torch.cat(crop_areas, dim=0)
imgs_crop = torch.cat(imgs_crop, dim=0)
bboxs_roi = torch.stack(bboxs_roi, dim=0)
bboxs_raw = torch.stack(bboxs_raw, dim=0)
if iter_pred is not None:
iter_priors = torch.cat(iter_priors, dim=0)
iter_priors = TRANSFORM(iter_priors)
iter_priors = iter_priors.cuda().float()
crop_areas = TRANSFORM(crop_areas)
imgs_crop = TRANSFORM(imgs_crop)
imgs_crop = imgs_crop.cuda().float()
bboxs_roi = bboxs_roi.cuda().float()
crop_areas = crop_areas.cuda().float()
img_lr = img_lr.cuda().float()
pred_depth_crops = []
with torch.no_grad():
for i, (img, bbox, crop_area) in enumerate(zip(imgs_crop, bboxs_roi, crop_areas)):
if iter_pred is not None:
iter_prior = iter_priors[i].unsqueeze(dim=0)
else:
iter_prior = None
if i == 0:
out_dict = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)
whole_depth_pred = out_dict['coarse_depth_pred']
pred_depth_crop = out_dict['metric_depth']
# return whole_depth_pred.squeeze()
else:
pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['metric_depth']
pred_depth_crop = nn.functional.interpolate(
pred_depth_crop, (height, width), mode='bilinear', align_corners=True)
# pred_depth_crop = nn.functional.interpolate(
# pred_depth_crop, (height, width), mode='nearest')
pred_depth_crops.append(pred_depth_crop.squeeze())
whole_depth_pred = whole_depth_pred.squeeze()
####### stich part
inner_idx = 0
init_flag = False
iter_pred = iter_pred.squeeze()
pred_depth = iter_pred
blur_mask = generatemask((height, width)) + 1e-3
for ii, x in enumerate(x_start):
for jj, y in enumerate(y_start):
if init_flag:
# wont be here
crop_temp = copy.deepcopy(whole_depth_pred[y: y+height, x: x+width])
blur_mask = torch.ones((height, width))
blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device)
pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx]+ (1 - blur_mask) * crop_temp
else:
if blr_mask:
blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device)
count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
count_map[y: y+height, x: x+width] = blur_mask
pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask
avg_depth_map.update(pred_map, count_map)
else:
# ensemble with running mean
if boundary != 0:
count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
count_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = 1
pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
pred_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = pred_depth_crops[inner_idx][boundary:-boundary, boundary:-boundary]
avg_depth_map.update(pred_map, count_map)
else:
count_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
count_map[y: y+height, x: x+width] = 1
pred_map = torch.zeros(IMG_RESOLUTION, device=pred_depth_crops[inner_idx].device)
pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx]
avg_depth_map.update(pred_map, count_map)
inner_idx += 1
if avg_depth_map is None:
return pred_depth
def random_tile_param(model, image, img_lr=None, iter_pred=None, boundary=0, update=False, avg_depth_map=None, blr_mask=False, crop_size=None,
img_resolution=None, transform=None):
height = crop_size[0]
width = crop_size[1]
x_start = [random.randint(0, img_resolution[1] - width - 1)]
y_start = [random.randint(0, img_resolution[0] - height - 1)]
imgs_crop = []
crop_areas = []
bboxs_roi = []
bboxs_raw = []
if iter_pred is not None:
iter_pred = iter_pred.unsqueeze(dim=0).unsqueeze(dim=0)
iter_priors = []
for x in x_start: # w
for y in y_start: # h
bbox = (int(y), int(y+height), int(x), int(x+width))
img_crop, crop_area = crop(image, bbox)
imgs_crop.append(img_crop)
crop_areas.append(crop_area)
crop_y1, crop_y2, crop_x1, crop_x2 = bbox
bbox_roi = torch.tensor([crop_x1 / img_resolution[1] * 512, crop_y1 / img_resolution[0] * 384, crop_x2 / img_resolution[1] * 512, crop_y2 / img_resolution[0] * 384])
bboxs_roi.append(bbox_roi)
bbox_raw = torch.tensor([crop_x1, crop_y1, crop_x2, crop_y2])
bboxs_raw.append(bbox_raw)
if iter_pred is not None:
iter_prior, _ = crop(iter_pred, bbox)
iter_priors.append(iter_prior)
crop_areas = torch.cat(crop_areas, dim=0)
imgs_crop = torch.cat(imgs_crop, dim=0)
bboxs_roi = torch.stack(bboxs_roi, dim=0)
bboxs_raw = torch.stack(bboxs_raw, dim=0)
if iter_pred is not None:
iter_priors = torch.cat(iter_priors, dim=0)
iter_priors = transform(iter_priors)
iter_priors = iter_priors.cuda().float()
crop_areas = transform(crop_areas)
imgs_crop = transform(imgs_crop)
imgs_crop = imgs_crop.cuda().float()
bboxs_roi = bboxs_roi.cuda().float()
crop_areas = crop_areas.cuda().float()
img_lr = img_lr.cuda().float()
pred_depth_crops = []
with torch.no_grad():
for i, (img, bbox, crop_area) in enumerate(zip(imgs_crop, bboxs_roi, crop_areas)):
if iter_pred is not None:
iter_prior = iter_priors[i].unsqueeze(dim=0)
else:
iter_prior = None
if i == 0:
out_dict = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)
whole_depth_pred = out_dict['coarse_depth_pred']
pred_depth_crop = out_dict['metric_depth']
# return whole_depth_pred.squeeze()
else:
pred_depth_crop = model(img.unsqueeze(dim=0), mode='eval', image_raw=img_lr, bbox=bbox.unsqueeze(dim=0), crop_area=crop_area.unsqueeze(dim=0), iter_prior=iter_prior if update is True else None)['metric_depth']
pred_depth_crop = nn.functional.interpolate(
pred_depth_crop, (height, width), mode='bilinear', align_corners=True)
# pred_depth_crop = nn.functional.interpolate(
# pred_depth_crop, (height, width), mode='nearest')
pred_depth_crops.append(pred_depth_crop.squeeze())
whole_depth_pred = whole_depth_pred.squeeze()
####### stich part
inner_idx = 0
init_flag = False
iter_pred = iter_pred.squeeze()
pred_depth = iter_pred
blur_mask = generatemask((height, width)) + 1e-3
for ii, x in enumerate(x_start):
for jj, y in enumerate(y_start):
if init_flag:
# wont be here
crop_temp = copy.deepcopy(whole_depth_pred[y: y+height, x: x+width])
blur_mask = torch.ones((height, width))
blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device)
pred_depth[y: y+height, x: x+width] = blur_mask * pred_depth_crops[inner_idx]+ (1 - blur_mask) * crop_temp
else:
if blr_mask:
blur_mask = torch.tensor(blur_mask, device=whole_depth_pred.device)
count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
count_map[y: y+height, x: x+width] = blur_mask
pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx] * blur_mask
avg_depth_map.update(pred_map, count_map)
else:
# ensemble with running mean
if boundary != 0:
count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
count_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = 1
pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
pred_map[y+boundary: y+height-boundary, x+boundary: x+width-boundary] = pred_depth_crops[inner_idx][boundary:-boundary, boundary:-boundary]
avg_depth_map.update(pred_map, count_map)
else:
count_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
count_map[y: y+height, x: x+width] = 1
pred_map = torch.zeros(img_resolution, device=pred_depth_crops[inner_idx].device)
pred_map[y: y+height, x: x+width] = pred_depth_crops[inner_idx]
avg_depth_map.update(pred_map, count_map)
inner_idx += 1
if avg_depth_map is None:
return pred_depth
def colorize_infer(value, cmap='magma_r', vmin=None, vmax=None):
# normalize
vmin = value.min() if vmin is None else vmin
# vmax = value.max() if vmax is None else vmax
vmax = np.percentile(value, 95) if vmax is None else vmax
if vmin != vmax:
value = (value - vmin) / (vmax - vmin) # vmin..vmax
else:
value = value * 0.
cmapper = matplotlib.cm.get_cmap(cmap)
value = cmapper(value, bytes=True) # ((1)xhxwx4)
value = value[:, :, :3] # bgr -> rgb
rgb_value = value[..., ::-1]
return rgb_value
def colorize(value, vmin=None, vmax=None, cmap='turbo_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None, dataset_name=None):
"""Converts a depth map to a color image.
Args:
value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
Returns:
numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
"""
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()
value = value.squeeze()
if invalid_mask is None:
invalid_mask = value == invalid_val
mask = np.logical_not(invalid_mask)
# normalize
# vmin = np.percentile(value[mask],2) if vmin is None else vmin
# vmin = value.min() if vmin is None else vmin
# vmax = np.percentile(value[mask],95) if vmax is None else vmax
# mid gt
if dataset_name == 'mid':
vmin = np.percentile(value[mask],2) if vmin is None else vmin
vmax = np.percentile(value[mask],85) if vmax is None else vmax
else:
vmin = value.min() if vmin is None else vmin
vmax = np.percentile(value[mask],95) if vmax is None else vmax
if vmin != vmax:
value = (value - vmin) / (vmax - vmin) # vmin..vmax
else:
# Avoid 0-division
value = value * 0.
# squeeze last dim if it exists
# grey out the invalid values
value[invalid_mask] = np.nan
cmapper = matplotlib.cm.get_cmap(cmap)
if value_transform:
value = value_transform(value)
# value = value / value.max()
value = cmapper(value, bytes=True) # (nxmx4)
# img = value[:, :, :]
img = value[...]
img[invalid_mask] = background_color
# return img.transpose((2, 0, 1))
if gamma_corrected:
# gamma correction
img = img / 255
img = np.power(img, 2.2)
img = img * 255
img = img.astype(np.uint8)
return img
def rescale(A, lbound=0, ubound=1):
"""
Rescale an array to [lbound, ubound].
Parameters:
- A: Input data as numpy array
- lbound: Lower bound of the scale, default is 0.
- ubound: Upper bound of the scale, default is 1.
Returns:
- Rescaled array
"""
A_min = np.min(A)
A_max = np.max(A)
return (ubound - lbound) * (A - A_min) / (A_max - A_min) + lbound
def run(model, dataset, gt_dir=None, show_path=None, show=False, save_flag=False, save_path=None, mode=None, dataset_name=None, base_zoed=False, blr_mask=False):
data_len = len(dataset)
if gt_dir is not None:
metrics_avg = RunningAverageDict()
for image_ind in tqdm(range(data_len)):
if dataset_name == 'nyu':
images, depths, pred_depths = dataset[image_ind]
else:
if gt_dir is None:
images = dataset[image_ind]
else:
images, depths = dataset[image_ind]
# Load image from dataset
img = torch.tensor(images.rgb_image).unsqueeze(dim=0).permute(0, 3, 1, 2) # shape: 1, 3, h, w
img_lr = TRANSFORM(img)
if base_zoed:
with torch.no_grad():
pred_depth = model(img.cuda())['metric_depth'].squeeze()
avg_depth_map = RunningAverageMap(pred_depth)
else:
# pred_depth, count_map = regular_tile(model, img, offset_x=0, offset_y=0, img_lr=img_lr)
# avg_depth_map = RunningAverageMap(pred_depth, count_map)
avg_depth_map = regular_tile(model, img, offset_x=0, offset_y=0, img_lr=img_lr)
if mode== 'p16':
pass
elif mode== 'p49':
regular_tile(model, img, offset_x=CROP_SIZE[1]//2, offset_y=0, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask)
regular_tile(model, img, offset_x=0, offset_y=CROP_SIZE[0]//2, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask)
regular_tile(model, img, offset_x=CROP_SIZE[1]//2, offset_y=CROP_SIZE[0]//2, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask)
elif mode[0] == 'r':
regular_tile(model, img, offset_x=CROP_SIZE[1]//2, offset_y=0, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask)
regular_tile(model, img, offset_x=0, offset_y=CROP_SIZE[0]//2, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask)
regular_tile(model, img, offset_x=CROP_SIZE[1]//2, offset_y=CROP_SIZE[0]//2, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask)
for i in tqdm(range(int(mode[1:]))):
random_tile(model, img, img_lr=img_lr, iter_pred=avg_depth_map.average_map, boundary=BOUNDARY, update=True, avg_depth_map=avg_depth_map, blr_mask=blr_mask)
if show:
color_map = copy.deepcopy(avg_depth_map.average_map)
color_map = colorize_infer(color_map.detach().cpu().numpy())
cv2.imwrite(os.path.join(show_path, '{}.png'.format(images.name)), color_map)
if save_flag:
np.save(os.path.join(save_path, '{}.npy'.format(images.name)), avg_depth_map.average_map.squeeze().detach().cpu().numpy())
# np.save(os.path.join(save_path, '{}.npy'.format(images.name)), depths.gt)
if gt_dir is not None:
if dataset_name == 'nyu':
metrics = compute_metrics(torch.tensor(depths.gt), avg_depth_map.average_map, disp_gt_edges=depths.edge, min_depth_eval=1e-3, max_depth_eval=10, garg_crop=False, eigen_crop=True, dataset='nyu', pred_depths=torch.tensor(pred_depths.gt))
# metrics = compute_metrics(torch.tensor(depths.gt), avg_depth_map.average_map, disp_gt_edges=depths.edge, min_depth_eval=1e-3, max_depth_eval=10, garg_crop=False, eigen_crop=True, dataset='nyu')
else:
metrics = compute_metrics(torch.tensor(depths.gt), avg_depth_map.average_map, disp_gt_edges=depths.edge, min_depth_eval=1e-3, max_depth_eval=80, garg_crop=False, eigen_crop=False, dataset='')
metrics_avg.update(metrics)
print(metrics)
if gt_dir is not None:
print(metrics_avg.get_value())
else:
print("successful!")
return avg_depth_map
####
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--rgb_dir', type=str, required=True)
parser.add_argument('--show_path', type=str, required=None)
parser.add_argument("--ckp_path", type=str, required=True)
parser.add_argument("-m", "--model", type=str, default="zoedepth")
parser.add_argument("--model_cfg_path", type=str, default="")
parser.add_argument("--gt_dir", type=str, default=None)
parser.add_argument("--dataset_name", type=str, default=None)
parser.add_argument("--show", action='store_true')
parser.add_argument("--save", action='store_true')
parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--img_resolution", type=str, default=None)
parser.add_argument("--crop_size", type=str, default=None)
parser.add_argument("--mode", type=str, default=None)
parser.add_argument("--base_zoed", action='store_true')
parser.add_argument("--boundary", type=int, default=0)
parser.add_argument("--blur_mask", action='store_true')
args, unknown_args = parser.parse_known_args()
# prepare some global args
global IMG_RESOLUTION
if args.dataset_name == 'u4k':
IMG_RESOLUTION = (2160, 3840)
elif args.dataset_name == 'gta':
IMG_RESOLUTION = (1080, 1920)
elif args.dataset_name == 'nyu':
IMG_RESOLUTION = (480, 640)
else:
IMG_RESOLUTION = (2160, 3840)
global TRANSFORM
TRANSFORM = Compose([Resize(512, 384, keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal")])
global BOUNDARY
BOUNDARY = args.boundary
if args.img_resolution is not None:
IMG_RESOLUTION = (int(args.img_resolution.split('x')[0]), int(args.img_resolution.split('x')[1]))
global CROP_SIZE
CROP_SIZE = (int(IMG_RESOLUTION[0] // 4), int(IMG_RESOLUTION[1] // 4))
if args.crop_size is not None:
CROP_SIZE = (int(args.crop_size.split('x')[0]), int(args.crop_size.split('x')[1]))
print("\nCurrent image resolution: {}\n Current crop size: {}".format(IMG_RESOLUTION, CROP_SIZE))
overwrite_kwargs = parse_unknown(unknown_args)
overwrite_kwargs['model_cfg_path'] = args.model_cfg_path
overwrite_kwargs["model"] = args.model
# blur_mask_crop = generatemask(CROP_SIZE)
# plt.imshow(blur_mask_crop)
# plt.savefig('./nfs/results_show/crop_mask.png')
# blur_mask_crop = generatemask_coarse(IMG_RESOLUTION)
# plt.imshow(blur_mask_crop)
# plt.savefig('./nfs/results_show/whole_mask.png')
config = get_config_user(args.model, **overwrite_kwargs)
config["pretrained_resource"] = ''
model = build_model(config)
model = load_ckpt(model, args.ckp_path)
model.eval()
model.cuda()
# Create dataset from input images
dataset_custom = ImageDataset(args.rgb_dir, args.gt_dir, args.dataset_name)
# start running
if args.show:
os.makedirs(args.show_path, exist_ok=True)
if args.save:
os.makedirs(args.save_path, exist_ok=True)
run(model, dataset_custom, args.gt_dir, args.show_path, args.show, args.save, args.save_path, mode=args.mode, dataset_name=args.dataset_name, base_zoed=args.base_zoed, blr_mask=args.blur_mask)