from os.path import exists from gen_utils import cast_dict_to_tensors from einops import rearrange from torch import Tensor from typing import List, Union import torch import numpy as np class Normalizer: def __init__(self, statistics_path: str='deps/statistics_motionfix.npy', nfeats: int=207, input_feats: List[str] = ["body_transl_delta_pelv", "body_orient_xy", "z_orient_delta", "body_pose", "body_joints_local_wo_z_rot"], dim_per_feat: List[int] = [3, 6, 6, 126, 66], *args, **kwargs): self.stats = self.load_norm_statistics(statistics_path, 'cpu') # from src.model.utils.tools import pack_to_render # mr = pack_to_render(aa.detach().cpu(), trans=None) # mr = {k: v[0] for k, v in mr.items()} # fname = render_motion(aitrenderer, mr, # "/home/nathanasiou/Desktop/conditional_action_gen/modilex/pose_test", # pose_repr='aa', # text_for_vid=str(keyids[0]), # color=color_map['generated'], # smpl_layer=smpl_layer) self.nfeats = nfeats self.dim_per_feat = dim_per_feat self.input_feats_dims = list(dim_per_feat) self.input_feats = list(input_feats) def load_norm_statistics(self, path, device): # workaround for cluster local/sync assert exists(path) stats = np.load(path, allow_pickle=True)[()] return cast_dict_to_tensors(stats, device=device) def norm_and_cat(self, batch, features_types): """ turn batch data into the format the forward() function expects """ seq_first = lambda t: rearrange(t, 'b s ... -> s b ...') input_batch = {} ## PREPARE INPUT ## motion_condition = any('source' in value for value in batch.keys()) mo_types = ['source', 'target'] for mot in mo_types: list_of_feat_tensors = [seq_first(batch[f'{feat_type}_{mot}']) for feat_type in features_types if f'{feat_type}_{mot}' in batch.keys()] # normalise and cat to a unified feature vector list_of_feat_tensors_normed = self.norm_inputs(list_of_feat_tensors, features_types) # list_of_feat_tensors_normed = [x[1:] if 'delta' in nx else x for nx, # x in zip(features_types, # list_of_feat_tensors_normed)] x_norm, _ = self.cat_inputs(list_of_feat_tensors_normed) input_batch[mot] = x_norm return input_batch def norm_and_cat_single_motion(self, batch, features_types): """ turn batch data into the format the forward() function expects """ seq_first = lambda t: rearrange(t, 'b s ... -> s b ...') input_batch = {} ## PREPARE INPUT ## list_of_feat_tensors = [seq_first(batch[feat_type]) for feat_type in features_types] # normalise and cat to a unified feature vector list_of_feat_tensors_normed = self.norm_inputs(list_of_feat_tensors, features_types) # list_of_feat_tensors_normed = [x[1:] if 'delta' in nx else x for nx, # x in zip(features_types, # list_of_feat_tensors_normed)] x_norm, _ = self.cat_inputs(list_of_feat_tensors_normed) input_batch['motion'] = x_norm return input_batch def norm(self, x, stats): mean = stats['mean'].to('cuda') std = stats['std'].to('cuda') return (x - mean) / (std + 1e-5) def unnorm(self, x, stats): mean = stats['mean'].to('cuda') std = stats['std'].to('cuda') return x * (std + 1e-5) + mean def unnorm_state(self, state_norm: Tensor) -> Tensor: # unnorm state return self.cat_inputs( self.unnorm_inputs(self.uncat_inputs(state_norm, self.first_pose_feats_dims), self.first_pose_feats))[0] def unnorm_delta(self, delta_norm: Tensor) -> Tensor: # unnorm delta return self.cat_inputs( self.unnorm_inputs(self.uncat_inputs(delta_norm, self.input_feats_dims), self.input_feats))[0] def norm_state(self, state:Tensor) -> Tensor: # normalise state return self.cat_inputs( self.norm_inputs(self.uncat_inputs(state, self.first_pose_feats_dims), self.first_pose_feats))[0] def norm_delta(self, delta:Tensor) -> Tensor: # normalise delta return self.cat_inputs( self.norm_inputs(self.uncat_inputs(delta, self.input_feats_dims), self.input_feats))[0] def cat_inputs(self, x_list: List[Tensor]): """ cat the inputs to a unified vector and return their lengths in order to un-cat them later """ return torch.cat(x_list, dim=-1), [x.shape[-1] for x in x_list] def uncat_inputs(self, x: Tensor, lengths: List[int]): """ split the unified feature vector back to its original parts """ return torch.split(x, lengths, dim=-1) def norm_inputs(self, x_list: List[Tensor], names: List[str]): """ Normalise inputs using the self.stats metrics """ x_norm = [] for x, name in zip(x_list, names): x_norm.append(self.norm(x, self.stats[name])) return x_norm def unnorm_inputs(self, x_list: List[Tensor], names: List[str]): """ Un-normalise inputs using the self.stats metrics """ x_unnorm = [] for x, name in zip(x_list, names): x_unnorm.append(self.unnorm(x, self.stats[name])) return x_unnorm