File size: 2,676 Bytes
205a7af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
"""Perspective field utilities.
Adapted from https://github.com/jinlinyi/PerspectiveFields
"""
import torch
from siclib.utils.conversions import deg2rad, rad2deg
def encode_up_bin(vector_field: torch.Tensor, num_bin: int) -> torch.Tensor:
"""Encode vector field into classification bins.
Args:
vector_field (torch.Tensor): gravity field of shape (2, h, w), with channel 0 cos(theta) and
1 sin(theta)
num_bin (int): number of classification bins
Returns:
torch.Tensor: encoded bin indices of shape (1, h, w)
"""
angle = (
torch.atan2(vector_field[1, :, :], vector_field[0, :, :]) / torch.pi * 180 + 180
) % 360 # [0,360)
angle_bin = torch.round(torch.div(angle, (360 / (num_bin - 1)))).long()
angle_bin[angle_bin == num_bin - 1] = 0
invalid = (vector_field == 0).sum(0) == vector_field.size(0)
angle_bin[invalid] = num_bin - 1
return deg2rad(angle_bin.type(torch.LongTensor))
def decode_up_bin(angle_bin: torch.Tensor, num_bin: int) -> torch.Tensor:
"""Decode classification bins into vector field.
Args:
angle_bin (torch.Tensor): bin indices of shape (1, h, w)
num_bin (int): number of classification bins
Returns:
torch.Tensor: decoded vector field of shape (2, h, w)
"""
angle = (angle_bin * (360 / (num_bin - 1)) - 180) / 180 * torch.pi
cos = torch.cos(angle)
sin = torch.sin(angle)
vector_field = torch.stack((cos, sin), dim=1)
invalid = angle_bin == num_bin - 1
invalid = invalid.unsqueeze(1).repeat(1, 2, 1, 1)
vector_field[invalid] = 0
return vector_field
def encode_bin_latitude(latimap: torch.Tensor, num_classes: int) -> torch.Tensor:
"""Encode latitude map into classification bins.
Args:
latimap (torch.Tensor): latitude map of shape (h, w) with values in [-90, 90]
num_classes (int): number of classes
Returns:
torch.Tensor: encoded latitude bin indices
"""
boundaries = torch.arange(-90, 90, 180 / num_classes)[1:]
binmap = torch.bucketize(rad2deg(latimap), boundaries)
return binmap.type(torch.LongTensor)
def decode_bin_latitude(binmap: torch.Tensor, num_classes: int) -> torch.Tensor:
"""Decode classification bins to latitude map.
Args:
binmap (torch.Tensor): encoded classification bins
num_classes (int): number of classes
Returns:
torch.Tensor: latitude map of shape (h, w)
"""
bin_size = 180 / num_classes
bin_centers = torch.arange(-90, 90, bin_size) + bin_size / 2
bin_centers = bin_centers.to(binmap.device)
latimap = bin_centers[binmap]
return deg2rad(latimap)
|