Spaces:
Running
Running
import torch.nn as nn | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
def init_out_weights(self): | |
for m in self.modules(): | |
for name, param in m.named_parameters(): | |
if 'weight' in name: | |
nn.init.uniform_(param.data, -1e-5, 1e-5) | |
elif 'bias' in name: | |
nn.init.constant_(param.data, 0.0) | |
class MLP(nn.Module): | |
def __init__(self, in_channels, out_channels, inter_channels = [512, 512, 512, 343, 512, 512], | |
res_layers = [], nlactv = nn.ReLU(), last_op=None, norm = None, init_last_layer = False): | |
super(MLP, self).__init__() | |
self.nlactv = nlactv | |
self.fc_list = nn.ModuleList() | |
self.res_layers = res_layers | |
if self.res_layers is None: | |
self.res_layers = [] | |
self.all_channels = [in_channels] + inter_channels + [out_channels] | |
for l in range(0, len(self.all_channels) - 2): | |
if l in self.res_layers: | |
if norm == 'weight': | |
# print('layer %d weight normalization in fusion mlp' % l) | |
self.fc_list.append(nn.Sequential( | |
nn.utils.weight_norm(nn.Conv1d(self.all_channels[l] + self.all_channels[0], self.all_channels[l + 1], 1)), | |
self.nlactv | |
)) | |
else: | |
self.fc_list.append(nn.Sequential( | |
nn.Conv1d(self.all_channels[l] + self.all_channels[0], self.all_channels[l + 1], 1), | |
self.nlactv | |
)) | |
self.all_channels[l] += self.all_channels[0] | |
else: | |
if norm == 'weight': | |
# print('layer %d weight normalization in fusion mlp' % l) | |
self.fc_list.append(nn.Sequential( | |
nn.utils.weight_norm(nn.Conv1d(self.all_channels[l], self.all_channels[l + 1], 1)), | |
self.nlactv | |
)) | |
else: | |
self.fc_list.append(nn.Sequential( | |
nn.Conv1d(self.all_channels[l], self.all_channels[l + 1], 1), | |
self.nlactv | |
)) | |
self.fc_list.append(nn.Conv1d(self.all_channels[-2], out_channels, 1)) | |
if init_last_layer: | |
self.fc_list[-1].apply(init_out_weights) | |
if last_op == 'sigmoid': | |
self.last_op = nn.Sigmoid() | |
elif last_op == 'tanh': | |
self.last_op = nn.Tanh() | |
else: | |
self.last_op = None | |
def forward(self, x, return_inter_layer = []): | |
tmpx = x | |
inter_feat_list = [] | |
for i, fc in enumerate(self.fc_list): | |
if i in self.res_layers: | |
x = fc(torch.cat([x, tmpx], dim = 1)) | |
else: | |
x = fc(x) | |
if i == len(self.fc_list) - 1 and self.last_op is not None: # last layer | |
x = self.last_op(x) | |
if i in return_inter_layer: | |
inter_feat_list.append(x.clone()) | |
if len(return_inter_layer) > 0: | |
return x, inter_feat_list | |
else: | |
return x | |
class MLPLinear(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
inter_channels, | |
res_layers = [], | |
nlactv = nn.ReLU(), | |
last_op = None): | |
super(MLPLinear, self).__init__() | |
self.fc_list = nn.ModuleList() | |
self.all_channels = [in_channels] + inter_channels + [out_channels] | |
self.res_layers = res_layers | |
self.nlactv = nlactv | |
self.last_op = last_op | |
for l in range(0, len(self.all_channels) - 2): | |
if l in self.res_layers: | |
self.all_channels[l] += in_channels | |
self.fc_list.append( | |
nn.Sequential( | |
nn.Linear(self.all_channels[l], self.all_channels[l + 1]), | |
self.nlactv | |
) | |
) | |
self.fc_list.append(nn.Linear(self.all_channels[-2], self.all_channels[-1])) | |
def forward(self, x): | |
tmpx = x | |
for i, layer in enumerate(self.fc_list): | |
if i in self.res_layers: | |
x = torch.cat([x, tmpx], dim = -1) | |
x = layer(x) | |
if self.last_op is not None: | |
x = self.last_op(x) | |
return x | |
def parallel_concat(tensors: list, n_parallel_group: int): | |
""" | |
:param tensors: list of tensors, each of which has a shape of [B, G*C, N] | |
:param n_parallel_group: | |
:return: [B, G*C', N] | |
""" | |
batch_size = tensors[0].shape[0] | |
point_num = tensors[0].shape[-1] | |
assert all([t.shape[0] == batch_size for t in tensors]), 'All tensors should have the same batch size' | |
assert all([t.shape[2] == point_num for t in tensors]), 'All tensors should have the same point num' | |
assert all([t.shape[1] % n_parallel_group==0 for t in tensors]), 'Invalid tensor channels' | |
tensors_ = [ | |
t.reshape(batch_size, n_parallel_group, -1, point_num) for t in tensors | |
] | |
concated = torch.cat(tensors_, dim=2) | |
concated = concated.reshape(batch_size, -1, point_num) | |
return concated | |
class ParallelMLP(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
group_num, | |
inter_channels, | |
res_layers = [], | |
nlactv = nn.ReLU(), | |
last_op = None): | |
super(ParallelMLP, self).__init__() | |
self.fc_list = nn.ModuleList() | |
self.all_channels = [in_channels] + inter_channels + [out_channels] | |
self.group_num = group_num | |
self.res_layers = res_layers | |
self.nlactv = nlactv | |
self.last_op = last_op | |
for l in range(0, len(self.all_channels) - 2): | |
if l in self.res_layers: | |
self.all_channels[l] += in_channels | |
self.fc_list.append( | |
nn.Sequential( | |
nn.Conv1d(self.all_channels[l] * self.group_num, self.all_channels[l + 1] * self.group_num, 1, groups = self.group_num), | |
self.nlactv | |
) | |
) | |
self.fc_list.append(nn.Conv1d(self.all_channels[-2] * self.group_num, self.all_channels[-1] * self.group_num, 1, groups = self.group_num)) | |
def forward(self, x): | |
""" | |
:param x: (batch_size, group_num, point_num, in_channels) | |
:return: (batch_size, group_num, point_num, out_channels) | |
""" | |
assert len(x.shape) == 4, 'input tensor should be a shape of [B, G, N, C]' | |
assert x.shape[1] == self.group_num, 'input tensor should have %d parallel groups, but it has %s' % (self.group_num, x.shape[1]) | |
B, G, N, C = x.shape | |
x = x.permute(0, 1, 3, 2).reshape(B, G * C, N) | |
tmpx = x | |
for i, layer in enumerate(self.fc_list): | |
if i in self.res_layers: | |
x = parallel_concat([x, tmpx], G) | |
x = layer(x) | |
if self.last_op is not None: | |
x = self.last_op(x) | |
x = x.view(B, G, -1, N).permute(0, 1, 3, 2) | |
return x | |
class SdfMLP(MLPLinear): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
inter_channels, | |
res_layers = [], | |
nlactv = nn.Softplus(beta = 100), | |
geometric_init = True, | |
bias = 0.5, | |
weight_norm = True | |
): | |
super(SdfMLP, self).__init__(in_channels, | |
out_channels, | |
inter_channels, | |
res_layers, | |
nlactv, | |
None) | |
for l, layer in enumerate(self.fc_list): | |
if isinstance(layer, nn.Sequential): | |
lin = layer[0] | |
elif isinstance(layer, nn.Linear): | |
lin = layer | |
else: | |
raise TypeError('Invalid %d layer' % l) | |
if geometric_init: | |
in_dim, out_dim = lin.in_features, lin.out_features | |
if l == len(self.fc_list) - 1: | |
torch.nn.init.normal_(lin.weight, mean = np.sqrt(np.pi) / np.sqrt(in_dim), std = 0.0001) | |
torch.nn.init.constant_(lin.bias, -bias) | |
elif l == 0: | |
torch.nn.init.constant_(lin.bias, 0.0) | |
torch.nn.init.constant_(lin.weight[:, 3:], 0.0) | |
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
elif l in self.res_layers: | |
torch.nn.init.constant_(lin.bias, 0.0) | |
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
torch.nn.init.constant_(lin.weight[:, -(in_channels - 3):], 0.0) | |
else: | |
torch.nn.init.constant_(lin.bias, 0.0) | |
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) | |
if weight_norm: | |
if isinstance(layer, nn.Sequential): | |
layer[0] = nn.utils.weight_norm(lin) | |
elif isinstance(layer, nn.Linear): | |
layer = nn.utils.weight_norm(lin) | |
class OffsetDecoder(nn.Module): | |
""" | |
Same architecture with ShapeDecoder in POP (https://github.com/qianlim/POP). | |
""" | |
def __init__(self, in_size, hsize = 256, actv_fn='softplus'): | |
self.hsize = hsize | |
super(OffsetDecoder, self).__init__() | |
self.conv1 = torch.nn.Conv1d(in_size, self.hsize, 1) | |
self.conv2 = torch.nn.Conv1d(self.hsize, self.hsize, 1) | |
self.conv3 = torch.nn.Conv1d(self.hsize, self.hsize, 1) | |
self.conv4 = torch.nn.Conv1d(self.hsize, self.hsize, 1) | |
self.conv5 = torch.nn.Conv1d(self.hsize+in_size, self.hsize, 1) | |
self.conv6 = torch.nn.Conv1d(self.hsize, self.hsize, 1) | |
self.conv7 = torch.nn.Conv1d(self.hsize, self.hsize, 1) | |
self.conv8 = torch.nn.Conv1d(self.hsize, 3, 1) | |
nn.init.uniform_(self.conv8.weight, -1e-5, 1e-5) | |
nn.init.constant_(self.conv8.bias, 0.) | |
self.bn1 = torch.nn.BatchNorm1d(self.hsize) | |
self.bn2 = torch.nn.BatchNorm1d(self.hsize) | |
self.bn3 = torch.nn.BatchNorm1d(self.hsize) | |
self.bn4 = torch.nn.BatchNorm1d(self.hsize) | |
self.bn5 = torch.nn.BatchNorm1d(self.hsize) | |
self.bn6 = torch.nn.BatchNorm1d(self.hsize) | |
self.bn7 = torch.nn.BatchNorm1d(self.hsize) | |
self.actv_fn = nn.ReLU() if actv_fn=='relu' else nn.Softplus() | |
def forward(self, x): | |
x1 = self.actv_fn(self.bn1(self.conv1(x))) | |
x2 = self.actv_fn(self.bn2(self.conv2(x1))) | |
x3 = self.actv_fn(self.bn3(self.conv3(x2))) | |
x4 = self.actv_fn(self.bn4(self.conv4(x3))) | |
x5 = self.actv_fn(self.bn5(self.conv5(torch.cat([x,x4],dim=1)))) | |
# position pred | |
x6 = self.actv_fn(self.bn6(self.conv6(x5))) | |
x7 = self.actv_fn(self.bn7(self.conv7(x6))) | |
x8 = self.conv8(x7) | |
return x8 | |
def forward_wo_bn(self, x): | |
x1 = self.actv_fn(self.conv1(x)) | |
x2 = self.actv_fn(self.conv2(x1)) | |
x3 = self.actv_fn(self.conv3(x2)) | |
x4 = self.actv_fn(self.conv4(x3)) | |
x5 = self.actv_fn(self.conv5(torch.cat([x,x4],dim=1))) | |
# position pred | |
x6 = self.actv_fn(self.conv6(x5)) | |
x7 = self.actv_fn(self.conv7(x6)) | |
x8 = self.conv8(x7) | |
return x8 | |
class ShapeDecoder(nn.Module): | |
''' | |
The "Shape Decoder" in the POP paper Fig. 2. The same as the "shared MLP" in the SCALE paper. | |
- with skip connection from the input features to the 4th layer's output features (like DeepSDF) | |
- branches out at the second-to-last layer, one branch for position pred, one for normal pred | |
''' | |
def __init__(self, in_size, hsize = 256, actv_fn='softplus'): | |
self.hsize = hsize | |
super(ShapeDecoder, self).__init__() | |
self.conv1 = torch.nn.Conv1d(in_size, self.hsize, 1) | |
self.conv2 = torch.nn.Conv1d(self.hsize, self.hsize, 1) | |
self.conv3 = torch.nn.Conv1d(self.hsize, self.hsize, 1) | |
self.conv4 = torch.nn.Conv1d(self.hsize, self.hsize, 1) | |
self.conv5 = torch.nn.Conv1d(self.hsize+in_size, self.hsize, 1) | |
self.conv6 = torch.nn.Conv1d(self.hsize, self.hsize, 1) | |
self.conv7 = torch.nn.Conv1d(self.hsize, self.hsize, 1) | |
self.conv8 = torch.nn.Conv1d(self.hsize, 3, 1) | |
self.conv6N = torch.nn.Conv1d(self.hsize, self.hsize, 1) | |
self.conv7N = torch.nn.Conv1d(self.hsize, self.hsize, 1) | |
self.conv8N = torch.nn.Conv1d(self.hsize, 3, 1) | |
self.bn1 = torch.nn.BatchNorm1d(self.hsize) | |
self.bn2 = torch.nn.BatchNorm1d(self.hsize) | |
self.bn3 = torch.nn.BatchNorm1d(self.hsize) | |
self.bn4 = torch.nn.BatchNorm1d(self.hsize) | |
self.bn5 = torch.nn.BatchNorm1d(self.hsize) | |
self.bn6 = torch.nn.BatchNorm1d(self.hsize) | |
self.bn7 = torch.nn.BatchNorm1d(self.hsize) | |
self.bn6N = torch.nn.BatchNorm1d(self.hsize) | |
self.bn7N = torch.nn.BatchNorm1d(self.hsize) | |
self.actv_fn = nn.ReLU() if actv_fn=='relu' else nn.Softplus() | |
# init last layer | |
nn.init.uniform_(self.conv8.weight, -1e-5, 1e-5) | |
nn.init.constant_(self.conv8.bias, 0) | |
def forward(self, x): | |
x1 = self.actv_fn(self.bn1(self.conv1(x))) | |
x2 = self.actv_fn(self.bn2(self.conv2(x1))) | |
x3 = self.actv_fn(self.bn3(self.conv3(x2))) | |
x4 = self.actv_fn(self.bn4(self.conv4(x3))) | |
x5 = self.actv_fn(self.bn5(self.conv5(torch.cat([x,x4],dim=1)))) | |
# position pred | |
x6 = self.actv_fn(self.bn6(self.conv6(x5))) | |
x7 = self.actv_fn(self.bn7(self.conv7(x6))) | |
x8 = self.conv8(x7) | |
# normals pred | |
xN6 = self.actv_fn(self.bn6N(self.conv6N(x5))) | |
xN7 = self.actv_fn(self.bn7N(self.conv7N(xN6))) | |
xN8 = self.conv8N(xN7) | |
return x8, xN8 | |
class MLPwoWeight(object): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
inter_channels, | |
res_layers = [], | |
nlactv = nn.ReLU(), | |
last_op = None): | |
super(MLPwoWeight, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.all_channels = [in_channels] + inter_channels + [out_channels] | |
self.res_layers = res_layers | |
self.nlactv = nlactv | |
self.last_op = last_op | |
self.param_num = 0 | |
for i in range(len(self.all_channels) - 1): | |
in_ch = self.all_channels[i] | |
if i in self.res_layers: | |
in_ch += self.in_channels | |
out_ch = self.all_channels[i + 1] | |
self.param_num += (in_ch * out_ch + out_ch) | |
self.param_num_per_group = self.param_num | |
def forward(self, x, params): | |
""" | |
:param x: (batch_size, point_num, in_channels) | |
:param params: (param_num, ) | |
:return: (batch_size, point_num, out_channels) | |
""" | |
x = x.permute(0, 2, 1) # (B, C, N) | |
tmpx = x | |
param_id = 0 | |
for i in range(len(self.all_channels) - 1): | |
in_ch = self.all_channels[i] | |
if i in self.res_layers: | |
in_ch += self.in_channels | |
x = torch.cat([x, tmpx], 1) | |
out_ch = self.all_channels[i + 1] | |
weight_len = out_ch * in_ch | |
weight = params[param_id: param_id + weight_len].reshape(out_ch, in_ch, 1) | |
param_id += weight_len | |
bias_len = out_ch | |
bias = params[param_id: param_id + bias_len] | |
param_id += bias_len | |
x = F.conv1d(x, weight, bias) | |
if i < len(self.all_channels) - 2: | |
x = self.nlactv(x) | |
if self.last_op is not None: | |
x = self.last_op(x) | |
return x.permute(0, 2, 1) | |
def __repr__(self): | |
main_str = self.__class__.__name__ + '(\n' | |
for i in range(len(self.all_channels) - 1): | |
main_str += '\tF.conv1d(in_features=%d, out_features=%d, bias=True)\n' % (self.all_channels[i], self.all_channels[i + 1]) | |
main_str += '\tnlactv: %s\n' % self.nlactv.__repr__() | |
main_str += ')' | |
return main_str | |
class ParallelMLPwoWeight(object): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
inter_channels, | |
group_num = 1, | |
res_layers = [], | |
nlactv = nn.ReLU(), | |
last_op = None): | |
super(ParallelMLPwoWeight, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.all_channels = [in_channels] + inter_channels + [out_channels] | |
self.res_layers = res_layers | |
self.group_num = group_num | |
self.nlactv = nlactv | |
self.last_op = last_op | |
self.param_num = 0 | |
for i in range(len(self.all_channels) - 1): | |
in_ch = self.all_channels[i] | |
if i in self.res_layers: | |
in_ch += self.in_channels | |
out_ch = self.all_channels[i + 1] | |
self.param_num += (in_ch * out_ch + out_ch) * self.group_num | |
self.param_num_per_group = self.param_num // self.group_num | |
def forward(self, x, params): | |
""" | |
:param x: (batch_size, group_num, point_num, in_channels) | |
:param params: (group_num, param_num) | |
:return: (batch_size, group_num, point_num, out_channels) | |
""" | |
batch_size, group_num, point_num, in_channels = x.shape | |
assert group_num == self.group_num and in_channels == self.in_channels | |
x = x.permute(0, 1, 3, 2) # (B, G, C, N) | |
x = x.reshape(batch_size, group_num * in_channels, point_num) | |
tmpx = x | |
param_id = 0 | |
for i in range(len(self.all_channels) - 1): | |
in_ch = self.all_channels[i] | |
if i in self.res_layers: | |
in_ch += self.in_channels | |
x = parallel_concat([x, tmpx], group_num) | |
out_ch = self.all_channels[i + 1] | |
weight_len = out_ch * in_ch | |
weight = params[:, param_id: param_id + weight_len].reshape(group_num * out_ch, in_ch, 1) | |
param_id += weight_len | |
bias_len = out_ch | |
bias = params[:, param_id: param_id + bias_len].reshape(group_num * out_ch) | |
param_id += bias_len | |
x = F.conv1d(x, weight, bias, groups = group_num) | |
if i < len(self.all_channels) - 2: | |
x = self.nlactv(x) | |
if self.last_op is not None: | |
x = self.last_op(x) | |
x = x.reshape(batch_size, group_num, self.out_channels, point_num) | |
return x.permute(0, 1, 3, 2) | |
def __repr__(self): | |
main_str = self.__class__.__name__ + '(\n' | |
main_str += '\tgroup_num: %d\n' % self.group_num | |
for i in range(len(self.all_channels) - 1): | |
main_str += '\tF.conv1d(in_features=%d, out_features=%d, bias=True)\n' % (self.all_channels[i], self.all_channels[i + 1]) | |
main_str += '\tnlactv: %s\n' % self.nlactv.__repr__() | |
main_str += ')' | |
return main_str | |