File size: 2,676 Bytes
205a7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""Perspective field utilities.

Adapted from https://github.com/jinlinyi/PerspectiveFields
"""

import torch

from siclib.utils.conversions import deg2rad, rad2deg


def encode_up_bin(vector_field: torch.Tensor, num_bin: int) -> torch.Tensor:
    """Encode vector field into classification bins.

    Args:
        vector_field (torch.Tensor): gravity field of shape (2, h, w), with channel 0 cos(theta) and
        1 sin(theta)
        num_bin (int): number of classification bins

    Returns:
        torch.Tensor: encoded bin indices of shape (1, h, w)
    """
    angle = (
        torch.atan2(vector_field[1, :, :], vector_field[0, :, :]) / torch.pi * 180 + 180
    ) % 360  # [0,360)
    angle_bin = torch.round(torch.div(angle, (360 / (num_bin - 1)))).long()
    angle_bin[angle_bin == num_bin - 1] = 0
    invalid = (vector_field == 0).sum(0) == vector_field.size(0)
    angle_bin[invalid] = num_bin - 1
    return deg2rad(angle_bin.type(torch.LongTensor))


def decode_up_bin(angle_bin: torch.Tensor, num_bin: int) -> torch.Tensor:
    """Decode classification bins into vector field.

    Args:
        angle_bin (torch.Tensor): bin indices of shape (1, h, w)
        num_bin (int): number of classification bins

    Returns:
        torch.Tensor: decoded vector field of shape (2, h, w)
    """
    angle = (angle_bin * (360 / (num_bin - 1)) - 180) / 180 * torch.pi
    cos = torch.cos(angle)
    sin = torch.sin(angle)
    vector_field = torch.stack((cos, sin), dim=1)
    invalid = angle_bin == num_bin - 1
    invalid = invalid.unsqueeze(1).repeat(1, 2, 1, 1)
    vector_field[invalid] = 0
    return vector_field


def encode_bin_latitude(latimap: torch.Tensor, num_classes: int) -> torch.Tensor:
    """Encode latitude map into classification bins.

    Args:
        latimap (torch.Tensor): latitude map of shape (h, w) with values in [-90, 90]
        num_classes (int): number of classes

    Returns:
        torch.Tensor: encoded latitude bin indices
    """
    boundaries = torch.arange(-90, 90, 180 / num_classes)[1:]
    binmap = torch.bucketize(rad2deg(latimap), boundaries)
    return binmap.type(torch.LongTensor)


def decode_bin_latitude(binmap: torch.Tensor, num_classes: int) -> torch.Tensor:
    """Decode classification bins to latitude map.

    Args:
        binmap (torch.Tensor): encoded classification bins
        num_classes (int): number of classes

    Returns:
        torch.Tensor: latitude map of shape (h, w)
    """
    bin_size = 180 / num_classes
    bin_centers = torch.arange(-90, 90, bin_size) + bin_size / 2
    bin_centers = bin_centers.to(binmap.device)
    latimap = bin_centers[binmap]

    return deg2rad(latimap)