megalado
Add local model code; tidy requirements
f87d582
from torch.utils.data import DataLoader
from data_loaders.tensors import collate as all_collate
from data_loaders.tensors import t2m_collate, t2m_prefix_collate
def get_dataset_class(name):
if name == "amass":
from .amass import AMASS
return AMASS
elif name == "uestc":
from .a2m.uestc import UESTC
return UESTC
elif name == "humanact12":
from .a2m.humanact12poses import HumanAct12Poses
return HumanAct12Poses
elif name == "humanml":
from data_loaders.humanml.data.dataset import HumanML3D
return HumanML3D
elif name == "kit":
from data_loaders.humanml.data.dataset import KIT
return KIT
else:
raise ValueError(f'Unsupported dataset name [{name}]')
def get_collate_fn(name, hml_mode='train', pred_len=0, batch_size=1):
if hml_mode == 'gt':
from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate
return t2m_eval_collate
if name in ["humanml", "kit"]:
if pred_len > 0:
return lambda x: t2m_prefix_collate(x, pred_len=pred_len)
return lambda x: t2m_collate(x, batch_size)
else:
return all_collate
def get_dataset(name, num_frames, split='train', hml_mode='train', abs_path='.', fixed_len=0,
device=None, autoregressive=False, cache_path=None):
DATA = get_dataset_class(name)
if name in ["humanml", "kit"]:
dataset = DATA(split=split, num_frames=num_frames, mode=hml_mode, abs_path=abs_path, fixed_len=fixed_len,
device=device, autoregressive=autoregressive)
else:
dataset = DATA(split=split, num_frames=num_frames)
return dataset
def get_dataset_loader(name, batch_size, num_frames, split='train', hml_mode='train', fixed_len=0, pred_len=0,
device=None, autoregressive=False):
dataset = get_dataset(name, num_frames, split=split, hml_mode=hml_mode, fixed_len=fixed_len,
device=device, autoregressive=autoregressive)
collate = get_collate_fn(name, hml_mode, pred_len, batch_size)
loader = DataLoader(
dataset, batch_size=batch_size, shuffle=True,
num_workers=8, drop_last=True, collate_fn=collate
)
return loader