Spaces:
Runtime error
Runtime error
File size: 5,364 Bytes
753e275 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import torch
import torch.nn as nn
import torch.nn.functional as F
def mask_zero(mask, value):
return torch.where(mask, value, torch.zeros_like(value))
def clampped_one_hot(x, num_classes):
mask = (x >= 0) & (x < num_classes) # (N, L)
x = x.clamp(min=0, max=num_classes-1)
y = F.one_hot(x, num_classes) * mask[...,None] # (N, L, C)
return y
class DistanceToBins(nn.Module):
def __init__(self, dist_min=0.0, dist_max=20.0, num_bins=64, use_onehot=False):
super().__init__()
self.dist_min = dist_min
self.dist_max = dist_max
self.num_bins = num_bins
self.use_onehot = use_onehot
if use_onehot:
offset = torch.linspace(dist_min, dist_max, self.num_bins)
else:
offset = torch.linspace(dist_min, dist_max, self.num_bins-1) # 1 overflow flag
self.coeff = -0.5 / ((offset[1] - offset[0]) * 0.2).item() ** 2 # `*0.2`: makes it not too blurred
self.register_buffer('offset', offset)
@property
def out_channels(self):
return self.num_bins
def forward(self, dist, dim, normalize=True):
"""
Args:
dist: (N, *, 1, *)
Returns:
(N, *, num_bins, *)
"""
assert dist.size()[dim] == 1
offset_shape = [1] * len(dist.size())
offset_shape[dim] = -1
if self.use_onehot:
diff = torch.abs(dist - self.offset.view(*offset_shape)) # (N, *, num_bins, *)
bin_idx = torch.argmin(diff, dim=dim, keepdim=True) # (N, *, 1, *)
y = torch.zeros_like(diff).scatter_(dim=dim, index=bin_idx, value=1.0)
else:
overflow_symb = (dist >= self.dist_max).float() # (N, *, 1, *)
y = dist - self.offset.view(*offset_shape) # (N, *, num_bins-1, *)
y = torch.exp(self.coeff * torch.pow(y, 2)) # (N, *, num_bins-1, *)
y = torch.cat([y, overflow_symb], dim=dim) # (N, *, num_bins, *)
if normalize:
y = y / y.sum(dim=dim, keepdim=True)
return y
class PositionalEncoding(nn.Module):
def __init__(self, num_funcs=6):
super().__init__()
self.num_funcs = num_funcs
self.register_buffer('freq_bands', 2.0 ** torch.linspace(0.0, num_funcs-1, num_funcs))
def get_out_dim(self, in_dim):
return in_dim * (2 * self.num_funcs + 1)
def forward(self, x):
"""
Args:
x: (..., d).
"""
shape = list(x.shape[:-1]) + [-1]
x = x.unsqueeze(-1) # (..., d, 1)
code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1)
code = code.reshape(shape)
return code
class AngularEncoding(nn.Module):
def __init__(self, num_funcs=3):
super().__init__()
self.num_funcs = num_funcs
self.register_buffer('freq_bands', torch.FloatTensor(
[i+1 for i in range(num_funcs)] + [1./(i+1) for i in range(num_funcs)]
))
def get_out_dim(self, in_dim):
return in_dim * (1 + 2*2*self.num_funcs)
def forward(self, x):
"""
Args:
x: (..., d).
"""
shape = list(x.shape[:-1]) + [-1]
x = x.unsqueeze(-1) # (..., d, 1)
code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1)
code = code.reshape(shape)
return code
class LayerNorm(nn.Module):
def __init__(self,
normal_shape,
gamma=True,
beta=True,
epsilon=1e-10):
"""Layer normalization layer
See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf)
:param normal_shape: The shape of the input tensor or the last dimension of the input tensor.
:param gamma: Add a scale parameter if it is True.
:param beta: Add an offset parameter if it is True.
:param epsilon: Epsilon for calculating variance.
"""
super().__init__()
if isinstance(normal_shape, int):
normal_shape = (normal_shape,)
else:
normal_shape = (normal_shape[-1],)
self.normal_shape = torch.Size(normal_shape)
self.epsilon = epsilon
if gamma:
self.gamma = nn.Parameter(torch.Tensor(*normal_shape))
else:
self.register_parameter('gamma', None)
if beta:
self.beta = nn.Parameter(torch.Tensor(*normal_shape))
else:
self.register_parameter('beta', None)
self.reset_parameters()
def reset_parameters(self):
if self.gamma is not None:
self.gamma.data.fill_(1)
if self.beta is not None:
self.beta.data.zero_()
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
std = (var + self.epsilon).sqrt()
y = (x - mean) / std
if self.gamma is not None:
y *= self.gamma
if self.beta is not None:
y += self.beta
return y
def extra_repr(self):
return 'normal_shape={}, gamma={}, beta={}, epsilon={}'.format(
self.normal_shape, self.gamma is not None, self.beta is not None, self.epsilon,
)
|