File size: 6,348 Bytes
d8530c7
 
 
 
 
 
 
 
 
08db8da
d8530c7
 
 
 
 
 
d425bee
d8530c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from os.path import exists
from gen_utils import cast_dict_to_tensors
from einops import rearrange
from torch import Tensor
from typing import List, Union
import torch
import numpy as np

class Normalizer:
    def __init__(self, statistics_path: str='deps/statistics_motionfix.npy', nfeats: int=207, 
                 input_feats: List[str] = ["body_transl_delta_pelv",
                                           "body_orient_xy",
                                           "z_orient_delta", "body_pose",
                                           "body_joints_local_wo_z_rot"], 
                 dim_per_feat: List[int] = [3, 6, 6, 126, 66], *args, **kwargs):

        self.stats = self.load_norm_statistics(statistics_path, 'cpu')
        # from src.model.utils.tools import pack_to_render
        # mr = pack_to_render(aa.detach().cpu(), trans=None)
        # mr = {k: v[0] for k, v in mr.items()}
        # fname = render_motion(aitrenderer, mr,
        #                  "/home/nathanasiou/Desktop/conditional_action_gen/modilex/pose_test",
        #                 pose_repr='aa',
        #                 text_for_vid=str(keyids[0]),
        #                 color=color_map['generated'],
        #                 smpl_layer=smpl_layer)

        self.nfeats = nfeats
        self.dim_per_feat = dim_per_feat
        self.input_feats_dims = list(dim_per_feat)
        self.input_feats = list(input_feats)
 
 
    def load_norm_statistics(self, path, device):
        # workaround for cluster local/sync
        assert exists(path)
        stats = np.load(path, allow_pickle=True)[()]
        return cast_dict_to_tensors(stats, device=device)

    def norm_and_cat(self, batch, features_types):
        """
        turn batch data into the format the forward() function expects
        """
        seq_first = lambda t: rearrange(t, 'b s ... -> s b ...') 
        input_batch = {}
        ## PREPARE INPUT ##
        motion_condition = any('source' in value for value in batch.keys())
        mo_types = ['source', 'target']
        for mot in mo_types:
            list_of_feat_tensors = [seq_first(batch[f'{feat_type}_{mot}']) 
                                    for feat_type in features_types if f'{feat_type}_{mot}' in batch.keys()]
            # normalise and cat to a unified feature vector
            list_of_feat_tensors_normed = self.norm_inputs(list_of_feat_tensors,
                                                           features_types)
            # list_of_feat_tensors_normed = [x[1:] if 'delta' in nx else x for nx,
                                                # x in zip(features_types, 
                                                # list_of_feat_tensors_normed)]
            x_norm, _ = self.cat_inputs(list_of_feat_tensors_normed)
            input_batch[mot] = x_norm
        return input_batch
    
    def norm_and_cat_single_motion(self, batch, features_types):
        """
        turn batch data into the format the forward() function expects
        """
        seq_first = lambda t: rearrange(t, 'b s ... -> s b ...') 
        input_batch = {}
        ## PREPARE INPUT ##
            
        list_of_feat_tensors = [seq_first(batch[feat_type]) 
                                for feat_type in features_types]
        # normalise and cat to a unified feature vector
        list_of_feat_tensors_normed = self.norm_inputs(list_of_feat_tensors,
                                                        features_types)
        # list_of_feat_tensors_normed = [x[1:] if 'delta' in nx else x for nx,
                                            # x in zip(features_types, 
                                            # list_of_feat_tensors_normed)]
        
        x_norm, _ = self.cat_inputs(list_of_feat_tensors_normed)
        input_batch['motion'] = x_norm
        return input_batch

    def norm(self, x, stats):
        mean = stats['mean'].to('cuda')
        std = stats['std'].to('cuda')
        return (x - mean) / (std + 1e-5)

    def unnorm(self, x, stats):
        mean = stats['mean'].to('cuda')
        std = stats['std'].to('cuda')
        return x * (std + 1e-5) + mean

    def unnorm_state(self, state_norm: Tensor) -> Tensor:
        # unnorm state
        return self.cat_inputs(
            self.unnorm_inputs(self.uncat_inputs(state_norm,
                                                 self.first_pose_feats_dims),
                               self.first_pose_feats))[0]
        
    def unnorm_delta(self, delta_norm: Tensor) -> Tensor:
        # unnorm delta
        return self.cat_inputs(
            self.unnorm_inputs(self.uncat_inputs(delta_norm,
                                                 self.input_feats_dims),
                               self.input_feats))[0]

    def norm_state(self, state:Tensor) -> Tensor:
        # normalise state
        return self.cat_inputs(
            self.norm_inputs(self.uncat_inputs(state, 
                                               self.first_pose_feats_dims),
                             self.first_pose_feats))[0]

    def norm_delta(self, delta:Tensor) -> Tensor:
        # normalise delta
        return self.cat_inputs(
            self.norm_inputs(self.uncat_inputs(delta, self.input_feats_dims),
                             self.input_feats))[0]

    def cat_inputs(self, x_list: List[Tensor]):
        """
        cat the inputs to a unified vector and return their lengths in order
        to un-cat them later
        """
        return torch.cat(x_list, dim=-1), [x.shape[-1] for x in x_list]
    
    def uncat_inputs(self, x: Tensor, lengths: List[int]):
        """
        split the unified feature vector back to its original parts
        """
        return torch.split(x, lengths, dim=-1)
    
    def norm_inputs(self, x_list: List[Tensor], names: List[str]):
        """
        Normalise inputs using the self.stats metrics
        """
        x_norm = []
        for x, name in zip(x_list, names):
            
            x_norm.append(self.norm(x, self.stats[name]))
        return x_norm

    def unnorm_inputs(self, x_list: List[Tensor], names: List[str]):
        """
        Un-normalise inputs using the self.stats metrics
        """
        x_unnorm = []
        for x, name in zip(x_list, names):
            x_unnorm.append(self.unnorm(x, self.stats[name]))
        return x_unnorm