pengc02's picture
all
ec9a6bc
raw
history blame
19.9 kB
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
def init_out_weights(self):
for m in self.modules():
for name, param in m.named_parameters():
if 'weight' in name:
nn.init.uniform_(param.data, -1e-5, 1e-5)
elif 'bias' in name:
nn.init.constant_(param.data, 0.0)
class MLP(nn.Module):
def __init__(self, in_channels, out_channels, inter_channels = [512, 512, 512, 343, 512, 512],
res_layers = [], nlactv = nn.ReLU(), last_op=None, norm = None, init_last_layer = False):
super(MLP, self).__init__()
self.nlactv = nlactv
self.fc_list = nn.ModuleList()
self.res_layers = res_layers
if self.res_layers is None:
self.res_layers = []
self.all_channels = [in_channels] + inter_channels + [out_channels]
for l in range(0, len(self.all_channels) - 2):
if l in self.res_layers:
if norm == 'weight':
# print('layer %d weight normalization in fusion mlp' % l)
self.fc_list.append(nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(self.all_channels[l] + self.all_channels[0], self.all_channels[l + 1], 1)),
self.nlactv
))
else:
self.fc_list.append(nn.Sequential(
nn.Conv1d(self.all_channels[l] + self.all_channels[0], self.all_channels[l + 1], 1),
self.nlactv
))
self.all_channels[l] += self.all_channels[0]
else:
if norm == 'weight':
# print('layer %d weight normalization in fusion mlp' % l)
self.fc_list.append(nn.Sequential(
nn.utils.weight_norm(nn.Conv1d(self.all_channels[l], self.all_channels[l + 1], 1)),
self.nlactv
))
else:
self.fc_list.append(nn.Sequential(
nn.Conv1d(self.all_channels[l], self.all_channels[l + 1], 1),
self.nlactv
))
self.fc_list.append(nn.Conv1d(self.all_channels[-2], out_channels, 1))
if init_last_layer:
self.fc_list[-1].apply(init_out_weights)
if last_op == 'sigmoid':
self.last_op = nn.Sigmoid()
elif last_op == 'tanh':
self.last_op = nn.Tanh()
else:
self.last_op = None
def forward(self, x, return_inter_layer = []):
tmpx = x
inter_feat_list = []
for i, fc in enumerate(self.fc_list):
if i in self.res_layers:
x = fc(torch.cat([x, tmpx], dim = 1))
else:
x = fc(x)
if i == len(self.fc_list) - 1 and self.last_op is not None: # last layer
x = self.last_op(x)
if i in return_inter_layer:
inter_feat_list.append(x.clone())
if len(return_inter_layer) > 0:
return x, inter_feat_list
else:
return x
class MLPLinear(nn.Module):
def __init__(self,
in_channels,
out_channels,
inter_channels,
res_layers = [],
nlactv = nn.ReLU(),
last_op = None):
super(MLPLinear, self).__init__()
self.fc_list = nn.ModuleList()
self.all_channels = [in_channels] + inter_channels + [out_channels]
self.res_layers = res_layers
self.nlactv = nlactv
self.last_op = last_op
for l in range(0, len(self.all_channels) - 2):
if l in self.res_layers:
self.all_channels[l] += in_channels
self.fc_list.append(
nn.Sequential(
nn.Linear(self.all_channels[l], self.all_channels[l + 1]),
self.nlactv
)
)
self.fc_list.append(nn.Linear(self.all_channels[-2], self.all_channels[-1]))
def forward(self, x):
tmpx = x
for i, layer in enumerate(self.fc_list):
if i in self.res_layers:
x = torch.cat([x, tmpx], dim = -1)
x = layer(x)
if self.last_op is not None:
x = self.last_op(x)
return x
def parallel_concat(tensors: list, n_parallel_group: int):
"""
:param tensors: list of tensors, each of which has a shape of [B, G*C, N]
:param n_parallel_group:
:return: [B, G*C', N]
"""
batch_size = tensors[0].shape[0]
point_num = tensors[0].shape[-1]
assert all([t.shape[0] == batch_size for t in tensors]), 'All tensors should have the same batch size'
assert all([t.shape[2] == point_num for t in tensors]), 'All tensors should have the same point num'
assert all([t.shape[1] % n_parallel_group==0 for t in tensors]), 'Invalid tensor channels'
tensors_ = [
t.reshape(batch_size, n_parallel_group, -1, point_num) for t in tensors
]
concated = torch.cat(tensors_, dim=2)
concated = concated.reshape(batch_size, -1, point_num)
return concated
class ParallelMLP(nn.Module):
def __init__(self,
in_channels,
out_channels,
group_num,
inter_channels,
res_layers = [],
nlactv = nn.ReLU(),
last_op = None):
super(ParallelMLP, self).__init__()
self.fc_list = nn.ModuleList()
self.all_channels = [in_channels] + inter_channels + [out_channels]
self.group_num = group_num
self.res_layers = res_layers
self.nlactv = nlactv
self.last_op = last_op
for l in range(0, len(self.all_channels) - 2):
if l in self.res_layers:
self.all_channels[l] += in_channels
self.fc_list.append(
nn.Sequential(
nn.Conv1d(self.all_channels[l] * self.group_num, self.all_channels[l + 1] * self.group_num, 1, groups = self.group_num),
self.nlactv
)
)
self.fc_list.append(nn.Conv1d(self.all_channels[-2] * self.group_num, self.all_channels[-1] * self.group_num, 1, groups = self.group_num))
def forward(self, x):
"""
:param x: (batch_size, group_num, point_num, in_channels)
:return: (batch_size, group_num, point_num, out_channels)
"""
assert len(x.shape) == 4, 'input tensor should be a shape of [B, G, N, C]'
assert x.shape[1] == self.group_num, 'input tensor should have %d parallel groups, but it has %s' % (self.group_num, x.shape[1])
B, G, N, C = x.shape
x = x.permute(0, 1, 3, 2).reshape(B, G * C, N)
tmpx = x
for i, layer in enumerate(self.fc_list):
if i in self.res_layers:
x = parallel_concat([x, tmpx], G)
x = layer(x)
if self.last_op is not None:
x = self.last_op(x)
x = x.view(B, G, -1, N).permute(0, 1, 3, 2)
return x
class SdfMLP(MLPLinear):
def __init__(self,
in_channels,
out_channels,
inter_channels,
res_layers = [],
nlactv = nn.Softplus(beta = 100),
geometric_init = True,
bias = 0.5,
weight_norm = True
):
super(SdfMLP, self).__init__(in_channels,
out_channels,
inter_channels,
res_layers,
nlactv,
None)
for l, layer in enumerate(self.fc_list):
if isinstance(layer, nn.Sequential):
lin = layer[0]
elif isinstance(layer, nn.Linear):
lin = layer
else:
raise TypeError('Invalid %d layer' % l)
if geometric_init:
in_dim, out_dim = lin.in_features, lin.out_features
if l == len(self.fc_list) - 1:
torch.nn.init.normal_(lin.weight, mean = np.sqrt(np.pi) / np.sqrt(in_dim), std = 0.0001)
torch.nn.init.constant_(lin.bias, -bias)
elif l == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
elif l in self.res_layers:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
torch.nn.init.constant_(lin.weight[:, -(in_channels - 3):], 0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
if weight_norm:
if isinstance(layer, nn.Sequential):
layer[0] = nn.utils.weight_norm(lin)
elif isinstance(layer, nn.Linear):
layer = nn.utils.weight_norm(lin)
class OffsetDecoder(nn.Module):
"""
Same architecture with ShapeDecoder in POP (https://github.com/qianlim/POP).
"""
def __init__(self, in_size, hsize = 256, actv_fn='softplus'):
self.hsize = hsize
super(OffsetDecoder, self).__init__()
self.conv1 = torch.nn.Conv1d(in_size, self.hsize, 1)
self.conv2 = torch.nn.Conv1d(self.hsize, self.hsize, 1)
self.conv3 = torch.nn.Conv1d(self.hsize, self.hsize, 1)
self.conv4 = torch.nn.Conv1d(self.hsize, self.hsize, 1)
self.conv5 = torch.nn.Conv1d(self.hsize+in_size, self.hsize, 1)
self.conv6 = torch.nn.Conv1d(self.hsize, self.hsize, 1)
self.conv7 = torch.nn.Conv1d(self.hsize, self.hsize, 1)
self.conv8 = torch.nn.Conv1d(self.hsize, 3, 1)
nn.init.uniform_(self.conv8.weight, -1e-5, 1e-5)
nn.init.constant_(self.conv8.bias, 0.)
self.bn1 = torch.nn.BatchNorm1d(self.hsize)
self.bn2 = torch.nn.BatchNorm1d(self.hsize)
self.bn3 = torch.nn.BatchNorm1d(self.hsize)
self.bn4 = torch.nn.BatchNorm1d(self.hsize)
self.bn5 = torch.nn.BatchNorm1d(self.hsize)
self.bn6 = torch.nn.BatchNorm1d(self.hsize)
self.bn7 = torch.nn.BatchNorm1d(self.hsize)
self.actv_fn = nn.ReLU() if actv_fn=='relu' else nn.Softplus()
def forward(self, x):
x1 = self.actv_fn(self.bn1(self.conv1(x)))
x2 = self.actv_fn(self.bn2(self.conv2(x1)))
x3 = self.actv_fn(self.bn3(self.conv3(x2)))
x4 = self.actv_fn(self.bn4(self.conv4(x3)))
x5 = self.actv_fn(self.bn5(self.conv5(torch.cat([x,x4],dim=1))))
# position pred
x6 = self.actv_fn(self.bn6(self.conv6(x5)))
x7 = self.actv_fn(self.bn7(self.conv7(x6)))
x8 = self.conv8(x7)
return x8
def forward_wo_bn(self, x):
x1 = self.actv_fn(self.conv1(x))
x2 = self.actv_fn(self.conv2(x1))
x3 = self.actv_fn(self.conv3(x2))
x4 = self.actv_fn(self.conv4(x3))
x5 = self.actv_fn(self.conv5(torch.cat([x,x4],dim=1)))
# position pred
x6 = self.actv_fn(self.conv6(x5))
x7 = self.actv_fn(self.conv7(x6))
x8 = self.conv8(x7)
return x8
class ShapeDecoder(nn.Module):
'''
The "Shape Decoder" in the POP paper Fig. 2. The same as the "shared MLP" in the SCALE paper.
- with skip connection from the input features to the 4th layer's output features (like DeepSDF)
- branches out at the second-to-last layer, one branch for position pred, one for normal pred
'''
def __init__(self, in_size, hsize = 256, actv_fn='softplus'):
self.hsize = hsize
super(ShapeDecoder, self).__init__()
self.conv1 = torch.nn.Conv1d(in_size, self.hsize, 1)
self.conv2 = torch.nn.Conv1d(self.hsize, self.hsize, 1)
self.conv3 = torch.nn.Conv1d(self.hsize, self.hsize, 1)
self.conv4 = torch.nn.Conv1d(self.hsize, self.hsize, 1)
self.conv5 = torch.nn.Conv1d(self.hsize+in_size, self.hsize, 1)
self.conv6 = torch.nn.Conv1d(self.hsize, self.hsize, 1)
self.conv7 = torch.nn.Conv1d(self.hsize, self.hsize, 1)
self.conv8 = torch.nn.Conv1d(self.hsize, 3, 1)
self.conv6N = torch.nn.Conv1d(self.hsize, self.hsize, 1)
self.conv7N = torch.nn.Conv1d(self.hsize, self.hsize, 1)
self.conv8N = torch.nn.Conv1d(self.hsize, 3, 1)
self.bn1 = torch.nn.BatchNorm1d(self.hsize)
self.bn2 = torch.nn.BatchNorm1d(self.hsize)
self.bn3 = torch.nn.BatchNorm1d(self.hsize)
self.bn4 = torch.nn.BatchNorm1d(self.hsize)
self.bn5 = torch.nn.BatchNorm1d(self.hsize)
self.bn6 = torch.nn.BatchNorm1d(self.hsize)
self.bn7 = torch.nn.BatchNorm1d(self.hsize)
self.bn6N = torch.nn.BatchNorm1d(self.hsize)
self.bn7N = torch.nn.BatchNorm1d(self.hsize)
self.actv_fn = nn.ReLU() if actv_fn=='relu' else nn.Softplus()
# init last layer
nn.init.uniform_(self.conv8.weight, -1e-5, 1e-5)
nn.init.constant_(self.conv8.bias, 0)
def forward(self, x):
x1 = self.actv_fn(self.bn1(self.conv1(x)))
x2 = self.actv_fn(self.bn2(self.conv2(x1)))
x3 = self.actv_fn(self.bn3(self.conv3(x2)))
x4 = self.actv_fn(self.bn4(self.conv4(x3)))
x5 = self.actv_fn(self.bn5(self.conv5(torch.cat([x,x4],dim=1))))
# position pred
x6 = self.actv_fn(self.bn6(self.conv6(x5)))
x7 = self.actv_fn(self.bn7(self.conv7(x6)))
x8 = self.conv8(x7)
# normals pred
xN6 = self.actv_fn(self.bn6N(self.conv6N(x5)))
xN7 = self.actv_fn(self.bn7N(self.conv7N(xN6)))
xN8 = self.conv8N(xN7)
return x8, xN8
class MLPwoWeight(object):
def __init__(self,
in_channels,
out_channels,
inter_channels,
res_layers = [],
nlactv = nn.ReLU(),
last_op = None):
super(MLPwoWeight, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.all_channels = [in_channels] + inter_channels + [out_channels]
self.res_layers = res_layers
self.nlactv = nlactv
self.last_op = last_op
self.param_num = 0
for i in range(len(self.all_channels) - 1):
in_ch = self.all_channels[i]
if i in self.res_layers:
in_ch += self.in_channels
out_ch = self.all_channels[i + 1]
self.param_num += (in_ch * out_ch + out_ch)
self.param_num_per_group = self.param_num
def forward(self, x, params):
"""
:param x: (batch_size, point_num, in_channels)
:param params: (param_num, )
:return: (batch_size, point_num, out_channels)
"""
x = x.permute(0, 2, 1) # (B, C, N)
tmpx = x
param_id = 0
for i in range(len(self.all_channels) - 1):
in_ch = self.all_channels[i]
if i in self.res_layers:
in_ch += self.in_channels
x = torch.cat([x, tmpx], 1)
out_ch = self.all_channels[i + 1]
weight_len = out_ch * in_ch
weight = params[param_id: param_id + weight_len].reshape(out_ch, in_ch, 1)
param_id += weight_len
bias_len = out_ch
bias = params[param_id: param_id + bias_len]
param_id += bias_len
x = F.conv1d(x, weight, bias)
if i < len(self.all_channels) - 2:
x = self.nlactv(x)
if self.last_op is not None:
x = self.last_op(x)
return x.permute(0, 2, 1)
def __repr__(self):
main_str = self.__class__.__name__ + '(\n'
for i in range(len(self.all_channels) - 1):
main_str += '\tF.conv1d(in_features=%d, out_features=%d, bias=True)\n' % (self.all_channels[i], self.all_channels[i + 1])
main_str += '\tnlactv: %s\n' % self.nlactv.__repr__()
main_str += ')'
return main_str
class ParallelMLPwoWeight(object):
def __init__(self,
in_channels,
out_channels,
inter_channels,
group_num = 1,
res_layers = [],
nlactv = nn.ReLU(),
last_op = None):
super(ParallelMLPwoWeight, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.all_channels = [in_channels] + inter_channels + [out_channels]
self.res_layers = res_layers
self.group_num = group_num
self.nlactv = nlactv
self.last_op = last_op
self.param_num = 0
for i in range(len(self.all_channels) - 1):
in_ch = self.all_channels[i]
if i in self.res_layers:
in_ch += self.in_channels
out_ch = self.all_channels[i + 1]
self.param_num += (in_ch * out_ch + out_ch) * self.group_num
self.param_num_per_group = self.param_num // self.group_num
def forward(self, x, params):
"""
:param x: (batch_size, group_num, point_num, in_channels)
:param params: (group_num, param_num)
:return: (batch_size, group_num, point_num, out_channels)
"""
batch_size, group_num, point_num, in_channels = x.shape
assert group_num == self.group_num and in_channels == self.in_channels
x = x.permute(0, 1, 3, 2) # (B, G, C, N)
x = x.reshape(batch_size, group_num * in_channels, point_num)
tmpx = x
param_id = 0
for i in range(len(self.all_channels) - 1):
in_ch = self.all_channels[i]
if i in self.res_layers:
in_ch += self.in_channels
x = parallel_concat([x, tmpx], group_num)
out_ch = self.all_channels[i + 1]
weight_len = out_ch * in_ch
weight = params[:, param_id: param_id + weight_len].reshape(group_num * out_ch, in_ch, 1)
param_id += weight_len
bias_len = out_ch
bias = params[:, param_id: param_id + bias_len].reshape(group_num * out_ch)
param_id += bias_len
x = F.conv1d(x, weight, bias, groups = group_num)
if i < len(self.all_channels) - 2:
x = self.nlactv(x)
if self.last_op is not None:
x = self.last_op(x)
x = x.reshape(batch_size, group_num, self.out_channels, point_num)
return x.permute(0, 1, 3, 2)
def __repr__(self):
main_str = self.__class__.__name__ + '(\n'
main_str += '\tgroup_num: %d\n' % self.group_num
for i in range(len(self.all_channels) - 1):
main_str += '\tF.conv1d(in_features=%d, out_features=%d, bias=True)\n' % (self.all_channels[i], self.all_channels[i + 1])
main_str += '\tnlactv: %s\n' % self.nlactv.__repr__()
main_str += ')'
return main_str