Spaces:
Sleeping
Sleeping
File size: 3,040 Bytes
2fd6166 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from .simple_model_utils import FeedForward, BasePointModel
class SimplePointModel(BasePointModel):
"""
A simple model that processes a point cloud by applying a series of MLPs to each point
individually, along with some pooled global features.
"""
def get_layers(self):
return nn.ModuleList([FeedForward(
d_in=(3 * self.dim), d_hidden=(4 * self.dim), d_out=self.dim,
activation=nn.SiLU(), is_gated=True, bias1=False, bias2=False, bias_gate=False, use_layernorm=True
) for _ in range(self.num_layers)])
def forward(self, inputs: torch.Tensor, t: torch.Tensor):
# Prepare inputs
x, coords = self.prepare_inputs(inputs, t)
# Model
for layer in self.layers:
x_pool_max, x_pool_std = self.get_global_tensors(x)
x_input = torch.cat((x, x_pool_max, x_pool_std), dim=-1) # (B, N, 3 * D)
x = x + layer(x_input) # (B, N, D_model)
# Project
x = self.output_projection(x) # (B, N, D_out)
x = torch.transpose(x, -2, -1) # -> (B, D_out, N)
return x
class SimpleNearestNeighborsPointModel(BasePointModel):
"""
A simple model that processes a point cloud by applying a series of MLPs to each point
individually, along with some pooled global features, and the features of its nearest
neighbors.
"""
def __init__(self, num_neighbors: int = 4, **kwargs):
self.num_neighbors = num_neighbors
super().__init__(**kwargs)
from pytorch3d.ops import knn_points
self.knn_points = knn_points
def get_layers(self):
return nn.ModuleList([FeedForward(
d_in=((3 + self.num_neighbors) * self.dim), d_hidden=(4 * self.dim), d_out=self.dim,
activation=nn.SiLU(), is_gated=True, bias1=False, bias2=False, bias_gate=False, use_layernorm=True
) for _ in range(self.num_layers)])
def forward(self, inputs: torch.Tensor, t: torch.Tensor):
# Prepare inputs
x, coords = self.prepare_inputs(inputs, t) # (B, N, D), (B, N, 3)
# Get nearest neighbors. Note that the first neighbor is the identity, which is convenient
_dists, indices, _neighbors = self.knn_points(
p1=coords, p2=coords, K=(self.num_neighbors + 1),
return_nn=False) # (B, N, K), (B, N, K)
(B, N, D), (_B, _N, K) = x.shape, indices.shape
# Model
for layer in self.layers:
x_neighbor = torch.stack([x_i[idx] for x_i, idx in zip(x, indices.reshape(B, N * K))]).reshape(B, N, K * D)
x_pool_max, x_pool_std = self.get_global_tensors(x)
x_input = torch.cat((x_neighbor, x_pool_max, x_pool_std), dim=-1) # (B, N, (3+K)*D)
x = x + layer(x_input) # (B, N, D_model)
# Project
x = self.output_projection(x) # (B, N, D_out)
x = torch.transpose(x, -2, -1) # -> (B, D_out, N)
return x
|