motionfix-demo / normalization.py
atnikos's picture
fix gpu issue
d425bee
from os.path import exists
from gen_utils import cast_dict_to_tensors
from einops import rearrange
from torch import Tensor
from typing import List, Union
import torch
import numpy as np
class Normalizer:
def __init__(self, statistics_path: str='deps/statistics_motionfix.npy', nfeats: int=207,
input_feats: List[str] = ["body_transl_delta_pelv",
"body_orient_xy",
"z_orient_delta", "body_pose",
"body_joints_local_wo_z_rot"],
dim_per_feat: List[int] = [3, 6, 6, 126, 66], *args, **kwargs):
self.stats = self.load_norm_statistics(statistics_path, 'cpu')
# from src.model.utils.tools import pack_to_render
# mr = pack_to_render(aa.detach().cpu(), trans=None)
# mr = {k: v[0] for k, v in mr.items()}
# fname = render_motion(aitrenderer, mr,
# "/home/nathanasiou/Desktop/conditional_action_gen/modilex/pose_test",
# pose_repr='aa',
# text_for_vid=str(keyids[0]),
# color=color_map['generated'],
# smpl_layer=smpl_layer)
self.nfeats = nfeats
self.dim_per_feat = dim_per_feat
self.input_feats_dims = list(dim_per_feat)
self.input_feats = list(input_feats)
def load_norm_statistics(self, path, device):
# workaround for cluster local/sync
assert exists(path)
stats = np.load(path, allow_pickle=True)[()]
return cast_dict_to_tensors(stats, device=device)
def norm_and_cat(self, batch, features_types):
"""
turn batch data into the format the forward() function expects
"""
seq_first = lambda t: rearrange(t, 'b s ... -> s b ...')
input_batch = {}
## PREPARE INPUT ##
motion_condition = any('source' in value for value in batch.keys())
mo_types = ['source', 'target']
for mot in mo_types:
list_of_feat_tensors = [seq_first(batch[f'{feat_type}_{mot}'])
for feat_type in features_types if f'{feat_type}_{mot}' in batch.keys()]
# normalise and cat to a unified feature vector
list_of_feat_tensors_normed = self.norm_inputs(list_of_feat_tensors,
features_types)
# list_of_feat_tensors_normed = [x[1:] if 'delta' in nx else x for nx,
# x in zip(features_types,
# list_of_feat_tensors_normed)]
x_norm, _ = self.cat_inputs(list_of_feat_tensors_normed)
input_batch[mot] = x_norm
return input_batch
def norm_and_cat_single_motion(self, batch, features_types):
"""
turn batch data into the format the forward() function expects
"""
seq_first = lambda t: rearrange(t, 'b s ... -> s b ...')
input_batch = {}
## PREPARE INPUT ##
list_of_feat_tensors = [seq_first(batch[feat_type])
for feat_type in features_types]
# normalise and cat to a unified feature vector
list_of_feat_tensors_normed = self.norm_inputs(list_of_feat_tensors,
features_types)
# list_of_feat_tensors_normed = [x[1:] if 'delta' in nx else x for nx,
# x in zip(features_types,
# list_of_feat_tensors_normed)]
x_norm, _ = self.cat_inputs(list_of_feat_tensors_normed)
input_batch['motion'] = x_norm
return input_batch
def norm(self, x, stats):
mean = stats['mean'].to('cuda')
std = stats['std'].to('cuda')
return (x - mean) / (std + 1e-5)
def unnorm(self, x, stats):
mean = stats['mean'].to('cuda')
std = stats['std'].to('cuda')
return x * (std + 1e-5) + mean
def unnorm_state(self, state_norm: Tensor) -> Tensor:
# unnorm state
return self.cat_inputs(
self.unnorm_inputs(self.uncat_inputs(state_norm,
self.first_pose_feats_dims),
self.first_pose_feats))[0]
def unnorm_delta(self, delta_norm: Tensor) -> Tensor:
# unnorm delta
return self.cat_inputs(
self.unnorm_inputs(self.uncat_inputs(delta_norm,
self.input_feats_dims),
self.input_feats))[0]
def norm_state(self, state:Tensor) -> Tensor:
# normalise state
return self.cat_inputs(
self.norm_inputs(self.uncat_inputs(state,
self.first_pose_feats_dims),
self.first_pose_feats))[0]
def norm_delta(self, delta:Tensor) -> Tensor:
# normalise delta
return self.cat_inputs(
self.norm_inputs(self.uncat_inputs(delta, self.input_feats_dims),
self.input_feats))[0]
def cat_inputs(self, x_list: List[Tensor]):
"""
cat the inputs to a unified vector and return their lengths in order
to un-cat them later
"""
return torch.cat(x_list, dim=-1), [x.shape[-1] for x in x_list]
def uncat_inputs(self, x: Tensor, lengths: List[int]):
"""
split the unified feature vector back to its original parts
"""
return torch.split(x, lengths, dim=-1)
def norm_inputs(self, x_list: List[Tensor], names: List[str]):
"""
Normalise inputs using the self.stats metrics
"""
x_norm = []
for x, name in zip(x_list, names):
x_norm.append(self.norm(x, self.stats[name]))
return x_norm
def unnorm_inputs(self, x_list: List[Tensor], names: List[str]):
"""
Un-normalise inputs using the self.stats metrics
"""
x_unnorm = []
for x, name in zip(x_list, names):
x_unnorm.append(self.unnorm(x, self.stats[name]))
return x_unnorm