| """ |
| VGNetwork — Vertex Generator Network (MLP-only, no PointTransformerV3). |
| Inputs: sample points + normals. Outputs: 3D displacement. |
| """ |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from .embedder import get_embedder |
|
|
|
|
| class VGNetwork(nn.Module): |
| def __init__(self, |
| d_in=3, |
| d_out=3, |
| d_hidden=256, |
| n_layers=8, |
| skip_in=(4,), |
| multires=8, |
| scale=1.0, |
| geometric_init=True, |
| weight_norm=True): |
| super(VGNetwork, self).__init__() |
|
|
| dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] |
|
|
| self.embed_fn_fine = None |
| if multires > 0: |
| embed_fn, input_ch = get_embedder(multires, input_dims=d_in) |
| self.embed_fn_fine = embed_fn |
| dims[0] = input_ch + 3 |
| else: |
| dims[0] += 3 |
|
|
| self.num_layers = len(dims) |
| self.skip_in = skip_in |
| self.scale = scale |
|
|
| for l in range(0, self.num_layers - 1): |
| if l + 1 in self.skip_in: |
| out_dim = dims[l + 1] - dims[0] |
| else: |
| out_dim = dims[l + 1] |
|
|
| lin = nn.Linear(dims[l], out_dim) |
|
|
| if geometric_init: |
| if multires > 0 and l == 0: |
| torch.nn.init.constant_(lin.bias, 0.0) |
| torch.nn.init.constant_(lin.weight[:, 3:], 0.0) |
| torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
| elif multires > 0 and l in self.skip_in: |
| torch.nn.init.constant_(lin.bias, 0.0) |
| torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
| torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) |
| else: |
| torch.nn.init.constant_(lin.bias, 0.0) |
| torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
|
|
| if weight_norm: |
| lin = nn.utils.weight_norm(lin) |
| setattr(self, "lin" + str(l), lin) |
|
|
| self.activation = nn.ReLU() |
|
|
| def forward(self, samples, normals): |
| """ |
| Args: |
| samples: (B, 3) query points |
| normals: (B, 3) estimated normals at samples |
| Returns: |
| moving_pcd: (B, 3) displaced points = samples + delta |
| """ |
| inputs = samples * self.scale |
| if self.embed_fn_fine is not None: |
| inputs = self.embed_fn_fine(inputs) |
| inputs = torch.cat((inputs, normals), dim=-1) |
|
|
| x = inputs |
| for l in range(0, self.num_layers - 1): |
| lin = getattr(self, "lin" + str(l)) |
| if l in self.skip_in: |
| x = torch.cat([x, inputs], 1) / np.sqrt(2) |
| x = lin(x) |
| if l < self.num_layers - 2: |
| x = self.activation(x) |
|
|
| moving_pcd = samples + x / self.scale |
| return moving_pcd |
|
|