File size: 2,334 Bytes
3ef1661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn

class SoftWeight(nn.Module):
    """
    Transfer n-channel discrete depth bins to a depth map.
    Args:
        @depth_bin: n-channel output of the network, [b, c, h, w]
    Return: 1-channel depth, [b, 1, h, w]
    """
    def __init__(self, depth_bins_border):
        super(SoftWeight, self).__init__()
        self.register_buffer("depth_bins_border", torch.tensor(depth_bins_border), persistent=False)

    def forward(self, pred_logit):
        if type(pred_logit).__module__ != torch.__name__:
            pred_logit = torch.tensor(pred_logit, dtype=torch.float32, device="cuda")
        pred_score = nn.functional.softmax(pred_logit, dim=1)
        pred_score_ch = pred_score.permute(0, 2, 3, 1) #[b, h, w, c]
        pred_score_weight = pred_score_ch * self.depth_bins_border
        depth_log = torch.sum(pred_score_weight, dim=3, dtype=torch.float32, keepdim=True)
        depth = 10 ** depth_log
        depth = depth.permute(0, 3, 1, 2)  # [b, 1, h, w]
        confidence, _ = torch.max(pred_logit, dim=1, keepdim=True)
        return depth, confidence

def soft_weight(pred_logit, depth_bins_border):
    """
    Transfer n-channel discrete depth bins to depth map.
    Args:
        @depth_bin: n-channel output of the network, [b, c, h, w]
    Return: 1-channel depth, [b, 1, h, w]
    """
    if type(pred_logit).__module__ != torch.__name__:
        pred_logit = torch.tensor(pred_logit, dtype=torch.float32, device="cuda")
    if type(depth_bins_border).__module__ != torch.__name__:
        depth_bins_border = torch.tensor(depth_bins_border, dtype=torch.float32, device="cuda")

    pred_score = nn.functional.softmax(pred_logit, dim=1)
    depth_bins_ch = pred_score.permute(0, 2, 3, 1) #[b, h, w, c]    depth = torch.sum(depth, dim=3, dtype=torch.float32, keepdim=True)
    depth = 10 ** depth
    depth = depth.permute(0, 3, 1, 2)  # [b, 1, h, w]

    confidence, _ = torch.max(pred_logit, dim=1, keepdim=True)
    return depth, confidence



if __name__ == '__main__':
    import numpy as np
    depth_max = 100
    depth_min = 0.5

    depth_bin_interval = (np.log10(depth_max) - np.log10(depth_min)) / 200
    depth_bins_border = [np.log10(depth_min) + depth_bin_interval * (i + 0.5)
                     for i in range(200)]
    
    sw = SoftWeight(depth_bins_border)