Spaces:
Sleeping
Sleeping
from typing import Any, Callable, Iterable, List, Optional, Union | |
import torch | |
import torch.jit as jit | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Size, Tensor, nn | |
from torch.nn import LayerNorm | |
from model.pvcnn.pvcnn_utils import get_timestep_embedding | |
def sample_b(size: Size, sigma: float) -> Tensor: | |
"""Sample b matrix for fourier features | |
Arguments: | |
size (Size): b matrix size | |
sigma (float): std of the gaussian | |
Returns: | |
b (Tensor): b matrix | |
""" | |
return torch.randn(size) * sigma | |
def map_positional_encoding(v: Tensor, freq_bands: Tensor) -> Tensor: | |
"""Map v to positional encoding representation phi(v) | |
Arguments: | |
v (Tensor): input features (B, IFeatures) | |
freq_bands (Tensor): frequency bands (N_freqs, ) | |
Returns: | |
phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3) | |
""" | |
pe = [v] | |
for freq in freq_bands: | |
fv = freq * v | |
pe += [torch.sin(fv), torch.cos(fv)] | |
return torch.cat(pe, dim=-1) | |
def map_fourier_features(v: Tensor, b: Tensor) -> Tensor: | |
"""Map v to fourier features representation phi(v) | |
Arguments: | |
v (Tensor): input features (B, IFeatures) | |
b (Tensor): b matrix (OFeatures, IFeatures) | |
Returns: | |
phi(v) (Tensor): fourrier features (B, 2 * Features) | |
""" | |
PI = 3.141592653589793 | |
a = 2 * PI * v @ b.T | |
return torch.cat((torch.sin(a), torch.cos(a)), dim=-1) | |
class FeatureMapping(nn.Module): | |
"""FeatureMapping nn.Module | |
Maps v to features following transformation phi(v) | |
Arguments: | |
i_dim (int): input dimensions | |
o_dim (int): output dimensions | |
""" | |
def __init__(self, i_dim: int, o_dim: int) -> None: | |
super().__init__() | |
self.i_dim = i_dim | |
self.o_dim = o_dim | |
def forward(self, v: Tensor) -> Tensor: | |
"""FeratureMapping forward pass | |
Arguments: | |
v (Tensor): input features (B, IFeatures) | |
Returns: | |
phi(v) (Tensor): mapped features (B, OFeatures) | |
""" | |
raise NotImplementedError("Forward pass not implemented yet!") | |
class PositionalEncoding(FeatureMapping): | |
"""PositionalEncoding module | |
Maps v to positional encoding representation phi(v) | |
Arguments: | |
i_dim (int): input dimension for v | |
N_freqs (int): #frequency to sample (default: 10) | |
""" | |
def __init__( | |
self, | |
i_dim: int, | |
N_freqs: int = 10, | |
) -> None: | |
super().__init__(i_dim, 3 + (2 * N_freqs) * 3) | |
self.N_freqs = N_freqs | |
a, b = 1, self.N_freqs - 1 | |
freq_bands = 2 ** torch.linspace(a, b, self.N_freqs) | |
self.register_buffer("freq_bands", freq_bands) | |
def forward(self, v: Tensor) -> Tensor: | |
"""Map v to positional encoding representation phi(v) | |
Arguments: | |
v (Tensor): input features (B, IFeatures) | |
Returns: | |
phi(v) (Tensor): fourrier features (B, 3 + (2 * N_freqs) * 3) | |
""" | |
return map_positional_encoding(v, self.freq_bands) | |
class FourierFeatures(FeatureMapping): | |
"""Fourier Features module | |
Maps v to fourier features representation phi(v) | |
Arguments: | |
i_dim (int): input dimension for v | |
features (int): output dimension (default: 256) | |
sigma (float): std of the gaussian (default: 26.) | |
""" | |
def __init__( | |
self, | |
i_dim: int, | |
features: int = 256, | |
sigma: float = 26., | |
) -> None: | |
super().__init__(i_dim, 2 * features) | |
self.features = features | |
self.sigma = sigma | |
self.size = Size((self.features, self.i_dim)) | |
self.register_buffer("b", sample_b(self.size, self.sigma)) | |
def forward(self, v: Tensor) -> Tensor: | |
"""Map v to fourier features representation phi(v) | |
Arguments: | |
v (Tensor): input features (B, IFeatures) | |
Returns: | |
phi(v) (Tensor): fourrier features (B, 2 * Features) | |
""" | |
return map_fourier_features(v, self.b) | |
class FeedForward(nn.Module): | |
""" Adapted from the FeedForward layer from labmlai """ | |
def __init__( | |
self, | |
d_in: int, | |
d_hidden: int, | |
d_out: int, | |
activation: Callable = nn.ReLU(), | |
is_gated: bool = False, | |
bias1: bool = True, | |
bias2: bool = True, | |
bias_gate: bool = True, | |
dropout: float = 0.1, | |
use_layernorm: bool = False, | |
): | |
super().__init__() | |
# Layer one parameterized by weight $W_1$ and bias $b_1$ | |
self.layer1 = nn.Linear(d_in, d_hidden, bias=bias1) | |
# Layer one parameterized by weight $W_1$ and bias $b_1$ | |
self.layer2 = nn.Linear(d_hidden, d_out, bias=bias2) | |
# Hidden layer dropout | |
self.dropout = nn.Dropout(dropout) | |
# Activation function $f$ | |
self.activation = activation | |
# Whether there is a gate | |
self.is_gated = is_gated | |
if is_gated: | |
# If there is a gate the linear layer to transform inputs to | |
# be multiplied by the gate, parameterized by weight $V$ and bias $c$ | |
self.linear_v = nn.Linear(d_in, d_hidden, bias=bias_gate) | |
# Whether to add a layernorm layer | |
self.use_layernorm = use_layernorm | |
if use_layernorm: | |
self.layernorm = LayerNorm(d_in) | |
def forward(self, x: Tensor, coords: Tensor = None) -> Tensor: | |
"""Applies a simple feed forward layer""" | |
x = self.layernorm(x) if self.use_layernorm else x | |
g = self.activation(self.layer1(x)) | |
x = (g * self.linear_v(x)) if self.is_gated else g | |
x = self.dropout(x) | |
x = self.layer2(x) | |
return x | |
class BasePointModel(nn.Module): | |
""" A base class providing useful methods for point cloud processing. """ | |
def __init__( | |
self, | |
*, | |
num_classes, | |
embed_dim, | |
extra_feature_channels, | |
dim: int = 128, | |
num_layers: int = 6 | |
): | |
super().__init__() | |
self.extra_feature_channels = extra_feature_channels | |
self.timestep_embed_dim = embed_dim | |
self.output_dim = num_classes | |
self.dim = dim | |
self.num_layers = num_layers | |
# Time embedding function | |
self.timestep_projection = nn.Sequential( | |
nn.Linear(embed_dim, embed_dim), | |
nn.LeakyReLU(0.1, inplace=True), | |
nn.Linear(embed_dim, embed_dim), | |
) | |
# Positional encoding | |
self.positional_encoding = PositionalEncoding(i_dim=3, N_freqs=10) | |
positional_encoding_d_out = 3 + (2 * 10) * 3 | |
# Input projection (point coords, point coord encodings, other features, and timestep embeddings) | |
self.input_projection = nn.Linear( | |
in_features=(3 + positional_encoding_d_out + extra_feature_channels + self.timestep_embed_dim), | |
out_features=self.dim | |
) | |
# Transformer layers | |
self.layers = self.get_layers() | |
# Output projection | |
self.output_projection = nn.Linear(self.dim, self.output_dim) | |
def get_layers(self): | |
raise NotImplementedError('This method should be implemented by subclasses') | |
def prepare_inputs(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | |
""" | |
The inputs have size (B, 3 + S, N), where S is the number of additional | |
feature channels and N is the number of points. The timesteps t can be either | |
continuous or discrete. This model has a sort of U-Net-like structure I think, | |
which is why it first goes down and then up in terms of resolution (?) | |
""" | |
# Embed and project timesteps | |
t_emb = get_timestep_embedding(self.timestep_embed_dim, t, inputs.device) | |
t_emb = self.timestep_projection(t_emb)[:, None, :].expand(-1, inputs.shape[-1], -1) # (B, N, D_t_emb) | |
# Separate input coordinates and features | |
x = torch.transpose(inputs, -2, -1) # -> (B, N, 3 + S) | |
coords = x[:, :, :3] # (B, N, 3), point coordinates | |
# Positional encoding of point coords | |
coords_posenc = self.positional_encoding(coords) # (B, N, D_p_enc) | |
# Project | |
x = torch.cat((x, coords_posenc, t_emb), dim=2) # (B, N, 3 + S + D_p_enc + D_t_emb) | |
x = self.input_projection(x) # (B, N, D_model) | |
return x, coords | |
def get_global_tensors(self, x: Tensor): | |
B, N, D = x.shape | |
x_pool_max = torch.max(x, dim=1, keepdim=True).values.repeat(1, N, 1) # (B, 1, D) | |
x_pool_std = torch.std(x, dim=1, keepdim=True).repeat(1, N, 1) # (B, 1, D) | |
return x_pool_max, x_pool_std | |
def forward(self, inputs: torch.Tensor, t: torch.Tensor): | |
raise NotImplementedError('This method should be implemented by subclasses') |