luost26's picture
Update
753e275
raw
history blame
5.36 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
def mask_zero(mask, value):
return torch.where(mask, value, torch.zeros_like(value))
def clampped_one_hot(x, num_classes):
mask = (x >= 0) & (x < num_classes) # (N, L)
x = x.clamp(min=0, max=num_classes-1)
y = F.one_hot(x, num_classes) * mask[...,None] # (N, L, C)
return y
class DistanceToBins(nn.Module):
def __init__(self, dist_min=0.0, dist_max=20.0, num_bins=64, use_onehot=False):
super().__init__()
self.dist_min = dist_min
self.dist_max = dist_max
self.num_bins = num_bins
self.use_onehot = use_onehot
if use_onehot:
offset = torch.linspace(dist_min, dist_max, self.num_bins)
else:
offset = torch.linspace(dist_min, dist_max, self.num_bins-1) # 1 overflow flag
self.coeff = -0.5 / ((offset[1] - offset[0]) * 0.2).item() ** 2 # `*0.2`: makes it not too blurred
self.register_buffer('offset', offset)
@property
def out_channels(self):
return self.num_bins
def forward(self, dist, dim, normalize=True):
"""
Args:
dist: (N, *, 1, *)
Returns:
(N, *, num_bins, *)
"""
assert dist.size()[dim] == 1
offset_shape = [1] * len(dist.size())
offset_shape[dim] = -1
if self.use_onehot:
diff = torch.abs(dist - self.offset.view(*offset_shape)) # (N, *, num_bins, *)
bin_idx = torch.argmin(diff, dim=dim, keepdim=True) # (N, *, 1, *)
y = torch.zeros_like(diff).scatter_(dim=dim, index=bin_idx, value=1.0)
else:
overflow_symb = (dist >= self.dist_max).float() # (N, *, 1, *)
y = dist - self.offset.view(*offset_shape) # (N, *, num_bins-1, *)
y = torch.exp(self.coeff * torch.pow(y, 2)) # (N, *, num_bins-1, *)
y = torch.cat([y, overflow_symb], dim=dim) # (N, *, num_bins, *)
if normalize:
y = y / y.sum(dim=dim, keepdim=True)
return y
class PositionalEncoding(nn.Module):
def __init__(self, num_funcs=6):
super().__init__()
self.num_funcs = num_funcs
self.register_buffer('freq_bands', 2.0 ** torch.linspace(0.0, num_funcs-1, num_funcs))
def get_out_dim(self, in_dim):
return in_dim * (2 * self.num_funcs + 1)
def forward(self, x):
"""
Args:
x: (..., d).
"""
shape = list(x.shape[:-1]) + [-1]
x = x.unsqueeze(-1) # (..., d, 1)
code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1)
code = code.reshape(shape)
return code
class AngularEncoding(nn.Module):
def __init__(self, num_funcs=3):
super().__init__()
self.num_funcs = num_funcs
self.register_buffer('freq_bands', torch.FloatTensor(
[i+1 for i in range(num_funcs)] + [1./(i+1) for i in range(num_funcs)]
))
def get_out_dim(self, in_dim):
return in_dim * (1 + 2*2*self.num_funcs)
def forward(self, x):
"""
Args:
x: (..., d).
"""
shape = list(x.shape[:-1]) + [-1]
x = x.unsqueeze(-1) # (..., d, 1)
code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1)
code = code.reshape(shape)
return code
class LayerNorm(nn.Module):
def __init__(self,
normal_shape,
gamma=True,
beta=True,
epsilon=1e-10):
"""Layer normalization layer
See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf)
:param normal_shape: The shape of the input tensor or the last dimension of the input tensor.
:param gamma: Add a scale parameter if it is True.
:param beta: Add an offset parameter if it is True.
:param epsilon: Epsilon for calculating variance.
"""
super().__init__()
if isinstance(normal_shape, int):
normal_shape = (normal_shape,)
else:
normal_shape = (normal_shape[-1],)
self.normal_shape = torch.Size(normal_shape)
self.epsilon = epsilon
if gamma:
self.gamma = nn.Parameter(torch.Tensor(*normal_shape))
else:
self.register_parameter('gamma', None)
if beta:
self.beta = nn.Parameter(torch.Tensor(*normal_shape))
else:
self.register_parameter('beta', None)
self.reset_parameters()
def reset_parameters(self):
if self.gamma is not None:
self.gamma.data.fill_(1)
if self.beta is not None:
self.beta.data.zero_()
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
std = (var + self.epsilon).sqrt()
y = (x - mean) / std
if self.gamma is not None:
y *= self.gamma
if self.beta is not None:
y += self.beta
return y
def extra_repr(self):
return 'normal_shape={}, gamma={}, beta={}, epsilon={}'.format(
self.normal_shape, self.gamma is not None, self.beta is not None, self.epsilon,
)