Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
from torch.utils.data import Sampler | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
import os, sys | |
import numpy as np | |
import math | |
import torch | |
def convert_arg_line_to_args(arg_line): | |
for arg in arg_line.split(): | |
if not arg.strip(): | |
continue | |
yield arg | |
def block_print(): | |
sys.stdout = open(os.devnull, 'w') | |
def enable_print(): | |
sys.stdout = sys.__stdout__ | |
def get_num_lines(file_path): | |
f = open(file_path, 'r') | |
lines = f.readlines() | |
f.close() | |
return len(lines) | |
def colorize(value, vmin=None, vmax=None, cmap='Greys'): | |
value = value.cpu().numpy()[:, :, :] | |
value = np.log10(value) | |
vmin = value.min() if vmin is None else vmin | |
vmax = value.max() if vmax is None else vmax | |
if vmin != vmax: | |
value = (value - vmin) / (vmax - vmin) | |
else: | |
value = value*0. | |
cmapper = matplotlib.cm.get_cmap(cmap) | |
value = cmapper(value, bytes=True) | |
img = value[:, :, :3] | |
return img.transpose((2, 0, 1)) | |
def normalize_result(value, vmin=None, vmax=None): | |
value = value.cpu().numpy()[0, :, :] | |
vmin = value.min() if vmin is None else vmin | |
vmax = value.max() if vmax is None else vmax | |
if vmin != vmax: | |
value = (value - vmin) / (vmax - vmin) | |
else: | |
value = value * 0. | |
return np.expand_dims(value, 0) | |
inv_normalize = transforms.Normalize( | |
mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], | |
std=[1/0.229, 1/0.224, 1/0.225] | |
) | |
eval_metrics = ['silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2', 'd3'] | |
def compute_errors(gt, pred): | |
thresh = np.maximum((gt / pred), (pred / gt)) | |
d1 = (thresh < 1.25).mean() | |
d2 = (thresh < 1.25 ** 2).mean() | |
d3 = (thresh < 1.25 ** 3).mean() | |
rms = (gt - pred) ** 2 | |
rms = np.sqrt(rms.mean()) | |
log_rms = (np.log(gt) - np.log(pred)) ** 2 | |
log_rms = np.sqrt(log_rms.mean()) | |
abs_rel = np.mean(np.abs(gt - pred) / gt) | |
sq_rel = np.mean(((gt - pred) ** 2) / gt) | |
err = np.log(pred) - np.log(gt) | |
silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100 | |
err = np.abs(np.log10(pred) - np.log10(gt)) | |
log10 = np.mean(err) | |
return [silog, abs_rel, log10, rms, sq_rel, log_rms, d1, d2, d3] | |
class silog_loss(nn.Module): | |
def __init__(self, variance_focus): | |
super(silog_loss, self).__init__() | |
self.variance_focus = variance_focus | |
def forward(self, depth_est, depth_gt, mask): | |
d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask]) | |
return torch.sqrt((d ** 2).mean() - self.variance_focus * (d.mean() ** 2)) * 10.0 | |
def entropy_loss(preds, gt_label, mask): | |
# preds: B, C, H, W | |
# gt_label: B, H, W | |
# mask: B, H, W | |
mask = mask > 0.0 # B, H, W | |
preds = preds.permute(0, 2, 3, 1) # B, H, W, C | |
preds_mask = preds[mask] # N, C | |
gt_label_mask = gt_label[mask] # N | |
loss = F.cross_entropy(preds_mask, gt_label_mask, reduction='mean') | |
return loss | |
def colormap(inputs, normalize=True, torch_transpose=True): | |
if isinstance(inputs, torch.Tensor): | |
inputs = inputs.detach().cpu().numpy() | |
_DEPTH_COLORMAP = plt.get_cmap('jet', 256) # for plotting | |
vis = inputs | |
if normalize: | |
ma = float(vis.max()) | |
mi = float(vis.min()) | |
d = ma - mi if ma != mi else 1e5 | |
vis = (vis - mi) / d | |
if vis.ndim == 4: | |
vis = vis.transpose([0, 2, 3, 1]) | |
vis = _DEPTH_COLORMAP(vis) | |
vis = vis[:, :, :, 0, :3] | |
if torch_transpose: | |
vis = vis.transpose(0, 3, 1, 2) | |
elif vis.ndim == 3: | |
vis = _DEPTH_COLORMAP(vis) | |
vis = vis[:, :, :, :3] | |
if torch_transpose: | |
vis = vis.transpose(0, 3, 1, 2) | |
elif vis.ndim == 2: | |
vis = _DEPTH_COLORMAP(vis) | |
vis = vis[..., :3] | |
if torch_transpose: | |
vis = vis.transpose(2, 0, 1) | |
return vis[0,:,:,:] | |
def colormap_magma(inputs, normalize=True, torch_transpose=True): | |
if isinstance(inputs, torch.Tensor): | |
inputs = inputs.detach().cpu().numpy() | |
_DEPTH_COLORMAP = plt.get_cmap('magma', 256) # for plotting | |
vis = inputs | |
if normalize: | |
ma = float(vis.max()) | |
mi = float(vis.min()) | |
d = ma - mi if ma != mi else 1e5 | |
vis = (vis - mi) / d | |
if vis.ndim == 4: | |
vis = vis.transpose([0, 2, 3, 1]) | |
vis = _DEPTH_COLORMAP(vis) | |
vis = vis[:, :, :, 0, :3] | |
if torch_transpose: | |
vis = vis.transpose(0, 3, 1, 2) | |
elif vis.ndim == 3: | |
vis = _DEPTH_COLORMAP(vis) | |
vis = vis[:, :, :, :3] | |
if torch_transpose: | |
vis = vis.transpose(0, 3, 1, 2) | |
elif vis.ndim == 2: | |
vis = _DEPTH_COLORMAP(vis) | |
vis = vis[..., :3] | |
if torch_transpose: | |
vis = vis.transpose(2, 0, 1) | |
return vis[0,:,:,:] | |
def flip_lr(image): | |
""" | |
Flip image horizontally | |
Parameters | |
---------- | |
image : torch.Tensor [B,3,H,W] | |
Image to be flipped | |
Returns | |
------- | |
image_flipped : torch.Tensor [B,3,H,W] | |
Flipped image | |
""" | |
assert image.dim() == 4, 'You need to provide a [B,C,H,W] image to flip' | |
return torch.flip(image, [3]) | |
def fuse_inv_depth(inv_depth, inv_depth_hat, method='mean'): | |
""" | |
Fuse inverse depth and flipped inverse depth maps | |
Parameters | |
---------- | |
inv_depth : torch.Tensor [B,1,H,W] | |
Inverse depth map | |
inv_depth_hat : torch.Tensor [B,1,H,W] | |
Flipped inverse depth map produced from a flipped image | |
method : str | |
Method that will be used to fuse the inverse depth maps | |
Returns | |
------- | |
fused_inv_depth : torch.Tensor [B,1,H,W] | |
Fused inverse depth map | |
""" | |
if method == 'mean': | |
return 0.5 * (inv_depth + inv_depth_hat) | |
elif method == 'max': | |
return torch.max(inv_depth, inv_depth_hat) | |
elif method == 'min': | |
return torch.min(inv_depth, inv_depth_hat) | |
else: | |
raise ValueError('Unknown post-process method {}'.format(method)) | |
def post_process_depth(depth, depth_flipped, method='mean'): | |
""" | |
Post-process an inverse and flipped inverse depth map | |
Parameters | |
---------- | |
inv_depth : torch.Tensor [B,1,H,W] | |
Inverse depth map | |
inv_depth_flipped : torch.Tensor [B,1,H,W] | |
Inverse depth map produced from a flipped image | |
method : str | |
Method that will be used to fuse the inverse depth maps | |
Returns | |
------- | |
inv_depth_pp : torch.Tensor [B,1,H,W] | |
Post-processed inverse depth map | |
""" | |
B, C, H, W = depth.shape | |
inv_depth_hat = flip_lr(depth_flipped) | |
inv_depth_fused = fuse_inv_depth(depth, inv_depth_hat, method=method) | |
xs = torch.linspace(0., 1., W, device=depth.device, | |
dtype=depth.dtype).repeat(B, C, H, 1) | |
mask = 1.0 - torch.clamp(20. * (xs - 0.05), 0., 1.) | |
mask_hat = flip_lr(mask) | |
return mask_hat * depth + mask * inv_depth_hat + \ | |
(1.0 - mask - mask_hat) * inv_depth_fused | |
class DistributedSamplerNoEvenlyDivisible(Sampler): | |
"""Sampler that restricts data loading to a subset of the dataset. | |
It is especially useful in conjunction with | |
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each | |
process can pass a DistributedSampler instance as a DataLoader sampler, | |
and load a subset of the original dataset that is exclusive to it. | |
.. note:: | |
Dataset is assumed to be of constant size. | |
Arguments: | |
dataset: Dataset used for sampling. | |
num_replicas (optional): Number of processes participating in | |
distributed training. | |
rank (optional): Rank of the current process within num_replicas. | |
shuffle (optional): If true (default), sampler will shuffle the indices | |
""" | |
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): | |
if num_replicas is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
num_replicas = dist.get_world_size() | |
if rank is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
rank = dist.get_rank() | |
self.dataset = dataset | |
self.num_replicas = num_replicas | |
self.rank = rank | |
self.epoch = 0 | |
num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas)) | |
rest = len(self.dataset) - num_samples * self.num_replicas | |
if self.rank < rest: | |
num_samples += 1 | |
self.num_samples = num_samples | |
self.total_size = len(dataset) | |
# self.total_size = self.num_samples * self.num_replicas | |
self.shuffle = shuffle | |
def __iter__(self): | |
# deterministically shuffle based on epoch | |
g = torch.Generator() | |
g.manual_seed(self.epoch) | |
if self.shuffle: | |
indices = torch.randperm(len(self.dataset), generator=g).tolist() | |
else: | |
indices = list(range(len(self.dataset))) | |
# add extra samples to make it evenly divisible | |
# indices += indices[:(self.total_size - len(indices))] | |
# assert len(indices) == self.total_size | |
# subsample | |
indices = indices[self.rank:self.total_size:self.num_replicas] | |
self.num_samples = len(indices) | |
# assert len(indices) == self.num_samples | |
return iter(indices) | |
def __len__(self): | |
return self.num_samples | |
def set_epoch(self, epoch): | |
self.epoch = epoch | |
class D_to_cloud(nn.Module): | |
"""Layer to transform depth into point cloud | |
""" | |
def __init__(self, batch_size, height, width): | |
super(D_to_cloud, self).__init__() | |
self.batch_size = batch_size | |
self.height = height | |
self.width = width | |
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') | |
self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) # 2, H, W | |
self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), requires_grad=False) # 2, H, W | |
self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), | |
requires_grad=False) # B, 1, H, W | |
self.pix_coords = torch.unsqueeze(torch.stack( | |
[self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) # 1, 2, L | |
self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) # B, 2, L | |
self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), requires_grad=False) # B, 3, L | |
def forward(self, depth, inv_K): | |
cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) | |
cam_points = depth.view(self.batch_size, 1, -1) * cam_points | |
return cam_points.permute(0, 2, 1) |