Spaces:
Running
Running
File size: 6,348 Bytes
d8530c7 08db8da d8530c7 d425bee d8530c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from os.path import exists
from gen_utils import cast_dict_to_tensors
from einops import rearrange
from torch import Tensor
from typing import List, Union
import torch
import numpy as np
class Normalizer:
def __init__(self, statistics_path: str='deps/statistics_motionfix.npy', nfeats: int=207,
input_feats: List[str] = ["body_transl_delta_pelv",
"body_orient_xy",
"z_orient_delta", "body_pose",
"body_joints_local_wo_z_rot"],
dim_per_feat: List[int] = [3, 6, 6, 126, 66], *args, **kwargs):
self.stats = self.load_norm_statistics(statistics_path, 'cpu')
# from src.model.utils.tools import pack_to_render
# mr = pack_to_render(aa.detach().cpu(), trans=None)
# mr = {k: v[0] for k, v in mr.items()}
# fname = render_motion(aitrenderer, mr,
# "/home/nathanasiou/Desktop/conditional_action_gen/modilex/pose_test",
# pose_repr='aa',
# text_for_vid=str(keyids[0]),
# color=color_map['generated'],
# smpl_layer=smpl_layer)
self.nfeats = nfeats
self.dim_per_feat = dim_per_feat
self.input_feats_dims = list(dim_per_feat)
self.input_feats = list(input_feats)
def load_norm_statistics(self, path, device):
# workaround for cluster local/sync
assert exists(path)
stats = np.load(path, allow_pickle=True)[()]
return cast_dict_to_tensors(stats, device=device)
def norm_and_cat(self, batch, features_types):
"""
turn batch data into the format the forward() function expects
"""
seq_first = lambda t: rearrange(t, 'b s ... -> s b ...')
input_batch = {}
## PREPARE INPUT ##
motion_condition = any('source' in value for value in batch.keys())
mo_types = ['source', 'target']
for mot in mo_types:
list_of_feat_tensors = [seq_first(batch[f'{feat_type}_{mot}'])
for feat_type in features_types if f'{feat_type}_{mot}' in batch.keys()]
# normalise and cat to a unified feature vector
list_of_feat_tensors_normed = self.norm_inputs(list_of_feat_tensors,
features_types)
# list_of_feat_tensors_normed = [x[1:] if 'delta' in nx else x for nx,
# x in zip(features_types,
# list_of_feat_tensors_normed)]
x_norm, _ = self.cat_inputs(list_of_feat_tensors_normed)
input_batch[mot] = x_norm
return input_batch
def norm_and_cat_single_motion(self, batch, features_types):
"""
turn batch data into the format the forward() function expects
"""
seq_first = lambda t: rearrange(t, 'b s ... -> s b ...')
input_batch = {}
## PREPARE INPUT ##
list_of_feat_tensors = [seq_first(batch[feat_type])
for feat_type in features_types]
# normalise and cat to a unified feature vector
list_of_feat_tensors_normed = self.norm_inputs(list_of_feat_tensors,
features_types)
# list_of_feat_tensors_normed = [x[1:] if 'delta' in nx else x for nx,
# x in zip(features_types,
# list_of_feat_tensors_normed)]
x_norm, _ = self.cat_inputs(list_of_feat_tensors_normed)
input_batch['motion'] = x_norm
return input_batch
def norm(self, x, stats):
mean = stats['mean'].to('cuda')
std = stats['std'].to('cuda')
return (x - mean) / (std + 1e-5)
def unnorm(self, x, stats):
mean = stats['mean'].to('cuda')
std = stats['std'].to('cuda')
return x * (std + 1e-5) + mean
def unnorm_state(self, state_norm: Tensor) -> Tensor:
# unnorm state
return self.cat_inputs(
self.unnorm_inputs(self.uncat_inputs(state_norm,
self.first_pose_feats_dims),
self.first_pose_feats))[0]
def unnorm_delta(self, delta_norm: Tensor) -> Tensor:
# unnorm delta
return self.cat_inputs(
self.unnorm_inputs(self.uncat_inputs(delta_norm,
self.input_feats_dims),
self.input_feats))[0]
def norm_state(self, state:Tensor) -> Tensor:
# normalise state
return self.cat_inputs(
self.norm_inputs(self.uncat_inputs(state,
self.first_pose_feats_dims),
self.first_pose_feats))[0]
def norm_delta(self, delta:Tensor) -> Tensor:
# normalise delta
return self.cat_inputs(
self.norm_inputs(self.uncat_inputs(delta, self.input_feats_dims),
self.input_feats))[0]
def cat_inputs(self, x_list: List[Tensor]):
"""
cat the inputs to a unified vector and return their lengths in order
to un-cat them later
"""
return torch.cat(x_list, dim=-1), [x.shape[-1] for x in x_list]
def uncat_inputs(self, x: Tensor, lengths: List[int]):
"""
split the unified feature vector back to its original parts
"""
return torch.split(x, lengths, dim=-1)
def norm_inputs(self, x_list: List[Tensor], names: List[str]):
"""
Normalise inputs using the self.stats metrics
"""
x_norm = []
for x, name in zip(x_list, names):
x_norm.append(self.norm(x, self.stats[name]))
return x_norm
def unnorm_inputs(self, x_list: List[Tensor], names: List[str]):
"""
Un-normalise inputs using the self.stats metrics
"""
x_unnorm = []
for x, name in zip(x_list, names):
x_unnorm.append(self.unnorm(x, self.stats[name]))
return x_unnorm |