Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# some code from https://raw.githubusercontent.com/weigq/3d_pose_baseline_pytorch/master/src/model.py | |
from __future__ import absolute_import | |
from __future__ import print_function | |
import torch | |
import torch.nn as nn | |
import os | |
import sys | |
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) | |
# from priors.vae_pose_model.vae_model import VAEmodel | |
from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior | |
def weight_init_dangerous(m): | |
# this is dangerous as it may overwrite the normalizing flow weights | |
if isinstance(m, nn.Linear): | |
nn.init.kaiming_normal(m.weight) | |
class Linear(nn.Module): | |
def __init__(self, linear_size, p_dropout=0.5): | |
super(Linear, self).__init__() | |
self.l_size = linear_size | |
self.relu = nn.ReLU(inplace=True) | |
self.dropout = nn.Dropout(p_dropout) | |
self.w1 = nn.Linear(self.l_size, self.l_size) | |
self.batch_norm1 = nn.BatchNorm1d(self.l_size) | |
self.w2 = nn.Linear(self.l_size, self.l_size) | |
self.batch_norm2 = nn.BatchNorm1d(self.l_size) | |
def forward(self, x): | |
y = self.w1(x) | |
y = self.batch_norm1(y) | |
y = self.relu(y) | |
y = self.dropout(y) | |
y = self.w2(y) | |
y = self.batch_norm2(y) | |
y = self.relu(y) | |
y = self.dropout(y) | |
out = x + y | |
return out | |
class LinearModel(nn.Module): | |
def __init__(self, | |
linear_size=1024, | |
num_stage=2, | |
p_dropout=0.5, | |
input_size=16*2, | |
output_size=16*3): | |
super(LinearModel, self).__init__() | |
self.linear_size = linear_size | |
self.p_dropout = p_dropout | |
self.num_stage = num_stage | |
# input | |
self.input_size = input_size # 2d joints: 16 * 2 | |
# output | |
self.output_size = output_size # 3d joints: 16 * 3 | |
# process input to linear size | |
self.w1 = nn.Linear(self.input_size, self.linear_size) | |
self.batch_norm1 = nn.BatchNorm1d(self.linear_size) | |
self.linear_stages = [] | |
for l in range(num_stage): | |
self.linear_stages.append(Linear(self.linear_size, self.p_dropout)) | |
self.linear_stages = nn.ModuleList(self.linear_stages) | |
# post-processing | |
self.w2 = nn.Linear(self.linear_size, self.output_size) | |
# helpers (relu and dropout) | |
self.relu = nn.ReLU(inplace=True) | |
self.dropout = nn.Dropout(self.p_dropout) | |
def forward(self, x): | |
# pre-processing | |
y = self.w1(x) | |
y = self.batch_norm1(y) | |
y = self.relu(y) | |
y = self.dropout(y) | |
# linear layers | |
for i in range(self.num_stage): | |
y = self.linear_stages[i](y) | |
# post-processing | |
y = self.w2(y) | |
return y | |
class LinearModelComplete(nn.Module): | |
def __init__(self, | |
linear_size=1024, | |
num_stage_comb=2, | |
num_stage_heads=1, | |
num_stage_heads_pose=1, | |
trans_sep=False, | |
p_dropout=0.5, | |
input_size=16*2, | |
intermediate_size=1024, | |
output_info=None, | |
n_joints=25, | |
n_z=512, | |
add_z_to_3d_input=False, | |
n_segbps=64*2, | |
add_segbps_to_3d_input=False, | |
structure_pose_net='default', | |
fix_vae_weights=True, | |
nf_version=None): # 0): n_silh_enc | |
super(LinearModelComplete, self).__init__() | |
if add_z_to_3d_input: | |
self.n_z_to_add = n_z # 512 | |
else: | |
self.n_z_to_add = 0 | |
if add_segbps_to_3d_input: | |
self.n_segbps_to_add = n_segbps # 64 | |
else: | |
self.n_segbps_to_add = 0 | |
self.input_size = input_size | |
self.linear_size = linear_size | |
self.p_dropout = p_dropout | |
self.num_stage_comb = num_stage_comb | |
self.num_stage_heads = num_stage_heads | |
self.num_stage_heads_pose = num_stage_heads_pose | |
self.trans_sep = trans_sep | |
self.input_size = input_size | |
self.intermediate_size = intermediate_size | |
self.structure_pose_net = structure_pose_net | |
self.fix_vae_weights = fix_vae_weights # only relevant if structure_pose_net='vae' | |
self.nf_version = nf_version | |
if output_info is None: | |
pose = {'name': 'pose', 'n': n_joints*6, 'out_shape':[n_joints, 6]} | |
cam = {'name': 'flength', 'n': 1} | |
if self.trans_sep: | |
translation_xy = {'name': 'trans_xy', 'n': 2} | |
translation_z = {'name': 'trans_z', 'n': 1} | |
self.output_info = [pose, translation_xy, translation_z, cam] | |
else: | |
translation = {'name': 'trans', 'n': 3} | |
self.output_info = [pose, translation, cam] | |
if self.structure_pose_net == 'vae' or self.structure_pose_net == 'normflow': | |
global_pose = {'name': 'global_pose', 'n': 1*6, 'out_shape':[1, 6]} | |
self.output_info.append(global_pose) | |
else: | |
self.output_info = output_info | |
self.linear_combined = LinearModel(linear_size=self.linear_size, | |
num_stage=self.num_stage_comb, | |
p_dropout=p_dropout, | |
input_size=self.input_size + self.n_segbps_to_add + self.n_z_to_add, ###### | |
output_size=self.intermediate_size) | |
self.output_info_linear_models = [] | |
for ind_el, element in enumerate(self.output_info): | |
if element['name'] == 'pose': | |
num_stage = self.num_stage_heads_pose | |
if self.structure_pose_net == 'default': | |
output_size_pose_lin = element['n'] | |
elif self.structure_pose_net == 'vae': | |
# load vae decoder | |
self.pose_vae_model = VAEmodel() | |
self.pose_vae_model.initialize_with_pretrained_weights() | |
# define the input size of the vae decoder | |
output_size_pose_lin = self.pose_vae_model.latent_size | |
elif self.structure_pose_net == 'normflow': | |
# the following will automatically be initialized | |
self.pose_normflow_model = NormalizingFlowPrior(nf_version=self.nf_version) | |
output_size_pose_lin = element['n'] - 6 # no global rotation | |
else: | |
raise NotImplementedError | |
self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size, | |
num_stage=num_stage, | |
p_dropout=p_dropout, | |
input_size=self.intermediate_size, | |
output_size=output_size_pose_lin)) | |
else: | |
if element['name'] == 'global_pose': | |
num_stage = self.num_stage_heads_pose | |
else: | |
num_stage = self.num_stage_heads | |
self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size, | |
num_stage=num_stage, | |
p_dropout=p_dropout, | |
input_size=self.intermediate_size, | |
output_size=element['n'])) | |
element['linear_model_index'] = ind_el | |
self.output_info_linear_models = nn.ModuleList(self.output_info_linear_models) | |
def forward(self, x): | |
device = x.device | |
# combined stage | |
if x.shape[1] == self.input_size + self.n_segbps_to_add + self.n_z_to_add: | |
y = self.linear_combined(x) | |
elif x.shape[1] == self.input_size + self.n_segbps_to_add: | |
x_mod = torch.cat((x, torch.normal(0, 1, size=(x.shape[0], self.n_z_to_add)).to(device)), dim=1) | |
y = self.linear_combined(x_mod) | |
else: | |
print(x.shape) | |
print(self.input_size) | |
print(self.n_segbps_to_add) | |
print(self.n_z_to_add) | |
raise ValueError | |
# heads | |
results = {} | |
results_trans = {} | |
for element in self.output_info: | |
linear_model = self.output_info_linear_models[element['linear_model_index']] | |
if element['name'] == 'pose': | |
if self.structure_pose_net == 'default': | |
results['pose'] = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) | |
normflow_z = None | |
elif self.structure_pose_net == 'vae': | |
res_lin = linear_model(y) | |
if self.fix_vae_weights: | |
self.pose_vae_model.requires_grad_(False) # let gradients flow through but don't update the parameters | |
res_vae = self.pose_vae_model.inference(feat=res_lin) | |
self.pose_vae_model.requires_grad_(True) | |
else: | |
res_vae = self.pose_vae_model.inference(feat=res_lin) | |
res_pose_not_glob = res_vae.reshape((-1, element['out_shape'][0], element['out_shape'][1])) | |
normflow_z = None | |
elif self.structure_pose_net == 'normflow': | |
normflow_z = linear_model(y)*0.1 | |
self.pose_normflow_model.requires_grad_(False) # let gradients flow though but don't update the parameters | |
res_pose_not_glob = self.pose_normflow_model.run_backwards(z=normflow_z).reshape((-1, element['out_shape'][0]-1, element['out_shape'][1])) | |
else: | |
raise NotImplementedError | |
elif element['name'] == 'global_pose': | |
res_pose_glob = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) | |
elif element['name'] == 'trans_xy' or element['name'] == 'trans_z': | |
results_trans[element['name']] = linear_model(y) | |
else: | |
results[element['name']] = linear_model(y) | |
if self.trans_sep: | |
results['trans'] = torch.cat((results_trans['trans_xy'], results_trans['trans_z']), dim=1) | |
# prepare pose including global rotation | |
if self.structure_pose_net == 'vae': | |
# results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob), dim=1) | |
results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob[:, 1:, :]), dim=1) | |
elif self.structure_pose_net == 'normflow': | |
results['pose'] = torch.cat((res_pose_glob, res_pose_not_glob[:, :, :]), dim=1) | |
# return a dictionary which contains all results | |
results['normflow_z'] = normflow_z | |
return results # this is a dictionary | |
# ------------------------------------------ | |
# for pretraining of the 3d model only: | |
# (see combined_model/model_shape_v2.py) | |
class Wrapper_LinearModelComplete(nn.Module): | |
def __init__(self, | |
linear_size=1024, | |
num_stage_comb=2, | |
num_stage_heads=1, | |
num_stage_heads_pose=1, | |
trans_sep=False, | |
p_dropout=0.5, | |
input_size=16*2, | |
intermediate_size=1024, | |
output_info=None, | |
n_joints=25, | |
n_z=512, | |
add_z_to_3d_input=False, | |
n_segbps=64*2, | |
add_segbps_to_3d_input=False, | |
structure_pose_net='default', | |
fix_vae_weights=True, | |
nf_version=None): | |
self.add_segbps_to_3d_input = add_segbps_to_3d_input | |
super(Wrapper_LinearModelComplete, self).__init__() | |
self.model_3d = LinearModelComplete(linear_size=linear_size, | |
num_stage_comb=num_stage_comb, | |
num_stage_heads=num_stage_heads, | |
num_stage_heads_pose=num_stage_heads_pose, | |
trans_sep=trans_sep, | |
p_dropout=p_dropout, # 0.5, | |
input_size=input_size, | |
intermediate_size=intermediate_size, | |
output_info=output_info, | |
n_joints=n_joints, | |
n_z=n_z, | |
add_z_to_3d_input=add_z_to_3d_input, | |
n_segbps=n_segbps, | |
add_segbps_to_3d_input=add_segbps_to_3d_input, | |
structure_pose_net=structure_pose_net, | |
fix_vae_weights=fix_vae_weights, | |
nf_version=nf_version) | |
def forward(self, input_vec): | |
# input_vec = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) | |
# predict 3d parameters (those are normalized, we need to correct mean and std in a next step) | |
output = self.model_3d(input_vec) | |
return output |