from typing import Optional from os.path import join as pjoin import numpy as np from omegaconf import DictConfig from .data import DataModule from .base import BaseDataModule from .utils import mld_collate, mld_collate_motion_only from .humanml.utils.word_vectorizer import WordVectorizer def get_mean_std(phase: str, cfg: DictConfig, dataset_name: str) -> tuple[np.ndarray, np.ndarray]: name = "t2m" if dataset_name == "humanml3d" else dataset_name assert name in ["t2m", "kit"] if phase in ["val"]: if name == 't2m': data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD01", "meta") elif name == 'kit': data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD005", "meta") else: raise ValueError("Only support t2m and kit") mean = np.load(pjoin(data_root, "mean.npy")) std = np.load(pjoin(data_root, "std.npy")) else: data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") mean = np.load(pjoin(data_root, "Mean.npy")) std = np.load(pjoin(data_root, "Std.npy")) return mean, std def get_WordVectorizer(cfg: DictConfig, dataset_name: str) -> Optional[WordVectorizer]: if dataset_name.lower() in ["humanml3d", "kit"]: return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab") else: raise ValueError("Only support WordVectorizer for HumanML3D and KIT") dataset_module_map = {"humanml3d": DataModule, "kit": DataModule} motion_subdir = {"humanml3d": "new_joint_vecs", "kit": "new_joint_vecs"} def get_dataset(cfg: DictConfig, motion_only: bool = False) -> BaseDataModule: dataset_name = cfg.DATASET.NAME if dataset_name.lower() in ["humanml3d", "kit"]: data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") mean, std = get_mean_std('train', cfg, dataset_name) mean_eval, std_eval = get_mean_std("val", cfg, dataset_name) wordVectorizer = None if motion_only else get_WordVectorizer(cfg, dataset_name) collate_fn = mld_collate_motion_only if motion_only else mld_collate dataset = dataset_module_map[dataset_name.lower()]( name=dataset_name.lower(), cfg=cfg, motion_only=motion_only, collate_fn=collate_fn, mean=mean, std=std, mean_eval=mean_eval, std_eval=std_eval, w_vectorizer=wordVectorizer, text_dir=pjoin(data_root, "texts"), motion_dir=pjoin(data_root, motion_subdir[dataset_name]), max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN, min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN, max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN, unit_length=eval(f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"), fps=eval(f"cfg.DATASET.{dataset_name.upper()}.FRAME_RATE"), padding_to_max=cfg.DATASET.PADDING_TO_MAX, window_size=cfg.DATASET.WINDOW_SIZE, control_args=eval(f"cfg.DATASET.{dataset_name.upper()}.CONTROL_ARGS")) cfg.DATASET.NFEATS = dataset.nfeats cfg.DATASET.NJOINTS = dataset.njoints return dataset elif dataset_name.lower() in ["humanact12", 'uestc', "amass"]: raise NotImplementedError