Metric3D / training /mono /utils /logit_to_depth.py
zach
initial commit based on github repo
3ef1661
raw
history blame contribute delete
No virus
2.33 kB
import torch
import torch.nn as nn
class SoftWeight(nn.Module):
"""
Transfer n-channel discrete depth bins to a depth map.
Args:
@depth_bin: n-channel output of the network, [b, c, h, w]
Return: 1-channel depth, [b, 1, h, w]
"""
def __init__(self, depth_bins_border):
super(SoftWeight, self).__init__()
self.register_buffer("depth_bins_border", torch.tensor(depth_bins_border), persistent=False)
def forward(self, pred_logit):
if type(pred_logit).__module__ != torch.__name__:
pred_logit = torch.tensor(pred_logit, dtype=torch.float32, device="cuda")
pred_score = nn.functional.softmax(pred_logit, dim=1)
pred_score_ch = pred_score.permute(0, 2, 3, 1) #[b, h, w, c]
pred_score_weight = pred_score_ch * self.depth_bins_border
depth_log = torch.sum(pred_score_weight, dim=3, dtype=torch.float32, keepdim=True)
depth = 10 ** depth_log
depth = depth.permute(0, 3, 1, 2) # [b, 1, h, w]
confidence, _ = torch.max(pred_logit, dim=1, keepdim=True)
return depth, confidence
def soft_weight(pred_logit, depth_bins_border):
"""
Transfer n-channel discrete depth bins to depth map.
Args:
@depth_bin: n-channel output of the network, [b, c, h, w]
Return: 1-channel depth, [b, 1, h, w]
"""
if type(pred_logit).__module__ != torch.__name__:
pred_logit = torch.tensor(pred_logit, dtype=torch.float32, device="cuda")
if type(depth_bins_border).__module__ != torch.__name__:
depth_bins_border = torch.tensor(depth_bins_border, dtype=torch.float32, device="cuda")
pred_score = nn.functional.softmax(pred_logit, dim=1)
depth_bins_ch = pred_score.permute(0, 2, 3, 1) #[b, h, w, c] depth = torch.sum(depth, dim=3, dtype=torch.float32, keepdim=True)
depth = 10 ** depth
depth = depth.permute(0, 3, 1, 2) # [b, 1, h, w]
confidence, _ = torch.max(pred_logit, dim=1, keepdim=True)
return depth, confidence
if __name__ == '__main__':
import numpy as np
depth_max = 100
depth_min = 0.5
depth_bin_interval = (np.log10(depth_max) - np.log10(depth_min)) / 200
depth_bins_border = [np.log10(depth_min) + depth_bin_interval * (i + 0.5)
for i in range(200)]
sw = SoftWeight(depth_bins_border)