|
from os.path import join as pjoin |
|
from typing import Callable, Optional |
|
|
|
import numpy as np |
|
|
|
from omegaconf import DictConfig |
|
|
|
from .humanml.utils.word_vectorizer import WordVectorizer |
|
from .HumanML3D import HumanML3DDataModule |
|
from .Kit import KitDataModule |
|
from .base import BASEDataModule |
|
from .utils import mld_collate |
|
|
|
|
|
def get_mean_std(phase: str, cfg: DictConfig, dataset_name: str) -> tuple[np.ndarray, np.ndarray]: |
|
name = "t2m" if dataset_name == "humanml3d" else dataset_name |
|
assert name in ["t2m", "kit"] |
|
if phase in ["val"]: |
|
if name == 't2m': |
|
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD01", "meta") |
|
elif name == 'kit': |
|
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD005", "meta") |
|
else: |
|
raise ValueError("Only support t2m and kit") |
|
mean = np.load(pjoin(data_root, "mean.npy")) |
|
std = np.load(pjoin(data_root, "std.npy")) |
|
else: |
|
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") |
|
mean = np.load(pjoin(data_root, "Mean.npy")) |
|
std = np.load(pjoin(data_root, "Std.npy")) |
|
|
|
return mean, std |
|
|
|
|
|
def get_WordVectorizer(cfg: DictConfig, phase: str, dataset_name: str) -> Optional[WordVectorizer]: |
|
if phase not in ["text_only"]: |
|
if dataset_name.lower() in ["humanml3d", "kit"]: |
|
return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab") |
|
else: |
|
raise ValueError("Only support WordVectorizer for HumanML3D") |
|
else: |
|
return None |
|
|
|
|
|
def get_collate_fn(name: str) -> Callable: |
|
if name.lower() in ["humanml3d", "kit"]: |
|
return mld_collate |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
dataset_module_map = {"humanml3d": HumanML3DDataModule, "kit": KitDataModule} |
|
motion_subdir = {"humanml3d": "new_joint_vecs", "kit": "new_joint_vecs"} |
|
|
|
|
|
def get_datasets(cfg: DictConfig, phase: str = "train") -> list[BASEDataModule]: |
|
dataset_names = eval(f"cfg.{phase.upper()}.DATASETS") |
|
datasets = [] |
|
for dataset_name in dataset_names: |
|
if dataset_name.lower() in ["humanml3d", "kit"]: |
|
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") |
|
mean, std = get_mean_std(phase, cfg, dataset_name) |
|
mean_eval, std_eval = get_mean_std("val", cfg, dataset_name) |
|
wordVectorizer = get_WordVectorizer(cfg, phase, dataset_name) |
|
collate_fn = get_collate_fn(dataset_name) |
|
dataset = dataset_module_map[dataset_name.lower()]( |
|
cfg=cfg, |
|
batch_size=cfg.TRAIN.BATCH_SIZE, |
|
num_workers=cfg.TRAIN.NUM_WORKERS, |
|
collate_fn=collate_fn, |
|
persistent_workers=cfg.TRAIN.PERSISTENT_WORKERS, |
|
mean=mean, |
|
std=std, |
|
mean_eval=mean_eval, |
|
std_eval=std_eval, |
|
w_vectorizer=wordVectorizer, |
|
text_dir=pjoin(data_root, "texts"), |
|
motion_dir=pjoin(data_root, motion_subdir[dataset_name]), |
|
max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN, |
|
min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN, |
|
max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN, |
|
unit_length=eval( |
|
f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"), |
|
model_kwargs=cfg.model |
|
) |
|
datasets.append(dataset) |
|
|
|
elif dataset_name.lower() in ["humanact12", 'uestc', "amass"]: |
|
raise NotImplementedError |
|
|
|
cfg.DATASET.NFEATS = datasets[0].nfeats |
|
cfg.DATASET.NJOINTS = datasets[0].njoints |
|
return datasets |
|
|