|
import os, torch, cv2, re |
|
import numpy as np |
|
|
|
from PIL import Image |
|
import torch.nn.functional as F |
|
import torchvision.transforms as T |
|
|
|
|
|
img2mse = lambda x, y: torch.mean((x - y) ** 2) |
|
mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.])) |
|
to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) |
|
mse2psnr2 = lambda x: -10. * np.log(x) / np.log(10.) |
|
|
|
|
|
def get_psnr(imgs_pred, imgs_gt): |
|
psnrs = [] |
|
for (img, tar) in zip(imgs_pred, imgs_gt): |
|
psnrs.append(mse2psnr2(np.mean((img - tar.cpu().numpy()) ** 2))) |
|
return np.array(psnrs) |
|
|
|
|
|
def init_log(log, keys): |
|
for key in keys: |
|
log[key] = torch.tensor([0.0], dtype=float) |
|
return log |
|
|
|
|
|
def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET): |
|
""" |
|
depth: (H, W) |
|
""" |
|
|
|
x = np.nan_to_num(depth) |
|
if minmax is None: |
|
mi = np.min(x[x > 0]) |
|
ma = np.max(x) |
|
else: |
|
mi, ma = minmax |
|
|
|
x = (x - mi) / (ma - mi + 1e-8) |
|
x = (255 * x).astype(np.uint8) |
|
x_ = cv2.applyColorMap(x, cmap) |
|
return x_, [mi, ma] |
|
|
|
|
|
def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET): |
|
""" |
|
depth: (H, W) |
|
""" |
|
if type(depth) is not np.ndarray: |
|
depth = depth.cpu().numpy() |
|
|
|
x = np.nan_to_num(depth) |
|
if minmax is None: |
|
mi = np.min(x[x > 0]) |
|
ma = np.max(x) |
|
else: |
|
mi, ma = minmax |
|
|
|
x = (x - mi) / (ma - mi + 1e-8) |
|
x = (255 * x).astype(np.uint8) |
|
x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) |
|
x_ = T.ToTensor()(x_) |
|
return x_, [mi, ma] |
|
|
|
|
|
def abs_error_numpy(depth_pred, depth_gt, mask): |
|
depth_pred, depth_gt = depth_pred[mask], depth_gt[mask] |
|
return np.abs(depth_pred - depth_gt) |
|
|
|
|
|
def abs_error(depth_pred, depth_gt, mask): |
|
depth_pred, depth_gt = depth_pred[mask], depth_gt[mask] |
|
err = depth_pred - depth_gt |
|
return np.abs(err) if type(depth_pred) is np.ndarray else err.abs() |
|
|
|
|
|
def acc_threshold(depth_pred, depth_gt, mask, threshold): |
|
""" |
|
computes the percentage of pixels whose depth error is less than @threshold |
|
""" |
|
errors = abs_error(depth_pred, depth_gt, mask) |
|
acc_mask = errors < threshold |
|
return acc_mask.astype('float') if type(depth_pred) is np.ndarray else acc_mask.float() |
|
|
|
|
|
def to_tensor_cuda(data, device, filter): |
|
for item in data.keys(): |
|
|
|
if item in filter: |
|
continue |
|
|
|
if type(data[item]) is np.ndarray: |
|
data[item] = torch.tensor(data[item], dtype=torch.float32, device=device) |
|
else: |
|
data[item] = data[item].float().to(device) |
|
return data |
|
|
|
|
|
def to_cuda(data, device, filter): |
|
for item in data.keys(): |
|
if item in filter: |
|
continue |
|
|
|
data[item] = data[item].float().to(device) |
|
return data |
|
|
|
|
|
def tensor_unsqueeze(data, filter): |
|
for item in data.keys(): |
|
if item in filter: |
|
continue |
|
|
|
data[item] = data[item][None] |
|
return data |
|
|
|
|
|
def filter_keys(dict): |
|
dict.pop('N_samples') |
|
if 'ndc' in dict.keys(): |
|
dict.pop('ndc') |
|
if 'lindisp' in dict.keys(): |
|
dict.pop('lindisp') |
|
return dict |
|
|
|
|
|
def sub_selete_data(data_batch, device, idx, filtKey=[], |
|
filtIndex=['view_ids_all', 'c2ws_all', 'scan', 'bbox', 'w2ref', 'ref2w', 'light_id', 'ckpt', |
|
'idx']): |
|
data_sub_selete = {} |
|
for item in data_batch.keys(): |
|
data_sub_selete[item] = data_batch[item][:, idx].float() if ( |
|
item not in filtIndex and torch.is_tensor(item) and item.dim() > 2) else data_batch[item].float() |
|
if not data_sub_selete[item].is_cuda: |
|
data_sub_selete[item] = data_sub_selete[item].to(device) |
|
return data_sub_selete |
|
|
|
|
|
def detach_data(dictionary): |
|
dictionary_new = {} |
|
for key in dictionary.keys(): |
|
dictionary_new[key] = dictionary[key].detach().clone() |
|
return dictionary_new |
|
|
|
|
|
def read_pfm(filename): |
|
file = open(filename, 'rb') |
|
color = None |
|
width = None |
|
height = None |
|
scale = None |
|
endian = None |
|
|
|
header = file.readline().decode('utf-8').rstrip() |
|
if header == 'PF': |
|
color = True |
|
elif header == 'Pf': |
|
color = False |
|
else: |
|
raise Exception('Not a PFM file.') |
|
|
|
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) |
|
if dim_match: |
|
width, height = map(int, dim_match.groups()) |
|
else: |
|
raise Exception('Malformed PFM header.') |
|
|
|
scale = float(file.readline().rstrip()) |
|
if scale < 0: |
|
endian = '<' |
|
scale = -scale |
|
else: |
|
endian = '>' |
|
|
|
data = np.fromfile(file, endian + 'f') |
|
shape = (height, width, 3) if color else (height, width) |
|
|
|
data = np.reshape(data, shape) |
|
data = np.flipud(data) |
|
file.close() |
|
return data, scale |
|
|
|
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR |
|
|
|
|
|
|
|
def get_scheduler(hparams, optimizer): |
|
eps = 1e-8 |
|
if hparams.lr_scheduler == 'steplr': |
|
scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step, |
|
gamma=hparams.decay_gamma) |
|
elif hparams.lr_scheduler == 'cosine': |
|
scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps) |
|
|
|
else: |
|
raise ValueError('scheduler not recognized!') |
|
|
|
|
|
|
|
|
|
return scheduler |
|
|
|
|
|
|
|
def get_nearest_pose_ids(tar_pose, ref_poses, num_select): |
|
''' |
|
Args: |
|
tar_pose: target pose [N, 4, 4] |
|
ref_poses: reference poses [M, 4, 4] |
|
num_select: the number of nearest views to select |
|
Returns: the selected indices |
|
''' |
|
|
|
dists = np.linalg.norm(tar_pose[:, None, :3, 3] - ref_poses[None, :, :3, 3], axis=-1) |
|
|
|
sorted_ids = np.argsort(dists, axis=-1) |
|
selected_ids = sorted_ids[:, :num_select] |
|
return selected_ids |
|
|