meow
init
d6d3a5b
raw
history blame
No virus
9.06 kB
from torch.utils.data import DataLoader
from data_loaders.tensors import collate as all_collate
from data_loaders.tensors import t2m_collate, motion_ours_collate, motion_ours_singe_seq_collate, motion_ours_obj_base_rel_dist_collate
# from data_loaders.humanml.data.dataset import HumanML3D
import torch
def get_dataset_class(name, args=None):
if name == "amass":
from .amass import AMASS
return AMASS
elif name == "uestc":
from .a2m.uestc import UESTC
return UESTC
elif name == "humanact12":
from .a2m.humanact12poses import HumanAct12Poses
return HumanAct12Poses ## to pose ##
elif name == "humanml":
from data_loaders.humanml.data.dataset import HumanML3D
return HumanML3D
elif name == "kit":
from data_loaders.humanml.data.dataset import KIT
return KIT
elif name == "motion_ours": # motion ours
if len(args.single_seq_path) > 0 and not args.use_predicted_infos and not args.use_interpolated_infos:
print(f"Using single frame dataset for evaluation purpose...")
# from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V16
if args.rep_type == "obj_base_rel_dist":
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V17 as my_data
elif args.rep_type == "ambient_obj_base_rel_dist":
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V18 as my_data
elif args.rep_type in[ "obj_base_rel_dist_we", "obj_base_rel_dist_we_wj", "obj_base_rel_dist_we_wj_latents"]:
if args.use_arctic and args.use_pose_pred:
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_Arctic_from_Pred as my_data
elif args.use_hho:
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_HHO as my_data
elif args.use_arctic:
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_Arctic as my_data
elif len(args.cad_model_fn) > 0:
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_Ours as my_data
elif len(args.predicted_info_fn) > 0:
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_From_Evaluated_Info as my_data
else:
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19 as my_data
else:
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V16 as my_data
return my_data
else:
if args.rep_type == "obj_base_rel_dist":
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V17 as my_data
elif args.rep_type == "ambient_obj_base_rel_dist":
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V18 as my_data
elif args.rep_type in ["obj_base_rel_dist_we", "obj_base_rel_dist_we_wj", "obj_base_rel_dist_we_wj_latents"]:
if args.use_arctic:
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V19_ARCTIC as my_data
elif args.use_vox_data: # use vox data here #
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V20 as my_data
elif args.use_predicted_infos: # train with predicted infos for test tim adaptation #
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V21 as my_data
elif args.use_interpolated_infos:
# GRAB_Dataset_V22
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V22 as my_data
else:
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V19 as my_data
else:
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V16 as my_data
return my_data
# from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V16
# return GRAB_Dataset_V16
else:
raise ValueError(f'Unsupported dataset name [{name}]')
def get_collate_fn(name, hml_mode='train', args=None):
print(f"name: {name}, hml_mode: {hml_mode}")
if hml_mode == 'gt':
from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate
return t2m_eval_collate
if name in ["humanml", "kit"]:
return t2m_collate
elif name in ["motion_ours"]:
## === single seq path === ##
print(f"single_seq_path: {args.single_seq_path}, rep_type: {args.rep_type}")
# motion_ours_obj_base_rel_dist_collate
## rep_type of the obj_base_pts rel_dist; ambient obj base rel dist ##
if args.rep_type in ["obj_base_rel_dist", "ambient_obj_base_rel_dist", "obj_base_rel_dist_we", "obj_base_rel_dist_we_wj", "obj_base_rel_dist_we_wj_latents"]:
return motion_ours_obj_base_rel_dist_collate
else: # single_seq_path #
if len(args.single_seq_path) > 0:
return motion_ours_singe_seq_collate
else:
return motion_ours_collate
# if len(args.single_seq_path) > 0:
# return motion_ours_singe_seq_collate
# else:
# if args.rep_type == "obj_base_rel_dist":
# return motion_ours_obj_base_rel_dist_collate
# else:
# return motion_ours_collate
else:
return all_collate
## get dataset and datasset ###
def get_dataset(name, num_frames, split='train', hml_mode='train', args=None):
DATA = get_dataset_class(name, args=args)
if name in ["humanml", "kit"]:
dataset = DATA(split=split, num_frames=num_frames, mode=hml_mode)
elif name in ["motion_ours"]:
# humanml_datawarper = HumanML3D(split=split, num_frames=num_frames, mode=hml_mode, load_vectorizer=True)
# w_vectorizer = humanml_datawarper.w_vectorizer
w_vectorizer = None
# split = "val" ## add split, split here --> split --> split and split ##
data_path = "/data1/sim/GRAB_processed"
# split, w_vectorizer, window_size=30, step_size=15, num_points=8000, args=None
window_size = args.window_size
# split= "val"
dataset = DATA(data_path, split=split, w_vectorizer=w_vectorizer, window_size=window_size, step_size=15, num_points=8000, args=args)
else:
dataset = DATA(split=split, num_frames=num_frames)
return dataset
def get_dataset_only(name, batch_size, num_frames, split='train', hml_mode='train', args=None):
dataset = get_dataset(name, num_frames, split, hml_mode, args=args)
return dataset
# python -m train.train_mdm --save_dir save/my_humanml_trans_enc_512 --dataset motion_ours
def get_dataset_loader(name, batch_size, num_frames, split='train', hml_mode='train', args=None):
dataset = get_dataset(name, num_frames, split, hml_mode, args=args)
collate = get_collate_fn(name, hml_mode, args=args)
if args is not None and name in ["motion_ours"] and len(args.single_seq_path) > 0:
shuffle_loader = False
drop_last = False
else:
shuffle_loader = True
drop_last = True
num_workers = 8 ## get data; get data loader ##
num_workers = 16 # num_workers # ## num_workders #
### ==== create dataloader here ==== ###
### ==== create dataloader here ==== ###
loader = DataLoader( # tag for each sequence
dataset, batch_size=batch_size, shuffle=shuffle_loader,
num_workers=num_workers, drop_last=drop_last, collate_fn=collate
)
return loader
# python -m train.train_mdm --save_dir save/my_humanml_trans_enc_512 --dataset motion_ours
def get_dataset_loader_dist(name, batch_size, num_frames, split='train', hml_mode='train', args=None):
dataset = get_dataset(name, num_frames, split, hml_mode, args=args)
collate = get_collate_fn(name, hml_mode, args=args)
if args is not None and name in ["motion_ours"] and len(args.single_seq_path) > 0:
# shuffle_loader = False
drop_last = False
else:
# shuffle_loader = True
drop_last = True
num_workers = 8 ## get data; get data loader ##
num_workers = 16 # num_workers # ## num_workders #
### ==== create dataloader here ==== ###
### ==== create dataloader here ==== ###
''' dist sampler and loader '''
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
loader = DataLoader(dataset, batch_size=batch_size,
sampler=sampler, num_workers=num_workers, drop_last=drop_last, collate_fn=collate)
# loader = DataLoader( # tag for each sequence
# dataset, batch_size=batch_size, shuffle=shuffle_loader,
# num_workers=num_workers, drop_last=drop_last, collate_fn=collate
# )
return loader