lightweightmr / vgnet.py
bdck's picture
Upload vgnet.py
70d6e03 verified
"""
VGNetwork — Vertex Generator Network (MLP-only, no PointTransformerV3).
Inputs: sample points + normals. Outputs: 3D displacement.
"""
import torch
import torch.nn as nn
import numpy as np
from .embedder import get_embedder
class VGNetwork(nn.Module):
def __init__(self,
d_in=3,
d_out=3,
d_hidden=256,
n_layers=8,
skip_in=(4,),
multires=8,
scale=1.0,
geometric_init=True,
weight_norm=True):
super(VGNetwork, self).__init__()
dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
self.embed_fn_fine = None
if multires > 0:
embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
self.embed_fn_fine = embed_fn
dims[0] = input_ch + 3 # positional encoding + original xyz + normals
else:
dims[0] += 3 # add normals
self.num_layers = len(dims)
self.skip_in = skip_in
self.scale = scale
for l in range(0, self.num_layers - 1):
if l + 1 in self.skip_in:
out_dim = dims[l + 1] - dims[0]
else:
out_dim = dims[l + 1]
lin = nn.Linear(dims[l], out_dim)
if geometric_init:
if multires > 0 and l == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
elif multires > 0 and l in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.activation = nn.ReLU()
def forward(self, samples, normals):
"""
Args:
samples: (B, 3) query points
normals: (B, 3) estimated normals at samples
Returns:
moving_pcd: (B, 3) displaced points = samples + delta
"""
inputs = samples * self.scale
if self.embed_fn_fine is not None:
inputs = self.embed_fn_fine(inputs)
inputs = torch.cat((inputs, normals), dim=-1)
x = inputs
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
if l in self.skip_in:
x = torch.cat([x, inputs], 1) / np.sqrt(2)
x = lin(x)
if l < self.num_layers - 2:
x = self.activation(x)
moving_pcd = samples + x / self.scale
return moving_pcd