xiexh20's picture
add hdm demo v1
2fd6166
raw
history blame
3.04 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from .simple_model_utils import FeedForward, BasePointModel
class SimplePointModel(BasePointModel):
"""
A simple model that processes a point cloud by applying a series of MLPs to each point
individually, along with some pooled global features.
"""
def get_layers(self):
return nn.ModuleList([FeedForward(
d_in=(3 * self.dim), d_hidden=(4 * self.dim), d_out=self.dim,
activation=nn.SiLU(), is_gated=True, bias1=False, bias2=False, bias_gate=False, use_layernorm=True
) for _ in range(self.num_layers)])
def forward(self, inputs: torch.Tensor, t: torch.Tensor):
# Prepare inputs
x, coords = self.prepare_inputs(inputs, t)
# Model
for layer in self.layers:
x_pool_max, x_pool_std = self.get_global_tensors(x)
x_input = torch.cat((x, x_pool_max, x_pool_std), dim=-1) # (B, N, 3 * D)
x = x + layer(x_input) # (B, N, D_model)
# Project
x = self.output_projection(x) # (B, N, D_out)
x = torch.transpose(x, -2, -1) # -> (B, D_out, N)
return x
class SimpleNearestNeighborsPointModel(BasePointModel):
"""
A simple model that processes a point cloud by applying a series of MLPs to each point
individually, along with some pooled global features, and the features of its nearest
neighbors.
"""
def __init__(self, num_neighbors: int = 4, **kwargs):
self.num_neighbors = num_neighbors
super().__init__(**kwargs)
from pytorch3d.ops import knn_points
self.knn_points = knn_points
def get_layers(self):
return nn.ModuleList([FeedForward(
d_in=((3 + self.num_neighbors) * self.dim), d_hidden=(4 * self.dim), d_out=self.dim,
activation=nn.SiLU(), is_gated=True, bias1=False, bias2=False, bias_gate=False, use_layernorm=True
) for _ in range(self.num_layers)])
def forward(self, inputs: torch.Tensor, t: torch.Tensor):
# Prepare inputs
x, coords = self.prepare_inputs(inputs, t) # (B, N, D), (B, N, 3)
# Get nearest neighbors. Note that the first neighbor is the identity, which is convenient
_dists, indices, _neighbors = self.knn_points(
p1=coords, p2=coords, K=(self.num_neighbors + 1),
return_nn=False) # (B, N, K), (B, N, K)
(B, N, D), (_B, _N, K) = x.shape, indices.shape
# Model
for layer in self.layers:
x_neighbor = torch.stack([x_i[idx] for x_i, idx in zip(x, indices.reshape(B, N * K))]).reshape(B, N, K * D)
x_pool_max, x_pool_std = self.get_global_tensors(x)
x_input = torch.cat((x_neighbor, x_pool_max, x_pool_std), dim=-1) # (B, N, (3+K)*D)
x = x + layer(x_input) # (B, N, D_model)
# Project
x = self.output_projection(x) # (B, N, D_out)
x = torch.transpose(x, -2, -1) # -> (B, D_out, N)
return x