File size: 2,372 Bytes
7629b39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61


from torch import distributions
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.distributions import Normal 
import numpy as np
import cv2
import trimesh
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
import FrEIA.framework as Ff
import FrEIA.modules as Fm


class INNForShape(nn.Module):
    def __init__(self, n_betas, n_betas_limbs, k_tot=2, betas_scale=1.0, betas_limbs_scale=0.1):
        super(INNForShape, self).__init__()
        self.n_betas = n_betas
        self.n_betas_limbs = n_betas_limbs
        self.n_dim = n_betas + n_betas_limbs
        self.betas_scale = betas_scale
        self.betas_limbs_scale = betas_limbs_scale
        self.k_tot = 2
        self.model_inn = self.build_inn_network(self.n_dim, k_tot=self.k_tot) 

    def subnet_fc(self, c_in, c_out):
        subnet = nn.Sequential(nn.Linear(c_in, 64), nn.ReLU(),
                                nn.Linear(64, 64), nn.ReLU(),
                                nn.Linear(64,  c_out))
        return subnet

    def build_inn_network(self, n_input, k_tot=12, verbose=False):
        coupling_block = Fm.RNVPCouplingBlock
        nodes = [Ff.InputNode(n_input, name='input')]
        for k in range(k_tot):
            nodes.append(Ff.Node(nodes[-1],
                                coupling_block,
                                {'subnet_constructor':self.subnet_fc, 'clamp':2.0},
                                name=F'coupling_{k}'))
            nodes.append(Ff.Node(nodes[-1],
                                Fm.PermuteRandom,
                                {'seed':k},
                                name=F'permute_{k}'))
        nodes.append(Ff.OutputNode(nodes[-1], name='output'))
        model = Ff.ReversibleGraphNet(nodes, verbose=verbose)
        return model

    def forward(self, latent_rep): 
        shape, _ = self.model_inn(latent_rep, rev=False, jac=False)
        betas = shape[:, :self.n_betas]*self.betas_scale
        betas_limbs = shape[:, self.n_betas:]*self.betas_limbs_scale
        return betas, betas_limbs

    def reverse(self, betas, betas_limbs):
        shape = torch.cat((betas/self.betas_scale, betas_limbs/self.betas_limbs_scale), dim=1)
        latent_rep, _ = self.model_inn(shape, rev=True, jac=False)
        return latent_rep