| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import pytorch_lightning as pl |
| | import torch.nn.functional as F |
| | from torch.autograd import grad |
| | |
| | import numpy as np |
| |
|
| | class SDF2Density(pl.LightningModule): |
| | def __init__(self): |
| | super(SDF2Density, self).__init__() |
| |
|
| | |
| | self.beta = nn.Parameter(torch.tensor(0.1)) |
| |
|
| | def forward(self, sdf): |
| | |
| | |
| | return 1.0/(self.beta+1e-6)*F.sigmoid(-sdf/(self.beta+1e-6)) |
| |
|
| | class SDF2Occ(pl.LightningModule): |
| | def __init__(self): |
| | super(SDF2Occ, self).__init__() |
| |
|
| | |
| | self.beta = nn.Parameter(torch.tensor(0.1)) |
| |
|
| | def forward(self, sdf): |
| | |
| | |
| | return F.sigmoid(-sdf/(self.beta+1e-6)) |
| |
|
| |
|
| | class DeformationMLP(pl.LightningModule): |
| | def __init__(self,input_dim=64,output_dim=3,activation='LeakyReLU',name=None,opt=None): |
| | super(DeformationMLP, self).__init__() |
| | self.name = name |
| | self.activation = activation |
| | self.activate = nn.LeakyReLU(inplace=True) |
| | |
| | |
| | |
| | |
| | |
| | channels=[input_dim+8+1+3,128, 64, output_dim] |
| | self.deform_mlp=MLP(filter_channels=channels, |
| | name="if", |
| | res_layers=opt.res_layers, |
| | norm=opt.norm_mlp, |
| | last_op=None) |
| | smplx_dim = 10475 |
| | k=8 |
| | self.per_pt_code = nn.Embedding(smplx_dim,k) |
| |
|
| | def forward(self, feature,smpl_vis,pts_id, xyz): |
| | ''' |
| | feature may include multiple view inputs |
| | args: |
| | feature: [B, C_in, N] |
| | return: |
| | [B, C_out, N] prediction |
| | ''' |
| | y = feature |
| | e_code=self.per_pt_code(pts_id).permute(0,2,1) |
| | y=torch.cat([y,xyz,smpl_vis,e_code],1) |
| | y = self.deform_mlp(y) |
| | return y |
| |
|
| | class MLP(pl.LightningModule): |
| |
|
| | def __init__(self, |
| | filter_channels, |
| | name=None, |
| | res_layers=[], |
| | norm='group', |
| | last_op=None): |
| |
|
| | super(MLP, self).__init__() |
| |
|
| | self.filters = nn.ModuleList() |
| | self.norms = nn.ModuleList() |
| | self.res_layers = res_layers |
| | self.norm = norm |
| | self.last_op = last_op |
| | self.name = name |
| | self.activate = nn.LeakyReLU(inplace=True) |
| |
|
| | for l in range(0, len(filter_channels) - 1): |
| | if l in self.res_layers: |
| | self.filters.append( |
| | nn.Conv1d(filter_channels[l] + filter_channels[0], |
| | filter_channels[l + 1], 1)) |
| | else: |
| | self.filters.append( |
| | nn.Conv1d(filter_channels[l], filter_channels[l + 1], 1)) |
| |
|
| | if l != len(filter_channels) - 2: |
| | if norm == 'group': |
| | self.norms.append(nn.GroupNorm(32, filter_channels[l + 1])) |
| | elif norm == 'batch': |
| | self.norms.append(nn.BatchNorm1d(filter_channels[l + 1])) |
| | elif norm == 'instance': |
| | self.norms.append(nn.InstanceNorm1d(filter_channels[l + |
| | 1])) |
| | elif norm == 'weight': |
| | self.filters[l] = nn.utils.weight_norm(self.filters[l], |
| | name='weight') |
| | |
| | |
| |
|
| | def forward(self, feature): |
| | ''' |
| | feature may include multiple view inputs |
| | args: |
| | feature: [B, C_in, N] |
| | return: |
| | [B, C_out, N] prediction |
| | ''' |
| | y = feature |
| | tmpy = feature |
| |
|
| | for i, f in enumerate(self.filters): |
| |
|
| | y = f(y if i not in self.res_layers else torch.cat([y, tmpy], 1)) |
| | if i != len(self.filters) - 1: |
| | if self.norm not in ['batch', 'group', 'instance']: |
| | y = self.activate(y) |
| | else: |
| | y = self.activate(self.norms[i](y)) |
| |
|
| | if self.last_op is not None: |
| | y = self.last_op(y) |
| |
|
| | return y |
| |
|
| |
|
| | |
| | class Embedder(pl.LightningModule): |
| | def __init__(self, **kwargs): |
| | self.kwargs = kwargs |
| | self.create_embedding_fn() |
| | |
| | def create_embedding_fn(self): |
| | embed_fns = [] |
| | d = self.kwargs['input_dims'] |
| | out_dim = 0 |
| | if self.kwargs['include_input']: |
| | embed_fns.append(lambda x : x) |
| | out_dim += d |
| | |
| | max_freq = self.kwargs['max_freq_log2'] |
| | N_freqs = self.kwargs['num_freqs'] |
| | |
| | if self.kwargs['log_sampling']: |
| | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) |
| | else: |
| | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) |
| | |
| | for freq in freq_bands: |
| | for p_fn in self.kwargs['periodic_fns']: |
| | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) |
| | out_dim += d |
| | |
| | self.embed_fns = embed_fns |
| | self.out_dim = out_dim |
| | |
| | def embed(self, inputs): |
| | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) |
| |
|
| |
|
| | def get_embedder(multires=6, i=0): |
| | if i == -1: |
| | return nn.Identity(), 3 |
| | |
| | embed_kwargs = { |
| | 'include_input' : True, |
| | 'input_dims' : 3, |
| | 'max_freq_log2' : multires-1, |
| | 'num_freqs' : multires, |
| | 'log_sampling' : True, |
| | 'periodic_fns' : [torch.sin, torch.cos], |
| | } |
| | |
| | embedder_obj = Embedder(**embed_kwargs) |
| | embed = lambda x, eo=embedder_obj : eo.embed(x) |
| | return embed, embedder_obj.out_dim |
| |
|
| |
|
| | |
| | |
| | |
| | class TransformerEncoderLayer(pl.LightningModule): |
| | def __init__(self, d_model=256, skips=4, multires=6, num_mlp_layers=8, dropout=0.1, opt=None): |
| | super(TransformerEncoderLayer, self).__init__() |
| |
|
| | embed_fn, input_ch = get_embedder(multires=multires) |
| | self.skips=skips |
| | self.dropout = dropout |
| | D=num_mlp_layers |
| | self.positional_encoding = embed_fn |
| | self.d_model = d_model |
| | triplane_dim=64 |
| | opt.mlp_dim[0]=triplane_dim+6+8 |
| | opt.mlp_dim_color[0]=triplane_dim+6+8 |
| |
|
| | self.geo_mlp=MLP(filter_channels=opt.mlp_dim, |
| | name="if", |
| | res_layers=opt.res_layers, |
| | norm=opt.norm_mlp, |
| | last_op=nn.Sigmoid()) |
| | |
| | self.color_mlp=MLP(filter_channels=opt.mlp_dim_color, |
| | name="color_if", |
| | res_layers=opt.res_layers, |
| | norm=opt.norm_mlp, |
| | last_op=nn.Tanh()) |
| |
|
| | self.softmax = nn.Softmax(dim=-1) |
| |
|
| |
|
| |
|
| | def forward(self,query_points,key_points,point_features,smpl_feat,training=True,type='shape'): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | attn_output=point_features |
| |
|
| | feature=torch.cat([attn_output,smpl_feat],dim=1) |
| | |
| | if type=='shape': |
| | h=feature |
| | |
| | h=self.geo_mlp(h) |
| | return h |
| | |
| | |
| | elif type=='color': |
| | |
| |
|
| | h=feature |
| | |
| | h=self.color_mlp(h) |
| | return h |
| | elif type=='shape_color': |
| | h_s=feature |
| | h_c=feature |
| | |
| | h_s=self.geo_mlp(h_s) |
| | |
| | h_c=self.color_mlp(h_c) |
| | |
| | return h_s,h_c |
| | |
| |
|
| |
|
| |
|
| | class Swish(pl.LightningModule): |
| | def __init__(self): |
| | super(Swish, self).__init__() |
| | |
| | def forward(self, x): |
| | x = x * F.sigmoid(x) |
| | return x |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | class PositionalEncoding(nn.Module): |
| | def __init__(self, d_model, max_len=1000): |
| | super(PositionalEncoding, self).__init__() |
| | |
| | pe = torch.zeros(max_len, d_model) |
| | position = torch.arange(0, max_len).unsqueeze(1) |
| | div_term = torch.exp(torch.arange(0, d_model, 2) * |
| | -(math.log(10000.0) / d_model)) |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| | pe = pe.unsqueeze(0) |
| | self.register_buffer('pe', pe) |
| |
|
| | def forward(self, x): |
| | x = x + self.pe[:, :x.size(1)] |
| | return x |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| |
|
| |
|