File size: 3,053 Bytes
c2a846f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
""" utils
"""
import os
import torch
import numpy as np


def load_checkpoint(fpath, model):
    print('loading checkpoint... {}'.format(fpath))

    ckpt = torch.load(fpath, map_location='cpu')['model']

    load_dict = {}
    for k, v in ckpt.items():
        if k.startswith('module.'):
            k_ = k.replace('module.', '')
            load_dict[k_] = v
        else:
            load_dict[k] = v

    model.load_state_dict(load_dict)
    print('loading checkpoint... / done')
    return model


def compute_normal_error(pred_norm, gt_norm):
    pred_error = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
    pred_error = torch.clamp(pred_error, min=-1.0, max=1.0)
    pred_error = torch.acos(pred_error) * 180.0 / np.pi
    pred_error = pred_error.unsqueeze(1)    # (B, 1, H, W)
    return pred_error


def compute_normal_metrics(total_normal_errors):
    total_normal_errors = total_normal_errors.detach().cpu().numpy()
    num_pixels = total_normal_errors.shape[0]

    metrics = {
        'mean': np.average(total_normal_errors),
        'median': np.median(total_normal_errors),
        'rmse': np.sqrt(np.sum(total_normal_errors * total_normal_errors) / num_pixels),
        'a1': 100.0 * (np.sum(total_normal_errors < 5) / num_pixels),
        'a2': 100.0 * (np.sum(total_normal_errors < 7.5) / num_pixels),
        'a3': 100.0 * (np.sum(total_normal_errors < 11.25) / num_pixels),
        'a4': 100.0 * (np.sum(total_normal_errors < 22.5) / num_pixels),
        'a5': 100.0 * (np.sum(total_normal_errors < 30) / num_pixels)
    }

    return metrics


def pad_input(orig_H, orig_W):
    if orig_W % 32 == 0:
        l = 0
        r = 0
    else:
        new_W = 32 * ((orig_W // 32) + 1)
        l = (new_W - orig_W) // 2
        r = (new_W - orig_W) - l

    if orig_H % 32 == 0:
        t = 0
        b = 0
    else:
        new_H = 32 * ((orig_H // 32) + 1)
        t = (new_H - orig_H) // 2
        b = (new_H - orig_H) - t
    return l, r, t, b


def get_intrins_from_fov(new_fov, H, W, device):
    # NOTE: top-left pixel should be (0,0)
    if W >= H:
        new_fu = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
        new_fv = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
    else:
        new_fu = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
        new_fv = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))

    new_cu = (W / 2.0) - 0.5
    new_cv = (H / 2.0) - 0.5

    new_intrins = torch.tensor([
        [new_fu,    0,          new_cu  ],
        [0,         new_fv,     new_cv  ],
        [0,         0,          1       ]
    ], dtype=torch.float32, device=device)

    return new_intrins


def get_intrins_from_txt(intrins_path, device):
    # NOTE: top-left pixel should be (0,0)
    with open(intrins_path, 'r') as f:
        intrins_ = f.readlines()[0].split()[0].split(',')
        intrins_ = [float(i) for i in intrins_]
        fx, fy, cx, cy = intrins_

    intrins = torch.tensor([
        [fx, 0,cx],
        [ 0,fy,cy],
        [ 0, 0, 1]
    ], dtype=torch.float32, device=device)

    return intrins