|
import copy |
|
from typing import Callable, Optional |
|
|
|
import numpy as np |
|
from omegaconf import DictConfig |
|
|
|
import torch |
|
|
|
from .base import BASEDataModule |
|
from .humanml.dataset import Text2MotionDatasetV2 |
|
from .humanml.scripts.motion_process import recover_from_ric |
|
|
|
|
|
class HumanML3DDataModule(BASEDataModule): |
|
|
|
def __init__(self, |
|
cfg: DictConfig, |
|
batch_size: int, |
|
num_workers: int, |
|
collate_fn: Optional[Callable] = None, |
|
persistent_workers: bool = True, |
|
phase: str = "train", |
|
**kwargs) -> None: |
|
super().__init__(batch_size=batch_size, |
|
num_workers=num_workers, |
|
collate_fn=collate_fn, |
|
persistent_workers=persistent_workers) |
|
self.hparams = copy.deepcopy(kwargs) |
|
self.name = "humanml3d" |
|
self.njoints = 22 |
|
if phase == "text_only": |
|
raise NotImplementedError |
|
else: |
|
self.Dataset = Text2MotionDatasetV2 |
|
self.cfg = cfg |
|
|
|
sample_overrides = {"tiny": True, "progress_bar": False} |
|
self._sample_set = self.get_sample_set(overrides=sample_overrides) |
|
self.nfeats = self._sample_set.nfeats |
|
|
|
def denorm_spatial(self, hint: torch.Tensor) -> torch.Tensor: |
|
raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint) |
|
raw_std = torch.tensor(self._sample_set.raw_std).to(hint) |
|
hint = hint * raw_std + raw_mean |
|
return hint |
|
|
|
def norm_spatial(self, hint: torch.Tensor) -> torch.Tensor: |
|
raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint) |
|
raw_std = torch.tensor(self._sample_set.raw_std).to(hint) |
|
hint = (hint - raw_mean) / raw_std |
|
return hint |
|
|
|
def feats2joints(self, features: torch.Tensor) -> torch.Tensor: |
|
mean = torch.tensor(self.hparams['mean']).to(features) |
|
std = torch.tensor(self.hparams['std']).to(features) |
|
features = features * std + mean |
|
return recover_from_ric(features, self.njoints) |
|
|
|
def renorm4t2m(self, features: torch.Tensor) -> torch.Tensor: |
|
|
|
ori_mean = torch.tensor(self.hparams['mean']).to(features) |
|
ori_std = torch.tensor(self.hparams['std']).to(features) |
|
eval_mean = torch.tensor(self.hparams['mean_eval']).to(features) |
|
eval_std = torch.tensor(self.hparams['std_eval']).to(features) |
|
features = features * ori_std + ori_mean |
|
features = (features - eval_mean) / eval_std |
|
return features |
|
|
|
def mm_mode(self, mm_on: bool = True) -> None: |
|
if mm_on: |
|
self.is_mm = True |
|
self.name_list = self.test_dataset.name_list |
|
self.mm_list = np.random.choice(self.name_list, |
|
self.cfg.TEST.MM_NUM_SAMPLES, |
|
replace=False) |
|
self.test_dataset.name_list = self.mm_list |
|
else: |
|
self.is_mm = False |
|
self.test_dataset.name_list = self.name_list |
|
|