from torch import distributions import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torch.distributions import Normal import numpy as np import cv2 import trimesh from tqdm import tqdm import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) import FrEIA.framework as Ff import FrEIA.modules as Fm from configs.barc_cfg_defaults import get_cfg_global_updated class NormalizingFlowPrior(nn.Module): def __init__(self, nf_version=None): super(NormalizingFlowPrior, self).__init__() # the normalizing flow network takes as input a vector of size (35-1)*6 which is # [all joints except root joint]*6. At the moment the rotation is represented as 6D # representation, which is actually not ideal. Nevertheless, in practice the # results seem to be ok. n_dim = (35 - 1) * 6 self.param_dict = self.get_version_param_dict(nf_version) self.model_inn = self.build_inn_network(n_dim, k_tot=self.param_dict['k_tot']) self.initialize_with_pretrained_weights() def get_version_param_dict(self, nf_version): # we had trained several version of the normalizing flow pose prior, here we just provide # the option that was user for the cvpr 2022 paper (nf_version=3) if nf_version == 3: param_dict = { 'k_tot': 2, 'path_pretrained': get_cfg_global_updated().paths.MODELPATH_NORMFLOW, 'subnet_fc_type': '3_64'} else: print(nf_version) raise ValueError return param_dict def initialize_with_pretrained_weights(self, weight_path=None): # The normalizing flow pose prior is pretrained separately. Afterwards all weights # are kept fixed. Here we load those pretrained weights. if weight_path is None: weight_path = self.param_dict['path_pretrained'] print(' normalizing flow pose prior: loading {}..'.format(weight_path)) pretrained_dict = torch.load(weight_path)['model_state_dict'] self.model_inn.load_state_dict(pretrained_dict, strict=True) def subnet_fc(self, c_in, c_out): if self.param_dict['subnet_fc_type']=='3_512': subnet = nn.Sequential(nn.Linear(c_in, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, c_out)) elif self.param_dict['subnet_fc_type']=='3_64': subnet = nn.Sequential(nn.Linear(c_in, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, c_out)) return subnet def build_inn_network(self, n_input, k_tot=12, verbose=False): coupling_block = Fm.RNVPCouplingBlock nodes = [Ff.InputNode(n_input, name='input')] for k in range(k_tot): nodes.append(Ff.Node(nodes[-1], coupling_block, {'subnet_constructor':self.subnet_fc, 'clamp':2.0}, name=F'coupling_{k}')) nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k}, name=F'permute_{k}')) nodes.append(Ff.OutputNode(nodes[-1], name='output')) model = Ff.ReversibleGraphNet(nodes, verbose=verbose) return model def calculate_loss_from_z(self, z, type='square'): assert type in ['square', 'neg_log_prob'] if type == 'square': loss = (z**2).mean() # * 0.00001 elif type == 'neg_log_prob': means = torch.zeros((z.shape[0], z.shape[1]), dtype=z.dtype, device=z.device) stds = torch.ones((z.shape[0], z.shape[1]), dtype=z.dtype, device=z.device) normal_distribution = Normal(means, stds) log_prob = normal_distribution.log_prob(z) loss = - log_prob.mean() return loss def calculate_loss(self, poses_rot6d, type='square'): assert type in ['square', 'neg_log_prob'] poses_rot6d_noglob = poses_rot6d[:, 1:, :].reshape((-1, 34*6)) z, _ = self.model_inn(poses_rot6d_noglob, rev=False, jac=False) loss = self.calculate_loss_from_z(z, type=type) return loss def forward(self, poses_rot6d): # from pose to latent pose representation z # poses_rot6d has shape (bs, 34, 6) poses_rot6d_noglob = poses_rot6d[:, 1:, :].reshape((-1, 34*6)) z, _ = self.model_inn(poses_rot6d_noglob, rev=False, jac=False) return z def run_backwards(self, z): # from latent pose representation z to pose poses_rot6d_noglob, _ = self.model_inn(z, rev=True, jac=False) return poses_rot6d_noglob