Spaces:
Sleeping
Sleeping
| from torch.utils.data import DataLoader | |
| from data_loaders.tensors import collate as all_collate | |
| from data_loaders.tensors import t2m_collate, t2m_prefix_collate | |
| def get_dataset_class(name): | |
| if name == "amass": | |
| from .amass import AMASS | |
| return AMASS | |
| elif name == "uestc": | |
| from .a2m.uestc import UESTC | |
| return UESTC | |
| elif name == "humanact12": | |
| from .a2m.humanact12poses import HumanAct12Poses | |
| return HumanAct12Poses | |
| elif name == "humanml": | |
| from data_loaders.humanml.data.dataset import HumanML3D | |
| return HumanML3D | |
| elif name == "kit": | |
| from data_loaders.humanml.data.dataset import KIT | |
| return KIT | |
| else: | |
| raise ValueError(f'Unsupported dataset name [{name}]') | |
| def get_collate_fn(name, hml_mode='train', pred_len=0, batch_size=1): | |
| if hml_mode == 'gt': | |
| from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate | |
| return t2m_eval_collate | |
| if name in ["humanml", "kit"]: | |
| if pred_len > 0: | |
| return lambda x: t2m_prefix_collate(x, pred_len=pred_len) | |
| return lambda x: t2m_collate(x, batch_size) | |
| else: | |
| return all_collate | |
| def get_dataset(name, num_frames, split='train', hml_mode='train', abs_path='.', fixed_len=0, | |
| device=None, autoregressive=False, cache_path=None): | |
| DATA = get_dataset_class(name) | |
| if name in ["humanml", "kit"]: | |
| dataset = DATA(split=split, num_frames=num_frames, mode=hml_mode, abs_path=abs_path, fixed_len=fixed_len, | |
| device=device, autoregressive=autoregressive) | |
| else: | |
| dataset = DATA(split=split, num_frames=num_frames) | |
| return dataset | |
| def get_dataset_loader(name, batch_size, num_frames, split='train', hml_mode='train', fixed_len=0, pred_len=0, | |
| device=None, autoregressive=False): | |
| dataset = get_dataset(name, num_frames, split=split, hml_mode=hml_mode, fixed_len=fixed_len, | |
| device=device, autoregressive=autoregressive) | |
| collate = get_collate_fn(name, hml_mode, pred_len, batch_size) | |
| loader = DataLoader( | |
| dataset, batch_size=batch_size, shuffle=True, | |
| num_workers=8, drop_last=True, collate_fn=collate | |
| ) | |
| return loader |