from torch import distributions import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torch.distributions import Normal import numpy as np import cv2 import trimesh from tqdm import tqdm import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) import FrEIA.framework as Ff import FrEIA.modules as Fm class INNForShape(nn.Module): def __init__(self, n_betas, n_betas_limbs, k_tot=2, betas_scale=1.0, betas_limbs_scale=0.1): super(INNForShape, self).__init__() self.n_betas = n_betas self.n_betas_limbs = n_betas_limbs self.n_dim = n_betas + n_betas_limbs self.betas_scale = betas_scale self.betas_limbs_scale = betas_limbs_scale self.k_tot = 2 self.model_inn = self.build_inn_network(self.n_dim, k_tot=self.k_tot) def subnet_fc(self, c_in, c_out): subnet = nn.Sequential(nn.Linear(c_in, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, c_out)) return subnet def build_inn_network(self, n_input, k_tot=12, verbose=False): coupling_block = Fm.RNVPCouplingBlock nodes = [Ff.InputNode(n_input, name='input')] for k in range(k_tot): nodes.append(Ff.Node(nodes[-1], coupling_block, {'subnet_constructor':self.subnet_fc, 'clamp':2.0}, name=F'coupling_{k}')) nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k}, name=F'permute_{k}')) nodes.append(Ff.OutputNode(nodes[-1], name='output')) model = Ff.ReversibleGraphNet(nodes, verbose=verbose) return model def forward(self, latent_rep): shape, _ = self.model_inn(latent_rep, rev=False, jac=False) betas = shape[:, :self.n_betas]*self.betas_scale betas_limbs = shape[:, self.n_betas:]*self.betas_limbs_scale return betas, betas_limbs def reverse(self, betas, betas_limbs): shape = torch.cat((betas/self.betas_scale, betas_limbs/self.betas_limbs_scale), dim=1) latent_rep, _ = self.model_inn(shape, rev=True, jac=False) return latent_rep