File size: 3,806 Bytes
4409449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import numpy as np
import torch
from os.path import join as pjoin
from .humanml.utils.word_vectorizer import WordVectorizer
from .humanml.scripts.motion_process import (process_file, recover_from_ric)
from .HumanML3D import HumanML3DDataModule
from .humanml import Text2MotionDatasetEval, Text2MotionDataset, Text2MotionDatasetCB, MotionDataset, MotionDatasetVQ, Text2MotionDatasetToken


class KitDataModule(HumanML3DDataModule):
    def __init__(self, cfg, **kwargs):

        super().__init__(cfg, **kwargs)

        # Basic info of the dataset
        self.name = "kit"
        self.njoints = 21

        # Path to the dataset
        data_root = cfg.DATASET.KIT.ROOT
        self.hparams.data_root = data_root
        self.hparams.text_dir = pjoin(data_root, "texts")
        self.hparams.motion_dir = pjoin(data_root, 'new_joint_vecs')

        # Mean and std of the dataset
        dis_data_root = pjoin(cfg.DATASET.KIT.MEAN_STD_PATH, 'kit',
                              "VQVAEV3_CB1024_CMT_H1024_NRES3", "meta")
        self.hparams.mean = np.load(pjoin(dis_data_root, "mean.npy"))
        self.hparams.std = np.load(pjoin(dis_data_root, "std.npy"))

        # Mean and std for fair evaluation
        dis_data_root_eval = pjoin(cfg.DATASET.KIT.MEAN_STD_PATH, 't2m',
                                   "Comp_v6_KLD005", "meta")
        self.hparams.mean_eval = np.load(pjoin(dis_data_root_eval, "mean.npy"))
        self.hparams.std_eval = np.load(pjoin(dis_data_root_eval, "std.npy"))

        # Length of the dataset
        self.hparams.max_motion_length = cfg.DATASET.KIT.MAX_MOTION_LEN
        self.hparams.min_motion_length = cfg.DATASET.KIT.MIN_MOTION_LEN
        self.hparams.max_text_len = cfg.DATASET.KIT.MAX_TEXT_LEN
        self.hparams.unit_length = cfg.DATASET.KIT.UNIT_LEN
        
        # Get additional info of the dataset
        self._sample_set = self.get_sample_set(overrides={"split": "test", "tiny": True})
        self.nfeats = self._sample_set.nfeats
        cfg.DATASET.NFEATS = self.nfeats

    def feats2joints(self, features):
        mean = torch.tensor(self.hparams.mean).to(features)
        std = torch.tensor(self.hparams.std).to(features)
        features = features * std + mean
        return recover_from_ric(features, self.njoints)

    def joints2feats(self, features):
        features = process_file(features, self.njoints)[0]
        # mean = torch.tensor(self.hparams.mean).to(features)
        # std = torch.tensor(self.hparams.std).to(features)
        # features = (features - mean) / std
        return features

    def normalize(self, features):
        mean = torch.tensor(self.hparams.mean).to(features)
        std = torch.tensor(self.hparams.std).to(features)
        features = (features - mean) / std
        return features

    def renorm4t2m(self, features):
        # renorm to t2m norms for using t2m evaluators
        ori_mean = torch.tensor(self.hparams.mean).to(features)
        ori_std = torch.tensor(self.hparams.std).to(features)
        eval_mean = torch.tensor(self.hparams.mean_eval).to(features)
        eval_std = torch.tensor(self.hparams.std_eval).to(features)
        features = features * ori_std + ori_mean
        features = (features - eval_mean) / eval_std
        return features

    def mm_mode(self, mm_on=True):
        # random select samples for mm
        if mm_on:
            self.is_mm = True
            self.name_list = self.test_dataset.name_list
            self.mm_list = np.random.choice(self.name_list,
                                            self.cfg.METRIC.MM_NUM_SAMPLES,
                                            replace=False)
            self.test_dataset.name_list = self.mm_list
        else:
            self.is_mm = False
            self.test_dataset.name_list = self.name_list