import torch.nn as nn import torch import torch.nn.functional as F import numpy as np def init_out_weights(self): for m in self.modules(): for name, param in m.named_parameters(): if 'weight' in name: nn.init.uniform_(param.data, -1e-5, 1e-5) elif 'bias' in name: nn.init.constant_(param.data, 0.0) class MLP(nn.Module): def __init__(self, in_channels, out_channels, inter_channels = [512, 512, 512, 343, 512, 512], res_layers = [], nlactv = nn.ReLU(), last_op=None, norm = None, init_last_layer = False): super(MLP, self).__init__() self.nlactv = nlactv self.fc_list = nn.ModuleList() self.res_layers = res_layers if self.res_layers is None: self.res_layers = [] self.all_channels = [in_channels] + inter_channels + [out_channels] for l in range(0, len(self.all_channels) - 2): if l in self.res_layers: if norm == 'weight': # print('layer %d weight normalization in fusion mlp' % l) self.fc_list.append(nn.Sequential( nn.utils.weight_norm(nn.Conv1d(self.all_channels[l] + self.all_channels[0], self.all_channels[l + 1], 1)), self.nlactv )) else: self.fc_list.append(nn.Sequential( nn.Conv1d(self.all_channels[l] + self.all_channels[0], self.all_channels[l + 1], 1), self.nlactv )) self.all_channels[l] += self.all_channels[0] else: if norm == 'weight': # print('layer %d weight normalization in fusion mlp' % l) self.fc_list.append(nn.Sequential( nn.utils.weight_norm(nn.Conv1d(self.all_channels[l], self.all_channels[l + 1], 1)), self.nlactv )) else: self.fc_list.append(nn.Sequential( nn.Conv1d(self.all_channels[l], self.all_channels[l + 1], 1), self.nlactv )) self.fc_list.append(nn.Conv1d(self.all_channels[-2], out_channels, 1)) if init_last_layer: self.fc_list[-1].apply(init_out_weights) if last_op == 'sigmoid': self.last_op = nn.Sigmoid() elif last_op == 'tanh': self.last_op = nn.Tanh() else: self.last_op = None def forward(self, x, return_inter_layer = []): tmpx = x inter_feat_list = [] for i, fc in enumerate(self.fc_list): if i in self.res_layers: x = fc(torch.cat([x, tmpx], dim = 1)) else: x = fc(x) if i == len(self.fc_list) - 1 and self.last_op is not None: # last layer x = self.last_op(x) if i in return_inter_layer: inter_feat_list.append(x.clone()) if len(return_inter_layer) > 0: return x, inter_feat_list else: return x class MLPLinear(nn.Module): def __init__(self, in_channels, out_channels, inter_channels, res_layers = [], nlactv = nn.ReLU(), last_op = None): super(MLPLinear, self).__init__() self.fc_list = nn.ModuleList() self.all_channels = [in_channels] + inter_channels + [out_channels] self.res_layers = res_layers self.nlactv = nlactv self.last_op = last_op for l in range(0, len(self.all_channels) - 2): if l in self.res_layers: self.all_channels[l] += in_channels self.fc_list.append( nn.Sequential( nn.Linear(self.all_channels[l], self.all_channels[l + 1]), self.nlactv ) ) self.fc_list.append(nn.Linear(self.all_channels[-2], self.all_channels[-1])) def forward(self, x): tmpx = x for i, layer in enumerate(self.fc_list): if i in self.res_layers: x = torch.cat([x, tmpx], dim = -1) x = layer(x) if self.last_op is not None: x = self.last_op(x) return x def parallel_concat(tensors: list, n_parallel_group: int): """ :param tensors: list of tensors, each of which has a shape of [B, G*C, N] :param n_parallel_group: :return: [B, G*C', N] """ batch_size = tensors[0].shape[0] point_num = tensors[0].shape[-1] assert all([t.shape[0] == batch_size for t in tensors]), 'All tensors should have the same batch size' assert all([t.shape[2] == point_num for t in tensors]), 'All tensors should have the same point num' assert all([t.shape[1] % n_parallel_group==0 for t in tensors]), 'Invalid tensor channels' tensors_ = [ t.reshape(batch_size, n_parallel_group, -1, point_num) for t in tensors ] concated = torch.cat(tensors_, dim=2) concated = concated.reshape(batch_size, -1, point_num) return concated class ParallelMLP(nn.Module): def __init__(self, in_channels, out_channels, group_num, inter_channels, res_layers = [], nlactv = nn.ReLU(), last_op = None): super(ParallelMLP, self).__init__() self.fc_list = nn.ModuleList() self.all_channels = [in_channels] + inter_channels + [out_channels] self.group_num = group_num self.res_layers = res_layers self.nlactv = nlactv self.last_op = last_op for l in range(0, len(self.all_channels) - 2): if l in self.res_layers: self.all_channels[l] += in_channels self.fc_list.append( nn.Sequential( nn.Conv1d(self.all_channels[l] * self.group_num, self.all_channels[l + 1] * self.group_num, 1, groups = self.group_num), self.nlactv ) ) self.fc_list.append(nn.Conv1d(self.all_channels[-2] * self.group_num, self.all_channels[-1] * self.group_num, 1, groups = self.group_num)) def forward(self, x): """ :param x: (batch_size, group_num, point_num, in_channels) :return: (batch_size, group_num, point_num, out_channels) """ assert len(x.shape) == 4, 'input tensor should be a shape of [B, G, N, C]' assert x.shape[1] == self.group_num, 'input tensor should have %d parallel groups, but it has %s' % (self.group_num, x.shape[1]) B, G, N, C = x.shape x = x.permute(0, 1, 3, 2).reshape(B, G * C, N) tmpx = x for i, layer in enumerate(self.fc_list): if i in self.res_layers: x = parallel_concat([x, tmpx], G) x = layer(x) if self.last_op is not None: x = self.last_op(x) x = x.view(B, G, -1, N).permute(0, 1, 3, 2) return x class SdfMLP(MLPLinear): def __init__(self, in_channels, out_channels, inter_channels, res_layers = [], nlactv = nn.Softplus(beta = 100), geometric_init = True, bias = 0.5, weight_norm = True ): super(SdfMLP, self).__init__(in_channels, out_channels, inter_channels, res_layers, nlactv, None) for l, layer in enumerate(self.fc_list): if isinstance(layer, nn.Sequential): lin = layer[0] elif isinstance(layer, nn.Linear): lin = layer else: raise TypeError('Invalid %d layer' % l) if geometric_init: in_dim, out_dim = lin.in_features, lin.out_features if l == len(self.fc_list) - 1: torch.nn.init.normal_(lin.weight, mean = np.sqrt(np.pi) / np.sqrt(in_dim), std = 0.0001) torch.nn.init.constant_(lin.bias, -bias) elif l == 0: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) elif l in self.res_layers: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) torch.nn.init.constant_(lin.weight[:, -(in_channels - 3):], 0.0) else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) if weight_norm: if isinstance(layer, nn.Sequential): layer[0] = nn.utils.weight_norm(lin) elif isinstance(layer, nn.Linear): layer = nn.utils.weight_norm(lin) class OffsetDecoder(nn.Module): """ Same architecture with ShapeDecoder in POP (https://github.com/qianlim/POP). """ def __init__(self, in_size, hsize = 256, actv_fn='softplus'): self.hsize = hsize super(OffsetDecoder, self).__init__() self.conv1 = torch.nn.Conv1d(in_size, self.hsize, 1) self.conv2 = torch.nn.Conv1d(self.hsize, self.hsize, 1) self.conv3 = torch.nn.Conv1d(self.hsize, self.hsize, 1) self.conv4 = torch.nn.Conv1d(self.hsize, self.hsize, 1) self.conv5 = torch.nn.Conv1d(self.hsize+in_size, self.hsize, 1) self.conv6 = torch.nn.Conv1d(self.hsize, self.hsize, 1) self.conv7 = torch.nn.Conv1d(self.hsize, self.hsize, 1) self.conv8 = torch.nn.Conv1d(self.hsize, 3, 1) nn.init.uniform_(self.conv8.weight, -1e-5, 1e-5) nn.init.constant_(self.conv8.bias, 0.) self.bn1 = torch.nn.BatchNorm1d(self.hsize) self.bn2 = torch.nn.BatchNorm1d(self.hsize) self.bn3 = torch.nn.BatchNorm1d(self.hsize) self.bn4 = torch.nn.BatchNorm1d(self.hsize) self.bn5 = torch.nn.BatchNorm1d(self.hsize) self.bn6 = torch.nn.BatchNorm1d(self.hsize) self.bn7 = torch.nn.BatchNorm1d(self.hsize) self.actv_fn = nn.ReLU() if actv_fn=='relu' else nn.Softplus() def forward(self, x): x1 = self.actv_fn(self.bn1(self.conv1(x))) x2 = self.actv_fn(self.bn2(self.conv2(x1))) x3 = self.actv_fn(self.bn3(self.conv3(x2))) x4 = self.actv_fn(self.bn4(self.conv4(x3))) x5 = self.actv_fn(self.bn5(self.conv5(torch.cat([x,x4],dim=1)))) # position pred x6 = self.actv_fn(self.bn6(self.conv6(x5))) x7 = self.actv_fn(self.bn7(self.conv7(x6))) x8 = self.conv8(x7) return x8 def forward_wo_bn(self, x): x1 = self.actv_fn(self.conv1(x)) x2 = self.actv_fn(self.conv2(x1)) x3 = self.actv_fn(self.conv3(x2)) x4 = self.actv_fn(self.conv4(x3)) x5 = self.actv_fn(self.conv5(torch.cat([x,x4],dim=1))) # position pred x6 = self.actv_fn(self.conv6(x5)) x7 = self.actv_fn(self.conv7(x6)) x8 = self.conv8(x7) return x8 class ShapeDecoder(nn.Module): ''' The "Shape Decoder" in the POP paper Fig. 2. The same as the "shared MLP" in the SCALE paper. - with skip connection from the input features to the 4th layer's output features (like DeepSDF) - branches out at the second-to-last layer, one branch for position pred, one for normal pred ''' def __init__(self, in_size, hsize = 256, actv_fn='softplus'): self.hsize = hsize super(ShapeDecoder, self).__init__() self.conv1 = torch.nn.Conv1d(in_size, self.hsize, 1) self.conv2 = torch.nn.Conv1d(self.hsize, self.hsize, 1) self.conv3 = torch.nn.Conv1d(self.hsize, self.hsize, 1) self.conv4 = torch.nn.Conv1d(self.hsize, self.hsize, 1) self.conv5 = torch.nn.Conv1d(self.hsize+in_size, self.hsize, 1) self.conv6 = torch.nn.Conv1d(self.hsize, self.hsize, 1) self.conv7 = torch.nn.Conv1d(self.hsize, self.hsize, 1) self.conv8 = torch.nn.Conv1d(self.hsize, 3, 1) self.conv6N = torch.nn.Conv1d(self.hsize, self.hsize, 1) self.conv7N = torch.nn.Conv1d(self.hsize, self.hsize, 1) self.conv8N = torch.nn.Conv1d(self.hsize, 3, 1) self.bn1 = torch.nn.BatchNorm1d(self.hsize) self.bn2 = torch.nn.BatchNorm1d(self.hsize) self.bn3 = torch.nn.BatchNorm1d(self.hsize) self.bn4 = torch.nn.BatchNorm1d(self.hsize) self.bn5 = torch.nn.BatchNorm1d(self.hsize) self.bn6 = torch.nn.BatchNorm1d(self.hsize) self.bn7 = torch.nn.BatchNorm1d(self.hsize) self.bn6N = torch.nn.BatchNorm1d(self.hsize) self.bn7N = torch.nn.BatchNorm1d(self.hsize) self.actv_fn = nn.ReLU() if actv_fn=='relu' else nn.Softplus() # init last layer nn.init.uniform_(self.conv8.weight, -1e-5, 1e-5) nn.init.constant_(self.conv8.bias, 0) def forward(self, x): x1 = self.actv_fn(self.bn1(self.conv1(x))) x2 = self.actv_fn(self.bn2(self.conv2(x1))) x3 = self.actv_fn(self.bn3(self.conv3(x2))) x4 = self.actv_fn(self.bn4(self.conv4(x3))) x5 = self.actv_fn(self.bn5(self.conv5(torch.cat([x,x4],dim=1)))) # position pred x6 = self.actv_fn(self.bn6(self.conv6(x5))) x7 = self.actv_fn(self.bn7(self.conv7(x6))) x8 = self.conv8(x7) # normals pred xN6 = self.actv_fn(self.bn6N(self.conv6N(x5))) xN7 = self.actv_fn(self.bn7N(self.conv7N(xN6))) xN8 = self.conv8N(xN7) return x8, xN8 class MLPwoWeight(object): def __init__(self, in_channels, out_channels, inter_channels, res_layers = [], nlactv = nn.ReLU(), last_op = None): super(MLPwoWeight, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.all_channels = [in_channels] + inter_channels + [out_channels] self.res_layers = res_layers self.nlactv = nlactv self.last_op = last_op self.param_num = 0 for i in range(len(self.all_channels) - 1): in_ch = self.all_channels[i] if i in self.res_layers: in_ch += self.in_channels out_ch = self.all_channels[i + 1] self.param_num += (in_ch * out_ch + out_ch) self.param_num_per_group = self.param_num def forward(self, x, params): """ :param x: (batch_size, point_num, in_channels) :param params: (param_num, ) :return: (batch_size, point_num, out_channels) """ x = x.permute(0, 2, 1) # (B, C, N) tmpx = x param_id = 0 for i in range(len(self.all_channels) - 1): in_ch = self.all_channels[i] if i in self.res_layers: in_ch += self.in_channels x = torch.cat([x, tmpx], 1) out_ch = self.all_channels[i + 1] weight_len = out_ch * in_ch weight = params[param_id: param_id + weight_len].reshape(out_ch, in_ch, 1) param_id += weight_len bias_len = out_ch bias = params[param_id: param_id + bias_len] param_id += bias_len x = F.conv1d(x, weight, bias) if i < len(self.all_channels) - 2: x = self.nlactv(x) if self.last_op is not None: x = self.last_op(x) return x.permute(0, 2, 1) def __repr__(self): main_str = self.__class__.__name__ + '(\n' for i in range(len(self.all_channels) - 1): main_str += '\tF.conv1d(in_features=%d, out_features=%d, bias=True)\n' % (self.all_channels[i], self.all_channels[i + 1]) main_str += '\tnlactv: %s\n' % self.nlactv.__repr__() main_str += ')' return main_str class ParallelMLPwoWeight(object): def __init__(self, in_channels, out_channels, inter_channels, group_num = 1, res_layers = [], nlactv = nn.ReLU(), last_op = None): super(ParallelMLPwoWeight, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.all_channels = [in_channels] + inter_channels + [out_channels] self.res_layers = res_layers self.group_num = group_num self.nlactv = nlactv self.last_op = last_op self.param_num = 0 for i in range(len(self.all_channels) - 1): in_ch = self.all_channels[i] if i in self.res_layers: in_ch += self.in_channels out_ch = self.all_channels[i + 1] self.param_num += (in_ch * out_ch + out_ch) * self.group_num self.param_num_per_group = self.param_num // self.group_num def forward(self, x, params): """ :param x: (batch_size, group_num, point_num, in_channels) :param params: (group_num, param_num) :return: (batch_size, group_num, point_num, out_channels) """ batch_size, group_num, point_num, in_channels = x.shape assert group_num == self.group_num and in_channels == self.in_channels x = x.permute(0, 1, 3, 2) # (B, G, C, N) x = x.reshape(batch_size, group_num * in_channels, point_num) tmpx = x param_id = 0 for i in range(len(self.all_channels) - 1): in_ch = self.all_channels[i] if i in self.res_layers: in_ch += self.in_channels x = parallel_concat([x, tmpx], group_num) out_ch = self.all_channels[i + 1] weight_len = out_ch * in_ch weight = params[:, param_id: param_id + weight_len].reshape(group_num * out_ch, in_ch, 1) param_id += weight_len bias_len = out_ch bias = params[:, param_id: param_id + bias_len].reshape(group_num * out_ch) param_id += bias_len x = F.conv1d(x, weight, bias, groups = group_num) if i < len(self.all_channels) - 2: x = self.nlactv(x) if self.last_op is not None: x = self.last_op(x) x = x.reshape(batch_size, group_num, self.out_channels, point_num) return x.permute(0, 1, 3, 2) def __repr__(self): main_str = self.__class__.__name__ + '(\n' main_str += '\tgroup_num: %d\n' % self.group_num for i in range(len(self.all_channels) - 1): main_str += '\tF.conv1d(in_features=%d, out_features=%d, bias=True)\n' % (self.all_channels[i], self.all_channels[i + 1]) main_str += '\tnlactv: %s\n' % self.nlactv.__repr__() main_str += ')' return main_str