Spaces:
Running
Running
from os.path import exists | |
from gen_utils import cast_dict_to_tensors | |
from einops import rearrange | |
from torch import Tensor | |
from typing import List, Union | |
import torch | |
import numpy as np | |
class Normalizer: | |
def __init__(self, statistics_path: str='deps/statistics_motionfix.npy', nfeats: int=207, | |
input_feats: List[str] = ["body_transl_delta_pelv", | |
"body_orient_xy", | |
"z_orient_delta", "body_pose", | |
"body_joints_local_wo_z_rot"], | |
dim_per_feat: List[int] = [3, 6, 6, 126, 66], *args, **kwargs): | |
self.stats = self.load_norm_statistics(statistics_path, 'cpu') | |
# from src.model.utils.tools import pack_to_render | |
# mr = pack_to_render(aa.detach().cpu(), trans=None) | |
# mr = {k: v[0] for k, v in mr.items()} | |
# fname = render_motion(aitrenderer, mr, | |
# "/home/nathanasiou/Desktop/conditional_action_gen/modilex/pose_test", | |
# pose_repr='aa', | |
# text_for_vid=str(keyids[0]), | |
# color=color_map['generated'], | |
# smpl_layer=smpl_layer) | |
self.nfeats = nfeats | |
self.dim_per_feat = dim_per_feat | |
self.input_feats_dims = list(dim_per_feat) | |
self.input_feats = list(input_feats) | |
def load_norm_statistics(self, path, device): | |
# workaround for cluster local/sync | |
assert exists(path) | |
stats = np.load(path, allow_pickle=True)[()] | |
return cast_dict_to_tensors(stats, device=device) | |
def norm_and_cat(self, batch, features_types): | |
""" | |
turn batch data into the format the forward() function expects | |
""" | |
seq_first = lambda t: rearrange(t, 'b s ... -> s b ...') | |
input_batch = {} | |
## PREPARE INPUT ## | |
motion_condition = any('source' in value for value in batch.keys()) | |
mo_types = ['source', 'target'] | |
for mot in mo_types: | |
list_of_feat_tensors = [seq_first(batch[f'{feat_type}_{mot}']) | |
for feat_type in features_types if f'{feat_type}_{mot}' in batch.keys()] | |
# normalise and cat to a unified feature vector | |
list_of_feat_tensors_normed = self.norm_inputs(list_of_feat_tensors, | |
features_types) | |
# list_of_feat_tensors_normed = [x[1:] if 'delta' in nx else x for nx, | |
# x in zip(features_types, | |
# list_of_feat_tensors_normed)] | |
x_norm, _ = self.cat_inputs(list_of_feat_tensors_normed) | |
input_batch[mot] = x_norm | |
return input_batch | |
def norm_and_cat_single_motion(self, batch, features_types): | |
""" | |
turn batch data into the format the forward() function expects | |
""" | |
seq_first = lambda t: rearrange(t, 'b s ... -> s b ...') | |
input_batch = {} | |
## PREPARE INPUT ## | |
list_of_feat_tensors = [seq_first(batch[feat_type]) | |
for feat_type in features_types] | |
# normalise and cat to a unified feature vector | |
list_of_feat_tensors_normed = self.norm_inputs(list_of_feat_tensors, | |
features_types) | |
# list_of_feat_tensors_normed = [x[1:] if 'delta' in nx else x for nx, | |
# x in zip(features_types, | |
# list_of_feat_tensors_normed)] | |
x_norm, _ = self.cat_inputs(list_of_feat_tensors_normed) | |
input_batch['motion'] = x_norm | |
return input_batch | |
def norm(self, x, stats): | |
mean = stats['mean'].to('cuda') | |
std = stats['std'].to('cuda') | |
return (x - mean) / (std + 1e-5) | |
def unnorm(self, x, stats): | |
mean = stats['mean'].to('cuda') | |
std = stats['std'].to('cuda') | |
return x * (std + 1e-5) + mean | |
def unnorm_state(self, state_norm: Tensor) -> Tensor: | |
# unnorm state | |
return self.cat_inputs( | |
self.unnorm_inputs(self.uncat_inputs(state_norm, | |
self.first_pose_feats_dims), | |
self.first_pose_feats))[0] | |
def unnorm_delta(self, delta_norm: Tensor) -> Tensor: | |
# unnorm delta | |
return self.cat_inputs( | |
self.unnorm_inputs(self.uncat_inputs(delta_norm, | |
self.input_feats_dims), | |
self.input_feats))[0] | |
def norm_state(self, state:Tensor) -> Tensor: | |
# normalise state | |
return self.cat_inputs( | |
self.norm_inputs(self.uncat_inputs(state, | |
self.first_pose_feats_dims), | |
self.first_pose_feats))[0] | |
def norm_delta(self, delta:Tensor) -> Tensor: | |
# normalise delta | |
return self.cat_inputs( | |
self.norm_inputs(self.uncat_inputs(delta, self.input_feats_dims), | |
self.input_feats))[0] | |
def cat_inputs(self, x_list: List[Tensor]): | |
""" | |
cat the inputs to a unified vector and return their lengths in order | |
to un-cat them later | |
""" | |
return torch.cat(x_list, dim=-1), [x.shape[-1] for x in x_list] | |
def uncat_inputs(self, x: Tensor, lengths: List[int]): | |
""" | |
split the unified feature vector back to its original parts | |
""" | |
return torch.split(x, lengths, dim=-1) | |
def norm_inputs(self, x_list: List[Tensor], names: List[str]): | |
""" | |
Normalise inputs using the self.stats metrics | |
""" | |
x_norm = [] | |
for x, name in zip(x_list, names): | |
x_norm.append(self.norm(x, self.stats[name])) | |
return x_norm | |
def unnorm_inputs(self, x_list: List[Tensor], names: List[str]): | |
""" | |
Un-normalise inputs using the self.stats metrics | |
""" | |
x_unnorm = [] | |
for x, name in zip(x_list, names): | |
x_unnorm.append(self.unnorm(x, self.stats[name])) | |
return x_unnorm |