#!/usr/bin/env python # -*- coding: utf-8 -*- # some code from https://raw.githubusercontent.com/weigq/3d_pose_baseline_pytorch/master/src/model.py from __future__ import absolute_import from __future__ import print_function import torch import torch.nn as nn import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) # from priors.vae_pose_model.vae_model import VAEmodel from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior def weight_init_dangerous(m): # this is dangerous as it may overwrite the normalizing flow weights if isinstance(m, nn.Linear): nn.init.kaiming_normal(m.weight) class Linear(nn.Module): def __init__(self, linear_size, p_dropout=0.5): super(Linear, self).__init__() self.l_size = linear_size self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout(p_dropout) self.w1 = nn.Linear(self.l_size, self.l_size) self.batch_norm1 = nn.BatchNorm1d(self.l_size) self.w2 = nn.Linear(self.l_size, self.l_size) self.batch_norm2 = nn.BatchNorm1d(self.l_size) def forward(self, x): y = self.w1(x) y = self.batch_norm1(y) y = self.relu(y) y = self.dropout(y) y = self.w2(y) y = self.batch_norm2(y) y = self.relu(y) y = self.dropout(y) out = x + y return out class LinearModel(nn.Module): def __init__(self, linear_size=1024, num_stage=2, p_dropout=0.5, input_size=16*2, output_size=16*3): super(LinearModel, self).__init__() self.linear_size = linear_size self.p_dropout = p_dropout self.num_stage = num_stage # input self.input_size = input_size # 2d joints: 16 * 2 # output self.output_size = output_size # 3d joints: 16 * 3 # process input to linear size self.w1 = nn.Linear(self.input_size, self.linear_size) self.batch_norm1 = nn.BatchNorm1d(self.linear_size) self.linear_stages = [] for l in range(num_stage): self.linear_stages.append(Linear(self.linear_size, self.p_dropout)) self.linear_stages = nn.ModuleList(self.linear_stages) # post-processing self.w2 = nn.Linear(self.linear_size, self.output_size) # helpers (relu and dropout) self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout(self.p_dropout) def forward(self, x): # pre-processing y = self.w1(x) y = self.batch_norm1(y) y = self.relu(y) y = self.dropout(y) # linear layers for i in range(self.num_stage): y = self.linear_stages[i](y) # post-processing y = self.w2(y) return y class LinearModelComplete(nn.Module): def __init__(self, linear_size=1024, num_stage_comb=2, num_stage_heads=1, num_stage_heads_pose=1, trans_sep=False, p_dropout=0.5, input_size=16*2, intermediate_size=1024, output_info=None, n_joints=25, n_z=512, add_z_to_3d_input=False, n_segbps=64*2, add_segbps_to_3d_input=False, structure_pose_net='default', fix_vae_weights=True, nf_version=None): # 0): n_silh_enc super(LinearModelComplete, self).__init__() if add_z_to_3d_input: self.n_z_to_add = n_z # 512 else: self.n_z_to_add = 0 if add_segbps_to_3d_input: self.n_segbps_to_add = n_segbps # 64 else: self.n_segbps_to_add = 0 self.input_size = input_size self.linear_size = linear_size self.p_dropout = p_dropout self.num_stage_comb = num_stage_comb self.num_stage_heads = num_stage_heads self.num_stage_heads_pose = num_stage_heads_pose self.trans_sep = trans_sep self.input_size = input_size self.intermediate_size = intermediate_size self.structure_pose_net = structure_pose_net self.fix_vae_weights = fix_vae_weights # only relevant if structure_pose_net='vae' self.nf_version = nf_version if output_info is None: pose = {'name': 'pose', 'n': n_joints*6, 'out_shape':[n_joints, 6]} cam = {'name': 'flength', 'n': 1} if self.trans_sep: translation_xy = {'name': 'trans_xy', 'n': 2} translation_z = {'name': 'trans_z', 'n': 1} self.output_info = [pose, translation_xy, translation_z, cam] else: translation = {'name': 'trans', 'n': 3} self.output_info = [pose, translation, cam] if self.structure_pose_net == 'vae' or self.structure_pose_net == 'normflow': global_pose = {'name': 'global_pose', 'n': 1*6, 'out_shape':[1, 6]} self.output_info.append(global_pose) else: self.output_info = output_info self.linear_combined = LinearModel(linear_size=self.linear_size, num_stage=self.num_stage_comb, p_dropout=p_dropout, input_size=self.input_size + self.n_segbps_to_add + self.n_z_to_add, ###### output_size=self.intermediate_size) self.output_info_linear_models = [] for ind_el, element in enumerate(self.output_info): if element['name'] == 'pose': num_stage = self.num_stage_heads_pose if self.structure_pose_net == 'default': output_size_pose_lin = element['n'] elif self.structure_pose_net == 'vae': # load vae decoder self.pose_vae_model = VAEmodel() self.pose_vae_model.initialize_with_pretrained_weights() # define the input size of the vae decoder output_size_pose_lin = self.pose_vae_model.latent_size elif self.structure_pose_net == 'normflow': # the following will automatically be initialized self.pose_normflow_model = NormalizingFlowPrior(nf_version=self.nf_version) output_size_pose_lin = element['n'] - 6 # no global rotation else: raise NotImplementedError self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size, num_stage=num_stage, p_dropout=p_dropout, input_size=self.intermediate_size, output_size=output_size_pose_lin)) else: if element['name'] == 'global_pose': num_stage = self.num_stage_heads_pose else: num_stage = self.num_stage_heads self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size, num_stage=num_stage, p_dropout=p_dropout, input_size=self.intermediate_size, output_size=element['n'])) element['linear_model_index'] = ind_el self.output_info_linear_models = nn.ModuleList(self.output_info_linear_models) def forward(self, x): device = x.device # combined stage if x.shape[1] == self.input_size + self.n_segbps_to_add + self.n_z_to_add: y = self.linear_combined(x) elif x.shape[1] == self.input_size + self.n_segbps_to_add: x_mod = torch.cat((x, torch.normal(0, 1, size=(x.shape[0], self.n_z_to_add)).to(device)), dim=1) y = self.linear_combined(x_mod) else: print(x.shape) print(self.input_size) print(self.n_segbps_to_add) print(self.n_z_to_add) raise ValueError # heads results = {} results_trans = {} for element in self.output_info: linear_model = self.output_info_linear_models[element['linear_model_index']] if element['name'] == 'pose': if self.structure_pose_net == 'default': results['pose'] = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) normflow_z = None elif self.structure_pose_net == 'vae': res_lin = linear_model(y) if self.fix_vae_weights: self.pose_vae_model.requires_grad_(False) # let gradients flow through but don't update the parameters res_vae = self.pose_vae_model.inference(feat=res_lin) self.pose_vae_model.requires_grad_(True) else: res_vae = self.pose_vae_model.inference(feat=res_lin) res_pose_not_glob = res_vae.reshape((-1, element['out_shape'][0], element['out_shape'][1])) normflow_z = None elif self.structure_pose_net == 'normflow': normflow_z = linear_model(y)*0.1 self.pose_normflow_model.requires_grad_(False) # let gradients flow though but don't update the parameters res_pose_not_glob = self.pose_normflow_model.run_backwards(z=normflow_z).reshape((-1, element['out_shape'][0]-1, element['out_shape'][1])) else: raise NotImplementedError elif element['name'] == 'global_pose': res_pose_glob = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) elif element['name'] == 'trans_xy' or element['name'] == 'trans_z': results_trans[element['name']] = linear_model(y) else: results[element['name']] = linear_model(y) if self.trans_sep: results['trans'] = torch.cat((results_trans['trans_xy'], results_trans['trans_z']), dim=1) # prepare pose including global rotation if self.structure_pose_net == 'vae': # results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob), dim=1) results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob[:, 1:, :]), dim=1) elif self.structure_pose_net == 'normflow': results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob[:, :, :]), dim=1) # return a dictionary which contains all results results['normflow_z'] = normflow_z return results # this is a dictionary # ------------------------------------------ # for pretraining of the 3d model only: # (see combined_model/model_shape_v2.py) class Wrapper_LinearModelComplete(nn.Module): def __init__(self, linear_size=1024, num_stage_comb=2, num_stage_heads=1, num_stage_heads_pose=1, trans_sep=False, p_dropout=0.5, input_size=16*2, intermediate_size=1024, output_info=None, n_joints=25, n_z=512, add_z_to_3d_input=False, n_segbps=64*2, add_segbps_to_3d_input=False, structure_pose_net='default', fix_vae_weights=True, nf_version=None): self.add_segbps_to_3d_input = add_segbps_to_3d_input super(Wrapper_LinearModelComplete, self).__init__() self.model_3d = LinearModelComplete(linear_size=linear_size, num_stage_comb=num_stage_comb, num_stage_heads=num_stage_heads, num_stage_heads_pose=num_stage_heads_pose, trans_sep=trans_sep, p_dropout=p_dropout, # 0.5, input_size=input_size, intermediate_size=intermediate_size, output_info=output_info, n_joints=n_joints, n_z=n_z, add_z_to_3d_input=add_z_to_3d_input, n_segbps=n_segbps, add_segbps_to_3d_input=add_segbps_to_3d_input, structure_pose_net=structure_pose_net, fix_vae_weights=fix_vae_weights, nf_version=nf_version) def forward(self, input_vec): # input_vec = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) # predict 3d parameters (those are normalized, we need to correct mean and std in a next step) output = self.model_3d(input_vec) return output