File size: 4,906 Bytes
7629b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

from torch import distributions
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.distributions import Normal 
import numpy as np
import cv2
import trimesh
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
import FrEIA.framework as Ff
import FrEIA.modules as Fm
from configs.barc_cfg_defaults import get_cfg_global_updated


class NormalizingFlowPrior(nn.Module):
    def __init__(self, nf_version=None):
        super(NormalizingFlowPrior, self).__init__()
        # the normalizing flow network takes as input a vector of size (35-1)*6 which is 
        # [all joints except root joint]*6. At the moment the rotation is represented as 6D 
        # representation, which is actually not ideal. Nevertheless, in practice the 
        # results seem to be ok.
        n_dim = (35 - 1) * 6        
        self.param_dict = self.get_version_param_dict(nf_version)
        self.model_inn = self.build_inn_network(n_dim, k_tot=self.param_dict['k_tot']) 
        self.initialize_with_pretrained_weights()

    def get_version_param_dict(self, nf_version):
        # we had trained several version of the normalizing flow pose prior, here we just provide 
        # the option that was user for the cvpr 2022 paper (nf_version=3)
        if nf_version == 3:
            param_dict = {
                'k_tot': 2,
                'path_pretrained': get_cfg_global_updated().paths.MODELPATH_NORMFLOW,
                'subnet_fc_type': '3_64'}  
        else:
            print(nf_version)
            raise ValueError
        return param_dict

    def initialize_with_pretrained_weights(self, weight_path=None):
        # The normalizing flow pose prior is pretrained separately. Afterwards all weights 
        # are kept fixed. Here we load those pretrained weights.
        if weight_path is None:
            weight_path = self.param_dict['path_pretrained']
        print(' normalizing flow pose prior: loading {}..'.format(weight_path))
        pretrained_dict = torch.load(weight_path)['model_state_dict']
        self.model_inn.load_state_dict(pretrained_dict, strict=True)

    def subnet_fc(self, c_in, c_out):
        if self.param_dict['subnet_fc_type']=='3_512':
            subnet = nn.Sequential(nn.Linear(c_in, 512), nn.ReLU(),
                                nn.Linear(512, 512), nn.ReLU(),
                                nn.Linear(512,  c_out))
        elif self.param_dict['subnet_fc_type']=='3_64':
            subnet = nn.Sequential(nn.Linear(c_in, 64), nn.ReLU(),
                                    nn.Linear(64, 64), nn.ReLU(),
                                    nn.Linear(64,  c_out))
        return subnet

    def build_inn_network(self, n_input, k_tot=12, verbose=False):
        coupling_block = Fm.RNVPCouplingBlock
        nodes = [Ff.InputNode(n_input, name='input')]
        for k in range(k_tot):
            nodes.append(Ff.Node(nodes[-1],
                                coupling_block,
                                {'subnet_constructor':self.subnet_fc, 'clamp':2.0},
                                name=F'coupling_{k}'))
            nodes.append(Ff.Node(nodes[-1],
                                Fm.PermuteRandom,
                                {'seed':k},
                                name=F'permute_{k}'))
        nodes.append(Ff.OutputNode(nodes[-1], name='output'))
        model = Ff.ReversibleGraphNet(nodes, verbose=verbose)
        return model

    def calculate_loss_from_z(self, z, type='square'):
        assert type in ['square', 'neg_log_prob']
        if type == 'square':
            loss = (z**2).mean()    # * 0.00001
        elif type == 'neg_log_prob':
            means = torch.zeros((z.shape[0], z.shape[1]), dtype=z.dtype, device=z.device)
            stds = torch.ones((z.shape[0], z.shape[1]), dtype=z.dtype, device=z.device)
            normal_distribution = Normal(means, stds)
            log_prob = normal_distribution.log_prob(z)
            loss = - log_prob.mean()
        return loss

    def calculate_loss(self, poses_rot6d, type='square'):
        assert type in ['square', 'neg_log_prob']
        poses_rot6d_noglob = poses_rot6d[:, 1:, :].reshape((-1, 34*6))     
        z, _ = self.model_inn(poses_rot6d_noglob, rev=False, jac=False)
        loss = self.calculate_loss_from_z(z, type=type)
        return loss

    def forward(self, poses_rot6d):
        # from pose to latent pose representation z 
        # poses_rot6d has shape (bs, 34, 6)
        poses_rot6d_noglob = poses_rot6d[:, 1:, :].reshape((-1, 34*6))   
        z, _ = self.model_inn(poses_rot6d_noglob, rev=False, jac=False)
        return z

    def run_backwards(self, z):
        # from latent pose representation z to pose 
        poses_rot6d_noglob, _ = self.model_inn(z, rev=True, jac=False)
        return poses_rot6d_noglob