diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..99e2bf1efa0e618f3e675b3daae9b179a7695c09 --- /dev/null +++ b/.gitignore @@ -0,0 +1,145 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +**/*.pyc + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/en/build + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# custom +data +!mmhuman3d/data +# data for pytest moved to http server +# !tests/data +.vscode +.idea +*.pkl +*.pkl.json +*.log.json +work_dirs/ +logs/ + +# Pytorch +*.pth +*.pt + + +# Visualization +*.mp4 +*.png +*.gif +*.jpg +*.obj +*.ply +!demo/resources/* + +# Resources as exception +!resources/* + +# Loaded/Saved data files +*.npz +*.npy +*.pickle + +# MacOS +*DS_Store* +# git +*.orig \ No newline at end of file diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4690913abcd55df5ac65531a832c375323ae6d07 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,13 @@ +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.9" + +sphinx: + configuration: docs/en/source/conf.py + +python: + install: + - requirements: requirements/docs.txt \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..c2016655a484ca5cf76c1f3f4a49a82277caae46 --- /dev/null +++ b/app.py @@ -0,0 +1,202 @@ +import os +import sys +import gradio as gr +import time + +os.makedirs("outputs", exist_ok=True) +sys.path.insert(0, '.') + +import argparse +import os.path as osp +import mmcv +import numpy as np +import torch +from mmcv.runner import load_checkpoint +from mmcv.parallel import MMDataParallel +from scipy.ndimage import gaussian_filter +from IPython.display import Image + +from mogen.models.utils.imagebind_wrapper import ( + extract_text_feature, + extract_audio_feature, + imagebind_huge +) +from mogen.models import build_architecture + +from mogen.utils.plot_utils import ( + plot_3d_motion, + add_audio, + get_audio_length +) +from mogen.datasets.paramUtil import ( + t2m_body_hand_kinematic_chain, + t2m_kinematic_chain +) +from mogen.datasets.utils import recover_from_ric +from mogen.datasets.pipelines import RetargetSkeleton + + +def motion_temporal_filter(motion, sigma=1): + motion = motion.reshape(motion.shape[0], -1) + for i in range(motion.shape[1]): + motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest") + return motion.reshape(motion.shape[0], -1, 3) + +def plot_tomato(data, kinematic_chain, result_path, npy_path, fps, sigma=None): + joints = recover_from_ric(torch.from_numpy(data).float(), 52).numpy() + joints = motion_temporal_filter(joints, sigma=2.5) + joints = rtg_skl({"keypoints3d": joints, "meta_data": {"has_lhnd": True}})["keypoints3d"] + plot_3d_motion( + out_path=result_path, + joints=joints, + kinematic_chain=kinematic_chain, + title=None, + fps=fps) + if npy_path is not None: + np.save(npy_path, joints) + +def create_lmm(): + config_path = "configs/lmm/lmm_small_demo.py" + ckpt_path = "pretrained/lmm_small_demo.pth" + cfg = mmcv.Config.fromfile(config_path) + model = build_architecture(cfg.model) + load_checkpoint(model, ckpt_path, map_location='cpu') + if device == 'cpu': + model = model.cpu() + else: + model = MMDataParallel(model, device_ids=[0]) + model.eval() + return model + +# device = 'cpu' +device = 'cuda' +# os.environ["NO_PROXY"] = os.environ["no_proxy"] = "localhost, 127.0.0.1:7860" +model_lmm = create_lmm() +model_imagebind = imagebind_huge(pretrained=True) +model_imagebind.eval() +model_imagebind.to(device) +rtg_skl = RetargetSkeleton(tgt_skel_file='data/motionverse/statistics/skeleton.npy') + +mean_path = "data/mean.npy" +std_path = "data/std.npy" +mean = np.load(mean_path) +std = np.load(std_path) + +def show_generation_result(model, text, audio_path, motion_length, result_path): + fps = 20 + if audio_path is not None: + motion_length = min(200, int(get_audio_length(audio_path) * fps) + 1) + motion = torch.zeros(1, motion_length, 669).to(device) + motion_mask = torch.ones(1, motion_length).to(device) + motion_mask[0, :motion_length] = 1 + motion_mask = motion_mask.unsqueeze(-1).repeat(1, 1, 10) + motion_mask[:, :, 9] = 0 + dataset_name = "humanml3d_t2m" + kinematic_chain = t2m_body_hand_kinematic_chain + rotation_type = "h3d_rot" + motion_metas = [{ + 'meta_data': dict(framerate=fps, dataset_name=dataset_name, rotation_type=rotation_type) + }] + motion_length = torch.Tensor([motion_length]).long().to(device) + if text is None and audio_path is not None: + text = "A person is standing and speaking." + + model = model.to(device) + input = { + 'motion': motion, + 'motion_mask': motion_mask, + 'motion_length': motion_length, + 'motion_metas': motion_metas, + 'num_intervals': 1 + } + if text is not None: + text_word_feat, text_seq_feat = \ + extract_text_feature([text], model_imagebind, device) + assert text_word_feat.shape[0] == 1 + assert text_word_feat.shape[1] == 77 + assert text_word_feat.shape[2] == 1024 + assert text_seq_feat.shape[0] == 1 + assert text_seq_feat.shape[1] == 1024 + input['text_word_feat'] = text_word_feat + input['text_seq_feat'] = text_seq_feat + input['text_cond'] = torch.Tensor([1.0] * 1).to(device) + else: + input['text_word_feat'] = torch.zeros(1, 77, 1024).to(device) + input['text_seq_feat'] = torch.zeros(1, 1024) + input['text_cond'] = torch.Tensor([0] * 1).to(device) + if audio_path is not None: + speech_word_feat, speech_seq_feat = \ + extract_audio_feature([audio_path], model_imagebind, device) + assert speech_word_feat.shape[0] == 1 + assert speech_word_feat.shape[1] == 229 + assert speech_word_feat.shape[2] == 768 + assert speech_seq_feat.shape[0] == 1 + assert speech_seq_feat.shape[1] == 1024 + input['speech_word_feat'] = speech_word_feat + input['speech_seq_feat'] = speech_seq_feat + input['speech_cond'] = torch.Tensor([1.0] * 1).to(device) + else: + input['speech_word_feat'] = torch.zeros(1, 229, 768).to(device) + input['speech_seq_feat'] = torch.zeros(1, 1024) + input['speech_cond'] = torch.Tensor([0] * 1).to(device) + + all_pred_motion = [] + with torch.no_grad(): + input['inference_kwargs'] = {} + output = model(**input)[0]['pred_motion'][:motion_length] + pred_motion = output.cpu().detach().numpy() + pred_motion = pred_motion * std + mean + + plot_tomato(pred_motion, kinematic_chain, result_path, None, fps, 2) + + if audio_path is not None: + add_audio(result_path, [audio_path]) + +def generate(prompt, audio_path, length): + if not os.path.exists("outputs"): + os.mkdir("outputs") + result_path = "outputs/" + str(int(time.time())) + ".mp4" + print(audio_path) + if audio_path.endswith("placeholder.wav"): + audio_path = None + if len(prompt) == 0: + prompt = None + show_generation_result(model_lmm, prompt, audio_path, length, result_path) + return result_path + +input_audio = gr.Audio( + type='filepath', + format='wav', + label="Audio (1-10s, overwrite motion length):", + show_label=True, + sources=["upload", "microphone"], + min_length=1, + max_length=10, + waveform_options=gr.WaveformOptions( + waveform_color="#01C6FF", + waveform_progress_color="#0066B4", + skip_length=2, + show_controls=False, + ), +) + +input_text = gr.Textbox( + label="Text prompt:" +) + +demo = gr.Interface( + fn=generate, + inputs=[input_text, input_audio, gr.Slider(20, 200, value=60, label="Motion length (fps 20):")], + outputs=gr.Video(label="Video:"), + examples=[ + ["A person walks in a circle.", "examples/placeholder.m4a", 120], + ["A person jumps forward.", "examples/placeholder.m4a", 100], + ["A person is stretching arms.", "examples/placeholder.m4a", 80], + ["", "examples/surprise.m4a", 200], + ["", "examples/angry.m4a", 200], + ], + title="LMM: Large Motion Model for Unified Multi-Modal Motion Generation", + description="\nThis is an interactive demo for LMM. For more information, feel free to visit our project page(https://github.com/mingyuan-zhang/LMM).") + +demo.queue() +demo.launch() \ No newline at end of file diff --git a/configs/lmm/lmm.py b/configs/lmm/lmm.py new file mode 100644 index 0000000000000000000000000000000000000000..e82d32ce03137a982fb3f6de57d4d3d5db7f9509 --- /dev/null +++ b/configs/lmm/lmm.py @@ -0,0 +1,75 @@ +dataset_names = [ + 'all', + 'amass_mocap', 'motionx_mocap', 'humanact12_mocap', 'uestc_mocap', 'ntu_mocap', 'aist_mocap', + 'beat_mocap', 'tedg_mocap', 'tedex_mocap', 's2g3d_mocap', 'h36m_mocap', 'mpi_mocap', + + 'humanml3d_t2m', 'kitml_t2m', 'babel_t2m', 'motionx_t2m', + 'humanact12_t2m', 'uestc_t2m', 'ntu_t2m', + + 'aist_m2d', + 'beat_s2g', 'tedg_s2g', 'tedex_s2g', 's2g3d_s2g', + + 'h36m_v2m', 'mpi_v2m' +] +num_datasets = len(dataset_names) +# model settings +model = dict( + type='UnifiedMotionDiffusion', + model=dict( + type='LargeMotionModel', + input_feats=669, + max_seq_len=200, + num_parts=10, + latent_part_dim=64, + time_embed_dim=2048, + dataset_names=dataset_names, + num_layers=4, + num_cond_layers=2, + num_datasets=num_datasets, + dropout=0, + ca_block_cfg=dict( + type='ArtAttention', + num_experts=16, + topk=4, + gate_type='cosine_top', + gate_noise=1.0, + num_datasets=num_datasets, + has_text=True, + has_music=True, + has_speech=True, + has_video=True + ), + text_input_dim=1024, + music_input_dim=768, + speech_input_dim=768, + video_input_dim=1024, + guidance_cfg=dict( + all=dict(type='linear', scale=5.5), + ), + moe_route_loss_weight=10.0, + template_kl_loss_weight=0.0001, + use_pos_embedding=False, + cond_drop_rate=0.1 + ), + loss_recon=dict( + type='KinematicLoss', loss_type='mse', loss_weight=[20], reduction='none'), + train_repeat=1, + diffusion_train=dict( + beta_scheduler='linear', + diffusion_steps=1000, + model_mean_type='start_x', + model_var_type='fixed_large', + ), + diffusion_test_dict=dict( + base=dict( + beta_scheduler='linear', + diffusion_steps=1000, + model_mean_type='start_x', + model_var_type='fixed_large', + ), + all='15,15,8,6,6' + ), + inference_type='ddim', + loss_reduction='batch', + loss_weight='data/motionverse/statistics/loss_weight.npy' +) diff --git a/configs/lmm/lmm_small_demo.py b/configs/lmm/lmm_small_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..ec1cd154b8a846bbca5fb36cc7c4a2f3ce174969 --- /dev/null +++ b/configs/lmm/lmm_small_demo.py @@ -0,0 +1,22 @@ +_base_ = ['lmm.py'] + +model = dict( + model=dict( + latent_part_dim=64, + num_layers=8, + num_cond_layers=2, + dropout=0.1, + ca_block_cfg=dict( + num_experts=16, + topk=4 + ), + guidance_cfg=dict( + humanml3d_t2m=dict(type='linear', scale=10.5), + ), + ), + diffusion_test_dict=dict( + humanml3d_t2m='15,15,8,6,6', + ), +) + +data = dict(samples_per_gpu=32) \ No newline at end of file diff --git a/examples/angry.m4a b/examples/angry.m4a new file mode 100644 index 0000000000000000000000000000000000000000..247de8a6fcb87b032ddff8ca814511c616c50ee8 Binary files /dev/null and b/examples/angry.m4a differ diff --git a/examples/placeholder.m4a b/examples/placeholder.m4a new file mode 100644 index 0000000000000000000000000000000000000000..92d24114ff1101dc9773a7ef093ba2bffa9ee327 Binary files /dev/null and b/examples/placeholder.m4a differ diff --git a/examples/surprise.m4a b/examples/surprise.m4a new file mode 100644 index 0000000000000000000000000000000000000000..06b735959824431eda145c9b40d3808bed51cce0 Binary files /dev/null and b/examples/surprise.m4a differ diff --git a/mogen/__init__.py b/mogen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5574a5119b48f7ef906d9267a41e811e626b9b09 --- /dev/null +++ b/mogen/__init__.py @@ -0,0 +1,56 @@ +import warnings + +import mmcv +from packaging.version import parse + +from .version import __version__ + + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + Returns: + tuple[int]: The version info in digits (integers). + """ + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) + else: + release.extend([0, 0]) + return tuple(release) + + +mmcv_minimum_version = '1.4.2' +mmcv_maximum_version = '1.9.0' +mmcv_version = digit_version(mmcv.__version__) + + +assert (mmcv_version >= digit_version(mmcv_minimum_version) + and mmcv_version <= digit_version(mmcv_maximum_version)), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' + +__all__ = ['__version__', 'digit_version'] diff --git a/mogen/apis/__init__.py b/mogen/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..238e84e8a70e6aa229e67450ad9766864f90f34a --- /dev/null +++ b/mogen/apis/__init__.py @@ -0,0 +1,8 @@ +from mogen.apis.test import (collect_results_cpu, collect_results_gpu, + multi_gpu_test, single_gpu_test) +from mogen.apis.train import set_random_seed, train_model + +__all__ = [ + 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test', + 'single_gpu_test', 'set_random_seed', 'train_model' +] diff --git a/mogen/apis/test.py b/mogen/apis/test.py new file mode 100644 index 0000000000000000000000000000000000000000..ce360e6bcf927744b3e02add971764286c78fd7a --- /dev/null +++ b/mogen/apis/test.py @@ -0,0 +1,158 @@ +import os.path as osp +import pickle +import shutil +import tempfile +import time + +import mmcv +import torch +import torch.distributed as dist +from mmcv.runner import get_dist_info + + +def single_gpu_test(model, data_loader): + """Test with single gpu.""" + model.eval() + results = [] + dataset = data_loader.dataset + prog_bar = mmcv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, **data) + + batch_size = len(result) + if isinstance(result, list): + results.extend(result) + else: + results.append(result) + + batch_size = data['motion'].size(0) + for _ in range(batch_size): + prog_bar.update() + return results + + +def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): + """Test model with multiple gpus. + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' + it encodes results to gpu tensors and use gpu communication for results + collection. On cpu mode it saves the results on different gpus to 'tmpdir' + and collects them by the rank 0 worker. + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + # Check if tmpdir is valid for cpu_collect + if (not gpu_collect) and (tmpdir is not None and osp.exists(tmpdir)): + raise OSError((f'The tmpdir {tmpdir} already exists.', + ' Since tmpdir will be deleted after testing,', + ' please make sure you specify an empty one.')) + prog_bar = mmcv.ProgressBar(len(dataset)) + time.sleep(2) # This line can prevent deadlock problem in some cases. + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, **data) + if isinstance(result, list): + results.extend(result) + else: + results.append(result) + + if rank == 0: + batch_size = data['motion'].size(0) + for _ in range(batch_size * world_size): + prog_bar.update() + + # collect results from all ranks + if gpu_collect: + results = collect_results_gpu(results, len(dataset)) + else: + results = collect_results_cpu(results, len(dataset), tmpdir) + return results + + +def collect_results_cpu(result_part, size, tmpdir=None): + """Collect results in cpu.""" + rank, world_size = get_dist_info() + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') + if rank == 0: + mmcv.mkdir_or_exist('.dist_test') + tmpdir = tempfile.mkdtemp(dir='.dist_test') + tmpdir = torch.tensor(bytearray(tmpdir.encode()), + dtype=torch.uint8, + device='cuda') + dir_tensor[:len(tmpdir)] = tmpdir + dist.broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + else: + mmcv.mkdir_or_exist(tmpdir) + # dump the part result to the dir + mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl')) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, f'part_{i}.pkl') + part_result = mmcv.load(part_file) + part_list.append(part_result) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results + + +def collect_results_gpu(result_part, size): + """Collect results in gpu.""" + rank, world_size = get_dist_info() + # dump result part to tensor with pickle + part_tensor = torch.tensor(bytearray(pickle.dumps(result_part)), + dtype=torch.uint8, + device='cuda') + # gather all result part tensor shape + shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_list = [shape_tensor.clone() for _ in range(world_size)] + dist.all_gather(shape_list, shape_tensor) + # padding result part tensor to max length + shape_max = torch.tensor(shape_list).max() + part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send[:shape_tensor[0]] = part_tensor + part_recv_list = [ + part_tensor.new_zeros(shape_max) for _ in range(world_size) + ] + # gather all result part + dist.all_gather(part_recv_list, part_send) + + if rank == 0: + ordered_results = [] + for recv, shape in zip(part_recv_list, shape_list): + part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()) + ordered_results.extend(part_result) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results diff --git a/mogen/apis/train.py b/mogen/apis/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7062f893872bcc13ea8f8a7d184ae612d41ff592 --- /dev/null +++ b/mogen/apis/train.py @@ -0,0 +1,161 @@ +import random +import warnings + +import numpy as np +import torch +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import ( + DistSamplerSeedHook, + Fp16OptimizerHook, + OptimizerHook, + GradientCumulativeFp16OptimizerHook, + GradientCumulativeOptimizerHook, + build_runner) + +from mogen.core.distributed_wrapper import DistributedDataParallelWrapper +from mogen.core.evaluation import DistEvalHook, EvalHook +from mogen.core.optimizer import build_optimizers +from mogen.datasets import build_dataloader, build_dataset +from mogen.utils import get_root_logger + + +def set_random_seed(seed, deterministic=False): + """Set random seed. + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def train_model(model, + dataset, + cfg, + distributed=False, + validate=False, + timestamp=None, + device='cuda', + meta=None): + """Main api for training model.""" + logger = get_root_logger(cfg.log_level) + + # prepare data loaders + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] + + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + # cfg.gpus will be ignored if distributed + num_gpus=len(cfg.gpu_ids), + dist=distributed, + round_up=True, + sampler_cfg=cfg.data.sampler_cfg, + batch_sampler_cfg=cfg.data.batch_sampler_cfg, + seed=cfg.seed) for ds in dataset + ] + + # determine whether use adversarial training precess or not + use_adversarial_train = cfg.get('use_adversarial_train', False) + + # put model on gpus + if distributed: + find_unused_parameters = cfg.get('find_unused_parameters', True) + # Sets the `find_unused_parameters` parameter in + # torch.nn.parallel.DistributedDataParallel + if use_adversarial_train: + # Use DistributedDataParallelWrapper for adversarial training + model = DistributedDataParallelWrapper( + model, + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters) + else: + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters) + else: + if device == 'cuda': + model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), + device_ids=cfg.gpu_ids) + elif device == 'cpu': + model = model.cpu() + else: + raise ValueError(F'unsupported device name {device}.') + + # build runner + optimizer = build_optimizers(model, cfg.optimizer) + + if cfg.get('runner') is None: + cfg.runner = { + 'type': 'EpochBasedRunner', + 'max_epochs': cfg.total_epochs + } + warnings.warn( + 'config is now expected to have a `runner` section, ' + 'please set `runner` in your config.', UserWarning) + + runner = build_runner(cfg.runner, + default_args=dict(model=model, + batch_processor=None, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta)) + + # an ugly walkaround to make the .log and .log.json filenames the same + runner.timestamp = timestamp + + if use_adversarial_train: + # The optimizer step process is included in the train_step function + # of the model, so the runner should NOT include optimizer hook. + optimizer_config = None + else: + if distributed and 'type' not in cfg.optimizer_config: + optimizer_config = OptimizerHook(**cfg.optimizer_config) + else: + optimizer_config = cfg.optimizer_config + + # register hooks + runner.register_training_hooks(cfg.lr_config, + optimizer_config, + cfg.checkpoint_config, + cfg.log_config, + cfg.get('momentum_config', None), + custom_hooks_config=cfg.get( + 'custom_hooks', None)) + if distributed: + runner.register_hook(DistSamplerSeedHook()) + + # register eval hooks + if validate: + val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) + val_dataloader = build_dataloader( + val_dataset, + samples_per_gpu=cfg.data.samples_per_gpu, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False, + round_up=True) + eval_cfg = cfg.get('evaluation', {}) + eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' + eval_hook = DistEvalHook if distributed else EvalHook + runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) + + if cfg.resume_from: + runner.resume(cfg.resume_from) + elif cfg.load_from: + runner.load_checkpoint(cfg.load_from) + runner.run(data_loaders, cfg.workflow) diff --git a/mogen/core/__init__.py b/mogen/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mogen/core/distributed_wrapper.py b/mogen/core/distributed_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..fe24e1c554e3c953cbeae8811ef2de7118c1eeac --- /dev/null +++ b/mogen/core/distributed_wrapper.py @@ -0,0 +1,135 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel +from mmcv.parallel.scatter_gather import scatter_kwargs +from torch.cuda._utils import _get_device_index + + +@MODULE_WRAPPERS.register_module() +class DistributedDataParallelWrapper(nn.Module): + """A DistributedDataParallel wrapper for models in 3D mesh estimation task. + + In some pieplines, there is a need to wrap different modules in + the models with separate DistributedDataParallel. Otherwise, it will cause + errors for GAN training. + More specific, the GAN model, usually has two sub-modules: + generator and discriminator. If we wrap both of them in one + standard DistributedDataParallel, it will cause errors during training, + because when we update the parameters of the generator (or discriminator), + the parameters of the discriminator (or generator) is not updated, which is + not allowed for DistributedDataParallel. + So we design this wrapper to separately wrap DistributedDataParallel + for generator and discriminator. + In this wrapper, we perform two operations: + 1. Wrap the modules in the models with separate MMDistributedDataParallel. + Note that only modules with parameters will be wrapped. + 2. Do scatter operation for 'forward', 'train_step' and 'val_step'. + Note that the arguments of this wrapper is the same as those in + `torch.nn.parallel.distributed.DistributedDataParallel`. + Args: + module (nn.Module): Module that needs to be wrapped. + device_ids (list[int | `torch.device`]): Same as that in + `torch.nn.parallel.distributed.DistributedDataParallel`. + dim (int, optional): Same as that in the official scatter function in + pytorch. Defaults to 0. + broadcast_buffers (bool): Same as that in + `torch.nn.parallel.distributed.DistributedDataParallel`. + Defaults to False. + find_unused_parameters (bool, optional): Same as that in + `torch.nn.parallel.distributed.DistributedDataParallel`. + Traverse the autograd graph of all tensors contained in returned + value of the wrapped module’s forward function. Defaults to False. + kwargs (dict): Other arguments used in + `torch.nn.parallel.distributed.DistributedDataParallel`. + """ + + def __init__(self, + module, + device_ids, + dim=0, + broadcast_buffers=False, + find_unused_parameters=False, + **kwargs): + super().__init__() + assert len(device_ids) == 1, ( + 'Currently, DistributedDataParallelWrapper only supports one' + 'single CUDA device for each process.' + f'The length of device_ids must be 1, but got {len(device_ids)}.') + self.module = module + self.dim = dim + self.to_ddp(device_ids=device_ids, + dim=dim, + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + **kwargs) + self.output_device = _get_device_index(device_ids[0], True) + + def to_ddp(self, device_ids, dim, broadcast_buffers, + find_unused_parameters, **kwargs): + """Wrap models with separate MMDistributedDataParallel. + + It only wraps the modules with parameters. + """ + for name, module in self.module._modules.items(): + if next(module.parameters(), None) is None: + module = module.cuda() + elif all(not p.requires_grad for p in module.parameters()): + module = module.cuda() + else: + module = MMDistributedDataParallel( + module.cuda(), + device_ids=device_ids, + dim=dim, + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + **kwargs) + self.module._modules[name] = module + + def scatter(self, inputs, kwargs, device_ids): + """Scatter function. + + Args: + inputs (Tensor): Input Tensor. + kwargs (dict): Args for + ``mmcv.parallel.scatter_gather.scatter_kwargs``. + device_ids (int): Device id. + """ + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def forward(self, *inputs, **kwargs): + """Forward function. + + Args: + inputs (tuple): Input data. + kwargs (dict): Args for + ``mmcv.parallel.scatter_gather.scatter_kwargs``. + """ + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + return self.module(*inputs[0], **kwargs[0]) + + def train_step(self, *inputs, **kwargs): + """Train step function. + + Args: + inputs (Tensor): Input Tensor. + kwargs (dict): Args for + ``mmcv.parallel.scatter_gather.scatter_kwargs``. + """ + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + output = self.module.train_step(*inputs[0], **kwargs[0]) + return output + + def val_step(self, *inputs, **kwargs): + """Validation step function. + + Args: + inputs (tuple): Input data. + kwargs (dict): Args for ``scatter_kwargs``. + """ + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + output = self.module.val_step(*inputs[0], **kwargs[0]) + return output diff --git a/mogen/core/optimizer/__init__.py b/mogen/core/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..faed650e5f6e4b1509b51b5c56840e91f58f2ce0 --- /dev/null +++ b/mogen/core/optimizer/__init__.py @@ -0,0 +1,3 @@ +from .builder import OPTIMIZERS, build_optimizers + +__all__ = ['build_optimizers', 'OPTIMIZERS'] diff --git a/mogen/core/optimizer/builder.py b/mogen/core/optimizer/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..80f659453a797def86628250cfbb7638ec0f323f --- /dev/null +++ b/mogen/core/optimizer/builder.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.runner import build_optimizer +from mmcv.utils import Registry + +OPTIMIZERS = Registry('optimizers') + + +def build_optimizers(model, cfgs): + """Build multiple optimizers from configs. If `cfgs` contains several dicts + for optimizers, then a dict for each constructed optimizers will be + returned. If `cfgs` only contains one optimizer config, the constructed + optimizer itself will be returned. For example, + + 1) Multiple optimizer configs: + + .. code-block:: python + + optimizer_cfg = dict( + model1=dict(type='SGD', lr=lr), + model2=dict(type='SGD', lr=lr)) + + The return dict is + ``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)`` + + 2) Single optimizer config: + + .. code-block:: python + + optimizer_cfg = dict(type='SGD', lr=lr) + + The return is ``torch.optim.Optimizer``. + + Args: + model (:obj:`nn.Module`): The model with parameters to be optimized. + cfgs (dict): The config dict of the optimizer. + + Returns: + dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`: + The initialized optimizers. + """ + optimizers = {} + if hasattr(model, 'module'): + model = model.module + # determine whether 'cfgs' has several dicts for optimizers + if all(isinstance(v, dict) for v in cfgs.values()): + for key, cfg in cfgs.items(): + cfg_ = cfg.copy() + module = getattr(model, key) + optimizers[key] = build_optimizer(module, cfg_) + return optimizers + + return build_optimizer(model, cfgs) diff --git a/mogen/datasets/__init__.py b/mogen/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6baa7a1db60b08dbdcaff2a73ef707e02a36b990 --- /dev/null +++ b/mogen/datasets/__init__.py @@ -0,0 +1,12 @@ +from .base_dataset import BaseMotionDataset +from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset +from .pipelines import Compose +from .samplers import DistributedSampler +from .text_motion_dataset import TextMotionDataset +from .motionverse_dataset import MotionVerse + +__all__ = [ + 'BaseMotionDataset', 'TextMotionDataset', 'DATASETS', 'PIPELINES', + 'build_dataloader', 'build_dataset', 'Compose', 'DistributedSampler', + 'MotionVerse' +] diff --git a/mogen/datasets/base_dataset.py b/mogen/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5d4494d27f5e3df3d2a25b4335d91efcbfefecec --- /dev/null +++ b/mogen/datasets/base_dataset.py @@ -0,0 +1,183 @@ +import copy +import os +import json +from abc import abstractmethod +from typing import Optional, Union, List, Dict + +import numpy as np +import torch +from torch.utils.data import Dataset + +# from mogen.core.evaluation import build_evaluator +from mogen.models.builder import build_submodule +from .builder import DATASETS +from .pipelines import Compose + + +@DATASETS.register_module() +class BaseMotionDataset(Dataset): + """ + Base class for motion datasets. + + Args: + data_prefix (str): The prefix of the data path. + pipeline (list): A list of dicts, where each element represents an operation + defined in `mogen.datasets.pipelines`. + dataset_name (Optional[Union[str, None]]): The name of the dataset. Used to + identify the type of evaluation metric. + fixed_length (Optional[Union[int, None]]): The fixed length of the dataset for + iteration. If None, the dataset length is based on the number + of annotations. + ann_file (Optional[Union[str, None]]): The annotation file. If it is a string, + it is expected to be read from the file. If None, it will be + read from `data_prefix`. + motion_dir (Optional[Union[str, None]]): The directory containing motion data. + eval_cfg (Optional[Union[dict, None]]): Configuration for evaluation metrics. + test_mode (Optional[bool]): Whether the dataset is in test mode. Default is False. + + Attributes: + data_infos (list): Loaded dataset annotations. + evaluators (list): List of evaluation objects. + eval_indexes (np.ndarray): Array of indices used for evaluation. + evaluator_model (torch.nn.Module): Model used for evaluation. + pipeline (Compose): Data processing pipeline. + """ + + def __init__(self, + data_prefix: str, + pipeline: List[Dict], + dataset_name: Optional[Union[str, None]] = None, + fixed_length: Optional[Union[int, None]] = None, + ann_file: Optional[Union[str, None]] = None, + motion_dir: Optional[Union[str, None]] = None, + eval_cfg: Optional[Union[dict, None]] = None, + test_mode: Optional[bool] = False): + super(BaseMotionDataset, self).__init__() + + self.data_prefix = data_prefix + self.pipeline = Compose(pipeline) + self.dataset_name = dataset_name + self.fixed_length = fixed_length + self.ann_file = os.path.join(data_prefix, 'datasets', dataset_name, ann_file) + self.motion_dir = os.path.join(data_prefix, 'datasets', dataset_name, motion_dir) + self.eval_cfg = copy.deepcopy(eval_cfg) + self.test_mode = test_mode + + self.load_annotations() + if self.test_mode: + self.prepare_evaluation() + + @abstractmethod + def load_anno(self, name: str) -> dict: + """ + Abstract method to load a single annotation. + + Args: + name (str): Name or identifier of the annotation to load. + + Returns: + dict: Loaded annotation as a dictionary. + """ + pass + + def load_annotations(self): + """Load annotations from `ann_file` to `data_infos`.""" + self.data_infos = [] + idx = 0 + for line in open(self.ann_file, 'r').readlines(): + line = line.strip() + self.data_infos.append(self.load_anno(idx, line)) + idx += 1 + + def prepare_data(self, idx: int) -> dict: + """ + Prepare raw data for the given index. + + Args: + idx (int): Index of the data to prepare. + + Returns: + dict: Processed data for the given index. + """ + results = copy.deepcopy(self.data_infos[idx]) + results['dataset_name'] = self.dataset_name + results['sample_idx'] = idx + return self.pipeline(results) + + def __len__(self) -> int: + """Return the length of the current dataset. + + Returns: + int: Length of the dataset. + """ + if self.test_mode: + return len(self.eval_indexes) + elif self.fixed_length is not None: + return self.fixed_length + return len(self.data_infos) + + def __getitem__(self, idx: int) -> dict: + """ + Prepare data for the given index. + + Args: + idx (int): Index of the data. + + Returns: + dict: Data for the specified index. + """ + if self.test_mode: + idx = self.eval_indexes[idx] + elif self.fixed_length is not None: + idx = idx % len(self.data_infos) + elif self.balanced_sampling: + cid = np.random.randint(0, len(self.category_list)) + idx = np.random.randint(0, len(self.category_list[cid])) + idx = self.category_list[cid][idx] + return self.prepare_data(idx) + + def prepare_evaluation(self): + """Prepare evaluation settings, including evaluators and evaluation indices.""" + self.evaluators = [] + self.eval_indexes = [] + self.evaluator_model = build_submodule(self.eval_cfg.get('evaluator_model', None)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.evaluator_model = self.evaluator_model.to(device) + self.evaluator_model.eval() + self.eval_cfg['evaluator_model'] = self.evaluator_model + + for _ in range(self.eval_cfg['replication_times']): + eval_indexes = np.arange(len(self.data_infos)) + if self.eval_cfg.get('shuffle_indexes', False): + np.random.shuffle(eval_indexes) + self.eval_indexes.append(eval_indexes) + + for metric in self.eval_cfg['metrics']: + evaluator, self.eval_indexes = build_evaluator( + metric, self.eval_cfg, len(self.data_infos), self.eval_indexes) + self.evaluators.append(evaluator) + + self.eval_indexes = np.concatenate(self.eval_indexes) + + def evaluate(self, results: List[dict], work_dir: str, logger=None) -> dict: + """ + Evaluate the model performance based on the results. + + Args: + results (list): A list of result dictionaries. + work_dir (str): Directory where evaluation logs will be stored. + logger: Logger object to record evaluation results (optional). + + Returns: + dict: Dictionary containing evaluation metrics. + """ + metrics = {} + for evaluator in self.evaluators: + metrics.update(evaluator.evaluate(results)) + if logger is not None: + logger.info(metrics) + eval_output = os.path.join(work_dir, 'eval_results.log') + with open(eval_output, 'w') as f: + for k, v in metrics.items(): + f.write(k + ': ' + str(v) + '\n') + return metrics diff --git a/mogen/datasets/builder.py b/mogen/datasets/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..f048d15b2520f3c33c3bf7579201ba13c1aad64d --- /dev/null +++ b/mogen/datasets/builder.py @@ -0,0 +1,149 @@ +import platform +import random +from functools import partial +from typing import Optional, Union + +import numpy as np +from mmcv.parallel import collate +from mmcv.runner import get_dist_info +from mmcv.utils import Registry, build_from_cfg +from torch.utils.data import DataLoader +from torch.utils.data.dataset import Dataset + +from .samplers import ( + DistributedSampler, + DistributedWeightedRandomSampler, + MonoTaskBatchSampler +) + +if platform.system() != 'Windows': + # https://github.com/pytorch/pytorch/issues/973 + import resource + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + base_soft_limit = rlimit[0] + hard_limit = rlimit[1] + soft_limit = min(max(4096, base_soft_limit), hard_limit) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) + +DATASETS = Registry('dataset') +PIPELINES = Registry('pipeline') + + +def build_dataset(cfg: Union[dict, list, tuple], + default_args: Optional[Union[dict, None]] = None): + """"Build dataset by the given config.""" + from .dataset_wrappers import ConcatDataset, RepeatDataset + if isinstance(cfg, (list, tuple)): + dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) + elif cfg['type'] == 'RepeatDataset': + dataset = RepeatDataset(build_dataset(cfg['dataset'], default_args), + cfg['times']) + else: + dataset = build_from_cfg(cfg, DATASETS, default_args) + + return dataset + + +def build_dataloader(dataset: Dataset, + samples_per_gpu: int, + workers_per_gpu: int, + num_gpus: Optional[int] = 1, + dist: Optional[bool] = True, + shuffle: Optional[bool] = True, + round_up: Optional[bool] = True, + seed: Optional[Union[int, None]] = None, + sampler_cfg: Optional[dict] = None, + batch_sampler_cfg: Optional[dict] = None, + persistent_workers: Optional[bool] = True, + **kwargs): + """Build PyTorch DataLoader. + In distributed training, each GPU/process has a dataloader. + In non-distributed training, there is only one dataloader for all GPUs. + Args: + dataset (:obj:`Dataset`): A PyTorch dataset. + samples_per_gpu (int): Number of training samples on each GPU, i.e., + batch size of each GPU. + workers_per_gpu (int): How many subprocesses to use for data loading + for each GPU. + num_gpus (int, optional): Number of GPUs. Only used in non-distributed + training. + dist (bool, optional): Distributed training/test or not. Default: True. + shuffle (bool, optional): Whether to shuffle the data at every epoch. + Default: True. + round_up (bool, optional): Whether to round up the length of dataset by + adding extra samples to make it evenly divisible. Default: True. + kwargs: any keyword argument to be used to initialize DataLoader + Returns: + DataLoader: A PyTorch dataloader. + """ + rank, world_size = get_dist_info() + if dist: + weighted_sample = False + if sampler_cfg is not None: + weighted_sample = sampler_cfg.get('weighted_sample', False) + if weighted_sample: + sampler_cls = DistributedWeightedRandomSampler + else: + sampler_cls = DistributedSampler + sampler = sampler_cls( + dataset, + world_size, + rank, + shuffle=shuffle, + round_up=round_up + ) + shuffle = False + batch_size = samples_per_gpu + num_workers = workers_per_gpu + else: + sampler = None + batch_size = num_gpus * samples_per_gpu + num_workers = num_gpus * workers_per_gpu + + init_fn = partial( + worker_init_fn, num_workers=num_workers, rank=rank, + seed=seed) if seed is not None else None + + if batch_sampler_cfg is not None: + type_name = batch_sampler_cfg['type'] + assert type_name == 'MonoTaskBatchSampler' + batch_sampler = MonoTaskBatchSampler( + sampler=sampler, + batch_size=batch_size, + num_tasks = batch_sampler_cfg['num_tasks'] + ) + data_loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + num_workers=num_workers, + collate_fn=partial( + collate, samples_per_gpu=samples_per_gpu), + pin_memory=False, + shuffle=shuffle, + worker_init_fn=init_fn, + persistent_workers=persistent_workers, + **kwargs) + else: + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=partial( + collate, samples_per_gpu=samples_per_gpu), + pin_memory=False, + shuffle=shuffle, + worker_init_fn=init_fn, + persistent_workers=persistent_workers, + **kwargs) + + return data_loader + + +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): + """Init random seed for each worker.""" + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/mogen/datasets/dataset_wrappers.py b/mogen/datasets/dataset_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..dde13f0ad1ac8a4a710cebf25387fef067eaa441 --- /dev/null +++ b/mogen/datasets/dataset_wrappers.py @@ -0,0 +1,42 @@ +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset +from torch.utils.data.dataset import Dataset + +from .builder import DATASETS + + +@DATASETS.register_module() +class ConcatDataset(_ConcatDataset): + """A wrapper of concatenated dataset. + Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but + add `get_cat_ids` function. + Args: + datasets (list[:obj:`Dataset`]): A list of datasets. + """ + + def __init__(self, datasets: list): + super(ConcatDataset, self).__init__(datasets) + + +@DATASETS.register_module() +class RepeatDataset(object): + """A wrapper of repeated dataset. + The length of repeated dataset will be `times` larger than the original + dataset. This is useful when the data loading time is long but the dataset + is small. Using RepeatDataset can reduce the data loading time between + epochs. + Args: + dataset (:obj:`Dataset`): The dataset to be repeated. + times (int): Repeat times. + """ + + def __init__(self, dataset: Dataset, times: int): + self.dataset = dataset + self.times = times + + self._ori_len = len(self.dataset) + + def __getitem__(self, idx: int): + return self.dataset[idx % self._ori_len] + + def __len__(self): + return self.times * self._ori_len diff --git a/mogen/datasets/human_body_prior/__init__.py b/mogen/datasets/human_body_prior/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d95fab9b48be9d43467cd1e8d77f25c5a397f9 --- /dev/null +++ b/mogen/datasets/human_body_prior/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2018.01.02 diff --git a/mogen/datasets/human_body_prior/body_model/__init__.py b/mogen/datasets/human_body_prior/body_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d95fab9b48be9d43467cd1e8d77f25c5a397f9 --- /dev/null +++ b/mogen/datasets/human_body_prior/body_model/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2018.01.02 diff --git a/mogen/datasets/human_body_prior/body_model/body_model.py b/mogen/datasets/human_body_prior/body_model/body_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6f6fc68bfebad475651f2838f3ba58d4675e9d2d --- /dev/null +++ b/mogen/datasets/human_body_prior/body_model/body_model.py @@ -0,0 +1,281 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2018.12.13 + +import numpy as np + +import torch +import torch.nn as nn + +# from smplx.lbs import lbs +from .lbs import lbs +import sys + +class BodyModel(nn.Module): + + def __init__(self, + bm_fname, + num_betas=10, + num_dmpls=None, dmpl_fname=None, + num_expressions=80, + use_posedirs=True, + dtype=torch.float32, + persistant_buffer=False): + + super(BodyModel, self).__init__() + + ''' + :param bm_fname: path to a SMPL model as pkl file + :param num_betas: number of shape parameters to include. + :param device: default on gpu + :param dtype: float precision of the computations + :return: verts, trans, pose, betas + ''' + + self.dtype = dtype + + + # -- Load SMPL params -- + if '.npz' in bm_fname: + smpl_dict = np.load(bm_fname, encoding='latin1') + else: + raise ValueError('bm_fname should be either a .pkl nor .npz file') + + # these are supposed for later convenient look up + self.num_betas = num_betas + self.num_dmpls = num_dmpls + self.num_expressions = num_expressions + + njoints = smpl_dict['posedirs'].shape[2] // 3 + self.model_type = {69: 'smpl', 153: 'smplh', 162: 'smplx', 45: 'mano', 105: 'animal_horse', 102: 'animal_dog', }[njoints] + + assert self.model_type in ['smpl', 'smplh', 'smplx', 'mano', 'mano', 'animal_horse', 'animal_dog'], ValueError( + 'model_type should be in smpl/smplh/smplx/mano.') + + self.use_dmpl = False + if num_dmpls is not None: + if dmpl_fname is not None: + self.use_dmpl = True + else: + raise (ValueError('dmpl_fname should be provided when using dmpls!')) + + if self.use_dmpl and self.model_type in ['smplx', 'mano', 'animal_horse', 'animal_dog']: raise ( + NotImplementedError('DMPLs only work with SMPL/SMPLH models for now.')) + + # Mean template vertices + self.comp_register('init_v_template', torch.tensor(smpl_dict['v_template'][None], dtype=dtype), persistent=persistant_buffer) + + self.comp_register('f', torch.tensor(smpl_dict['f'].astype(np.int32), dtype=torch.int32), persistent=persistant_buffer) + + num_total_betas = smpl_dict['shapedirs'].shape[-1] + if num_betas < 1: + num_betas = num_total_betas + + shapedirs = smpl_dict['shapedirs'][:, :, :num_betas] + self.comp_register('shapedirs', torch.tensor(shapedirs, dtype=dtype), persistent=persistant_buffer) + + if self.model_type == 'smplx': + if smpl_dict['shapedirs'].shape[-1] > 300: + begin_shape_id = 300 + else: + begin_shape_id = 10 + num_expressions = smpl_dict['shapedirs'].shape[-1] - 10 + + exprdirs = smpl_dict['shapedirs'][:, :, begin_shape_id:(begin_shape_id + num_expressions)] + self.comp_register('exprdirs', torch.tensor(exprdirs, dtype=dtype), persistent=persistant_buffer) + + expression = torch.tensor(np.zeros((1, num_expressions)), dtype=dtype) + self.comp_register('init_expression', expression, persistent=persistant_buffer) + + if self.use_dmpl: + dmpldirs = np.load(dmpl_fname)['eigvec'] + + dmpldirs = dmpldirs[:, :, :num_dmpls] + self.comp_register('dmpldirs', torch.tensor(dmpldirs, dtype=dtype), persistent=persistant_buffer) + + # Regressor for joint locations given shape - 6890 x 24 + self.comp_register('J_regressor', torch.tensor(smpl_dict['J_regressor'], dtype=dtype), persistent=persistant_buffer) + + # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*30 x 207 + if use_posedirs: + posedirs = smpl_dict['posedirs'] + posedirs = posedirs.reshape([posedirs.shape[0] * 3, -1]).T + self.comp_register('posedirs', torch.tensor(posedirs, dtype=dtype), persistent=persistant_buffer) + else: + self.posedirs = None + + # indices of parents for each joints + kintree_table = smpl_dict['kintree_table'].astype(np.int32) + self.comp_register('kintree_table', torch.tensor(kintree_table, dtype=torch.int32), persistent=persistant_buffer) + + # LBS weights + # weights = np.repeat(smpl_dict['weights'][np.newaxis], batch_size, axis=0) + weights = smpl_dict['weights'] + self.comp_register('weights', torch.tensor(weights, dtype=dtype), persistent=persistant_buffer) + + self.comp_register('init_trans', torch.zeros((1,3), dtype=dtype), persistent=persistant_buffer) + # self.register_parameter('trans', nn.Parameter(trans, requires_grad=True)) + + # root_orient + # if self.model_type in ['smpl', 'smplh']: + self.comp_register('init_root_orient', torch.zeros((1,3), dtype=dtype), persistent=persistant_buffer) + + # pose_body + if self.model_type in ['smpl', 'smplh', 'smplx']: + self.comp_register('init_pose_body', torch.zeros((1,63), dtype=dtype), persistent=persistant_buffer) + elif self.model_type == 'animal_horse': + self.comp_register('init_pose_body', torch.zeros((1,105), dtype=dtype), persistent=persistant_buffer) + elif self.model_type == 'animal_dog': + self.comp_register('init_pose_body', torch.zeros((1,102), dtype=dtype), persistent=persistant_buffer) + + # pose_hand + if self.model_type in ['smpl']: + self.comp_register('init_pose_hand', torch.zeros((1,1*3*2), dtype=dtype), persistent=persistant_buffer) + elif self.model_type in ['smplh', 'smplx']: + self.comp_register('init_pose_hand', torch.zeros((1,15*3*2), dtype=dtype), persistent=persistant_buffer) + elif self.model_type in ['mano']: + self.comp_register('init_pose_hand', torch.zeros((1,15*3), dtype=dtype), persistent=persistant_buffer) + + # face poses + if self.model_type == 'smplx': + self.comp_register('init_pose_jaw', torch.zeros((1,1*3), dtype=dtype), persistent=persistant_buffer) + self.comp_register('init_pose_eye', torch.zeros((1,2*3), dtype=dtype), persistent=persistant_buffer) + + self.comp_register('init_betas', torch.zeros((1,num_betas), dtype=dtype), persistent=persistant_buffer) + + if self.use_dmpl: + self.comp_register('init_dmpls', torch.zeros((1,num_dmpls), dtype=dtype), persistent=persistant_buffer) + + def comp_register(self, name, value, persistent=False): + if sys.version_info[0] > 2: + self.register_buffer(name, value, persistent) + else: + self.register_buffer(name, value) + + def r(self): + from human_body_prior.tools.omni_tools import copy2cpu as c2c + return c2c(self.forward().v) + + def forward(self, root_orient=None, pose_body=None, pose_hand=None, pose_jaw=None, pose_eye=None, betas=None, + trans=None, dmpls=None, expression=None, v_template =None, joints=None, v_shaped=None, return_dict=False, **kwargs): + ''' + + :param root_orient: Nx3 + :param pose_body: + :param pose_hand: + :param pose_jaw: + :param pose_eye: + :param kwargs: + :return: + ''' + batch_size = 1 + # compute batchsize by any of the provided variables + for arg in [root_orient,pose_body,pose_hand,pose_jaw,pose_eye,betas,trans, dmpls,expression, v_template,joints]: + if arg is not None: + batch_size = arg.shape[0] + break + + # assert not (v_template is not None and betas is not None), ValueError('vtemplate and betas could not be used jointly.') + assert self.model_type in ['smpl', 'smplh', 'smplx', 'mano', 'animal_horse', 'animal_dog'], ValueError( + 'model_type should be in smpl/smplh/smplx/mano') + if root_orient is None: root_orient = self.init_root_orient.expand(batch_size, -1) + if self.model_type in ['smplh', 'smpl']: + if pose_body is None: pose_body = self.init_pose_body.expand(batch_size, -1) + if pose_hand is None: pose_hand = self.init_pose_hand.expand(batch_size, -1) + elif self.model_type == 'smplx': + if pose_body is None: pose_body = self.init_pose_body.expand(batch_size, -1) + if pose_hand is None: pose_hand = self.init_pose_hand.expand(batch_size, -1) + if pose_jaw is None: pose_jaw = self.init_pose_jaw.expand(batch_size, -1) + if pose_eye is None: pose_eye = self.init_pose_eye.expand(batch_size, -1) + elif self.model_type in ['mano',]: + if pose_hand is None: pose_hand = self.init_pose_hand.expand(batch_size, -1) + elif self.model_type in ['animal_horse','animal_dog']: + if pose_body is None: pose_body = self.init_pose_body.expand(batch_size, -1) + + if pose_hand is None and self.model_type not in ['animal_horse', 'animal_dog']: pose_hand = self.init_pose_hand.expand(batch_size, -1) + + if trans is None: trans = self.init_trans.expand(batch_size, -1) + if v_template is None: v_template = self.init_v_template.expand(batch_size, -1,-1) + if betas is None: betas = self.init_betas.expand(batch_size, -1) + + if self.model_type in ['smplh', 'smpl']: + full_pose = torch.cat([root_orient, pose_body, pose_hand], dim=-1) + elif self.model_type == 'smplx': + full_pose = torch.cat([root_orient, pose_body, pose_jaw, pose_eye, pose_hand], dim=-1) # orient:3, body:63, jaw:3, eyel:3, eyer:3, handl, handr + elif self.model_type in ['mano', ]: + full_pose = torch.cat([root_orient, pose_hand], dim=-1) + elif self.model_type in ['animal_horse', 'animal_dog']: + full_pose = torch.cat([root_orient, pose_body], dim=-1) + + if self.use_dmpl: + if dmpls is None: dmpls = self.init_dmpls.expand(batch_size, -1) + shape_components = torch.cat([betas, dmpls], dim=-1) + shapedirs = torch.cat([self.shapedirs, self.dmpldirs], dim=-1) + elif self.model_type == 'smplx': + if expression is None: expression = self.init_expression.expand(batch_size, -1) + shape_components = torch.cat([betas, expression], dim=-1) + shapedirs = torch.cat([self.shapedirs, self.exprdirs], dim=-1) + else: + shape_components = betas + shapedirs = self.shapedirs + + verts, Jtr = lbs(betas=shape_components, pose=full_pose, v_template=v_template, + shapedirs=shapedirs, posedirs=self.posedirs, + J_regressor=self.J_regressor, parents=self.kintree_table[0].long(), + lbs_weights=self.weights, joints=joints, v_shaped=v_shaped, + dtype=self.dtype) + + Jtr = Jtr + trans.unsqueeze(dim=1) + verts = verts + trans.unsqueeze(dim=1) + + res = {} + res['v'] = verts + res['f'] = self.f + res['Jtr'] = Jtr # Todo: ik can be made with vposer + # res['bStree_table'] = self.kintree_table + + # if self.model_type == 'smpl': + # res['pose_body'] = pose_body + # elif self.model_type == 'smplh': + # res['pose_body'] = pose_body + # res['pose_hand'] = pose_hand + # elif self.model_type == 'smplx': + # res['pose_body'] = pose_body + # res['pose_hand'] = pose_hand + # res['pose_jaw'] = pose_jaw + # res['pose_eye'] = pose_eye + # elif self.model_type in ['mano', 'mano']: + # res['pose_hand'] = pose_hand + res['full_pose'] = full_pose + + if not return_dict: + class result_meta(object): + pass + + res_class = result_meta() + for k, v in res.items(): + res_class.__setattr__(k, v) + res = res_class + + return res + + diff --git a/mogen/datasets/human_body_prior/body_model/lbs.py b/mogen/datasets/human_body_prior/body_model/lbs.py new file mode 100644 index 0000000000000000000000000000000000000000..468adc90dbdf4c0072459e0dbd252ec5dde3461e --- /dev/null +++ b/mogen/datasets/human_body_prior/body_model/lbs.py @@ -0,0 +1,404 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Vassilis Choutas +# + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import numpy as np + +import torch +import torch.nn.functional as F + +def to_tensor(array, dtype=torch.float32): + if 'torch.tensor' not in str(type(array)): + return torch.tensor(array, dtype=dtype) + + +class Struct(object): + def __init__(self, **kwargs): + for key, val in kwargs.items(): + setattr(self, key, val) + + +def to_np(array, dtype=np.float32): + if 'scipy.sparse' in str(type(array)): + array = array.todense() + return np.array(array, dtype=dtype) + + +def rot_mat_to_euler(rot_mats): + # Calculates rotation matrix to euler angles + # Careful for extreme cases of eular angles like [0.0, pi, 0.0] + + sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + + rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) + return torch.atan2(-rot_mats[:, 2, 0], sy) + + +def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx, + dynamic_lmk_b_coords, + neck_kin_chain, dtype=torch.float32): + ''' Compute the faces, barycentric coordinates for the dynamic landmarks + + + To do so, we first compute the rotation of the neck around the y-axis + and then use a pre-computed look-up table to find the faces and the + barycentric coordinates that will be used. + + Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de) + for providing the original TensorFlow implementation and for the LUT. + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + pose: torch.tensor Bx(Jx3), dtype = torch.float32 + The current pose of the body model + dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long + The look-up table from neck rotation to faces + dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32 + The look-up table from neck rotation to barycentric coordinates + neck_kin_chain: list + A python list that contains the indices of the joints that form the + kinematic chain of the neck. + dtype: torch.dtype, optional + + Returns + ------- + dyn_lmk_faces_idx: torch.tensor, dtype = torch.long + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + dyn_lmk_b_coords: torch.tensor, dtype = torch.float32 + A tensor of size BxL that contains the indices of the faces that + will be used to compute the current dynamic landmarks. + ''' + + batch_size = vertices.shape[0] + + aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1, + neck_kin_chain) + rot_mats = batch_rodrigues( + aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3) + + rel_rot_mat = torch.eye(3, device=vertices.device, + dtype=dtype).unsqueeze_(dim=0) + for idx in range(len(neck_kin_chain)): + rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat) + + y_rot_angle = torch.round( + torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi, + max=39)).to(dtype=torch.long) + neg_mask = y_rot_angle.lt(0).to(dtype=torch.long) + mask = y_rot_angle.lt(-39).to(dtype=torch.long) + neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle) + y_rot_angle = (neg_mask * neg_vals + + (1 - neg_mask) * y_rot_angle) + + dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx, + 0, y_rot_angle) + dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords, + 0, y_rot_angle) + + return dyn_lmk_faces_idx, dyn_lmk_b_coords + + +def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords): + ''' Calculates landmarks by barycentric interpolation + + Parameters + ---------- + vertices: torch.tensor BxVx3, dtype = torch.float32 + The tensor of input vertices + faces: torch.tensor Fx3, dtype = torch.long + The faces of the mesh + lmk_faces_idx: torch.tensor L, dtype = torch.long + The tensor with the indices of the faces used to calculate the + landmarks. + lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32 + The tensor of barycentric coordinates that are used to interpolate + the landmarks + + Returns + ------- + landmarks: torch.tensor BxLx3, dtype = torch.float32 + The coordinates of the landmarks for each mesh in the batch + ''' + # Extract the indices of the vertices for each face + # BxLx3 + batch_size, num_verts = vertices.shape[:2] + device = vertices.device + + lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).view( + batch_size, -1, 3) + + lmk_faces += torch.arange( + batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts + + lmk_vertices = vertices.view(-1, 3)[lmk_faces].view( + batch_size, -1, 3, 3) + + landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords]) + return landmarks + + +def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents, + lbs_weights, joints = None, pose2rot=True, v_shaped=None, dtype=torch.float32): + ''' Performs Linear Blend Skinning with the given shape and pose parameters + + Parameters + ---------- + betas : torch.tensor BxNB + The tensor of shape parameters + pose : torch.tensor Bx(J + 1) * 3 + The pose parameters in axis-angle format + v_template torch.tensor BxVx3 + The template mesh that will be deformed + shapedirs : torch.tensor 1xNB + The tensor of PCA shape displacements + posedirs : torch.tensor Px(V * 3) + The pose PCA coefficients + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from + the position of the vertices + parents: torch.tensor J + The array that describes the kinematic tree for the model + lbs_weights: torch.tensor N x V x (J + 1) + The linear blend skinning weights that represent how much the + rotation matrix of each part affects each vertex + pose2rot: bool, optional + Flag on whether to convert the input pose tensor to rotation + matrices. The default value is True. If False, then the pose tensor + should already contain rotation matrices and have a size of + Bx(J + 1)x9 + dtype: torch.dtype, optional + + Returns + ------- + verts: torch.tensor BxVx3 + The vertices of the mesh after applying the shape and pose + displacements. + joints: torch.tensor BxJx3 + The joints of the model + ''' + + batch_size = max(betas.shape[0], pose.shape[0]) + device = betas.device + + # Add shape contribution + if v_shaped is None: + v_shaped = v_template + blend_shapes(betas, shapedirs) + + # Get the joints + # NxJx3 array + if joints is not None: + J = joints + else: + J = vertices2joints(J_regressor, v_shaped) + + # 3. Add pose blend shapes + # N x J x 3 x 3 + ident = torch.eye(3, dtype=dtype, device=device) + if pose2rot: + rot_mats = batch_rodrigues( + pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3]) + + pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) + # (N x P) x (P, V * 3) -> N x V x 3 + pose_offsets = torch.matmul(pose_feature, posedirs).view(batch_size, -1, 3) + else: + pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident + rot_mats = pose.view(batch_size, -1, 3, 3) + + pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), + posedirs).view(batch_size, -1, 3) + + v_posed = pose_offsets + v_shaped + # 4. Get the global joint location + J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) + + # 5. Do skinning: + # W is N x V x (J + 1) + W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1]) + # (N x V x (J + 1)) x (N x (J + 1) x 16) + num_joints = J_regressor.shape[0] + T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \ + .view(batch_size, -1, 4, 4) + + homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1], + dtype=dtype, device=device) + v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2) + v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1)) + + verts = v_homo[:, :, :3, 0] + + return verts, J_transformed + + +def vertices2joints(J_regressor, vertices): + ''' Calculates the 3D joint locations from the vertices + + Parameters + ---------- + J_regressor : torch.tensor JxV + The regressor array that is used to calculate the joints from the + position of the vertices + vertices : torch.tensor BxVx3 + The tensor of mesh vertices + + Returns + ------- + torch.tensor BxJx3 + The location of the joints + ''' + + return torch.einsum('bik,ji->bjk', [vertices, J_regressor]) + + +def blend_shapes(betas, shape_disps): + ''' Calculates the per vertex displacement due to the blend shapes + + + Parameters + ---------- + betas : torch.tensor Bx(num_betas) + Blend shape coefficients + shape_disps: torch.tensor Vx3x(num_betas) + Blend shapes + + Returns + ------- + torch.tensor BxVx3 + The per-vertex displacement due to shape deformation + ''' + + # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l] + # i.e. Multiply each shape displacement by its corresponding beta and + # then sum them. + + #print(betas.device,shape_disps.device) + blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps]) + return blend_shape + + +def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32): + ''' Calculates the rotation matrices for a batch of rotation vectors + Parameters + ---------- + rot_vecs: torch.tensor Nx3 + array of N axis-angle vectors + Returns + ------- + R: torch.tensor Nx3x3 + The rotation matrices for the given axis-angle parameters + ''' + + batch_size = rot_vecs.shape[0] + device = rot_vecs.device + + angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) + rot_dir = rot_vecs / angle + + cos = torch.unsqueeze(torch.cos(angle), dim=1) + sin = torch.unsqueeze(torch.sin(angle), dim=1) + + # Bx1 arrays + rx, ry, rz = torch.split(rot_dir, 1, dim=1) + K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) + + zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ + .view((batch_size, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) + rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) + return rot_mat + + +def transform_mat(R, t): + ''' Creates a batch of transformation matrices + Args: + - R: Bx3x3 array of a batch of rotation matrices + - t: Bx3x1 array of a batch of translation vectors + Returns: + - T: Bx4x4 Transformation matrix + ''' + # No padding left or right, only add an extra row + return torch.cat([F.pad(R, [0, 0, 0, 1]), + F.pad(t, [0, 0, 0, 1], value=1)], dim=2) + + +def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32): + """ + Applies a batch of rigid transformations to the joints + + Parameters + ---------- + rot_mats : torch.tensor BxNx3x3 + Tensor of rotation matrices + joints : torch.tensor BxNx3 + Locations of joints + parents : torch.tensor BxN + The kinematic tree of each object + dtype : torch.dtype, optional: + The data type of the created tensors, the default is torch.float32 + + Returns + ------- + posed_joints : torch.tensor BxNx3 + The locations of the joints after applying the pose rotations + rel_transforms : torch.tensor BxNx4x4 + The relative (with respect to the root joint) rigid transformations + for all the joints + """ + + joints = torch.unsqueeze(joints, dim=-1) + + rel_joints = joints.clone() + rel_joints[:, 1:] -= joints[:, parents[1:]] + + transforms_mat = transform_mat( + rot_mats.reshape(-1, 3, 3), + rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) + + transform_chain = [transforms_mat[:, 0]] + for i in range(1, parents.shape[0]): + # Subtract the joint location at the rest pose + # No need for rotation, since it's identity when at rest + curr_res = torch.matmul(transform_chain[parents[i]], + transforms_mat[:, i]) + transform_chain.append(curr_res) + + transforms = torch.stack(transform_chain, dim=1) + + # The last column of the transformations contains the posed joints + posed_joints = transforms[:, :, :3, 3] + + # The last column of the transformations contains the posed joints + posed_joints = transforms[:, :, :3, 3] + + joints_homogen = F.pad(joints, [0, 0, 0, 1]) + + rel_transforms = transforms - F.pad( + torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) + + return posed_joints, rel_transforms diff --git a/mogen/datasets/human_body_prior/body_model/parts_segm/readme b/mogen/datasets/human_body_prior/body_model/parts_segm/readme new file mode 100644 index 0000000000000000000000000000000000000000..cc6a7dd0bd745bf982774b447a23220c1c638d8b --- /dev/null +++ b/mogen/datasets/human_body_prior/body_model/parts_segm/readme @@ -0,0 +1 @@ +### Parts segmentation file obtained from https://github.com/vchoutas/torch-mesh-isect#examples and put here for convenience \ No newline at end of file diff --git a/mogen/datasets/human_body_prior/body_model/rigid_object_model.py b/mogen/datasets/human_body_prior/body_model/rigid_object_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b0af3ee6587dcc395d188a4a626f76fe506672d5 --- /dev/null +++ b/mogen/datasets/human_body_prior/body_model/rigid_object_model.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2018.12.13 + +import numpy as np + +import torch +import torch.nn as nn + +# from smplx.lbs import lbs +from human_body_prior.body_model.lbs import lbs +# import trimesh # dont use this package for loading meshes since it messes up the order of vertices +from psbody.mesh import Mesh +from human_body_prior.body_model.lbs import batch_rodrigues + +class RigidObjectModel(nn.Module): + + def __init__(self, plpath, batch_size=1, dtype=torch.float32): + super(RigidObjectModel, self).__init__() + + trans = torch.tensor(np.zeros((batch_size, 3)), dtype=dtype, requires_grad=True) + self.register_parameter('trans', nn.Parameter(trans, requires_grad=True)) + + root_orient = torch.tensor(np.zeros((batch_size, 3)), dtype=dtype, requires_grad=True) + self.register_parameter('root_orient', nn.Parameter(root_orient, requires_grad=True)) + + mesh = Mesh(filename=plpath) + + self.rigid_v = torch.from_numpy(np.repeat(mesh.v[np.newaxis], batch_size, axis=0)).type(dtype) + self.f = torch.from_numpy(mesh.f.astype(np.int32)) + + def forward(self, root_orient, trans): + if root_orient is None: root_orient = self.root_orient + if trans is None: trans = self.trans + verts = torch.bmm(self.rigid_v, batch_rodrigues(root_orient)) + trans.view(-1,1,3) + + res = {} + res['v'] = verts + res['f'] = self.f + + class result_meta(object): pass + + res_class = result_meta() + for k, v in res.items(): + res_class.__setattr__(k, v) + res = res_class + + return res diff --git a/mogen/datasets/human_body_prior/models/__init__.py b/mogen/datasets/human_body_prior/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..172e1b84eb010c48afa5d7e0d142218c829ef0fb --- /dev/null +++ b/mogen/datasets/human_body_prior/models/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2020.12.12 \ No newline at end of file diff --git a/mogen/datasets/human_body_prior/models/ik_engine.py b/mogen/datasets/human_body_prior/models/ik_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..fd742bb302dc20188fc335e6329ae56d787270e1 --- /dev/null +++ b/mogen/datasets/human_body_prior/models/ik_engine.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2021.02.12 + +from typing import List, Dict + +from psbody.mesh import Mesh +from body_visualizer.tools.psbody_mesh_tools import rotateXYZ, points_to_cubes, points_to_spheres + + +from torch import nn +import torch + +from human_body_prior.tools.model_loader import load_model + +import numpy as np + +from body_visualizer.tools.vis_tools import colors +from human_body_prior.tools.omni_tools import copy2cpu as c2c +from psbody.mesh import MeshViewers + +from human_body_prior.tools.omni_tools import log2file + +from human_body_prior.models.vposer_model import VPoser +from human_body_prior.tools.omni_tools import flatten_list + + +def visualize(points, bm_f, mvs, kpts_colors, verbosity=2, logger=None): + from human_body_prior.tools.omni_tools import log2file + + if logger is None: logger = log2file() + + def view(opt_objs, body_v, virtual_markers, opt_it): + if verbosity <= 0: return + opt_objs_cpu = {k: c2c(v) for k, v in opt_objs.items()} + + total_loss = np.sum([np.sum(v) for k, v in opt_objs_cpu.items()]) + message = 'it {} -- [total loss = {:.2e}] - {}'.format(opt_it, total_loss, ' | '.join(['%s = %2.2e' % (k, np.sum(v)) for k, v in opt_objs_cpu.items()])) + logger(message) + if verbosity>1: + bs = body_v.shape[0] + np.random.seed(100) + frame_ids = list(range(bs)) if bs <= len(mvs) else np.random.choice(bs , size=len(mvs), replace=False).tolist() + if bs > len(mvs): message += ' -- [frame_ids: {}]'.format(frame_ids) + for dispId, fId in enumerate(frame_ids): # check for the number of frames in mvs and show a randomly picked number of frames in body if there is more to show than row*cols available + new_body_v = rotateXYZ(body_v[fId], [-90,0,0]) + + orig_mrk_mesh = points_to_spheres(rotateXYZ(c2c(points[fId]), [-90,0,0]), radius=0.01, color=kpts_colors) + virtual_markers_mesh = points_to_cubes(rotateXYZ(virtual_markers[fId], [-90,0,0]), radius=0.01, color=kpts_colors) + new_body_mesh = Mesh(new_body_v, bm_f, vc=colors['grey']) + + # linev = rotateXYZ(np.hstack((c2c(points[fId]), virtual_markers[fId])).reshape((-1, 3)), [-90,0,0]) + # linee = np.arange(len(linev)).reshape((-1, 2)) + # ll = Lines(v=linev, e=linee) + # ll.vc = (ll.v * 0. + 1) * np.array([0.00, 0.00, 1.00]) + # mvs[dispId].set_dynamic_lines([ll]) + + # orig_mrk_mesh = points_to_spheres(data_pc, radius=0.01, vc=colors['blue']) + mvs[dispId].set_dynamic_meshes([orig_mrk_mesh, virtual_markers_mesh]) + mvs[dispId].set_static_meshes([new_body_mesh]) + + mvs[0].set_titlebar(message) + # if out_dir is not None: mv.save_snapshot(os.path.join(out_dir, '%05d_it_%.5d.png' %(frame_id, opt_it))) + return view + + +class AdamInClosure(): + def __init__(self, var_list, lr, max_iter=100, tolerance_change=1e-5): + self.optimizer = torch.optim.Adam(var_list, lr) + self.max_iter = max_iter + self.tolerance_change = tolerance_change + + + def step(self, closure): + prev_loss = None + for it in range(self.max_iter): + loss = closure() + self.optimizer.step() + if prev_loss is None: + prev_loss = loss + continue + if torch.isnan(loss): + # breakpoint() + break + if abs(loss - prev_loss) < self.tolerance_change: + print('abs(loss - prev_loss) < self.tolerance_change') + break + + def zero_grad(self): + self.optimizer.zero_grad() + +def ik_fit(optimizer, source_kpts_model, static_vars, vp_model, extra_params={}, on_step=None, gstep=0): + + data_loss = extra_params.get('data_loss', torch.nn.SmoothL1Loss(reduction='mean')) + # data_loss = + # data_loss = torch.nn.L1Loss(reduction='mean')#change with SmoothL1 + + def fit(weights, free_vars): + + fit.gstep += 1 + optimizer.zero_grad() + + free_vars['pose_body'] = vp_model.decode(free_vars['poZ_body'])['pose_body'].contiguous().view(-1, 63) + nonan_mask = torch.isnan(free_vars['poZ_body']).sum(-1) == 0 + + opt_objs = {} + + res = source_kpts_model(free_vars) + + opt_objs['data'] = data_loss(res['source_kpts'], static_vars['target_kpts']) + + opt_objs['betas'] = torch.pow(free_vars['betas'][nonan_mask],2).sum() + opt_objs['poZ_body'] = torch.pow(free_vars['poZ_body'][nonan_mask],2).sum() + + + opt_objs = {k: opt_objs[k]*v for k, v in weights.items() if k in opt_objs.keys()} + loss_total = torch.sum(torch.stack(list(opt_objs.values()))) + # breakpoint() + + loss_total.backward() + + if on_step is not None: + on_step(opt_objs, c2c(res['body'].v), c2c(res['source_kpts']), fit.gstep) + + fit.free_vars = {k:v for k,v in free_vars.items()}# if k in IK_Engine.fields_to_optimize} + # fit.nonan_mask = nonan_mask + fit.final_loss = loss_total + + return loss_total + + fit.gstep = gstep + fit.final_loss = None + fit.free_vars = {} + # fit.nonan_mask = None + return fit + +class IK_Engine(nn.Module): + + + def __init__(self, + vposer_expr_dir: str, + data_loss, + optimizer_args: dict={'type':'ADAM'}, + stepwise_weights: List[Dict]=[{'data': 10., 'poZ_body': .01, 'betas': .5}], + display_rc: tuple = (2,1), + verbosity: int = 1, + logger=None, + ): + ''' + + :param vposer_expr_dir: The vposer directory that holds the settings and model snapshot + :param data_loss: should be a pytorch callable (source, target) that returns the accumulated loss + :param optimizer_args: arguments for optimizers + :param stepwise_weights: list of dictionaries. each list element defines weights for one full step of optimization + if a weight value is left out, its respective object item will be removed as well. imagine optimizing without data term! + :param display_rc: number of row and columns in case verbosity > 1 + :param verbosity: 0: silent, 1: text, 2: text/visual. running 2 over ssh would need extra work + :param logger: an instance of human_body_prior.tools.omni_tools.log2file + ''' + + + super(IK_Engine, self).__init__() + + assert isinstance(stepwise_weights, list), ValueError('stepwise_weights should be a list of dictionaries.') + assert np.all(['data' in l for l in stepwise_weights]), ValueError('The term data should be available in every weight of anealed optimization step: {}'.format(stepwise_weights)) + + self.data_loss = torch.nn.SmoothL1Loss(reduction='mean') if data_loss is None else data_loss + + self.stepwise_weights = stepwise_weights + self.verbosity = verbosity + self.optimizer_args = optimizer_args + + self.logger = log2file() if logger is None else logger + + + if verbosity>1: + mvs = MeshViewers(display_rc, keepalive=True) + self.mvs = flatten_list(mvs) + self.mvs[0].set_background_color(colors['white']) + else: + self.mvs=None + + self.vp_model, _ = load_model(vposer_expr_dir, + model_code=VPoser, + remove_words_in_model_weights='vp_model.', + disable_grad=True) + + + def forward(self, source_kpts, target_kpts, initial_body_params={}): + ''' + source_kpts is a function that given body parameters computes source key points that should match target key points + Try to reconstruct the bps signature by optimizing the body_poZ + ''' + # if self.rt_ps.verbosity > 0: self.logger('Processing {} frames'.format(points.shape[0])) + + bs = target_kpts.shape[0] + + + on_step = visualize(target_kpts, + kpts_colors=source_kpts.kpts_colors, + bm_f=source_kpts.bm_f, + mvs=self.mvs, + verbosity=self.verbosity, + logger=self.logger) + + comp_device = target_kpts.device + # comp_device = self.vp_model.named_parameters().__next__()[1].device + if 'pose_body' not in initial_body_params: + initial_body_params['pose_body'] = torch.zeros([bs, 63], device=comp_device, dtype=torch.float, requires_grad=False) + if 'trans' not in initial_body_params: + initial_body_params['trans'] = torch.zeros([bs, 3], device=comp_device, dtype=torch.float, requires_grad=False) + if 'betas' not in initial_body_params: + initial_body_params['betas'] = torch.zeros([bs, 10], device=comp_device, dtype=torch.float, requires_grad=False) + if 'root_orient' not in initial_body_params: + initial_body_params['root_orient'] = torch.zeros([bs, 3], device=comp_device, dtype=torch.float, requires_grad=False) + + initial_body_params['poZ_body'] = self.vp_model.encode(initial_body_params['pose_body']).mean + + free_vars = {k: torch.nn.Parameter(v.detach(), requires_grad=True) for k,v in initial_body_params.items() if k in ['betas', 'trans', 'poZ_body', 'root_orient']} + static_vars = { + 'target_kpts': target_kpts, + # 'trans': initial_body_params['trans'].detach(), + # 'betas': initial_body_params['betas'].detach(), + # 'poZ_body': initial_body_params['poZ_body'].detach() + } + + if self.optimizer_args['type'].upper() == 'LBFGS': + optimizer = torch.optim.LBFGS(list(free_vars.values()), + lr=self.optimizer_args.get('lr', 1), + max_iter=self.optimizer_args.get('max_iter', 100), + tolerance_change=self.optimizer_args.get('tolerance_change', 1e-5), + max_eval=self.optimizer_args.get('max_eval', None), + history_size=self.optimizer_args.get('history_size', 100), + line_search_fn='strong_wolfe') + + elif self.optimizer_args['type'].upper() == 'ADAM': + optimizer = AdamInClosure(list(free_vars.values()), + lr=self.optimizer_args.get('lr', 1e-3), + max_iter=self.optimizer_args.get('max_iter', 100), + tolerance_change=self.optimizer_args.get('tolerance_change', 1e-5), + ) + else: + raise ValueError('optimizer_type not recognized.') + + gstep = 0 + closure = ik_fit(optimizer, + source_kpts_model=source_kpts, + static_vars=static_vars, + vp_model=self.vp_model, + extra_params={'data_loss': self.data_loss}, + on_step=on_step, + gstep=gstep) + # try: + + for wts in self.stepwise_weights: + optimizer.step(lambda: closure(wts, free_vars)) + free_vars = closure.free_vars + # except: + # + # pass + + # if closure.final_loss is None or torch.isnan(closure.final_loss) or torch.any(torch.isnan(free_vars['trans'])): + # if self.verbosity > 0: + # self.logger('NaN observed in the optimization results. you might want to restart the refinment procedure.') + # breakpoint() + # return None + + return closure.free_vars#, closure.nonan_mask diff --git a/mogen/datasets/human_body_prior/models/model_components.py b/mogen/datasets/human_body_prior/models/model_components.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca4bdc47afb7be8490f7424a1ea4f88bc27e211 --- /dev/null +++ b/mogen/datasets/human_body_prior/models/model_components.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2020.12.12 + +from torch import nn + +class View(nn.Module): + def __init__(self, *args): + super(View, self).__init__() + self.shape = args + self._name = 'reshape' + + def forward(self, x): + return x.view(self.shape) + +class BatchFlatten(nn.Module): + def __init__(self): + super(BatchFlatten, self).__init__() + self._name = 'batch_flatten' + + def forward(self, x): + return x.view(x.shape[0], -1) \ No newline at end of file diff --git a/mogen/datasets/human_body_prior/models/vposer_model.py b/mogen/datasets/human_body_prior/models/vposer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..74739b024d7a367604b6bc971818d54037557d5a --- /dev/null +++ b/mogen/datasets/human_body_prior/models/vposer_model.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2020.12.12 + +import numpy as np +import torch +from human_body_prior.models.model_components import BatchFlatten +from human_body_prior.tools.rotation_tools import matrot2aa +from torch import nn +from torch.nn import functional as F + + +class ContinousRotReprDecoder(nn.Module): + def __init__(self): + super(ContinousRotReprDecoder, self).__init__() + + def forward(self, module_input): + reshaped_input = module_input.view(-1, 3, 2) + + b1 = F.normalize(reshaped_input[:, :, 0], dim=1) + + dot_prod = torch.sum(b1 * reshaped_input[:, :, 1], dim=1, keepdim=True) + b2 = F.normalize(reshaped_input[:, :, 1] - dot_prod * b1, dim=-1) + b3 = torch.cross(b1, b2, dim=1) + + return torch.stack([b1, b2, b3], dim=-1) + + +class NormalDistDecoder(nn.Module): + def __init__(self, num_feat_in, latentD): + super(NormalDistDecoder, self).__init__() + + self.mu = nn.Linear(num_feat_in, latentD) + self.logvar = nn.Linear(num_feat_in, latentD) + + def forward(self, Xout): + return torch.distributions.normal.Normal(self.mu(Xout), F.softplus(self.logvar(Xout))) + + +class VPoser(nn.Module): + def __init__(self, model_ps): + super(VPoser, self).__init__() + + num_neurons, self.latentD = model_ps.model_params.num_neurons, model_ps.model_params.latentD + + self.num_joints = 21 + n_features = self.num_joints * 3 + + self.encoder_net = nn.Sequential( + BatchFlatten(), + nn.BatchNorm1d(n_features), + nn.Linear(n_features, num_neurons), + nn.LeakyReLU(), + nn.BatchNorm1d(num_neurons), + nn.Dropout(0.1), + nn.Linear(num_neurons, num_neurons), + nn.Linear(num_neurons, num_neurons), + NormalDistDecoder(num_neurons, self.latentD) + ) + + self.decoder_net = nn.Sequential( + nn.Linear(self.latentD, num_neurons), + nn.LeakyReLU(), + nn.Dropout(0.1), + nn.Linear(num_neurons, num_neurons), + nn.LeakyReLU(), + nn.Linear(num_neurons, self.num_joints * 6), + ContinousRotReprDecoder(), + ) + + def encode(self, pose_body): + ''' + :param Pin: Nx(numjoints*3) + :param rep_type: 'matrot'/'aa' for matrix rotations or axis-angle + :return: + ''' + return self.encoder_net(pose_body) + + def decode(self, Zin): + bs = Zin.shape[0] + + prec = self.decoder_net(Zin) + + return { + 'pose_body': matrot2aa(prec.view(-1, 3, 3)).view(bs, -1, 3), + 'pose_body_matrot': prec.view(bs, -1, 9) + } + + + def forward(self, pose_body): + ''' + :param Pin: aa: Nx1xnum_jointsx3 / matrot: Nx1xnum_jointsx9 + :param input_type: matrot / aa for matrix rotations or axis angles + :param output_type: matrot / aa + :return: + ''' + + q_z = self.encode(pose_body) + q_z_sample = q_z.rsample() + decode_results = self.decode(q_z_sample) + decode_results.update({'poZ_body_mean': q_z.mean, 'poZ_body_std': q_z.scale, 'q_z': q_z}) + return decode_results + + def sample_poses(self, num_poses, seed=None): + np.random.seed(seed) + + some_weight = [a for a in self.parameters()][0] + dtype = some_weight.dtype + device = some_weight.device + self.eval() + with torch.no_grad(): + Zgen = torch.tensor(np.random.normal(0., 1., size=(num_poses, self.latentD)), dtype=dtype, device=device) + + return self.decode(Zgen) diff --git a/mogen/datasets/human_body_prior/tools/__init__.py b/mogen/datasets/human_body_prior/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..377b56b6106405d0b15f6d13e0b8dcc67e3f9973 --- /dev/null +++ b/mogen/datasets/human_body_prior/tools/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2020.12.12 diff --git a/mogen/datasets/human_body_prior/tools/angle_continuous_repres.py b/mogen/datasets/human_body_prior/tools/angle_continuous_repres.py new file mode 100644 index 0000000000000000000000000000000000000000..829637e77f4da537bac94913981ea94738bab5ab --- /dev/null +++ b/mogen/datasets/human_body_prior/tools/angle_continuous_repres.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2020.12.12 +import torch.nn.functional as F +import torch +from torch import nn + +import numpy as np + +# numpy implementation of yi zhou's method +def norm(v): + return v/np.linalg.norm(v) + +def gs(M): + a1 = M[:,0] + a2 = M[:,1] + b1 = norm(a1) + b2 = norm((a2-np.dot(b1,a2)*b1)) + b3 = np.cross(b1,b2) + return np.vstack([b1,b2,b3]).T + +# input sz bszx3x2 +def bgs(d6s): + + bsz = d6s.shape[0] + b1 = F.normalize(d6s[:,:,0], p=2, dim=1) + a2 = d6s[:,:,1] + c = torch.bmm(b1.view(bsz,1,-1),a2.view(bsz,-1,1)).view(bsz,1)*b1 + b2 = F.normalize(a2-c,p=2,dim=1) + b3=torch.cross(b1,b2,dim=1) + return torch.stack([b1,b2,b3],dim=1).permute(0,2,1) + + +class geodesic_loss_R(nn.Module): + def __init__(self, reduction='batchmean'): + super(geodesic_loss_R, self).__init__() + + self.reduction = reduction + self.eps = 1e-6 + + # batch geodesic loss for rotation matrices + def bgdR(self,m1,m2): + batch = m1.shape[0] + m = torch.bmm(m1, m2.transpose(1, 2)) # batch*3*3 + + cos = (m[:, 0, 0] + m[:, 1, 1] + m[:, 2, 2] - 1) / 2 + cos = torch.min(cos, m1.new(np.ones(batch))) + cos = torch.max(cos, m1.new(np.ones(batch)) * -1) + + return torch.acos(cos) + + def forward(self, ypred, ytrue): + theta = self.bgdR(ypred,ytrue) + if self.reduction == 'mean': + return torch.mean(theta) + if self.reduction == 'batchmean': + breakpoint() + return torch.mean(torch.sum(theta, dim=theta.shape[1:])) + + else: + return theta \ No newline at end of file diff --git a/mogen/datasets/human_body_prior/tools/configurations.py b/mogen/datasets/human_body_prior/tools/configurations.py new file mode 100644 index 0000000000000000000000000000000000000000..f7447584666657baa0d569d54b72a0b292956711 --- /dev/null +++ b/mogen/datasets/human_body_prior/tools/configurations.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2020.12.12 +from dotmap import DotMap +import os +import yaml + +def load_config(default_ps_fname=None, **kwargs): + if isinstance(default_ps_fname, str): + assert os.path.exists(default_ps_fname), FileNotFoundError(default_ps_fname) + assert default_ps_fname.lower().endswith('.yaml'), NotImplementedError('Only .yaml files are accepted.') + default_ps = yaml.safe_load(open(default_ps_fname, 'r')) + else: + default_ps = {} + + default_ps.update(kwargs) + + return DotMap(default_ps, _dynamic=False) + +def dump_config(data, fname): + ''' + dump current configuration to an ini file + :param fname: + :return: + ''' + with open(fname, 'w') as file: + yaml.dump(data.toDict(), file) + return fname diff --git a/mogen/datasets/human_body_prior/tools/model_loader.py b/mogen/datasets/human_body_prior/tools/model_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3acad1683cd0010e6f32be3536f88a9a069771 --- /dev/null +++ b/mogen/datasets/human_body_prior/tools/model_loader.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: Nima Ghorbani +# 2018.01.02 + +import os, glob +import numpy as np +from human_body_prior.tools.configurations import load_config, dump_config +import os.path as osp + +def exprdir2model(expr_dir): + + if not os.path.exists(expr_dir): raise ValueError('Could not find the experiment directory: %s' % expr_dir) + + model_snapshots_dir = osp.join(expr_dir, 'snapshots') + available_ckpts = sorted(glob.glob(osp.join(model_snapshots_dir, '*.ckpt')), key=osp.getmtime) + assert len(available_ckpts) > 0, ValueError('No checck points found at {}'.format(model_snapshots_dir)) + trained_weigths_fname = available_ckpts[-1] + + model_ps_fname = glob.glob(osp.join('/', '/'.join(trained_weigths_fname.split('/')[:-2]), '*.yaml')) + if len(model_ps_fname) == 0: + model_ps_fname = glob.glob(osp.join('/'.join(trained_weigths_fname.split('/')[:-2]), '*.yaml')) + + model_ps_fname = model_ps_fname[0] + model_ps = load_config(default_ps_fname=model_ps_fname) + + model_ps.logging.best_model_fname = trained_weigths_fname + + return model_ps, trained_weigths_fname + + +def load_model(expr_dir, model_code=None, remove_words_in_model_weights=None, load_only_ps=False, disable_grad=True, custom_ps = None): + ''' + + :param expr_dir: + :param model_code: an imported module + from supercap.train.supercap_smpl import SuperCap, then pass SuperCap to this function + :param if True will load the model definition used for training, and not the one in current repository + :return: + ''' + import importlib + import torch + + model_ps, trained_weigths_fname = exprdir2model(expr_dir) + if load_only_ps: return model_ps + if custom_ps is not None: model_ps = custom_ps + assert model_code is not None, ValueError('mode_code should be provided') + model_instance = model_code(model_ps) + if disable_grad: # i had to do this. torch.no_grad() couldnt achieve what i was looking for + for param in model_instance.parameters(): + param.requires_grad = False + state_dict = torch.load(trained_weigths_fname)['state_dict'] + if remove_words_in_model_weights is not None: + words = '{}'.format(remove_words_in_model_weights) + state_dict = {k.replace(words, '') if k.startswith(words) else k: v for k, v in state_dict.items()} + + ## keys that were in the model trained file and not in the current model + instance_model_keys = list(model_instance.state_dict().keys()) + trained_model_keys = list(state_dict.keys()) + wts_in_model_not_in_file = set(instance_model_keys).difference(set(trained_model_keys)) + ## keys that are in the current model not in the training weights + wts_in_file_not_in_model = set(trained_model_keys).difference(set(instance_model_keys)) + # assert len(wts_in_model_not_in_file) == 0, ValueError('Some model weights are not present in the pretrained file. {}'.format(wts_in_model_not_in_file)) + + state_dict = {k:v for k, v in state_dict.items() if k in instance_model_keys} + model_instance.load_state_dict(state_dict, strict=False) # Todo fix the issues so that we can set the strict to true. The body model uses unnecessary registered buffers + model_instance.eval() + + return model_instance, model_ps + + diff --git a/mogen/datasets/human_body_prior/tools/omni_tools.py b/mogen/datasets/human_body_prior/tools/omni_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..9e38f7c758f2a99d599d32302f7d71fa037a3744 --- /dev/null +++ b/mogen/datasets/human_body_prior/tools/omni_tools.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2018.01.02 +import numpy as np +import random +import torch +import os +import sys +import os.path as osp + +def copy2cpu(tensor): + if isinstance(tensor, np.ndarray): return tensor + return tensor.detach().cpu().numpy() + +def create_list_chunks(list_, group_size, overlap_size, cut_smaller_batches=True): + if cut_smaller_batches: + return [list_[i:i + group_size] for i in range(0, len(list_), group_size - overlap_size) if len(list_[i:i + group_size])==group_size] + else: + return [list_[i:i + group_size] for i in range(0, len(list_), group_size - overlap_size)] + + +def trainable_params_count(params): + return sum([p.numel() for p in params if p.requires_grad]) + +def flatten_list(l): + return [item for sublist in l for item in sublist] + +def get_support_data_dir(current_fname=__file__): + support_data_dir = osp.abspath(current_fname) + support_data_dir_split = support_data_dir.split('/') + support_data_dir = '/'.join(support_data_dir_split[:support_data_dir_split.index('src')]) + support_data_dir = osp.join(support_data_dir, 'support_data') + assert osp.exists(support_data_dir) + return support_data_dir + +def make_deterministic(seed): + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + +def id_generator(size=13): + import string + import random + chars = string.ascii_uppercase + string.digits + return ''.join(random.choice(chars) for _ in range(size)) + +def logger_sequencer(logger_list, prefix=None): + def post_text(text): + if prefix is not None: text = '{} -- '.format(prefix) + text + for logger_call in logger_list: logger_call(text) + return post_text + +class log2file(): + def __init__(self,logpath=None, prefix='', auto_newline = True, write2file_only=False): + if logpath is not None: + makepath(logpath, isfile=True) + self.fhandle = open(logpath,'a+') + else: + self.fhandle = None + + self.prefix = prefix + self.auto_newline = auto_newline + self.write2file_only = write2file_only + + def __call__(self, text): + if text is None: return + if self.prefix != '': text = '{} -- '.format(self.prefix) + text + # breakpoint() + if self.auto_newline: + if not text.endswith('\n'): + text = text + '\n' + if not self.write2file_only: sys.stderr.write(text) + if self.fhandle is not None: + self.fhandle.write(text) + self.fhandle.flush() + + +def makepath(*args, **kwargs): + ''' + if the path does not exist make it + :param desired_path: can be path to a file or a folder name + :return: + ''' + isfile = kwargs.get('isfile', False) + import os + desired_path = os.path.join(*args) + if isfile: + if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path)) + else: + if not os.path.exists(desired_path): os.makedirs(desired_path) + return desired_path + +def matrot2axisangle(matrots): + ''' + :param matrots: N*T*num_joints*9 + :return: N*T*num_joints*3 + ''' + import cv2 + N = matrots.shape[0] + T = matrots.shape[1] + n_joints = matrots.shape[2] + out_axisangle = [] + for tIdx in range(T): + T_axisangle = [] + for mIdx in range(N): + cur_axisangle = [] + for jIdx in range(n_joints): + cur_axisangle.append(cv2.Rodrigues(matrots[mIdx, tIdx, jIdx:jIdx + 1, :].reshape(3, 3))[0].T) + T_axisangle.append(np.vstack(cur_axisangle)[np.newaxis]) + out_axisangle.append(np.vstack(T_axisangle).reshape([N,1, -1,3])) + return np.concatenate(out_axisangle, axis=1) + +def axisangle2matrots(axisangle): + ''' + :param matrots: N*1*num_joints*3 + :return: N*num_joints*9 + ''' + import cv2 + batch_size = axisangle.shape[0] + axisangle = axisangle.reshape([batch_size,1,-1,3]) + out_matrot = [] + for mIdx in range(axisangle.shape[0]): + cur_axisangle = [] + for jIdx in range(axisangle.shape[2]): + a = cv2.Rodrigues(axisangle[mIdx, 0, jIdx:jIdx + 1, :].reshape(1, 3))[0].T + cur_axisangle.append(a) + + out_matrot.append(np.array(cur_axisangle).reshape([batch_size,1,-1,9])) + return np.vstack(out_matrot) + + +def apply_mesh_tranfsormations_(meshes, transf): + ''' + apply inplace translations to meshes + :param meshes: list of trimesh meshes + :param transf: + :return: + ''' + for i in range(len(meshes)): + meshes[i] = meshes[i].apply_transform(transf) \ No newline at end of file diff --git a/mogen/datasets/human_body_prior/tools/rotation_tools.py b/mogen/datasets/human_body_prior/tools/rotation_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..0f5fa4aab0a3cc54dcc409708b3421353bc5e0b8 --- /dev/null +++ b/mogen/datasets/human_body_prior/tools/rotation_tools.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2020.12.12 +import numpy as np + +from torch.nn import functional as F +from human_body_prior.tools import tgm_conversion as tgm +import torch + +def local2global_pose(local_pose, kintree): + bs = local_pose.shape[0] + + local_pose = local_pose.view(bs, -1, 3, 3) + + global_pose = local_pose.clone() + + for jId in range(len(kintree)): + parent_id = kintree[jId] + if parent_id >= 0: + global_pose[:, jId] = torch.matmul(global_pose[:, parent_id], global_pose[:, jId]) + + return global_pose + +def em2euler(em): + ''' + + :param em: rotation in expo-map (3,) + :return: rotation in euler angles (3,) + ''' + from transforms3d.euler import axangle2euler + + theta = np.sqrt((em ** 2).sum()) + axis = em / theta + return np.array(axangle2euler(axis, theta)) + + +def euler2em(ea): + ''' + + :param ea: rotation in euler angles (3,) + :return: rotation in expo-map (3,) + ''' + from transforms3d.euler import euler2axangle + axis, theta = euler2axangle(*ea) + return np.array(axis*theta) + + +def remove_zrot(pose): + noZ = em2euler(pose[:3].copy()) + noZ[2] = 0 + pose[:3] = euler2em(noZ).copy() + return pose + +def matrot2aa(pose_matrot): + ''' + :param pose_matrot: Nx3x3 + :return: Nx3 + ''' + bs = pose_matrot.size(0) + homogen_matrot = F.pad(pose_matrot, [0,1]) + pose = tgm.rotation_matrix_to_angle_axis(homogen_matrot) + return pose + +def aa2matrot(pose): + ''' + :param Nx3 + :return: pose_matrot: Nx3x3 + ''' + bs = pose.size(0) + num_joints = pose.size(1)//3 + pose_body_matrot = tgm.angle_axis_to_rotation_matrix(pose)[:, :3, :3].contiguous()#.view(bs, num_joints*9) + return pose_body_matrot + +def noisy_zrot(rot_in): + ''' + + :param rot_in: np.array Nx3 rotations in axis-angle representation + :return: + will add a degree from a full circle to the zrotations + ''' + is_batched = False + if rot_in.ndim == 2: is_batched = True + if not is_batched: + rot_in = rot_in[np.newaxis] + + rnd_zrot = np.random.uniform(-np.pi, np.pi) + rot_out = [] + for bId in range(len(rot_in)): + pose_cpu = rot_in[bId] + pose_euler = em2euler(pose_cpu) + + pose_euler[2] += rnd_zrot + + pose_aa = euler2em(pose_euler) + rot_out.append(pose_aa.copy()) + + return np.array(rot_out) + +def rotate_points_xyz(mesh_v, Rxyz): + ''' + + :param mesh_v: Nxnum_vx3 + :param Rxyz: Nx3 + :return: + ''' + + mesh_v_rotated = [] + + for fId in range(mesh_v.shape[0]): + angle = np.radians(Rxyz[fId, 0]) + rx = np.array([ + [1., 0., 0. ], + [0., np.cos(angle), -np.sin(angle)], + [0., np.sin(angle), np.cos(angle) ] + ]) + + angle = np.radians(Rxyz[fId, 1]) + ry = np.array([ + [np.cos(angle), 0., np.sin(angle)], + [0., 1., 0. ], + [-np.sin(angle), 0., np.cos(angle)] + ]) + + angle = np.radians(Rxyz[fId, 2]) + rz = np.array([ + [np.cos(angle), -np.sin(angle), 0. ], + [np.sin(angle), np.cos(angle), 0. ], + [0., 0., 1. ] + ]) + mesh_v_rotated.append(rz.dot(ry.dot(rx.dot(mesh_v[fId].T))).T) + + return np.array(mesh_v_rotated) \ No newline at end of file diff --git a/mogen/datasets/human_body_prior/tools/tgm_conversion.py b/mogen/datasets/human_body_prior/tools/tgm_conversion.py new file mode 100644 index 0000000000000000000000000000000000000000..0e51eaaa675ef623fc886f3c7f9e03bf606110cb --- /dev/null +++ b/mogen/datasets/human_body_prior/tools/tgm_conversion.py @@ -0,0 +1,527 @@ +''' +This is a ripped code from an version of torchgeometry now called Kornia. Since Kornia has a +know bug: https://github.com/kornia/kornia/issues/317#issuecomment-751305910 +in converting rotation representations we use this code until the original bug in Kornia is addressed +''' + +import torch +import torch.nn as nn + +__all__ = [ + # functional api + "pi", + "rad2deg", + "deg2rad", + "convert_points_from_homogeneous", + "convert_points_to_homogeneous", + "angle_axis_to_rotation_matrix", + "rotation_matrix_to_angle_axis", + "rotation_matrix_to_quaternion", + "quaternion_to_angle_axis", + "angle_axis_to_quaternion", + "rtvec_to_pose", + # layer api + "RadToDeg", + "DegToRad", + "ConvertPointsFromHomogeneous", + "ConvertPointsToHomogeneous", +] + + +"""Constant with number pi +""" +pi = torch.Tensor([3.14159265358979323846]) + + +def rad2deg(tensor): + r"""Function that converts angles from radians to degrees. + + See :class:`~torchgeometry.RadToDeg` for details. + + Args: + tensor (Tensor): Tensor of arbitrary shape. + + Returns: + Tensor: Tensor with same shape as input. + + Example: + >>> input = tgm.pi * torch.rand(1, 3, 3) + >>> output = tgm.rad2deg(input) + """ + if not torch.is_tensor(tensor): + raise TypeError("Input type is not a torch.Tensor. Got {}" + .format(type(tensor))) + + return 180. * tensor / pi.to(tensor.device).type(tensor.dtype) + + +def deg2rad(tensor): + r"""Function that converts angles from degrees to radians. + + See :class:`~torchgeometry.DegToRad` for details. + + Args: + tensor (Tensor): Tensor of arbitrary shape. + + Returns: + Tensor: Tensor with same shape as input. + + Examples:: + + >>> input = 360. * torch.rand(1, 3, 3) + >>> output = tgm.deg2rad(input) + """ + if not torch.is_tensor(tensor): + raise TypeError("Input type is not a torch.Tensor. Got {}" + .format(type(tensor))) + + return tensor * pi.to(tensor.device).type(tensor.dtype) / 180. + + +def convert_points_from_homogeneous(points): + r"""Function that converts points from homogeneous to Euclidean space. + + See :class:`~torchgeometry.ConvertPointsFromHomogeneous` for details. + + Examples:: + + >>> input = torch.rand(2, 4, 3) # BxNx3 + >>> output = tgm.convert_points_from_homogeneous(input) # BxNx2 + """ + if not torch.is_tensor(points): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(points))) + if len(points.shape) < 2: + raise ValueError("Input must be at least a 2D tensor. Got {}".format( + points.shape)) + + return points[..., :-1] / points[..., -1:] + + +def convert_points_to_homogeneous(points): + r"""Function that converts points from Euclidean to homogeneous space. + + See :class:`~torchgeometry.ConvertPointsToHomogeneous` for details. + + Examples:: + + >>> input = torch.rand(2, 4, 3) # BxNx3 + >>> output = tgm.convert_points_to_homogeneous(input) # BxNx4 + """ + if not torch.is_tensor(points): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(points))) + if len(points.shape) < 2: + raise ValueError("Input must be at least a 2D tensor. Got {}".format( + points.shape)) + + return nn.functional.pad(points, (0, 1), "constant", 1.0) + + +def angle_axis_to_rotation_matrix(angle_axis): + """Convert 3d vector of axis-angle rotation to 4x4 rotation matrix + + Args: + angle_axis (Tensor): tensor of 3d vector of axis-angle rotations. + + Returns: + Tensor: tensor of 4x4 rotation matrices. + + Shape: + - Input: :math:`(N, 3)` + - Output: :math:`(N, 4, 4)` + + Example: + >>> input = torch.rand(1, 3) # Nx3 + >>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx4x4 + """ + def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6): + # We want to be careful to only evaluate the square root if the + # norm of the angle_axis vector is greater than zero. Otherwise + # we get a division by zero. + k_one = 1.0 + theta = torch.sqrt(theta2) + wxyz = angle_axis / (theta + eps) + wx, wy, wz = torch.chunk(wxyz, 3, dim=1) + cos_theta = torch.cos(theta) + sin_theta = torch.sin(theta) + + r00 = cos_theta + wx * wx * (k_one - cos_theta) + r10 = wz * sin_theta + wx * wy * (k_one - cos_theta) + r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta) + r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta + r11 = cos_theta + wy * wy * (k_one - cos_theta) + r21 = wx * sin_theta + wy * wz * (k_one - cos_theta) + r02 = wy * sin_theta + wx * wz * (k_one - cos_theta) + r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta) + r22 = cos_theta + wz * wz * (k_one - cos_theta) + rotation_matrix = torch.cat( + [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1) + return rotation_matrix.view(-1, 3, 3) + + def _compute_rotation_matrix_taylor(angle_axis): + rx, ry, rz = torch.chunk(angle_axis, 3, dim=1) + k_one = torch.ones_like(rx) + rotation_matrix = torch.cat( + [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1) + return rotation_matrix.view(-1, 3, 3) + + # stolen from ceres/rotation.h + + _angle_axis = torch.unsqueeze(angle_axis, dim=1) + theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2)) + theta2 = torch.squeeze(theta2, dim=1) + + # compute rotation matrices + rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2) + rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis) + + # create mask to handle both cases + eps = 1e-6 + mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device) + mask_pos = (mask).type_as(theta2) + mask_neg = (mask == False).type_as(theta2) # noqa + + # create output pose matrix + batch_size = angle_axis.shape[0] + rotation_matrix = torch.eye(4).to(angle_axis.device).type_as(angle_axis) + rotation_matrix = rotation_matrix.view(1, 4, 4).repeat(batch_size, 1, 1) + # fill output matrix with masked values + rotation_matrix[..., :3, :3] = \ + mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor + return rotation_matrix # Nx4x4 + + +def rtvec_to_pose(rtvec): + """ + Convert axis-angle rotation and translation vector to 4x4 pose matrix + + Args: + rtvec (Tensor): Rodrigues vector transformations + + Returns: + Tensor: transformation matrices + + Shape: + - Input: :math:`(N, 6)` + - Output: :math:`(N, 4, 4)` + + Example: + >>> input = torch.rand(3, 6) # Nx6 + >>> output = tgm.rtvec_to_pose(input) # Nx4x4 + """ + assert rtvec.shape[-1] == 6, 'rtvec=[rx, ry, rz, tx, ty, tz]' + pose = angle_axis_to_rotation_matrix(rtvec[..., :3]) + pose[..., :3, 3] = rtvec[..., 3:] + return pose + + +def rotation_matrix_to_angle_axis(rotation_matrix): + """Convert 3x4 rotation matrix to Rodrigues vector + + Args: + rotation_matrix (Tensor): rotation matrix. + + Returns: + Tensor: Rodrigues vector transformation. + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 3)` + + Example: + >>> input = torch.rand(2, 3, 4) # Nx4x4 + >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3 + """ + # todo add check that matrix is a valid rotation matrix + quaternion = rotation_matrix_to_quaternion(rotation_matrix) + return quaternion_to_angle_axis(quaternion) + + +def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6): + """Convert 3x4 rotation matrix to 4d quaternion vector + + This algorithm is based on algorithm described in + https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201 + + Args: + rotation_matrix (Tensor): the rotation matrix to convert. + + Return: + Tensor: the rotation in quaternion + + Shape: + - Input: :math:`(N, 3, 4)` + - Output: :math:`(N, 4)` + + Example: + >>> input = torch.rand(4, 3, 4) # Nx3x4 + >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4 + """ + if not torch.is_tensor(rotation_matrix): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(rotation_matrix))) + + if len(rotation_matrix.shape) > 3: + raise ValueError( + "Input size must be a three dimensional tensor. Got {}".format( + rotation_matrix.shape)) + if not rotation_matrix.shape[-2:] == (3, 4): + raise ValueError( + "Input size must be a N x 3 x 4 tensor. Got {}".format( + rotation_matrix.shape)) + + rmat_t = torch.transpose(rotation_matrix, 1, 2) + + mask_d2 = rmat_t[:, 2, 2] < eps + + mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1] + mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1] + + t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1) + t0_rep = t0.repeat(4, 1).t() + + t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2] + q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] + rmat_t[:, 1, 0], + t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1) + t1_rep = t1.repeat(4, 1).t() + + t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0], + rmat_t[:, 2, 0] + rmat_t[:, 0, 2], + rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1) + t2_rep = t2.repeat(4, 1).t() + + t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2] + q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1], + rmat_t[:, 2, 0] - rmat_t[:, 0, 2], + rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1) + t3_rep = t3.repeat(4, 1).t() + + mask_c0 = mask_d2 * mask_d0_d1 + mask_c1 = mask_d2 * torch.logical_not(mask_d0_d1) + mask_c2 = torch.logical_not(mask_d2) * mask_d0_nd1 + mask_c3 = torch.logical_not(mask_d2) * torch.logical_not(mask_d0_nd1) + mask_c0 = mask_c0.view(-1, 1).type_as(q0) + mask_c1 = mask_c1.view(-1, 1).type_as(q1) + mask_c2 = mask_c2.view(-1, 1).type_as(q2) + mask_c3 = mask_c3.view(-1, 1).type_as(q3) + + q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3 + q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 + # noqa + t2_rep * mask_c2 + t3_rep * mask_c3) # noqa + q *= 0.5 + return q + + +def quaternion_to_angle_axis(quaternion) -> torch.Tensor: + """Convert quaternion vector to angle axis of rotation. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + quaternion (torch.Tensor): tensor with quaternions. + + Return: + torch.Tensor: tensor with angle axis of rotation. + + Shape: + - Input: :math:`(*, 4)` where `*` means, any number of dimensions + - Output: :math:`(*, 3)` + + Example: + >>> quaternion = torch.rand(2, 4) # Nx4 + >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3 + """ + if not torch.is_tensor(quaternion): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(quaternion))) + + if not quaternion.shape[-1] == 4: + raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}" + .format(quaternion.shape)) + # unpack input and compute conversion + q1 = quaternion[..., 1] + q2 = quaternion[..., 2] + q3 = quaternion[..., 3] + sin_squared_theta = q1 * q1 + q2 * q2 + q3 * q3 + + sin_theta = torch.sqrt(sin_squared_theta) + cos_theta = quaternion[..., 0] + two_theta = 2.0 * torch.where( + cos_theta < 0.0, + torch.atan2(-sin_theta, -cos_theta), + torch.atan2(sin_theta, cos_theta)) + + k_pos = two_theta / sin_theta + k_neg = 2.0 * torch.ones_like(sin_theta) + k = torch.where(sin_squared_theta > 0.0, k_pos, k_neg) + + angle_axis = torch.zeros_like(quaternion)[..., :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + +# based on: +# https://github.com/facebookresearch/QuaterNet/blob/master/common/quaternion.py#L138 + + +def angle_axis_to_quaternion(angle_axis) -> torch.Tensor: + """Convert an angle axis to a quaternion. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + angle_axis (torch.Tensor): tensor with angle axis. + + Return: + torch.Tensor: tensor with quaternion. + + Shape: + - Input: :math:`(*, 3)` where `*` means, any number of dimensions + - Output: :math:`(*, 4)` + + Example: + >>> angle_axis = torch.rand(2, 4) # Nx4 + >>> quaternion = tgm.angle_axis_to_quaternion(angle_axis) # Nx3 + """ + if not torch.is_tensor(angle_axis): + raise TypeError("Input type is not a torch.Tensor. Got {}".format( + type(angle_axis))) + + if not angle_axis.shape[-1] == 3: + raise ValueError("Input must be a tensor of shape Nx3 or 3. Got {}" + .format(angle_axis.shape)) + # unpack input and compute conversion + a0 = angle_axis[..., 0:1] + a1 = angle_axis[..., 1:2] + a2 = angle_axis[..., 2:3] + theta_squared = a0 * a0 + a1 * a1 + a2 * a2 + + theta = torch.sqrt(theta_squared) + half_theta = theta * 0.5 + + mask = theta_squared > 0.0 + ones = torch.ones_like(half_theta) + + k_neg = 0.5 * ones + k_pos = torch.sin(half_theta) / theta + k = torch.where(mask, k_pos, k_neg) + w = torch.where(mask, torch.cos(half_theta), ones) + + quaternion = torch.zeros_like(angle_axis) + quaternion[..., 0:1] += a0 * k + quaternion[..., 1:2] += a1 * k + quaternion[..., 2:3] += a2 * k + return torch.cat([w, quaternion], dim=-1) + +# TODO: add below funtionalities +# - pose_to_rtvec + + +# layer api + + +class RadToDeg(nn.Module): + r"""Creates an object that converts angles from radians to degrees. + + Args: + tensor (Tensor): Tensor of arbitrary shape. + + Returns: + Tensor: Tensor with same shape as input. + + Examples:: + + >>> input = tgm.pi * torch.rand(1, 3, 3) + >>> output = tgm.RadToDeg()(input) + """ + + def __init__(self): + super(RadToDeg, self).__init__() + + def forward(self, input): + return rad2deg(input) + + +class DegToRad(nn.Module): + r"""Function that converts angles from degrees to radians. + + Args: + tensor (Tensor): Tensor of arbitrary shape. + + Returns: + Tensor: Tensor with same shape as input. + + Examples:: + + >>> input = 360. * torch.rand(1, 3, 3) + >>> output = tgm.DegToRad()(input) + """ + + def __init__(self): + super(DegToRad, self).__init__() + + def forward(self, input): + return deg2rad(input) + + +class ConvertPointsFromHomogeneous(nn.Module): + r"""Creates a transformation that converts points from homogeneous to + Euclidean space. + + Args: + points (Tensor): tensor of N-dimensional points. + + Returns: + Tensor: tensor of N-1-dimensional points. + + Shape: + - Input: :math:`(B, D, N)` or :math:`(D, N)` + - Output: :math:`(B, D, N + 1)` or :math:`(D, N + 1)` + + Examples:: + + >>> input = torch.rand(2, 4, 3) # BxNx3 + >>> transform = tgm.ConvertPointsFromHomogeneous() + >>> output = transform(input) # BxNx2 + """ + + def __init__(self): + super(ConvertPointsFromHomogeneous, self).__init__() + + def forward(self, input): + return convert_points_from_homogeneous(input) + + +class ConvertPointsToHomogeneous(nn.Module): + r"""Creates a transformation to convert points from Euclidean to + homogeneous space. + + Args: + points (Tensor): tensor of N-dimensional points. + + Returns: + Tensor: tensor of N+1-dimensional points. + + Shape: + - Input: :math:`(B, D, N)` or :math:`(D, N)` + - Output: :math:`(B, D, N + 1)` or :math:`(D, N + 1)` + + Examples:: + + >>> input = torch.rand(2, 4, 3) # BxNx3 + >>> transform = tgm.ConvertPointsToHomogeneous() + >>> output = transform(input) # BxNx4 + """ + + def __init__(self): + super(ConvertPointsToHomogeneous, self).__init__() + + def forward(self, input): + return convert_points_to_homogeneous(input) \ No newline at end of file diff --git a/mogen/datasets/human_body_prior/train/README.md b/mogen/datasets/human_body_prior/train/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2caa2c84d42912533bbd3d3b4d9bf673af4b167d --- /dev/null +++ b/mogen/datasets/human_body_prior/train/README.md @@ -0,0 +1,41 @@ +# Train VPoser from Scratch +To train your own VPoser with new configuration duplicate the provided **V02_05** folder while setting a new experiment ID +and change the settings as you desire. +First you would need to download the +[AMASS](https://amass.is.tue.mpg.de/) dataset, then following the [data preparation tutorial](../data/README.md) +prepare the data for training. +Following is a code snippet for training that can be found in the [example training experiment](https://github.com/nghorbani/human_body_prior/blob/master/src/human_body_prior/train/V02_05/V02_05.py): + +```python +import glob +import os.path as osp + +from human_body_prior.tools.configurations import load_config +from human_body_prior.train.vposer_trainer import train_vposer_once + +def main(): + expr_id = 'V02_05' + + default_ps_fname = glob.glob(osp.join(osp.dirname(__file__), '*.yaml'))[0] + + vp_ps = load_config(default_ps_fname) + + vp_ps.train_parms.batch_size = 128 + + vp_ps.general.expr_id = expr_id + + total_jobs = [] + total_jobs.append(vp_ps.toDict().copy()) + + print('#training_jobs to be done: {}'.format(len(total_jobs))) + if len(total_jobs) == 0: + print('No jobs to be done') + return + + for job in total_jobs: + train_vposer_once(job) +``` +The above code uses yaml configuration files to handle experiment settings. +It loads the default settings in *.yaml* and overloads it with your new args. + +The training code, will dump a log file along with tensorboard readable events file. \ No newline at end of file diff --git a/mogen/datasets/human_body_prior/train/V02_05/V02_05.py b/mogen/datasets/human_body_prior/train/V02_05/V02_05.py new file mode 100644 index 0000000000000000000000000000000000000000..aff9edcad55ddb9bcf1d7ebf08e2064d5fd4f901 --- /dev/null +++ b/mogen/datasets/human_body_prior/train/V02_05/V02_05.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2020.12.12 + +import glob +import os.path as osp + +from human_body_prior.tools.configurations import load_config +from human_body_prior.train.vposer_trainer import train_vposer_once + +def main(): + expr_id = 'V02_05' + + default_ps_fname = glob.glob(osp.join(osp.dirname(__file__), '*.yaml'))[0] + + vp_ps = load_config(default_ps_fname) + + vp_ps.train_parms.batch_size = 128 + + vp_ps.general.expr_id = expr_id + + total_jobs = [] + total_jobs.append(vp_ps.toDict().copy()) + + print('#training_jobs to be done: {}'.format(len(total_jobs))) + if len(total_jobs) == 0: + print('No jobs to be done') + return + + for job in total_jobs: + train_vposer_once(job) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/mogen/datasets/human_body_prior/train/V02_05/V02_05.yaml b/mogen/datasets/human_body_prior/train/V02_05/V02_05.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c1591391e61c3c8284108b7fd6258a2c6f2cf6c --- /dev/null +++ b/mogen/datasets/human_body_prior/train/V02_05/V02_05.yaml @@ -0,0 +1,84 @@ +--- +body_model: + gender: neutral + bm_fname: ../../../../support_data/dowloads/models/smplx/neutral/model.npz + +general: + verbosity: 0 + expr_id: + dataset_id: V02_03 #SMPLx neutral + rnd_seed: 100 + work_basedir: ../../../../support_data/training/training_experiments + dataset_basedir: ../../../../support_data/training/data + +logging: + expr_msg: + num_bodies_to_display: 25 + work_dir: + dataset_dir: + render_during_training: False + best_model_fname: + +train_parms: + batch_size: + num_epochs: 100 + restore_optimizer: False + gen_optimizer: + type: Adam + args: + lr: 0.001 + weight_decay: 0.00001 + lr_scheduler: + type: ReduceLROnPlateau + args: + # metrics: val_loss + verbose: true + patience: 5 + early_stopping: + monitor: val_loss + min_delta: 0.0 + patience: 10 + verbose: True + mode: min + keep_extra_loss_terms_until_epoch: 15 + loss_weights: + loss_kl_wt: 0.005 + loss_rec_wt: 4 + loss_matrot_wt: 2 + loss_jtr_wt: 2 + + +data_parms: + num_workers: 5 # Used for dataloaders + amass_dir: support_data/dowloads/amass/smplx_neutral + num_timeseq_frames: 1 + amass_splits: + vald: +# - HumanEva +# - MPI_HDM05 +# - SFU +# - MPI_mosh + - BMLrub_vald + train: + - CMU + - BMLrub_train +# - MPI_Limits +# - TotalCapture +# - Eyes_Japan_Dataset +# - KIT +# - BMLrub +# - EKUT +# - TCD_handMocap +# - ACCAD +# - BMLmovi + test: + - BMLrub_test +# - Transitions_mocap +# - SSM_synced +# - DFaust_67 + + +model_params: + num_neurons : 512 + latentD : 32 + diff --git a/mogen/datasets/human_body_prior/train/V02_05/__init__.py b/mogen/datasets/human_body_prior/train/V02_05/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..377b56b6106405d0b15f6d13e0b8dcc67e3f9973 --- /dev/null +++ b/mogen/datasets/human_body_prior/train/V02_05/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2020.12.12 diff --git a/mogen/datasets/human_body_prior/train/__init__.py b/mogen/datasets/human_body_prior/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d95fab9b48be9d43467cd1e8d77f25c5a397f9 --- /dev/null +++ b/mogen/datasets/human_body_prior/train/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2018.01.02 diff --git a/mogen/datasets/human_body_prior/train/vposer_trainer.py b/mogen/datasets/human_body_prior/train/vposer_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..87589146279de46336849068ed74b4fa448ffc1c --- /dev/null +++ b/mogen/datasets/human_body_prior/train/vposer_trainer.py @@ -0,0 +1,337 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2020.12.12 + +# from pytorch_lightning import Trainer + +import glob +import os +import os.path as osp +from datetime import datetime as dt +from pytorch_lightning.plugins import DDPPlugin + +import numpy as np +import pytorch_lightning as pl +import torch +from human_body_prior.body_model.body_model import BodyModel +from human_body_prior.data.dataloader import VPoserDS +from human_body_prior.data.prepare_data import dataset_exists +from human_body_prior.data.prepare_data import prepare_vposer_datasets +from human_body_prior.models.vposer_model import VPoser +from human_body_prior.tools.angle_continuous_repres import geodesic_loss_R +from human_body_prior.tools.configurations import load_config, dump_config +from human_body_prior.tools.omni_tools import copy2cpu as c2c +from human_body_prior.tools.omni_tools import get_support_data_dir +from human_body_prior.tools.omni_tools import log2file +from human_body_prior.tools.omni_tools import make_deterministic +from human_body_prior.tools.omni_tools import makepath +from human_body_prior.tools.rotation_tools import aa2matrot +from human_body_prior.visualizations.training_visualization import vposer_trainer_renderer +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.callbacks.early_stopping import EarlyStopping + +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.core import LightningModule +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.utilities import rank_zero_only +from torch import optim as optim_module +from torch.optim import lr_scheduler as lr_sched_module +from torch.utils.data import DataLoader + + +class VPoserTrainer(LightningModule): + """ + + It includes all data loading and train / val logic., and it is used for both training and testing models. + """ + + def __init__(self, _config): + super(VPoserTrainer, self).__init__() + + _support_data_dir = get_support_data_dir() + + vp_ps = load_config(**_config) + + make_deterministic(vp_ps.general.rnd_seed) + + self.expr_id = vp_ps.general.expr_id + self.dataset_id = vp_ps.general.dataset_id + + self.work_dir = vp_ps.logging.work_dir = makepath(vp_ps.general.work_basedir, self.expr_id) + self.dataset_dir = vp_ps.logging.dataset_dir = osp.join(vp_ps.general.dataset_basedir, vp_ps.general.dataset_id) + + self._log_prefix = '[{}]'.format(self.expr_id) + self.text_logger = log2file(prefix=self._log_prefix) + + self.seq_len = vp_ps.data_parms.num_timeseq_frames + + self.vp_model = VPoser(vp_ps) + + with torch.no_grad(): + + self.bm_train = BodyModel(vp_ps.body_model.bm_fname) + + if vp_ps.logging.render_during_training: + self.renderer = vposer_trainer_renderer(self.bm_train, vp_ps.logging.num_bodies_to_display) + else: + self.renderer = None + + self.example_input_array = {'pose_body':torch.ones(vp_ps.train_parms.batch_size, 63),} + self.vp_ps = vp_ps + + def forward(self, pose_body): + + return self.vp_model(pose_body) + + def _get_data(self, split_name): + + assert split_name in ('train', 'vald', 'test') + + split_name = split_name.replace('vald', 'vald') + + assert dataset_exists(self.dataset_dir), FileNotFoundError('Dataset does not exist dataset_dir = {}'.format(self.dataset_dir)) + dataset = VPoserDS(osp.join(self.dataset_dir, split_name), data_fields = ['pose_body']) + + assert len(dataset) != 0, ValueError('Dataset has nothing in it!') + + return DataLoader(dataset, + batch_size=self.vp_ps.train_parms.batch_size, + shuffle=True if split_name == 'train' else False, + num_workers=self.vp_ps.data_parms.num_workers, + pin_memory=True) + + @rank_zero_only + def on_train_start(self): + if self.global_rank != 0: return + self.train_starttime = dt.now().replace(microsecond=0) + + ######## make a backup of vposer + git_repo_dir = os.path.abspath(__file__).split('/') + git_repo_dir = '/'.join(git_repo_dir[:git_repo_dir.index('human_body_prior') + 1]) + starttime = dt.strftime(self.train_starttime, '%Y_%m_%d_%H_%M_%S') + archive_path = makepath(self.work_dir, 'code', 'vposer_{}.tar.gz'.format(starttime), isfile=True) + cmd = 'cd %s && git ls-files -z | xargs -0 tar -czf %s' % (git_repo_dir, archive_path) + os.system(cmd) + ######## + self.text_logger('Created a git archive backup at {}'.format(archive_path)) + dump_config(self.vp_ps, osp.join(self.work_dir, '{}.yaml'.format(self.expr_id))) + + def train_dataloader(self): + return self._get_data('train') + + def val_dataloader(self): + return self._get_data('vald') + + def configure_optimizers(self): + params_count = lambda params: sum(p.numel() for p in params if p.requires_grad) + + gen_params = [a[1] for a in self.vp_model.named_parameters() if a[1].requires_grad] + gen_optimizer_class = getattr(optim_module, self.vp_ps.train_parms.gen_optimizer.type) + gen_optimizer = gen_optimizer_class(gen_params, **self.vp_ps.train_parms.gen_optimizer.args) + + self.text_logger('Total Trainable Parameters Count in vp_model is %2.2f M.' % (params_count(gen_params) * 1e-6)) + + lr_sched_class = getattr(lr_sched_module, self.vp_ps.train_parms.lr_scheduler.type) + + gen_lr_scheduler = lr_sched_class(gen_optimizer, **self.vp_ps.train_parms.lr_scheduler.args) + + schedulers = [ + { + 'scheduler': gen_lr_scheduler, + 'monitor': 'val_loss', + 'interval': 'epoch', + 'frequency': 1 + }, + ] + return [gen_optimizer], schedulers + + def _compute_loss(self, dorig, drec): + l1_loss = torch.nn.L1Loss(reduction='mean') + geodesic_loss = geodesic_loss_R(reduction='mean') + + bs, latentD = drec['poZ_body_mean'].shape + device = drec['poZ_body_mean'].device + + loss_kl_wt = self.vp_ps.train_parms.loss_weights.loss_kl_wt + loss_rec_wt = self.vp_ps.train_parms.loss_weights.loss_rec_wt + loss_matrot_wt = self.vp_ps.train_parms.loss_weights.loss_matrot_wt + loss_jtr_wt = self.vp_ps.train_parms.loss_weights.loss_jtr_wt + + # q_z = torch.distributions.normal.Normal(drec['mean'], drec['std']) + q_z = drec['q_z'] + # dorig['fullpose'] = torch.cat([dorig['root_orient'], dorig['pose_body']], dim=-1) + + # Reconstruction loss - L1 on the output mesh + with torch.no_grad(): + bm_orig = self.bm_train(pose_body=dorig['pose_body']) + + bm_rec = self.bm_train(pose_body=drec['pose_body'].contiguous().view(bs, -1)) + + v2v = l1_loss(bm_rec.v, bm_orig.v) + + # KL loss + p_z = torch.distributions.normal.Normal( + loc=torch.zeros((bs, latentD), device=device, requires_grad=False), + scale=torch.ones((bs, latentD), device=device, requires_grad=False)) + weighted_loss_dict = { + 'loss_kl':loss_kl_wt * torch.mean(torch.sum(torch.distributions.kl.kl_divergence(q_z, p_z), dim=[1])), + 'loss_mesh_rec': loss_rec_wt * v2v + } + + if (self.current_epoch < self.vp_ps.train_parms.keep_extra_loss_terms_until_epoch): + # breakpoint() + weighted_loss_dict['matrot'] = loss_matrot_wt * geodesic_loss(drec['pose_body_matrot'].view(-1,3,3), aa2matrot(dorig['pose_body'].view(-1, 3))) + weighted_loss_dict['jtr'] = loss_jtr_wt * l1_loss(bm_rec.Jtr, bm_orig.Jtr) + + weighted_loss_dict['loss_total'] = torch.stack(list(weighted_loss_dict.values())).sum() + + with torch.no_grad(): + unweighted_loss_dict = {'v2v': torch.sqrt(torch.pow(bm_rec.v-bm_orig.v, 2).sum(-1)).mean()} + unweighted_loss_dict['loss_total'] = torch.cat( + list({k: v.view(-1) for k, v in unweighted_loss_dict.items()}.values()), dim=-1).sum().view(1) + + return {'weighted_loss': weighted_loss_dict, 'unweighted_loss': unweighted_loss_dict} + + def training_step(self, batch, batch_idx, optimizer_idx=None): + + drec = self(batch['pose_body'].view(-1, 63)) + + loss = self._compute_loss(batch, drec) + + train_loss = loss['weighted_loss']['loss_total'] + + tensorboard_logs = {'train_loss': train_loss} + progress_bar = {k: c2c(v) for k, v in loss['weighted_loss'].items()} + return {'loss': train_loss, 'progress_bar':progress_bar, 'log': tensorboard_logs} + + def validation_step(self, batch, batch_idx): + + drec = self(batch['pose_body'].view(-1, 63)) + + loss = self._compute_loss(batch, drec) + val_loss = loss['unweighted_loss']['loss_total'] + + if self.renderer is not None and self.global_rank == 0 and batch_idx % 500==0 and np.random.rand()>0.5: + out_fname = makepath(self.work_dir, 'renders/vald_rec_E{:03d}_It{:04d}_val_loss_{:.2f}.png'.format(self.current_epoch, batch_idx, val_loss.item()), isfile=True) + self.renderer([batch, drec], out_fname = out_fname) + dgen = self.vp_model.sample_poses(self.vp_ps.logging.num_bodies_to_display) + out_fname = makepath(self.work_dir, 'renders/vald_gen_E{:03d}_I{:04d}.png'.format(self.current_epoch, batch_idx), isfile=True) + self.renderer([dgen], out_fname = out_fname) + + + progress_bar = {'v2v': val_loss} + return {'val_loss': c2c(val_loss), 'progress_bar': progress_bar, 'log': progress_bar} + + def validation_epoch_end(self, outputs): + metrics = {'val_loss': np.nanmean(np.concatenate([v['val_loss'] for v in outputs])) } + + if self.global_rank == 0: + + self.text_logger('Epoch {}: {}'.format(self.current_epoch, ', '.join('{}:{:.2f}'.format(k, v) for k, v in metrics.items()))) + self.text_logger('lr is {}'.format([pg['lr'] for opt in self.trainer.optimizers for pg in opt.param_groups])) + + metrics = {k: torch.as_tensor(v) for k, v in metrics.items()} + + return {'val_loss': metrics['val_loss'], 'log': metrics} + + + @rank_zero_only + def on_train_end(self): + + self.train_endtime = dt.now().replace(microsecond=0) + endtime = dt.strftime(self.train_endtime, '%Y_%m_%d_%H_%M_%S') + elapsedtime = self.train_endtime - self.train_starttime + self.vp_ps.logging.best_model_fname = self.trainer.checkpoint_callback.best_model_path + + self.text_logger('Epoch {} - Finished training at {} after {}'.format(self.current_epoch, endtime, elapsedtime)) + self.text_logger('best_model_fname: {}'.format(self.vp_ps.logging.best_model_fname)) + + dump_config(self.vp_ps, osp.join(self.work_dir, '{}_{}.yaml'.format(self.expr_id, self.dataset_id))) + self.hparams = self.vp_ps.toDict() + + @rank_zero_only + def prepare_data(self): + '''' Similar to standard AMASS dataset preparation pipeline: + Donwload npz file, corresponding to body data from https://amass.is.tue.mpg.de/ and place them under amass_dir + ''' + self.text_logger = log2file(makepath(self.work_dir, '{}.log'.format(self.expr_id), isfile=True), prefix=self._log_prefix) + + prepare_vposer_datasets(self.dataset_dir, self.vp_ps.data_parms.amass_splits, self.vp_ps.data_parms.amass_dir, logger=self.text_logger) + + +def create_expr_message(ps): + expr_msg = '[{}] batch_size = {}.'.format(ps.general.expr_id, ps.train_parms.batch_size) + + return expr_msg + + +def train_vposer_once(_config): + + resume_training_if_possible = True + + model = VPoserTrainer(_config) + model.vp_ps.logging.expr_msg = create_expr_message(model.vp_ps) + # model.text_logger(model.vp_ps.logging.expr_msg.replace(". ", '.\n')) + dump_config(model.vp_ps, osp.join(model.work_dir, '{}.yaml'.format(model.expr_id))) + + logger = TensorBoardLogger(model.work_dir, name='tensorboard') + lr_monitor = LearningRateMonitor() + + snapshots_dir = osp.join(model.work_dir, 'snapshots') + checkpoint_callback = ModelCheckpoint( + dirpath=makepath(snapshots_dir, isfile=True), + filename="%s_{epoch:02d}_{val_loss:.2f}" % model.expr_id, + save_top_k=1, + verbose=True, + monitor='val_loss', + mode='min', + ) + early_stop_callback = EarlyStopping(**model.vp_ps.train_parms.early_stopping) + + resume_from_checkpoint = None + if resume_training_if_possible: + available_ckpts = sorted(glob.glob(osp.join(snapshots_dir, '*.ckpt')), key=os.path.getmtime) + if len(available_ckpts)>0: + resume_from_checkpoint = available_ckpts[-1] + model.text_logger('Resuming the training from {}'.format(resume_from_checkpoint)) + + trainer = pl.Trainer(gpus=1, + weights_summary='top', + distributed_backend = 'ddp', + # replace_sampler_ddp=False, + # accumulate_grad_batches=4, + # profiler=False, + # overfit_batches=0.05, + # fast_dev_run = True, + # limit_train_batches=0.02, + # limit_val_batches=0.02, + # num_sanity_val_steps=2, + plugins=[DDPPlugin(find_unused_parameters=False)], + + callbacks=[lr_monitor, early_stop_callback, checkpoint_callback], + + max_epochs=model.vp_ps.train_parms.num_epochs, + logger=logger, + resume_from_checkpoint=resume_from_checkpoint + ) + + trainer.fit(model) diff --git a/mogen/datasets/human_body_prior/visualizations/__init__.py b/mogen/datasets/human_body_prior/visualizations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..377b56b6106405d0b15f6d13e0b8dcc67e3f9973 --- /dev/null +++ b/mogen/datasets/human_body_prior/visualizations/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2020.12.12 diff --git a/mogen/datasets/human_body_prior/visualizations/training_visualization.py b/mogen/datasets/human_body_prior/visualizations/training_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..f7c9d745faeda2614b06ecabdd03239a3ad617c6 --- /dev/null +++ b/mogen/datasets/human_body_prior/visualizations/training_visualization.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG), +# acting on behalf of its Max Planck Institute for Intelligent Systems and the +# Max Planck Institute for Biological Cybernetics. All rights reserved. +# +# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights +# on this computer program. You can only use this computer program if you have closed a license agreement +# with MPG or you get the right to use the computer program from someone who is authorized to grant you that right. +# Any use of the computer program without a valid license is prohibited and liable to prosecution. +# Contact: ps-license@tuebingen.mpg.de +# +# +# If you use this code in a research publication please consider citing the following: +# +# Expressive Body Capture: 3D Hands, Face, and Body from a Single Image +# +# +# Code Developed by: +# Nima Ghorbani +# +# 2020.12.12 + +def pyrenderer(imw=2048, imh=2048): + + from body_visualizer.mesh.mesh_viewer import MeshViewer + import cv2 + + import numpy as np + import trimesh + + try: + mv = MeshViewer(width=imw, height=imh, use_offscreen=True) + except: + import os + os.environ['PYOPENGL_PLATFORM'] = 'egl' + os.environ['EGL_DEVICE_ID'] = os.environ['GPU_DEVICE_ORDINAL'].split(',')[0] + + mv = MeshViewer(width=imw, height=imh, use_offscreen=True) + + mv.set_cam_trans([0, -0.5, 2.]) + + def render_an_image(meshes): + n_all = len(meshes) + nc = int(np.sqrt(n_all)) + + out_image = np.zeros([1, 1, 1, mv.width, mv.height, 4]) + + scale_percent = 100./nc + width = int(mv.width * scale_percent / 100) + height = int(mv.height * scale_percent / 100) + dim = (width, height) + + for rId in range(nc): + for cId in range(nc): + i = (nc*rId) + cId + if i>len(meshes): break + + mesh = meshes[i] + + # mesh.apply_transform(trimesh.transformations.rotation_matrix(np.radians(-90), (1, 0, 0))) + mesh.vertices -= np.median(np.array(mesh.vertices), axis=0) + mv.set_dynamic_meshes([mesh]) + img = mv.render(render_wireframe=False, RGBA=True) + img_resized = cv2.resize(img, dim, interpolation=cv2.INTER_AREA) + + out_image[0, 0, 0, (rId*width):((rId+1)*width), (cId*height):((cId+1)*height)] = cv2.cvtColor(img_resized, cv2.COLOR_BGRA2RGBA) + + return out_image.astype(np.uint8) + + return render_an_image + +def vposer_trainer_renderer(bm, num_bodies_to_display=5): + import numpy as np + import trimesh + import torch + + from body_visualizer.tools.vis_tools import imagearray2file, colors + from human_body_prior.tools.omni_tools import copy2cpu as c2c + from human_body_prior.tools.omni_tools import makepath + from trimesh import Trimesh as Mesh + from trimesh.util import concatenate as mesh_cat + + renderer = pyrenderer(1024, 1024) + + faces = c2c(bm.f) + + def render_once(body_parms, body_colors=[colors['grey'], colors['brown-light']], out_fname=None): + ''' + + :param body_parms: list of dictionaries of body parameters. + :param body_colors: list of np arrays of color rgb values + :param movie_outpath: a mp4 path + :return: + ''' + + if out_fname is not None: makepath(out_fname, isfile=True) + assert len(body_parms) <= len(body_colors), ValueError('Not enough colors provided for #{} body_parms'.format(len(body_parms))) + + bs = body_parms[0]['pose_body'].shape[0] + + body_ids = np.random.choice(bs, num_bodies_to_display) + + body_evals = [c2c(bm(root_orient=v['root_orient'].view(bs, -1) if 'root_orient' in v else torch.zeros(bs, 3).type_as(v['pose_body']), + pose_body=v['pose_body'].contiguous().view(bs, -1)).v) for v in body_parms] + num_verts = body_evals[0].shape[1] + + render_meshes = [] + for bId in body_ids: + concat_cur_meshes = None + for body, body_color in zip(body_evals, body_colors): + cur_body_mesh = Mesh(body[bId], faces, vertex_colors=np.ones([num_verts, 3]) * body_color) + concat_cur_meshes = cur_body_mesh if concat_cur_meshes is None else mesh_cat(concat_cur_meshes, cur_body_mesh) + render_meshes.append(concat_cur_meshes) + + img = renderer(render_meshes) + + if out_fname is not None: imagearray2file(img, out_fname, fps=10) + + + return + + return render_once diff --git a/mogen/datasets/motionverse_dataset.py b/mogen/datasets/motionverse_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e8747f526fdf678f065e5eb107949550d489d69b --- /dev/null +++ b/mogen/datasets/motionverse_dataset.py @@ -0,0 +1,828 @@ +import copy +import os +import pickle as pkl +from typing import Optional, Union, List + +import numpy as np +import torch +import torch.nn as nn +import json +from torch.utils.data import ConcatDataset, Dataset, WeightedRandomSampler +from .builder import DATASETS +from .pipelines import Compose, RetargetSkeleton +import random +import pytorch3d.transforms as geometry +from scipy.ndimage import gaussian_filter +# from mogen.core.evaluation import build_evaluator +# from mogen.core.evaluation.utils import compute_similarity_transform, transform_pose_sequence +from mogen.models.builder import build_submodule +from .utils import copy_repr_data, extract_repr_data, move_repr_data, recover_from_ric + +class SingleMotionVerseDataset(Dataset): + """ + A dataset class for handling single MotionVerse datasets. + + Args: + dataset_name (str): Name of the dataset and task to load. + data_prefix (str): Path to the directory containing the dataset. + ann_file (str): Path to the annotation file. + pipeline (list): A list of transformations to apply on the data. + mode (str): the mode of current work. Choices: ['pretrain', 'train', 'test']. + eval_cfg (dict): Configuration for evaluation metrics. + """ + + def __init__(self, + dataset_path: Optional[str] = None, + task_name: Optional[str] = None, + data_prefix: Optional[str] = None, + ann_file: Optional[str] = None, + pipeline: Optional[List[dict]] = None, + + # for text2motion and speech2gesture + tgt_min_motion_length: int = 20, + tgt_max_motion_length: int = 200, + + # for video2motion + v2m_window_size: int = 20, + + # for motion prediction + mp_input_length: int = 50, + mp_output_length: int = 25, + mp_stride_step: int = 5, + + # for general test + test_rotation_type: str = 'h3d_rot', + target_framerate: float = 20, + eval_cfg: Optional[dict] = None, + test_mode: Optional[bool] = False): + data_prefix = os.path.join(data_prefix, 'datasets', dataset_path) + self.dataset_path = dataset_path + assert task_name in ['mocap', 't2m', 'v2m', 's2g', 'm2d'] + self.task_name = task_name + self.dataset_name = dataset_path + '_' + task_name + + # define subdirectories + self.meta_dir = os.path.join(data_prefix, 'metas') + self.motion_dir = os.path.join(data_prefix, 'motions') + self.eval_motion_dir = os.path.join(data_prefix, 'eval_motions') + self.text_dir = os.path.join(data_prefix, 'texts') + self.text_feat_dir = os.path.join(data_prefix, 'text_feats') + self.speech_dir = os.path.join(data_prefix, 'speeches') + self.speech_feat_dir = os.path.join(data_prefix, 'speech_feats') + self.music_dir = os.path.join(data_prefix, 'musics') + self.music_feat_dir = os.path.join(data_prefix, 'music_feats') + self.video_feat_dir = os.path.join(data_prefix, 'video_feats') + self.anno_file = os.path.join(data_prefix, 'splits', ann_file) + + self.pipeline = Compose(pipeline) + + self.tgt_min_motion_length = tgt_min_motion_length + self.tgt_max_motion_length = tgt_max_motion_length + + self.v2m_window_size = v2m_window_size + + self.mp_input_length = mp_input_length + self.mp_output_length = mp_output_length + self.mp_stride_step = mp_stride_step + + self.target_framerate = target_framerate + self.test_rotation_type = test_rotation_type + self.test_mode = test_mode + self.load_annotations() + self.eval_cfg = copy.deepcopy(eval_cfg) + if self.test_mode: + self.prepare_evaluation() + + def __len__(self) -> int: + """Return the length of the current dataset.""" + if self.test_mode: + return len(self.eval_indexes) + return len(self.name_list) + + def __getitem__(self, idx: int) -> dict: + """Prepare data for the given index.""" + if self.test_mode: + idx = self.eval_indexes[idx] + return self.prepare_data(idx) + + def load_annotations(self): + if self.task_name == 'mocap': + self.load_annotations_mocap() + elif self.task_name == 't2m': + self.load_annotations_t2m() + elif self.task_name == 'v2m': + self.load_annotations_v2m() + elif self.task_name == 's2g': + self.load_annotations_s2g() + elif self.task_name == 'm2d': + self.load_annotations_m2d() + else: + raise NotImplementedError() + + def load_annotations_mocap(self): + if self.test_mode: + self.name_list = [] + self.src_start_frame = [] + self.src_end_frame = [] + self.tgt_start_frame = [] + self.tgt_end_frame = [] + tgt_motion_length = self.mp_input_length + self.mp_output_length + for name in open(self.anno_file): + name = name.strip() + meta_path = os.path.join(self.meta_dir, name + ".json") + meta_data = json.load(open(meta_path)) + num_frames = meta_data['num_frames'] + downrate = int(meta_data['framerate'] / self.target_framerate + 0.1) + if num_frames < (self.mp_input_length + self.mp_output_length) * downrate: + continue + lim = num_frames // downrate - tgt_motion_length + for start_frame in range(0, lim, self.mp_stride_step): + self.name_list.append(name) + self.src_start_frame.append((start_frame + 1) * downrate) + self.src_end_frame.append((start_frame + tgt_motion_length + 1) * downrate) + self.tgt_start_frame.append(start_frame + self.mp_input_length) + self.tgt_end_frame.append(start_frame + tgt_motion_length) + else: + self.name_list = [] + for name in open(self.anno_file): + name = name.strip() + self.name_list.append(name) + + def load_annotations_t2m(self): + self.name_list = [] + self.text_idx = [] + for name in open(self.anno_file): + name = name.strip() + meta_path = os.path.join(self.meta_dir, name + ".json") + meta_data = json.load(open(meta_path)) + downrate = int(meta_data['framerate'] / self.target_framerate + 0.1) + text_path = os.path.join(self.text_dir, name + ".json") + text_data = json.load(open(text_path)) + for i, anno in enumerate(text_data): + start_frame = anno['start_frame'] // downrate + end_frame = min(anno['end_frame'], meta_data['num_frames']) // downrate + num_frame = end_frame - start_frame + if num_frame < self.tgt_min_motion_length or num_frame > self.tgt_max_motion_length: + continue + if len(anno['body_text']) > 0: + self.name_list.append(name) + self.text_idx.append(i) + + def load_annotations_v2m(self): + if not self.test_mode: + self.name_list = [] + for name in open(self.anno_file): + name = name.strip() + self.name_list.append(name) + else: + self.name_list = [] + self.start_frame = [] + self.end_frame = [] + self.valid_start_frame = [] + self.valid_end_frame = [] + for name in open(self.anno_file): + name = name.strip() + meta_path = os.path.join(self.meta_dir, name + ".json") + meta_data = json.load(open(meta_path)) + num_frames = meta_data['num_frames'] + assert num_frames >= self.v2m_window_size + cur_idx = 0 + while cur_idx < num_frames: + if cur_idx + self.v2m_window_size < num_frames: + self.name_list.append(name) + self.start_frame.append(cur_idx) + self.end_frame.append(cur_idx + self.v2m_window_size) + self.valid_start_frame.append(cur_idx) + self.valid_end_frame.append(cur_idx + self.v2m_window_size) + cur_idx += self.v2m_window_size + else: + self.name_list.append(name) + self.start_frame.append(num_frames - self.v2m_window_size) + self.end_frame.append(num_frames) + self.valid_start_frame.append(cur_idx) + self.valid_end_frame.append(num_frames) + break + + def load_annotations_s2g(self): + self.name_list = [] + self.speech_idx = [] + for name in open(self.anno_file): + name = name.strip() + meta_path = os.path.join(self.meta_dir, name + ".json") + meta_data = json.load(open(meta_path)) + downrate = int(meta_data['framerate'] / self.target_framerate + 0.1) + speech_path = os.path.join(self.speech_dir, name + ".json") + speech_data = json.load(open(speech_path)) + for i, anno in enumerate(speech_data): + start_frame = anno['start_frame'] // downrate + end_frame = min(anno['end_frame'], meta_data['num_frames']) // downrate + num_frame = end_frame - start_frame + if num_frame < self.tgt_min_motion_length or num_frame > self.tgt_max_motion_length: + continue + self.name_list.append(name) + self.speech_idx.append(i) + + def load_annotations_m2d(self): + self.name_list = [] + self.music_idx = [] + for name in open(self.anno_file): + name = name.strip() + meta_path = os.path.join(self.meta_dir, name + ".json") + meta_data = json.load(open(meta_path)) + downrate = int(meta_data['framerate'] / self.target_framerate + 0.1) + music_path = os.path.join(self.music_dir, name + ".json") + music_data = json.load(open(music_path)) + for i, anno in enumerate(music_data): + start_frame = anno['start_frame'] // downrate + end_frame = min(anno['end_frame'], meta_data['num_frames']) // downrate + num_frame = end_frame - start_frame + if num_frame < self.tgt_min_motion_length or num_frame > self.tgt_max_motion_length: + continue + self.name_list.append(name) + self.music_idx.append(i) + + def prepare_data_base(self, idx: int) -> dict: + results = {} + name = self.name_list[idx] + results['motion_path'] = os.path.join(self.motion_dir, name + ".npz") + meta_path = os.path.join(self.meta_dir, name + ".json") + meta_data = json.load(open(meta_path)) + meta_data['dataset_name'] = self.dataset_name + results['meta_data'] = meta_data + results['meta_data']['sample_idx'] = idx + results.update({ + 'text_word_feat': np.zeros((77, 1024)).astype(np.float32), + 'text_seq_feat': np.zeros((1024)).astype(np.float32), + 'text_cond': 0, + 'music_word_feat': np.zeros((229, 768)).astype(np.float32), + 'music_seq_feat': np.zeros((1024)).astype(np.float32), + 'music_cond': 0, + 'speech_word_feat': np.zeros((229, 768)).astype(np.float32), + 'speech_seq_feat': np.zeros((1024)).astype(np.float32), + 'speech_cond': 0, + 'video_seq_feat': np.zeros((1024)).astype(np.float32), + 'video_cond': 0, + }) + return results + + def prepare_data(self, idx: int) -> dict: + if self.task_name == 'mocap': + results = self.prepare_data_mocap(idx) + elif self.task_name == 't2m': + results = self.prepare_data_t2m(idx) + elif self.task_name == 'v2m': + results = self.prepare_data_v2m(idx) + elif self.task_name == 's2g': + results = self.prepare_data_s2g(idx) + elif self.task_name == 'm2d': + results = self.prepare_data_m2d(idx) + else: + raise NotImplementedError() + results = self.pipeline(results) + return results + + def prepare_data_mocap(self, idx: int) -> dict: + results = self.prepare_data_base(idx) + if self.test_mode: + results['meta_data']['start_frame'] = self.src_start_frame[idx] + results['meta_data']['end_frame'] = self.src_end_frame[idx] + results['context_mask'] = np.concatenate( + (np.ones((self.mp_input_length - 1)), np.zeros((self.mp_output_length))), + axis=-1 + ) + return results + + def prepare_data_t2m(self, idx: int) -> dict: + results = self.prepare_data_base(idx) + name = self.name_list[idx] + text_idx = self.text_idx[idx] + text_path = os.path.join(self.text_dir, name + ".json") + text_data = json.load(open(text_path))[text_idx] + text_feat_path = os.path.join(self.text_feat_dir, name + ".pkl") + text_feat_data = pkl.load(open(text_feat_path, "rb")) + text_list = text_data['body_text'] + tid = np.random.randint(len(text_list)) + text = text_list[tid] + text_word_feat = text_feat_data['text_word_feats'][text_idx][tid] + text_seq_feat = text_feat_data['text_seq_feats'][text_idx][tid] + assert text_word_feat.shape[0] == 77 + assert text_word_feat.shape[1] == 1024 + assert text_seq_feat.shape[0] == 1024 + + if self.test_mode: + motion_path = os.path.join(self.eval_motion_dir, name + ".npy") + motion_data = np.load(motion_path) + assert not np.isnan(motion_data).any() + downrate = int(results['meta_data']['framerate'] / self.target_framerate + 0.1) + start_frame = text_data['start_frame'] // downrate + end_frame = text_data['end_frame'] // downrate + motion_data = motion_data[start_frame: end_frame] + results['meta_data']['framerate'] = self.target_framerate + results['meta_data']['rotation_type'] = self.test_rotation_type + assert motion_data.shape[0] > 0 + if 'body_tokens' in text_data: + token = text_data['body_tokens'][tid] + else: + token = "" + text_cond = 1 + results.update({ + 'motion': motion_data, + 'text_word_feat': text_word_feat, + 'text_seq_feat': text_seq_feat, + 'text_cond': text_cond, + 'text': text, + 'token': token + }) + else: + results['meta_data']['start_frame'] = text_data['start_frame'] + results['meta_data']['end_frame'] = text_data['end_frame'] + text_cond = 1 + results.update({ + 'text_word_feat': text_word_feat, + 'text_seq_feat': text_seq_feat, + 'text_cond': text_cond + }) + return results + + def prepare_data_v2m(self, idx: int) -> dict: + results = self.prepare_data_base(idx) + name = self.name_list[idx] + video_feat_path = os.path.join(self.video_feat_dir, name + ".pkl") + video_feat_data = pkl.load(open(video_feat_path, "rb")) + video_word_feat = video_feat_data['video_word_feats'] + video_seq_feat = video_feat_data['video_seq_feats'] + assert video_word_feat.shape[0] == results['meta_data']['num_frames'] + assert video_word_feat.shape[1] == 1024 + assert video_seq_feat.shape[0] == 1024 + video_cond = 1 + if self.test_mode: + results['meta_data']['start_frame'] = self.start_frame[idx] + results['meta_data']['end_frame'] = self.end_frame[idx] + motion_path = os.path.join(self.eval_motion_dir, name + ".npy") + motion_data = np.load(motion_path) + assert not np.isnan(motion_data).any() + + start_frame = self.start_frame[idx] + end_frame = self.end_frame[idx] + motion_data = motion_data[start_frame: end_frame] + video_word_feat = video_word_feat[start_frame: end_frame] + results['meta_data']['framerate'] = self.target_framerate + results['meta_data']['rotation_type'] = self.test_rotation_type + assert motion_data.shape[0] > 0 + results.update({ + 'motion': motion_data, + 'video_word_feat': video_word_feat, + 'video_seq_feat': video_seq_feat, + 'video_cond': video_cond + }) + else: + results.update({ + 'video_word_feat': video_word_feat, + 'video_seq_feat': video_seq_feat, + 'video_cond': video_cond + }) + return results + + def prepare_data_s2g(self, idx: int) -> dict: + results = self.prepare_data_base(idx) + name = self.name_list[idx] + speech_idx = self.speech_idx[idx] + speech_path = os.path.join(self.speech_dir, name + ".json") + speech_data = json.load(open(speech_path))[speech_idx] + speech_feat_path = os.path.join(self.speech_feat_dir, name + ".pkl") + speech_feat_data = pkl.load(open(speech_feat_path, "rb")) + try: + speech_word_feat = speech_feat_data['audio_word_feats'][speech_idx] + speech_seq_feat = speech_feat_data['audio_seq_feats'][speech_idx] + except: + speech_word_feat = speech_feat_data['speech_word_feats'][speech_idx] + speech_seq_feat = speech_feat_data['speech_seq_feats'][speech_idx] + del speech_feat_data + assert speech_word_feat.shape[0] == 229 + assert speech_word_feat.shape[1] == 768 + assert speech_seq_feat.shape[0] == 1024 + + results['meta_data']['start_frame'] = speech_data['start_frame'] + results['meta_data']['end_frame'] = speech_data['end_frame'] + speech_cond = 1 + results.update({ + 'speech_word_feat': speech_word_feat, + 'speech_seq_feat': speech_seq_feat, + 'speech_cond': speech_cond + }) + if self.test_mode: + results['meta_data']['framerate'] = self.target_framerate + results['meta_data']['rotation_type'] = self.test_rotation_type + eval_data_path = os.path.join(self.eval_motion_dir, name + ".npz") + eval_data = np.load(eval_data_path) + motion_data = eval_data["bvh_rot_beat141"] + sem_data = eval_data["sem"] + wav_data = eval_data["wave16k"] + assert not np.isnan(motion_data).any() + + start_frame = results['meta_data']['start_frame'] + end_frame = results['meta_data']['end_frame'] + wav_start_frame = start_frame / results['meta_data']['framerate'] * 16000 + wav_end_frame = end_frame / results['meta_data']['framerate'] * 16000 + motion_data = motion_data[start_frame: end_frame] + sem_data = sem_data[start_frame: end_frame] + wav_data = wav_data[wav_start_frame: wav_end_frame] + assert motion_data.shape[0] > 0 + results.update({ + 'motion': motion_data, + 'sem_score': sem_data, + 'wav_feat': wav_data + }) + return results + + def prepare_data_m2d(self, idx: int) -> dict: + results = self.prepare_data_base(idx) + name = self.name_list[idx] + music_idx = self.music_idx[idx] + music_path = os.path.join(self.music_dir, name + ".json") + music_data = json.load(open(music_path))[music_idx] + music_feat_path = os.path.join(self.music_feat_dir, name + ".pkl") + music_feat_data = pkl.load(open(music_feat_path, "rb")) + music_word_feat = music_feat_data['audio_word_feats'][music_idx] + music_seq_feat = music_feat_data['audio_seq_feats'][music_idx] + assert music_word_feat.shape[0] == 229 + assert music_word_feat.shape[1] == 768 + assert music_seq_feat.shape[0] == 1024 + + results['meta_data']['start_frame'] = music_data['start_frame'] + results['meta_data']['end_frame'] = music_data['end_frame'] + music_cond = 1 + results.update({ + 'music_word_feat': music_word_feat, + 'music_seq_feat': music_seq_feat, + 'music_cond': music_cond + }) + return results + + def prepare_evaluation(self): + """ + Prepare the dataset for evaluation by initializing evaluators and creating evaluation indexes. + """ + self.evaluators = [] + self.eval_indexes = [] + self.evaluator_model = build_submodule(self.eval_cfg.get('evaluator_model', None)) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if self.evaluator_model is not None: + self.evaluator_model = self.evaluator_model.to(device) + self.evaluator_model.eval() + self.eval_cfg['evaluator_model'] = self.evaluator_model + + for _ in range(self.eval_cfg['replication_times']): + eval_indexes = np.arange(len(self.name_list)) + if self.eval_cfg.get('shuffle_indexes', False): + np.random.shuffle(eval_indexes) + self.eval_indexes.append(eval_indexes) + + for metric in self.eval_cfg['metrics']: + evaluator, self.eval_indexes = build_evaluator( + metric, self.eval_cfg, len(self.name_list), self.eval_indexes) + self.evaluators.append(evaluator) + + self.eval_indexes = np.concatenate(self.eval_indexes) + + def process_outputs(self, results): + return results + + def evaluate(self, results: List[dict], work_dir: str, logger=None) -> dict: + """ + Evaluate the model performance based on the results. + + Args: + results (list): A list of result dictionaries. + work_dir (str): Directory where evaluation logs will be stored. + logger: Logger object to record evaluation results (optional). + + Returns: + dict: Dictionary containing evaluation metrics. + """ + metrics = {} + results = self.process_outputs(results) + for evaluator in self.evaluators: + metrics.update(evaluator.evaluate(results)) + if logger is not None: + logger.info(metrics) + eval_output = os.path.join(work_dir, 'eval_results.log') + with open(eval_output, 'w') as f: + for k, v in metrics.items(): + f.write(k + ': ' + str(v) + '\n') + return metrics + + +def create_single_dataset(cfg: dict): + dataset_path = cfg['dataset_path'] + if dataset_path == 'amass': + return MotionVerseAMASS(**cfg) + elif dataset_path == 'humanml3d': + return MotionVerseH3D(**cfg) + elif dataset_path == 'kitml': + return MotionVerseKIT(**cfg) + elif dataset_path == 'babel': + return MotionVerseBABEL(**cfg) + elif dataset_path == 'motionx': + return MotionVerseMotionX(**cfg) + elif dataset_path == 'humanact12': + return MotionVerseACT12(**cfg) + elif dataset_path == 'uestc': + return MotionVerseUESTC(**cfg) + elif dataset_path == 'ntu': + return MotionVerseNTU(**cfg) + elif dataset_path == 'h36m': + return MotionVerseH36M(**cfg) + elif dataset_path == 'mpi': + return MotionVerseMPI(**cfg) + elif dataset_path == 'pw3d': + return MotionVersePW3D(**cfg) + elif dataset_path == 'aist': + return MotionVerseAIST(**cfg) + elif dataset_path == 'beat': + return MotionVerseBEAT(**cfg) + elif dataset_path == 'tedg': + return MotionVerseTEDG(**cfg) + elif dataset_path == 'tedex': + return MotionVerseTEDEx(**cfg) + elif dataset_path == 's2g3d': + return MotionVerseS2G3D(**cfg) + else: + raise NotImplementedError() + + +@DATASETS.register_module() +class MotionVerse(Dataset): + """ + A dataset class that handles multiple MotionBench datasets. + + Args: + dataset_cfgs (list[str]): List of dataset configurations. + partitions (list[float]): List of partition weights corresponding to the datasets. + num_data (Optional[int]): Number of data samples to load. Defaults to None. + data_prefix (str): Path to the directory containing the dataset. + """ + + def __init__(self, + dataset_cfgs: List[dict], + partitions: List[int], + num_data: Optional[int] = None, + data_prefix: Optional[str] = None): + """Load data from multiple datasets.""" + assert min(partitions) >= 0 + assert len(dataset_cfgs) == len(partitions) + datasets = [] + new_partitions = [] + for idx, cfg in enumerate(dataset_cfgs): + if partitions[idx] == 0: + continue + new_partitions.append(partitions[idx]) + cfg.update({ + 'data_prefix': data_prefix + }) + datasets.append(create_single_dataset(cfg)) + self.dataset = ConcatDataset(datasets) + if num_data is not None: + self.length = num_data + else: + self.length = max(len(ds) for ds in datasets) + partitions = new_partitions + weights = [np.ones(len(ds)) * p / len(ds) for (p, ds) in zip(partitions, datasets)] + weights = np.concatenate(weights, axis=0) + self.weights = weights + self.task_proj = { + 'mocap': 0, + 't2m': 1, + 'v2m': 2, + 's2g': 3, + 'm2d': 4 + } + self.task_idx_list = [] + for ds in datasets: + self.task_idx_list += [self.task_proj[ds.task_name]] * len(ds) + + def __len__(self) -> int: + """Get the size of the dataset.""" + return self.length + + def __getitem__(self, idx: int) -> dict: + """Given an index, sample data from multiple datasets with the specified proportion.""" + return self.dataset[idx] + + def get_task_idx(self, idx: int) -> int: + return self.task_idx_list[idx] + + +@DATASETS.register_module() +class MotionVerseEval(Dataset): + + def __init__(self, + eval_cfgs: dict, + testset: str, + test_mode: bool = True): + """Load data from multiple datasets.""" + assert testset in eval_cfgs + dataset_path, task_name = testset.split('_') + dataset_cfg = eval_cfgs[testset] + dataset_cfg['dataset_path'] = dataset_path + dataset_cfg['task_name'] = task_name + dataset_cfg['test_mode'] = test_mode + self.dataset = create_single_dataset(dataset_cfg) + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> dict: + return self.dataset[idx] + + def load_annotation(self): + self.dataset.load_annotation() + + def prepare_data(self, idx: int) -> dict: + return self.dataset.prepare_data(idx) + + def prepare_evaluation(self): + self.dataset.prepare_evaluation() + + def process_outputs(self, results): + return self.dataset.process_outputs(results) + + def evaluate(self, results: List[dict], work_dir: str, logger=None) -> dict: + return self.dataset.evaluate(results=results, work_dir=work_dir, logger=logger) + + +@DATASETS.register_module() +class MotionVerseAMASS(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'amass' + task_name = kwargs['task_name'] + assert task_name in ['mocap'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseH3D(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'humanml3d' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 't2m'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseKIT(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'kitml' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 't2m'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseBABEL(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'babel' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 't2m'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseMotionX(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'motionx' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 't2m'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseACT12(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'humanact12' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 't2m'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseUESTC(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'uestc' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 't2m'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseNTU(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'ntu' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 't2m'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseH36M(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'h36m' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 'v2m'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseMPI(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'mpi' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 'v2m'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVersePW3D(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = '3dpw' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 'v2m'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseAIST(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'aist' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 'm2d'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseBEAT(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'beat' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 's2g'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseTEDG(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'tedg' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 's2g'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseTEDEx(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 'tedex' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 's2g'] + super().__init__(**kwargs) + + +@DATASETS.register_module() +class MotionVerseS2G3D(SingleMotionVerseDataset): + + def __init__(self, **kwargs): + if 'dataset_path' not in kwargs: + kwargs['dataset_path'] = 's2g3d' + task_name = kwargs['task_name'] + assert task_name in ['mocap', 's2g'] + super().__init__(**kwargs) + \ No newline at end of file diff --git a/mogen/datasets/paramUtil.py b/mogen/datasets/paramUtil.py new file mode 100644 index 0000000000000000000000000000000000000000..cb5bfbf7372fec43d55a5ca4eb24090a948879d1 --- /dev/null +++ b/mogen/datasets/paramUtil.py @@ -0,0 +1,140 @@ +# coding=utf-8 +# Copyright 2022 The IDEA Authors (Shunlin Lu and Ling-Hao Chen). All rights reserved. +# +# For all the datasets, be sure to read and follow their license agreements, +# and cite them accordingly. +# If the unifier is used in your research, please consider to cite as: +# +# @article{humantomato, +# title={HumanTOMATO: Text-aligned Whole-body Motion Generation}, +# author={Lu, Shunlin and Chen, Ling-Hao and Zeng, Ailing and Lin, Jing and Zhang, Ruimao and Zhang, Lei and Shum, Heung-Yeung}, +# journal={arxiv:2310.12978}, +# year={2023} +# } +# +# @InProceedings{Guo_2022_CVPR, +# author = {Guo, Chuan and Zou, Shihao and Zuo, Xinxin and Wang, Sen and Ji, Wei and Li, Xingyu and Cheng, Li}, +# title = {Generating Diverse and Natural 3D Human Motions From Text}, +# booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, +# month = {June}, +# year = {2022}, +# pages = {5152-5161} +# } +# +# Licensed under the IDEA License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/IDEA-Research/HumanTOMATO/blob/main/LICENSE +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. We provide a license to use the code, +# please read the specific details carefully. +# +# ------------------------------------------------------------------------------------------------ +# Copyright (c) Chuan Guo. +# ------------------------------------------------------------------------------------------------ +# Portions of this code were adapted from the following open-source project: +# https://github.com/EricGuo5513/HumanML3D +# ------------------------------------------------------------------------------------------------ + + +import numpy as np + +# Define a kinematic tree for the skeletal struture +kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] + +kit_raw_offsets = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1] + ] +) + +t2m_raw_body_offsets = np.array([[0,0,0], + [1,0,0], + [-1,0,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,0,1], + [0,0,1], + [0,1,0], + [1,0,0], + [-1,0,0], + [0,0,1], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0]]) + +t2m_raw_hand_offsets = np.array([[1, 0, 0], # left_index1 + [1, 0, 0], # left_index2 + [1, 0, 0], # left_index3 + [1, 0, 0], # left_middle1 + [1, 0, 0], # left_middle2 + [1, 0, 0], # left_middle3 + [1, 0, 0], # left_pinky1 + [1, 0, 0], # left_pinky2 + [1, 0, 0], # left_pinky3 + [1, 0, 0], # left_ring1 + [1, 0, 0], # left_ring2 + [1, 0, 0], # left_ring3 + [1, 0, 0], # left_thumb1 + [1, 0, 0], # left_thumb2 + [1, 0, 0], # left_thumb3 + [-1, 0, 0], # right_index1 + [-1, 0, 0], # right_index2 + [-1, 0, 0], # right_index3 + [-1, 0, 0], # right_middle1 + [-1, 0, 0], # right_middle2 + [-1, 0, 0], # right_middle3 + [-1, 0, 0], # right_pinky1 + [-1, 0, 0], # right_pinky2 + [-1, 0, 0], # right_pinky3 + [-1, 0, 0], # right_ring1 + [-1, 0, 0], # right_ring2 + [-1, 0, 0], # right_ring3 + [-1, 0, 0], # right_thumb1 + [-1, 0, 0], # right_thumb2 + [-1, 0, 0],]) # right_thumb3 + +t2m_raw_offsets = np.concatenate( + (t2m_raw_body_offsets, t2m_raw_hand_offsets), axis=0) +t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] +t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] +t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] + +t2m_body_hand_kinematic_chain = t2m_kinematic_chain + t2m_left_hand_chain + t2m_right_hand_chain + +kit_tgt_skel_id = '03950' + +t2m_tgt_skel_id = '000021' diff --git a/mogen/datasets/pipelines/__init__.py b/mogen/datasets/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..00149bb6027108a17a398206e23f197661cebc7a --- /dev/null +++ b/mogen/datasets/pipelines/__init__.py @@ -0,0 +1,30 @@ +from .compose import Compose +from .formatting import ( + Collect, + ToTensor, + Transpose, + WrapFieldsToLists, + to_tensor +) +from .siamese_motion import ProcessSiameseMotion, SwapSiameseMotion +from .transforms import Crop, Normalize, RandomCrop +from .motionverse import ( + LoadMotion, + RetargetSkeleton, + MotionDownsample, + PutOnFloor, + MoveToOrigin, + RotateToZ, + KeypointsToTomato, + RandomCropKeypoints, + MaskedCrop, + MaskedRandomCrop +) + +__all__ = [ + 'Compose', 'to_tensor', 'Transpose', 'Collect', 'WrapFieldsToLists', + 'ToTensor', 'Crop', 'RandomCrop', 'Normalize', 'SwapSiameseMotion', + 'ProcessSiameseMotion', 'LoadMotion', 'RetargetSkeleton', 'MotionDownsample', + 'PutOnFloor', 'MoveToOrigin', 'RotateToZ', 'KeypointsToTomato', 'RandomCropKeypoints', + 'MaskedCrop', 'MaskedRandomCrop' +] diff --git a/mogen/datasets/pipelines/compose.py b/mogen/datasets/pipelines/compose.py new file mode 100644 index 0000000000000000000000000000000000000000..21960b2225aaaffb342edc90385f85a21d31390a --- /dev/null +++ b/mogen/datasets/pipelines/compose.py @@ -0,0 +1,42 @@ +from collections.abc import Sequence + +from mmcv.utils import build_from_cfg + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Compose(object): + """Compose a data pipeline with a sequence of transforms. + + Args: + transforms (list[dict | callable]): + Either config dicts of transforms or transform objects. + """ + + def __init__(self, transforms): + assert isinstance(transforms, Sequence) + self.transforms = [] + for transform in transforms: + if isinstance(transform, dict): + transform = build_from_cfg(transform, PIPELINES) + self.transforms.append(transform) + elif callable(transform): + self.transforms.append(transform) + else: + raise TypeError('transform must be callable or a dict, but got' + f' {type(transform)}') + + def __call__(self, data): + for t in self.transforms: + data = t(data) + if data is None: + return None + return data + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += f'\n {t}' + format_string += '\n)' + return format_string diff --git a/mogen/datasets/pipelines/formatting.py b/mogen/datasets/pipelines/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..febc74322095caa243502ce4c2a4984ce4152c41 --- /dev/null +++ b/mogen/datasets/pipelines/formatting.py @@ -0,0 +1,135 @@ +from collections.abc import Sequence + +import mmcv +import numpy as np +import torch +from mmcv.parallel import DataContainer as DC + +from ..builder import PIPELINES + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + """ + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not mmcv.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError( + f'Type {type(data)} cannot be converted to tensor.' + 'Supported types are: `numpy.ndarray`, `torch.Tensor`, ' + '`Sequence`, `int` and `float`') + + +@PIPELINES.register_module() +class ToTensor(object): + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + for key in self.keys: + results[key] = to_tensor(results[key]) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PIPELINES.register_module() +class Transpose(object): + + def __init__(self, keys, order): + self.keys = keys + self.order = order + + def __call__(self, results): + for key in self.keys: + results[key] = results[key].transpose(self.order) + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, order={self.order})' + + +@PIPELINES.register_module() +class Collect(object): + """Collect data from the loader relevant to the specific task. + + This is usually the last stage of the data loader pipeline. + + Args: + keys (Sequence[str]): Keys of results to be collected in ``data``. + meta_keys (Sequence[str], optional): Meta keys to be converted to + ``mmcv.DataContainer`` and collected in ``data[motion_metas]``. + Default: ``('filename', 'ori_filename', + 'ori_shape', 'motion_shape', 'motion_mask')`` + + Returns: + dict: The result dict contains the following keys + - keys in``self.keys`` + - ``motion_metas`` if available + """ + + def __init__(self, + keys, + meta_keys=('filename', 'ori_filename', 'ori_shape', + 'motion_shape', 'motion_mask')): + self.keys = keys + self.meta_keys = meta_keys + + def __call__(self, results): + data = {} + motion_meta = {} + for key in self.meta_keys: + if key in results: + motion_meta[key] = results[key] + data['motion_metas'] = DC(motion_meta, cpu_only=True) + for key in self.keys: + data[key] = results[key] + return data + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, meta_keys={self.meta_keys})' + + +@PIPELINES.register_module() +class WrapFieldsToLists(object): + """Wrap fields of the data dictionary into lists for evaluation. + + This class can be used as a last step of a test or validation + pipeline for single image evaluation or inference. + + Example: + >>> test_pipeline = [ + >>> dict(type='LoadImageFromFile'), + >>> dict(type='Normalize', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True), + >>> dict(type='ImageToTensor', keys=['img']), + >>> dict(type='Collect', keys=['img']), + >>> dict(type='WrapIntoLists') + >>> ] + """ + + def __call__(self, results): + # Wrap dict fields into lists + for key, val in results.items(): + results[key] = [val] + return results + + def __repr__(self): + return f'{self.__class__.__name__}()' diff --git a/mogen/datasets/pipelines/motionverse.py b/mogen/datasets/pipelines/motionverse.py new file mode 100644 index 0000000000000000000000000000000000000000..e10ae21b8fd76abfcfa1288d43fd4d38ce7f530c --- /dev/null +++ b/mogen/datasets/pipelines/motionverse.py @@ -0,0 +1,707 @@ +import random +from typing import Optional, Union, List +import pytorch3d.transforms as geometry + +import numpy as np +import torch + +from ..builder import PIPELINES +from ..skeleton import Skeleton +from ..paramUtil import t2m_raw_offsets, t2m_body_hand_kinematic_chain +from ..quaternion import ( + qbetween_np, + qrot_np, + qinv_np, + qmul_np, + quaternion_to_cont6d_np +) + + +@PIPELINES.register_module() +class LoadMotion(object): + r"""Load motion data from a file. + + This pipeline component loads motion data from a specified file path provided in the `results` dictionary. + It randomly selects a motion type from the given `motion_types` list and updates the `results` dictionary + with the corresponding motion data based on the selected motion type. + + Args: + motion_types (List[str]): A list of motion types to choose from. Possible values include: + - `'tomato_repr'`: Load the 'tomato_repr' motion representation. + - `'smpl_rot'`: Load SMPL rotation data. + - `'bvh_rot'`: Load BVH rotation data. + - `'h3d_rot'`: Calculate H3D rotation data(in another pipeline). + max_size (int): The maximum size of the cropped motion + sequence (inclusive). This is only used for `tomato_repr` + """ + + def __init__(self, motion_types: List[str], max_size: int = -1): + self.motion_types = motion_types + self.max_size = max_size + + def __call__(self, results): + """Load motion data and update the results dictionary. + + Args: + results (dict): A dictionary containing the key `'motion_path'`, which specifies the path to the motion data file. + + Returns: + dict: The updated results dictionary with loaded motion data and the selected motion type. + """ + # Load motion data from the specified file + motion_path = results['motion_path'] + motion_data = np.load(motion_path) + + # Randomly select a motion type from the provided list + motion_type = np.random.choice(self.motion_types) + results['motion_type'] = motion_type + + if motion_type == 'tomato_repr': + motion = motion_data['tomato_repr'] + length = motion.shape[0] + assert self.max_size != -1 + actual_length = min(self.max_size, length) + padding_length = self.max_size - actual_length + if padding_length > 0: + D = motion.shape[1:] + padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) + motion = np.concatenate([motion, padding_zeros], axis=0) + else: + motion = motion[:actual_length] + results['motion_length'] = actual_length + results['motion'] = motion + results['motion_shape'] = motion.shape + motion_mask = torch.cat( + (torch.ones(actual_length), + torch.zeros(padding_length)), + dim=0).numpy() + motion_mask = np.expand_dims(motion_mask, axis=1) + motion_mask = np.repeat(motion_mask, 10, axis=1) + meta_data = results['meta_data'] + if not meta_data['has_root']: + motion_mask[:, 0] = 0 + if not meta_data['has_head']: + motion_mask[:, 1] = 0 + if not meta_data['has_stem']: + motion_mask[:, 2] = 0 + if not meta_data['has_larm']: + motion_mask[:, 3] = 0 + if not meta_data['has_rarm']: + motion_mask[:, 4] = 0 + if not meta_data['has_lleg']: + motion_mask[:, 5] = 0 + if not meta_data['has_rleg']: + motion_mask[:, 6] = 0 + if not meta_data['has_lhnd']: + motion_mask[:, 7] = 0 + if not meta_data['has_rhnd']: + motion_mask[:, 8] = 0 + if not meta_data['has_face']: + motion_mask[:, 9] = 0 + results['motion_mask'] = motion_mask + results['meta_data']['rotation_type'] = 'h3d_rot' + if not 'video_word_feat' in results: + results['video_word_feat'] = np.zeros((self.max_size, 1024)) + else: + keypoints3d = motion_data['keypoints3d'] + if keypoints3d.shape[0] == 0: + print(results['motion_path']) + start_frame = results['meta_data'].get('start_frame', 0) + end_frame = results['meta_data'].get('end_frame', keypoints3d.shape[0]) + keypoints3d = keypoints3d.reshape(keypoints3d.shape[0], -1, 3) + if keypoints3d.shape[1] == 24: + keypoints3d = np.concatenate( + (keypoints3d[:, :22, :], np.zeros((keypoints3d.shape[0], 30, 3))), + axis=1 + ) + elif keypoints3d.shape[1] == 22: + keypoints3d = np.concatenate( + (keypoints3d, np.zeros((keypoints3d.shape[0], 30, 3))), + axis=1 + ) + keypoints3d = keypoints3d[start_frame: end_frame] + assert not np.isnan(keypoints3d).any() + results['keypoints3d'] = keypoints3d + if motion_type == 'smpl_rot': + results['rotation'] = motion_data['smpl_rot'][start_frame: end_frame] + assert not np.isnan(results['rotation']).any() + results['meta_data']['rotation_type'] = 'smpl_rot' + elif motion_type == 'bvh_rot': + results['rotation'] = motion_data['bvh_rot'][start_frame: end_frame] + assert not np.isnan(results['rotation']).any() + results['meta_data']['rotation_type'] = 'bvh_rot' + else: + results['meta_data']['rotation_type'] = 'h3d_rot' + if 'expression' in motion_data: + results['expression'] = motion_data['expression'][start_frame: end_frame] + assert not np.isnan(results['expression']).any() + if 'video_word_feat' in results: + results['video_word_feat'] = results['video_word_feat'][start_frame: end_frame] + else: + results['video_word_feat'] = np.zeros((keypoints3d.shape[0], 1024)) + return results + + def __repr__(self): + """Return a string representation of the class.""" + return f"{self.__class__.__name__}(motion_types={self.motion_types}, max_size={self.max_size})" + + +@PIPELINES.register_module() +class RandomCropKeypoints(object): + r"""Random crop keypoints sequences. + + Args: + min_size (int or None): The minimum size of the cropped motion + sequence (inclusive). + max_size (int or None): The maximum size of the cropped motion + sequence (inclusive). + """ + + def __init__(self, + min_size: Optional[Union[int, None]] = None, + max_size: Optional[Union[int, None]] = None): + self.min_size = min_size + self.max_size = max_size + assert self.min_size is not None + assert self.max_size is not None + + def __call__(self, results): + keypoints3d = results['keypoints3d'] + length = len(keypoints3d) + crop_size = random.randint(self.min_size, self.max_size) + if length > crop_size: + idx = random.randint(0, length - crop_size) + keypoints3d = keypoints3d[idx: idx + crop_size] + if 'rotation' in results: + results['rotation'] = results['rotation'][idx: idx + crop_size] + if 'expression' in results: + results['expression'] = results['expression'][idx: idx + crop_size] + if 'video_word_feat' in results: + results['video_word_feat'] = results['video_word_feat'][idx: idx + crop_size] + results['keypoints3d'] = keypoints3d + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(min_size={self.min_size}' + repr_str += f', max_size={self.max_size})' + return repr_str + + +@PIPELINES.register_module() +class RetargetSkeleton(object): + """Retarget motion data to a target skeleton. + + Adjusts motion data from a source skeleton to match a target skeleton structure by scaling and retargeting. + + Args: + tgt_skel_file (str): Path to the file containing the target skeleton data. + + Note: + This code is adapted from: + https://github.com/EricGuo5513/HumanML3D/blob/main/motion_representation.ipynb. + """ + + def __init__(self, tgt_skel_file: str): + skeleton_data = np.load(tgt_skel_file) + skeleton_data = skeleton_data.reshape(len(skeleton_data), -1, 3) + skeleton_data = torch.from_numpy(skeleton_data) + self.raw_offsets = torch.from_numpy(t2m_raw_offsets) + tgt_skeleton = Skeleton(self.raw_offsets, t2m_body_hand_kinematic_chain, 'cpu') + self.tgt_offsets = tgt_skeleton.get_offsets_joints(skeleton_data[0]) + self.tgt_skel_file = tgt_skel_file + + def __call__(self, results): + """Retarget the motion data to the target skeleton. + + Args: + results (dict): Contains 'keypoints3d' with source motion data. + + Returns: + dict: Updated results with retargeted motion data in 'keypoints3d'. + """ + positions = results['keypoints3d'] + src_skel = Skeleton(self.raw_offsets, t2m_body_hand_kinematic_chain, 'cpu') + src_offset = src_skel.get_offsets_joints(torch.from_numpy(positions[0])) + src_offset = src_offset.numpy() + tgt_offset = self.tgt_offsets.numpy() + + # Calculate scale ratio based on leg lengths + l_idx1, l_idx2 = 5, 8 # Indices for leg joints + eps = 1e-5 + src_leg_len = np.linalg.norm(src_offset[l_idx1]) + np.linalg.norm(src_offset[l_idx2]) + tgt_leg_len = np.linalg.norm(tgt_offset[l_idx1]) + np.linalg.norm(tgt_offset[l_idx2]) + if src_leg_len < eps: + a_idx1, a_idx2 = 19, 21 + src_arm_len = np.linalg.norm(src_offset[a_idx1]) + np.linalg.norm(src_offset[a_idx2]) + tgt_arm_len = np.linalg.norm(tgt_offset[a_idx1]) + np.linalg.norm(tgt_offset[a_idx2]) + if src_arm_len < eps: + scale_rt = 1.0 + else: + scale_rt = tgt_arm_len / src_arm_len + else: + scale_rt = tgt_leg_len / src_leg_len + + # Scale root positions + src_root_pos = positions[:, 0] + tgt_root_pos = src_root_pos * scale_rt + + # Perform inverse kinematics to get rotation parameters + face_joint_idx = [2, 1, 17, 16] # Indices for face-related joints + quat_params = src_skel.inverse_kinematics_np(positions, face_joint_idx) + # Set offsets to target skeleton and perform forward kinematics + src_skel.set_offset(self.tgt_offsets) + new_joints = src_skel.forward_kinematics_np(quat_params, tgt_root_pos) + new_joints[np.isnan(new_joints)] = 0 + + if not results['meta_data'].get('has_lhnd', False): + new_joints[:, 22:, :] = 0 + results['keypoints3d'] = new_joints + + return results + + def __repr__(self): + return f"{self.__class__.__name__}(tgt_skel_file='{self.tgt_skel_file}')" + + +@PIPELINES.register_module() +class MotionDownsample(object): + + def __init__(self, framerate_list: List[float]): + self.framerate_list = framerate_list + + def __call__(self, results): + framerate = np.random.choice(self.framerate_list) + downsample_rate = int(results['meta_data']['framerate'] / framerate) + results['meta_data']['framerate'] = framerate + if 'keypoints3d' in results: + results['keypoints3d'] = results['keypoints3d'][::downsample_rate] + if 'rotation' in results: + results['rotation'] = results['rotation'][::downsample_rate] + if 'expression' in results: + results['expression'] = results['expression'][::downsample_rate] + if 'video_word_feat' in results: + results['video_word_feat'] = results['video_word_feat'] [::downsample_rate] + return results + + def __repr__(self): + return f"{self.__class__.__name__}()" + + +@PIPELINES.register_module() +class PutOnFloor(object): + """Shift motion data so that the skeleton stands on the floor. + + This pipeline component adjusts the motion data by translating it vertically, + ensuring that the lowest point of the skeleton aligns with the floor level (y=0). + + Note: + This code is adapted from: + https://github.com/EricGuo5513/HumanML3D/blob/main/motion_representation.ipynb. + """ + + def __init__(self): + pass # No initialization parameters required + + def __call__(self, results): + """Adjust the motion data to place the skeleton on the floor. + + Args: + results (dict): Contains 'keypoints3d' with motion data. + + Returns: + dict: Updated results with adjusted 'keypoints3d'. + """ + positions = results['keypoints3d'] + # Calculate the minimum y-coordinate among the first 22 joints over all frames + floor_height = positions[:, :22, 1].min() + # Shift the y-coordinates so that the lowest point is at y=0 + positions[:, :, 1] -= floor_height + results['keypoints3d'] = positions + return results + + def __repr__(self): + return f"{self.__class__.__name__}()" + + +@PIPELINES.register_module() +class MoveToOrigin(object): + """Translate motion data so the root joint starts at the origin. + + This pipeline component adjusts the motion data by translating it so that + the initial position of the root joint aligns with the origin. + + Note: + This code is adapted from: + https://github.com/EricGuo5513/HumanML3D/blob/main/motion_representation.ipynb. + """ + + def __init__(self, origin: str): + assert origin in ['xz', 'xyz'] + if origin == 'xz': + self.weight = np.array([1, 0, 1]) + elif origin == 'xyz': + self.weight = np.array([1, 1, 1]) + + def __call__(self, results): + """Adjust the motion data to move the root joint to the origin. + + Args: + results (dict): Contains 'keypoints3d' with motion data. + + Returns: + dict: Updated results with adjusted 'keypoints3d'. + """ + positions = results['keypoints3d'] + # Get the initial root joint position (frame 0, joint 0) + root_pos_init = positions[0, 0] + + root_pos_init = root_pos_init * self.weight + positions = positions - root_pos_init + results['keypoints3d'] = positions + return results + + def __repr__(self): + return f"{self.__class__.__name__}()" + + +@PIPELINES.register_module() +class RotateToZ(object): + """Rotate motion data so the initial facing direction aligns with the Z-axis. + + This pipeline component rotates the motion data such that the character's initial + facing direction is aligned with the positive Z-axis, standardizing the orientation + of the motion data. + + Note: + This code is adapted from: + https://github.com/EricGuo5513/HumanML3D/blob/main/motion_representation.ipynb. + """ + + def __init__(self): + pass # No initialization parameters required + + def __call__(self, results): + """Rotate the motion data to align the initial facing direction with the Z-axis. + + Args: + results (dict): Contains 'keypoints3d' with motion data. + + Returns: + dict: Updated results with rotated 'keypoints3d'. + """ + positions = results['keypoints3d'] + # Indices for specific joints used to determine facing direction + face_joint_idx = [2, 1, 17, 16] # Right hip, left hip, right shoulder, left shoulder + r_hip, l_hip, sdr_r, sdr_l = face_joint_idx + + # Calculate the initial across vector from hips and shoulders + pos_init = positions[0] + across1 = pos_init[r_hip] - pos_init[l_hip] + across2 = pos_init[sdr_r] - pos_init[sdr_l] + across = across1 + across2 + eps = 1e-8 + across = across / (np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis] + eps) + + # Calculate the initial forward vector using a cross product with the up vector + forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) + forward_init = forward_init / (np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis] + eps) + + # Compute the rotation quaternion between the initial forward vector and target vector (Z-axis) + target_vector = np.array([[0, 0, 1]]) + root_quat_init = qbetween_np(forward_init, target_vector) + + # Apply the rotation to all joints across all frames + root_quat_init = np.ones(positions.shape[:-1] + (4,)) * root_quat_init + positions = qrot_np(root_quat_init, positions) + positions[np.isnan(positions)] = 0 + results['keypoints3d'] = positions + return results + + def __repr__(self): + return f"{self.__class__.__name__}()" + + +@PIPELINES.register_module() +class KeypointsToTomato(object): + """Convert keypoint motion data to the TOMATO representation. + + This pipeline component transforms 3D keypoints into the TOMATO motion representation, + suitable for motion generation models. + + Note: + Adapted from: + https://github.com/EricGuo5513/HumanML3D/blob/main/motion_representation.ipynb. + """ + + def __init__(self, smooth_forward=True): + self.raw_offsets = torch.from_numpy(t2m_raw_offsets) + self.smooth_forward = smooth_forward + + def get_cont6d_params(self, positions): + """Compute continuous 6D rotation parameters and root velocities.""" + skel = Skeleton(self.raw_offsets, t2m_body_hand_kinematic_chain, "cpu") + face_joint_idx = [2, 1, 17, 16] + + quat_params = skel.inverse_kinematics_np(positions, face_joint_idx, smooth_forward=self.smooth_forward) + quat_params[np.isnan(quat_params)] = 0 + cont_6d_params = quaternion_to_cont6d_np(quat_params) + r_rot = quat_params[:, 0].copy() + + velocity = (positions[1:, 0] - positions[:-1, 0]).copy() + velocity = qrot_np(r_rot[1:], velocity) + + r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1])) + + return cont_6d_params, r_velocity, velocity, r_rot + + def get_rifke(self, r_rot, positions): + """Compute rotation-invariant joint positions.""" + positions[..., 0] -= positions[:, 0:1, 0] + positions[..., 2] -= positions[:, 0:1, 2] + positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions) + return positions + + def __call__(self, results): + """Convert keypoints to TOMATO motion representation.""" + positions = results['keypoints3d'] + global_positions = positions.copy() + + cont_6d_params, r_velocity, velocity, r_rot = self.get_cont6d_params(positions) + positions = self.get_rifke(r_rot, positions) + + root_y = positions[:-1, 0, 1:2] + r_velocity = np.arcsin(r_velocity[:, 2:3]) + l_velocity = velocity[:, [0, 2]] + + root_data = np.concatenate([r_velocity, l_velocity, root_y], axis=-1) + + rot_data = cont_6d_params[:-1, 1:].reshape(len(cont_6d_params) - 1, -1) + motion_type = results['motion_type'] + if motion_type == 'smpl_rot': + rot_data = results['rotation'][:-1] + if rot_data.shape[1] == 72: + num_frames = rot_data.shape[0] + rot_data = rot_data.reshape((num_frames, 24, 3)) + rot_data = np.concatenate(( + rot_data[:, 1: 22, :], + np.zeros((num_frames, 30, 3)) + ), axis=1) + rot_data = torch.from_numpy(rot_data) + rot_data = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(rot_data)) + rot_data = rot_data.numpy().reshape((num_frames, -1)) + elif rot_data.shape[1] == 156: + num_frames = rot_data.shape[0] + rot_data = rot_data.reshape((num_frames, 52, 3))[:, 1:, :] + rot_data = torch.from_numpy(rot_data) + rot_data = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(rot_data)) + rot_data = rot_data.numpy().reshape((num_frames, -1)) + else: + raise NotImplementedError() + ric_data = positions[:-1, 1:].reshape(len(positions) - 1, -1) + + local_vel = qrot_np( + np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1), + global_positions[1:] - global_positions[:-1] + ) + local_vel = local_vel.reshape(len(local_vel), -1) + if not results['meta_data']['has_lhnd']: + ric_data[:, 21 * 3: 51 * 3] = 0 + rot_data[:, 21 * 6: 51 * 6] = 0 + local_vel[:, 22 * 3: 52 * 3] = 0 + data = np.concatenate([root_data, ric_data, rot_data, local_vel], axis=-1) + + if 'expression' in results: + data = np.concatenate([data, results['expression'][:-1]], axis=-1) + else: + data = np.concatenate([data, np.zeros((data.shape[0], 50))], axis=-1) + + if 'video_word_feat' in results: + results['video_word_feat'] = results['video_word_feat'][:-1] + + data[np.isnan(data)] = 0 + results['motion'] = data + return results + + def __repr__(self): + return f"{self.__class__.__name__}()" + + +@PIPELINES.register_module() +class MaskedRandomCrop(object): + r"""Masked Random crop motion sequences. Each sequence will be padded with zeros + to the maximum length. + + Args: + min_size (int or None): The minimum size of the cropped motion + sequence (inclusive). + max_size (int or None): The maximum size of the cropped motion + sequence (inclusive). + """ + + def __init__(self, + min_size: Optional[Union[int, None]] = None, + max_size: Optional[Union[int, None]] = None, + pad_size: Optional[Union[int, None]] = None): + self.min_size = min_size + self.max_size = max_size + assert self.min_size is not None + assert self.max_size is not None + self.pad_size = max_size if pad_size is None else pad_size + + def __call__(self, results): + motion = results['motion'] + length = len(motion) + crop_size = random.randint(self.min_size, self.max_size) + if length > crop_size: + idx = random.randint(0, length - crop_size) + motion = motion[idx: idx + crop_size] + results['motion_length'] = crop_size + if 'video_word_feat' in results: + results['video_word_feat'] = results['video_word_feat'][idx: idx + crop_size] + else: + results['motion_length'] = length + padding_length = self.pad_size - min(crop_size, length) + if padding_length > 0: + D = motion.shape[1:] + padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) + motion = np.concatenate([motion, padding_zeros], axis=0) + if 'video_word_feat' in results: + D = results['video_word_feat'].shape[1] + results['video_word_feat'] = np.concatenate( + [results['video_word_feat'], np.zeros((padding_length, D))], + axis=0 + ) + results['motion'] = motion + results['motion_shape'] = motion.shape + if length >= self.pad_size and crop_size == self.pad_size: + motion_mask = torch.ones(self.pad_size).numpy() + else: + motion_mask = torch.cat( + (torch.ones(min(length, crop_size)), + torch.zeros(self.pad_size - min(length, crop_size))), + dim=0).numpy() + motion_mask = np.expand_dims(motion_mask, axis=1) + motion_mask = np.repeat(motion_mask, 10, axis=1) + meta_data = results['meta_data'] + if not meta_data['has_root']: + motion_mask[:, 0] = 0 + if not meta_data['has_head']: + motion_mask[:, 1] = 0 + if not meta_data['has_stem']: + motion_mask[:, 2] = 0 + if not meta_data['has_larm']: + motion_mask[:, 3] = 0 + if not meta_data['has_rarm']: + motion_mask[:, 4] = 0 + if not meta_data['has_lleg']: + motion_mask[:, 5] = 0 + if not meta_data['has_rleg']: + motion_mask[:, 6] = 0 + if not meta_data['has_lhnd']: + motion_mask[:, 7] = 0 + if not meta_data['has_rhnd']: + motion_mask[:, 8] = 0 + if not meta_data['has_face']: + motion_mask[:, 9] = 0 + results['motion_mask'] = motion_mask + assert len(motion) == self.pad_size + + if 'video_word_feat' in results: + assert len(results['video_word_feat']) == self.pad_size + else: + results['video_word_feat'] = np.zeros((self.pad_size, 1024)) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(min_size={self.min_size}' + repr_str += f', max_size={self.max_size})' + return repr_str + + +@PIPELINES.register_module() +class MaskedCrop(object): + r"""Masked crop motion sequences. Each sequence will be padded with zeros + to the maximum length. + + Args: + min_size (int or None): The minimum size of the cropped motion + sequence (inclusive). + max_size (int or None): The maximum size of the cropped motion + sequence (inclusive). + """ + + def __init__(self, + crop_size: Optional[Union[int, None]] = None, + pad_size: Optional[Union[int, None]] = None): + self.crop_size = crop_size + assert self.crop_size is not None + if pad_size is None: + self.pad_size = self.crop_size + else: + self.pad_size = pad_size + + def __call__(self, results): + motion = results['motion'] + length = len(motion) + crop_size = self.crop_size + pad_size = self.pad_size + if length > crop_size: + idx = random.randint(0, length - crop_size) + motion = motion[idx: idx + crop_size] + results['motion_length'] = crop_size + if 'video_word_feat' in results: + results['video_word_feat'] = results['video_word_feat'][idx: idx + crop_size] + else: + results['motion_length'] = length + actual_length = min(crop_size, length) + padding_length = pad_size - actual_length + if padding_length > 0: + D = motion.shape[1:] + padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) + motion = np.concatenate([motion, padding_zeros], axis=0) + if 'video_word_feat' in results: + D = results['video_word_feat'].shape[1] + results['video_word_feat'] = np.concatenate( + [results['video_word_feat'], np.zeros((padding_length, D))], + axis=0 + ) + results['motion'] = motion + results['motion_shape'] = motion.shape + motion_mask = torch.cat( + (torch.ones(actual_length), + torch.zeros(padding_length)), + dim=0).numpy() + motion_mask = np.expand_dims(motion_mask, axis=1) + motion_mask = np.repeat(motion_mask, 10, axis=1) + meta_data = results['meta_data'] + if not meta_data['has_root']: + motion_mask[:, 0] = 0 + if not meta_data['has_head']: + motion_mask[:, 1] = 0 + if not meta_data['has_stem']: + motion_mask[:, 2] = 0 + if not meta_data['has_larm']: + motion_mask[:, 3] = 0 + if not meta_data['has_rarm']: + motion_mask[:, 4] = 0 + if not meta_data['has_lleg']: + motion_mask[:, 5] = 0 + if not meta_data['has_rleg']: + motion_mask[:, 6] = 0 + if not meta_data['has_lhnd']: + motion_mask[:, 7] = 0 + if not meta_data['has_rhnd']: + motion_mask[:, 8] = 0 + if not meta_data['has_face']: + motion_mask[:, 9] = 0 + results['motion_mask'] = motion_mask + assert len(motion) == pad_size + if 'video_word_feat' in results: + assert len(results['video_word_feat']) == pad_size + else: + results['video_word_feat'] = np.zeros((pad_size, 1024)) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size}' + return repr_str \ No newline at end of file diff --git a/mogen/datasets/pipelines/siamese_motion.py b/mogen/datasets/pipelines/siamese_motion.py new file mode 100644 index 0000000000000000000000000000000000000000..e201d0fef3dd8e7f35db180a8f2898ee0b067ede --- /dev/null +++ b/mogen/datasets/pipelines/siamese_motion.py @@ -0,0 +1,175 @@ +import numpy as np +import torch + +from ..builder import PIPELINES +from ..quaternion import qbetween_np, qinv_np, qmul_np, qrot_np + +face_joint_indx = [2, 1, 17, 16] +fid_l = [7, 10] +fid_r = [8, 11] + +trans_matrix = torch.Tensor([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0], + [0.0, -1.0, 0.0]]) + + +def rigid_transform(relative, data): + + global_positions = data[..., :22 * 3].reshape(data.shape[:-1] + (22, 3)) + global_vel = data[..., 22 * 3:22 * 6].reshape(data.shape[:-1] + (22, 3)) + + relative_rot = relative[0] + relative_t = relative[1:3] + relative_r_rot_quat = np.zeros(global_positions.shape[:-1] + (4, )) + relative_r_rot_quat[..., 0] = np.cos(relative_rot) + relative_r_rot_quat[..., 2] = np.sin(relative_rot) + global_positions = qrot_np(qinv_np(relative_r_rot_quat), global_positions) + global_positions[..., [0, 2]] += relative_t + data[..., :22 * 3] = global_positions.reshape(data.shape[:-1] + (-1, )) + global_vel = qrot_np(qinv_np(relative_r_rot_quat), global_vel) + data[..., 22 * 3:22 * 6] = global_vel.reshape(data.shape[:-1] + (-1, )) + + return data + + +@PIPELINES.register_module() +class SwapSiameseMotion(object): + r"""Swap motion sequences. + + Args: + prob (float): The probability of swapping siamese motions + """ + + def __init__(self, prob=0.5): + self.prob = prob + assert prob >= 0 and prob <= 1.0 + + def __call__(self, results): + if np.random.rand() <= self.prob: + motion1 = results['motion1'] + motion2 = results['motion2'] + results['motion1'] = motion2 + results['motion2'] = motion1 + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(prob={self.prob})' + return repr_str + + +@PIPELINES.register_module() +class ProcessSiameseMotion(object): + r"""Process siamese motion sequences. + The code is borrowed from + https://github.com/tr3e/InterGen/blob/master/utils/utils.py + """ + + def __init__(self, feet_threshold, prev_frames, n_joints, prob): + self.feet_threshold = feet_threshold + self.prev_frames = prev_frames + self.n_joints = n_joints + self.prob = prob + + def process_single_motion(self, motion): + feet_thre = self.feet_threshold + prev_frames = self.prev_frames + n_joints = self.n_joints + '''Uniform Skeleton''' + # positions = uniform_skeleton(positions, tgt_offsets) + + positions = motion[:, :n_joints * 3].reshape(-1, n_joints, 3) + rotations = motion[:, n_joints * 3:] + + positions = np.einsum("mn, tjn->tjm", trans_matrix, positions) + '''Put on Floor''' + floor_height = positions.min(axis=0).min(axis=0)[1] + positions[:, :, 1] -= floor_height + '''XZ at origin''' + root_pos_init = positions[prev_frames] + root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1]) + positions = positions - root_pose_init_xz + '''All initially face Z+''' + r_hip, l_hip, sdr_r, sdr_l = face_joint_indx + across = root_pos_init[r_hip] - root_pos_init[l_hip] + across = across / np.sqrt((across**2).sum(axis=-1))[..., np.newaxis] + + # forward (3,), rotate around y-axis + forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1) + # forward (3,) + forward_init = forward_init / np.sqrt((forward_init**2).sum(axis=-1)) + forward_init = forward_init[..., np.newaxis] + + target = np.array([[0, 0, 1]]) + root_quat_init = qbetween_np(forward_init, target) + root_quat_init_for_all = \ + np.ones(positions.shape[:-1] + (4,)) * root_quat_init + + positions = qrot_np(root_quat_init_for_all, positions) + """ Get Foot Contacts """ + + def foot_detect(positions, thres): + velfactor, heightfactor = \ + np.array([thres, thres]), np.array([0.12, 0.05]) + + feet_l_x = \ + (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2 + feet_l_y = \ + (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2 + feet_l_z = \ + (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2 + feet_l_h = positions[:-1, fid_l, 1] + feet_l_sum = feet_l_x + feet_l_y + feet_l_z + feet_l = ((feet_l_sum < velfactor) & (feet_l_h < heightfactor)) + feet_l = feet_l.astype(np.float32) + + feet_r_x = \ + (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2 + feet_r_y = \ + (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2 + feet_r_z = \ + (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2 + feet_r_h = positions[:-1, fid_r, 1] + feet_r_sum = feet_r_x + feet_r_y + feet_r_z + feet_r = ((feet_r_sum < velfactor) & (feet_r_h < heightfactor)) + feet_r = feet_r.astype(np.float32) + return feet_l, feet_r + + feet_l, feet_r = foot_detect(positions, feet_thre) + '''Get Joint Rotation Representation''' + rot_data = rotations + '''Get Joint Rotation Invariant Position Represention''' + joint_positions = positions.reshape(len(positions), -1) + joint_vels = positions[1:] - positions[:-1] + joint_vels = joint_vels.reshape(len(joint_vels), -1) + + data = joint_positions[:-1] + data = np.concatenate([data, joint_vels], axis=-1) + data = np.concatenate([data, rot_data[:-1]], axis=-1) + data = np.concatenate([data, feet_l, feet_r], axis=-1) + + return data, root_quat_init, root_pose_init_xz[None] + + def __call__(self, results): + motion1, root_quat_init1, root_pos_init1 = \ + self.process_single_motion(results['motion1']) + motion2, root_quat_init2, root_pos_init2 = \ + self.process_single_motion(results['motion2']) + r_relative = qmul_np(root_quat_init2, qinv_np(root_quat_init1)) + angle = np.arctan2(r_relative[:, 2:3], r_relative[:, 0:1]) + + xz = qrot_np(root_quat_init1, root_pos_init2 - root_pos_init1)[:, + [0, 2]] + relative = np.concatenate([angle, xz], axis=-1)[0] + motion2 = rigid_transform(relative, motion2) + if np.random.rand() <= self.prob: + motion2, motion1 = motion1, motion2 + motion = np.concatenate((motion1, motion2), axis=-1) + results['motion'] = motion + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(feet_threshold={self.feet_threshold})' + repr_str += f'(feet_threshold={self.feet_threshold})' + repr_str += f'(n_joints={self.n_joints})' + repr_str += f'(prob={self.prob})' + return repr_str diff --git a/mogen/datasets/pipelines/transforms.py b/mogen/datasets/pipelines/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d305aa0018feda73089226ce114c04d3924e3baa --- /dev/null +++ b/mogen/datasets/pipelines/transforms.py @@ -0,0 +1,172 @@ +import random +from typing import Optional, Union + +import numpy as np +import torch + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Crop(object): + r"""Crop motion sequences. + + Args: + crop_size (int): The size of the cropped motion sequence. + """ + + def __init__(self, crop_size: Optional[Union[int, None]] = None): + self.crop_size = crop_size + assert self.crop_size is not None + + def __call__(self, results): + motion = results['motion'] + length = len(motion) + if length >= self.crop_size: + idx = random.randint(0, length - self.crop_size) + motion = motion[idx:idx + self.crop_size] + results['motion_length'] = self.crop_size + else: + padding_length = self.crop_size - length + D = motion.shape[1:] + padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) + motion = np.concatenate([motion, padding_zeros], axis=0) + results['motion_length'] = length + assert len(motion) == self.crop_size + results['motion'] = motion + results['motion_shape'] = motion.shape + if length >= self.crop_size: + results['motion_mask'] = torch.ones(self.crop_size).numpy() + else: + results['motion_mask'] = torch.cat( + (torch.ones(length), + torch.zeros(self.crop_size - length))).numpy() + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size})' + return repr_str + + +@PIPELINES.register_module() +class PairCrop(object): + r"""Crop motion sequences. + + Args: + crop_size (int): The size of the cropped motion sequence. + """ + + def __init__(self, crop_size: Optional[Union[int, None]] = None): + self.crop_size = crop_size + assert self.crop_size is not None + + def __call__(self, results): + motion = results['motion'] + raw_motion = results['raw_motion'] + length = len(motion) + if length >= self.crop_size: + idx = random.randint(0, length - self.crop_size) + motion = motion[idx:idx + self.crop_size] + raw_motion = raw_motion[idx:idx + self.crop_size] + results['motion_length'] = self.crop_size + else: + padding_length = self.crop_size - length + D = motion.shape[1:] + padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) + motion = np.concatenate([motion, padding_zeros], axis=0) + D = raw_motion.shape[1:] + padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) + raw_motion = np.concatenate([raw_motion, padding_zeros], axis=0) + results['motion_length'] = length + assert len(motion) == self.crop_size + assert len(raw_motion) == self.crop_size + results['motion'] = motion + results['raw_motion'] = raw_motion + results['motion_shape'] = motion.shape + if length >= self.crop_size: + results['motion_mask'] = torch.ones(self.crop_size).numpy() + else: + results['motion_mask'] = torch.cat( + (torch.ones(length), + torch.zeros(self.crop_size - length))).numpy() + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size})' + return repr_str + + +@PIPELINES.register_module() +class RandomCrop(object): + r"""Random crop motion sequences. Each sequence will be padded with zeros + to the maximum length. + + Args: + min_size (int or None): The minimum size of the cropped motion + sequence (inclusive). + max_size (int or None): The maximum size of the cropped motion + sequence (inclusive). + """ + + def __init__(self, + min_size: Optional[Union[int, None]] = None, + max_size: Optional[Union[int, None]] = None): + self.min_size = min_size + self.max_size = max_size + assert self.min_size is not None + assert self.max_size is not None + + def __call__(self, results): + motion = results['motion'] + length = len(motion) + crop_size = random.randint(self.min_size, self.max_size) + if length > crop_size: + idx = random.randint(0, length - crop_size) + motion = motion[idx:idx + crop_size] + results['motion_length'] = crop_size + else: + results['motion_length'] = length + padding_length = self.max_size - min(crop_size, length) + if padding_length > 0: + D = motion.shape[1:] + padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) + motion = np.concatenate([motion, padding_zeros], axis=0) + results['motion'] = motion + results['motion_shape'] = motion.shape + if length >= self.max_size and crop_size == self.max_size: + results['motion_mask'] = torch.ones(self.max_size).numpy() + else: + results['motion_mask'] = torch.cat( + (torch.ones(min(length, crop_size)), + torch.zeros(self.max_size - min(length, crop_size))), + dim=0).numpy() + assert len(motion) == self.max_size + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(min_size={self.min_size}' + repr_str += f', max_size={self.max_size})' + return repr_str + + +@PIPELINES.register_module() +class Normalize(object): + """Normalize motion sequences. + + Args: + mean_path (str): Path of mean file. + std_path (str): Path of std file. + """ + + def __init__(self, mean_path, std_path, eps=1e-9, keys=['motion']): + self.mean = np.load(mean_path) + self.std = np.load(std_path) + self.eps = eps + self.keys = keys + + def __call__(self, results): + for k in self.keys: + motion = results[k] + motion = (motion - self.mean) / (self.std + self.eps) + results[k] = motion + return results diff --git a/mogen/datasets/quaternion.py b/mogen/datasets/quaternion.py new file mode 100644 index 0000000000000000000000000000000000000000..91f98b34630af1876535573174183d76dac095bb --- /dev/null +++ b/mogen/datasets/quaternion.py @@ -0,0 +1,450 @@ +# Copyright (c) 2018-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numpy as np +import torch + +_EPS4 = np.finfo(np.float32).eps * 4.0 + +_FLOAT_EPS = np.finfo(np.float32).eps + +# PyTorch-backed implementations + + +def qinv(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + mask = torch.ones_like(q) + mask[..., 1:] = -mask[..., 1:] + return q * mask + + +def qinv_np(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + return qinv(torch.from_numpy(q).float()).numpy() + + +def qnormalize(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + return q / torch.norm(q, dim=-1, keepdim=True) + + +def qmul(q, r): + """ + Multiply quaternion(s) q with quaternion(s) r. + Expects two equally-sized tensors of shape (*, 4), + where * denotes any number of dimensions. + Returns q*r as a tensor of shape (*, 4). + """ + assert q.shape[-1] == 4 + assert r.shape[-1] == 4 + + original_shape = q.shape + + # Compute outer product + terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) + + w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] + x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] + y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] + z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] + return torch.stack((w, x, y, z), dim=1).view(original_shape) + + +def qrot(q, v): + """ + Rotate vector(s) v about the rotation described by quaternion(s) q. + Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, + where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + original_shape = list(v.shape) + # print(q.shape) + q = q.contiguous().view(-1, 4) + v = v.contiguous().view(-1, 3) + + qvec = q[:, 1:] + uv = torch.cross(qvec, v, dim=1) + uuv = torch.cross(qvec, uv, dim=1) + return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) + + +def qeuler(q, order, epsilon=0, deg=True): + """ + Convert quaternion(s) q to Euler angles. + Expects a tensor of shape (*, 4), where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + + original_shape = list(q.shape) + original_shape[-1] = 3 + q = q.view(-1, 4) + + q0 = q[:, 0] + q1 = q[:, 1] + q2 = q[:, 2] + q3 = q[:, 3] + + if order == 'xyz': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin( + torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + elif order == 'yzx': + x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin( + torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) + elif order == 'zxy': + x = torch.asin( + torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'xzy': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) + z = torch.asin( + torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) + elif order == 'yxz': + x = torch.asin( + torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) + y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) + z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) + elif order == 'zyx': + x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) + y = torch.asin( + torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) + z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) + else: + raise + + if deg: + return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi + else: + return torch.stack((x, y, z), dim=1).view(original_shape) + + +# Numpy-backed implementations + + +def qmul_np(q, r): + q = torch.from_numpy(q).contiguous().float() + r = torch.from_numpy(r).contiguous().float() + return qmul(q, r).numpy() + + +def qrot_np(q, v): + q = torch.from_numpy(q).contiguous().float() + v = torch.from_numpy(v).contiguous().float() + return qrot(q, v).numpy() + + +def qeuler_np(q, order, epsilon=0, use_gpu=False): + if use_gpu: + q = torch.from_numpy(q).cuda().float() + return qeuler(q, order, epsilon).cpu().numpy() + else: + q = torch.from_numpy(q).contiguous().float() + return qeuler(q, order, epsilon).numpy() + + +def qfix(q): + """ + Enforce quaternion continuity across the time dimension by selecting + the representation (q or -q) with minimal distance (or, equivalently, + maximal dot product) + between two consecutive frames. + + Expects a tensor of shape (L, J, 4), where L is the sequence length and + J is the number of joints. + Returns a tensor of the same shape. + """ + assert len(q.shape) == 3 + assert q.shape[-1] == 4 + + result = q.copy() + dot_products = np.sum(q[1:] * q[:-1], axis=2) + mask = dot_products < 0 + mask = (np.cumsum(mask, axis=0) % 2).astype(bool) + result[1:][mask] *= -1 + return result + + +def euler2quat(e, order, deg=True): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.view(-1, 3) + + # if euler angles in degrees + if deg: + e = e * np.pi / 180. + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = torch.stack((torch.cos(x / 2), torch.sin( + x / 2), torch.zeros_like(x), torch.zeros_like(x)), + dim=1) + ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin( + y / 2), torch.zeros_like(y)), + dim=1) + rz = torch.stack((torch.cos( + z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), + dim=1) + + result = None + for coord in order: + if coord == 'x': + r = rx + elif coord == 'y': + r = ry + elif coord == 'z': + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ['xyz', 'yzx', 'zxy']: + result *= -1 + + return result.view(original_shape) + + +def expmap_to_quaternion(e): + """ + Convert axis-angle rotations (aka exponential maps) to quaternions. + Stable formula from "Practical Parameterization of Rotations + Using the Exponential Map". + Expects a tensor of shape (*, 3), where * denotes any number of dimensions. + Returns a tensor of shape (*, 4). + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + e = e.reshape(-1, 3) + + theta = np.linalg.norm(e, axis=1).reshape(-1, 1) + w = np.cos(0.5 * theta).reshape(-1, 1) + xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e + return np.concatenate((w, xyz), axis=1).reshape(original_shape) + + +def euler_to_quaternion(e, order): + """ + Convert Euler angles to quaternions. + """ + assert e.shape[-1] == 3 + + original_shape = list(e.shape) + original_shape[-1] = 4 + + e = e.reshape(-1, 3) + + x = e[:, 0] + y = e[:, 1] + z = e[:, 2] + + rx = np.stack( + (np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), + axis=1) + ry = np.stack( + (np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), + axis=1) + rz = np.stack( + (np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), + axis=1) + + result = None + for coord in order: + if coord == 'x': + r = rx + elif coord == 'y': + r = ry + elif coord == 'z': + r = rz + else: + raise + if result is None: + result = r + else: + result = qmul_np(result, r) + + # Reverse antipodal representation to have a non-negative "w" + if order in ['xyz', 'yzx', 'zxy']: + result *= -1 + + return result.reshape(original_shape) + + +def quaternion_to_matrix(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def quaternion_to_matrix_np(quaternions): + q = torch.from_numpy(quaternions).contiguous().float() + return quaternion_to_matrix(q).numpy() + + +def quaternion_to_cont6d_np(quaternions): + rotation_mat = quaternion_to_matrix_np(quaternions) + cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], + axis=-1) + return cont_6d + + +def quaternion_to_cont6d(quaternions): + rotation_mat = quaternion_to_matrix(quaternions) + cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1) + return cont_6d + + +def cont6d_to_matrix(cont6d): + assert cont6d.shape[-1] == 6, "The last dimension must be 6" + x_raw = cont6d[..., 0:3] + y_raw = cont6d[..., 3:6] + + x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True) + z = torch.cross(x, y_raw, dim=-1) + z = z / torch.norm(z, dim=-1, keepdim=True) + + y = torch.cross(z, x, dim=-1) + + x = x[..., None] + y = y[..., None] + z = z[..., None] + + mat = torch.cat([x, y, z], dim=-1) + return mat + + +def cont6d_to_matrix_np(cont6d): + q = torch.from_numpy(cont6d).contiguous().float() + return cont6d_to_matrix(q).numpy() + + +def qpow(q0, t, dtype=torch.float): + ''' q0 : tensor of quaternions + t: tensor of powers + ''' + q0 = qnormalize(q0) + theta0 = torch.acos(q0[..., 0]) + + # if theta0 is close to zero, add epsilon to avoid NaNs + mask = (theta0 <= 10e-10) * (theta0 >= -10e-10) + theta0 = (1 - mask) * theta0 + mask * 10e-10 + v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1) + + if isinstance(t, torch.Tensor): + q = torch.zeros(t.shape + q0.shape) + theta = t.view(-1, 1) * theta0.view(1, -1) + else: # if t is a number + q = torch.zeros(q0.shape) + theta = t * theta0 + + q[..., 0] = torch.cos(theta) + q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1) + + return q.to(dtype) + + +def qslerp(q0, q1, t): + ''' + q0: starting quaternion + q1: ending quaternion + t: array of points along the way + + Returns: + Tensor of Slerps: t.shape + q0.shape + ''' + + q0 = qnormalize(q0) + q1 = qnormalize(q1) + q_ = qpow(qmul(q1, qinv(q0)), t) + qq = q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape) + qq = qq.expand(t.shape + q0.shape).contiguous() + return qmul(q_, qq) + + +def qbetween(v0, v1): + ''' + find the quaternion used to rotate v0 to v1 + ''' + assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' + assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' + + v = torch.cross(v0, v1) + t = (v0**2).sum(dim=-1, keepdim=True) * (v1**2) + t = t.sum(dim=-1, keepdim=True) + w = torch.sqrt(t) + (v0 * v1).sum(dim=-1, keepdim=True) + return qnormalize(torch.cat([w, v], dim=-1)) + + +def qbetween_np(v0, v1): + ''' + find the quaternion used to rotate v0 to v1 + ''' + assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)' + assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)' + + v0 = torch.from_numpy(v0).float() + v1 = torch.from_numpy(v1).float() + return qbetween(v0, v1).numpy() + + +def lerp(p0, p1, t): + if not isinstance(t, torch.Tensor): + t = torch.Tensor([t]) + + new_shape = t.shape + p0.shape + new_view_t = t.shape + torch.Size([1] * len(p0.shape)) + new_view_p = torch.Size([1] * len(t.shape)) + p0.shape + p0 = p0.view(new_view_p).expand(new_shape) + p1 = p1.view(new_view_p).expand(new_shape) + t = t.view(new_view_t).expand(new_shape) + + return p0 + t * (p1 - p0) diff --git a/mogen/datasets/samplers/__init__.py b/mogen/datasets/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f60f9d4169c4e337c3bf414b154da259e17debd --- /dev/null +++ b/mogen/datasets/samplers/__init__.py @@ -0,0 +1,4 @@ +from .distributed_sampler import DistributedSampler, DistributedWeightedRandomSampler +from .batch_sampler import MonoTaskBatchSampler + +__all__ = ['DistributedSampler', 'MonoTaskBatchSampler', 'DistributedWeightedRandomSampler'] diff --git a/mogen/datasets/samplers/batch_sampler.py b/mogen/datasets/samplers/batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac8e6d936d9cccf8ffc0f34fde516de153ec795 --- /dev/null +++ b/mogen/datasets/samplers/batch_sampler.py @@ -0,0 +1,57 @@ +from typing import Iterator, List + +from torch.utils.data import BatchSampler, Sampler + + +class MonoTaskBatchSampler(BatchSampler): + + def __init__(self, + sampler: Sampler, + batch_size: int, + num_tasks: int, + drop_last: bool = False) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + self._task_buckets = [[] for _ in range(num_tasks)] + self.num_tasks = num_tasks + + def __iter__(self) -> Iterator[List[int]]: + for idx in self.sampler: + bucket_id = self.sampler.dataset.get_task_idx(idx) + bucket = self._task_buckets[bucket_id] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + + # yield the rest data and reset the bucket + left_data = [] + for i in range(self.num_tasks): + if len(self._task_buckets[i]) > 0: + left_data.append(self._task_buckets[i]) + + self._task_buckets = [[] for _ in range(self.num_tasks)] + for data in left_data: + yield data + # while len(left_data) > 0: + # if len(left_data) <= self.batch_size: + # if not self.drop_last: + # yield left_data[:] + # left_data = [] + # else: + # yield left_data[:self.batch_size] + # left_data = left_data[self.batch_size:] + + def __len__(self) -> int: + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size \ No newline at end of file diff --git a/mogen/datasets/samplers/distributed_sampler.py b/mogen/datasets/samplers/distributed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..3beb2ed7260dc1d52d600901060c7276e2956bff --- /dev/null +++ b/mogen/datasets/samplers/distributed_sampler.py @@ -0,0 +1,103 @@ +import torch +from torch.utils.data import DistributedSampler as _DistributedSampler +from typing import Optional, Union, Iterator +import numpy as np + + +class DistributedSampler(_DistributedSampler): + """ + A custom distributed sampler that supports shuffling, round-up of the sample size, + and ensures deterministic shuffling across epochs. + + Args: + dataset: The dataset from which samples are drawn. + num_replicas: Optional; the number of processes participating in the distributed training. + rank: Optional; the rank of the current process among num_replicas. + shuffle: Optional; whether to shuffle the dataset every epoch. Defaults to True. + round_up: Optional; whether to round up the total size to make it divisible among replicas. + Defaults to True. + + Attributes: + shuffle (bool): Whether to shuffle the dataset. + round_up (bool): Whether to round up the total size to make it evenly divisible among replicas. + total_size (int): The total number of samples. + """ + + def __init__(self, + dataset: torch.utils.data.Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + round_up: bool = True): + super().__init__(dataset, num_replicas=num_replicas, rank=rank) + self.shuffle = shuffle + self.round_up = round_up + if self.round_up: + self.total_size = self.num_samples * self.num_replicas + else: + self.total_size = len(self.dataset) + + def __iter__(self) -> Iterator[int]: + """ + Returns an iterator over the indices of the dataset, shuffled if required, + with optional rounding up to make the number of samples divisible among replicas. + + Returns: + Iterator[int]: An iterator over the indices for the current rank. + """ + # deterministically shuffle based on epoch + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + if self.round_up: + indices = ( + indices * + int(self.total_size / len(indices) + 1))[:self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + if self.round_up: + assert len(indices) == self.num_samples + + return iter(indices) + + +class DistributedWeightedRandomSampler(_DistributedSampler): + def __init__(self, + dataset: torch.utils.data.Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + round_up: bool = True): + super().__init__(dataset, num_replicas=num_replicas, rank=rank) + self.shuffle = shuffle + self.round_up = round_up + if self.round_up: + self.total_size = self.num_samples * self.num_replicas + else: + self.total_size = len(self.dataset) + + def __iter__(self) -> Iterator[int]: + weights = self.dataset.weights + indices = np.random.choice(len(weights), size=len(self.dataset), replace=True, p=weights) + indices = indices.tolist() + + # add extra samples to make it evenly divisible + if self.round_up: + indices = ( + indices * + int(self.total_size / len(indices) + 1))[:self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + if self.round_up: + assert len(indices) == self.num_samples + + return iter(indices) \ No newline at end of file diff --git a/mogen/datasets/skeleton.py b/mogen/datasets/skeleton.py new file mode 100644 index 0000000000000000000000000000000000000000..2c37f3a521c94a30fae63238a8cc9c3531d80710 --- /dev/null +++ b/mogen/datasets/skeleton.py @@ -0,0 +1,208 @@ +# ------------------------------------------------------------------------------------------------ +# Copyright (c) Chuan Guo. +# ------------------------------------------------------------------------------------------------ +# This code were adapted from the following open-source project: +# https://github.com/EricGuo5513/HumanML3D +# ------------------------------------------------------------------------------------------------ + +from .quaternion import * +import scipy.ndimage.filters as filters + +class Skeleton(object): + def __init__(self, offset, kinematic_tree, device): + self.device = device + self._raw_offset_np = offset.numpy() + self._raw_offset = offset.clone().detach().to(device).float() + self._kinematic_tree = kinematic_tree + self._offset = None + self._parents = [0] * len(self._raw_offset) + self._parents[0] = -1 + for chain in self._kinematic_tree: + for j in range(1, len(chain)): + self._parents[chain[j]] = chain[j-1] + + def njoints(self): + return len(self._raw_offset) + + def offset(self): + return self._offset + + def set_offset(self, offsets): + self._offset = offsets.clone().detach().to(self.device).float() + + def kinematic_tree(self): + return self._kinematic_tree + + def parents(self): + return self._parents + + # joints (batch_size, joints_num, 3) + def get_offsets_joints_batch(self, joints): + assert len(joints.shape) == 3 + _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone() + for i in range(1, self._raw_offset.shape[0]): + _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i] + + self._offset = _offsets.detach() + return _offsets + + # joints (joints_num, 3) + def get_offsets_joints(self, joints): + assert len(joints.shape) == 2 + _offsets = self._raw_offset.clone() + for i in range(1, self._raw_offset.shape[0]): + # print(joints.shape) + _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i] + + self._offset = _offsets.detach() + return _offsets + + # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder + # joints (batch_size, joints_num, 3) + def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False): + assert len(face_joint_idx) == 4 + '''Get Forward Direction''' + l_hip, r_hip, sdr_r, sdr_l = face_joint_idx + across1 = joints[:, r_hip] - joints[:, l_hip] + across2 = joints[:, sdr_r] - joints[:, sdr_l] + across = across1 + across2 + eps = 1e-8 + across = across / (np.sqrt((across**2).sum(axis=-1))[:, np.newaxis] + eps) + # print(across1.shape, across2.shape) + + # forward (batch_size, 3) + forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1) + if smooth_forward: + forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest') + # forward (batch_size, 3) + eps = 1e-8 + forward = forward / (np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis] + eps) + + '''Get Root Rotation''' + target = np.array([[0,0,1]]).repeat(len(forward), axis=0) + root_quat = qbetween_np(forward, target) + + '''Inverse Kinematics''' + # quat_params (batch_size, joints_num, 4) + # print(joints.shape[:-1]) + quat_params = np.zeros(joints.shape[:-1] + (4,)) + # print(quat_params.shape) + root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]]) + quat_params[:, 0] = root_quat + # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]]) + for chain in self._kinematic_tree: + R = root_quat + for j in range(len(chain) - 1): + # (batch, 3) + u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0) + # print(u.shape) + # (batch, 3) + v = joints[:, chain[j+1]] - joints[:, chain[j]] + eps = 1e-8 + v = v / (np.sqrt((v**2).sum(axis=-1))[:, np.newaxis] + eps) + # print(u.shape, v.shape) + rot_u_v = qbetween_np(u, v) + + R_loc = qmul_np(qinv_np(R), rot_u_v) + quat_params[:,chain[j + 1], :] = R_loc + R = qmul_np(R, R_loc) + + return quat_params + + # Be sure root joint is at the beginning of kinematic chains + def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True): + # quat_params (batch_size, joints_num, 4) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(quat_params.shape[0], -1, -1) + joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + R = quat_params[:, 0] + else: + R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device) + for i in range(1, len(chain)): + R = qmul(R, quat_params[:, chain[i]]) + offset_vec = offsets[:, chain[i]] + joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]] + return joints + + # Be sure root joint is at the beginning of kinematic chains + def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True): + # quat_params (batch_size, joints_num, 4) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(quat_params.shape[0], -1, -1) + offsets = offsets.numpy() + joints = np.zeros(quat_params.shape[:-1] + (3,)) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + R = quat_params[:, 0] + else: + R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0) + for i in range(1, len(chain)): + R = qmul_np(R, quat_params[:, chain[i]]) + offset_vec = offsets[:, chain[i]] + joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]] + return joints + + def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): + # cont6d_params (batch_size, joints_num, 6) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) + offsets = offsets.numpy() + joints = np.zeros(cont6d_params.shape[:-1] + (3,)) + joints[:, 0] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + matR = cont6d_to_matrix_np(cont6d_params[:, 0]) + else: + matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0) + for i in range(1, len(chain)): + matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]])) + offset_vec = offsets[:, chain[i]][..., np.newaxis] + # print(matR.shape, offset_vec.shape) + joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] + return joints + + def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True): + # cont6d_params (batch_size, joints_num, 6) + # joints (batch_size, joints_num, 3) + # root_pos (batch_size, 3) + if skel_joints is not None: + # skel_joints = torch.from_numpy(skel_joints) + offsets = self.get_offsets_joints_batch(skel_joints) + if len(self._offset.shape) == 2: + offsets = self._offset.expand(cont6d_params.shape[0], -1, -1) + joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device) + joints[..., 0, :] = root_pos + for chain in self._kinematic_tree: + if do_root_R: + matR = cont6d_to_matrix(cont6d_params[:, 0]) + else: + matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device) + for i in range(1, len(chain)): + matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]])) + offset_vec = offsets[:, chain[i]].unsqueeze(-1) + # print(matR.shape, offset_vec.shape) + joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]] + return joints + + + + + diff --git a/mogen/datasets/text_motion_dataset.py b/mogen/datasets/text_motion_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bd22498aec9db2eea6ce3928f9c76c8fdc3637b2 --- /dev/null +++ b/mogen/datasets/text_motion_dataset.py @@ -0,0 +1,169 @@ +import copy +import os +import os.path +from typing import Optional, Union, List, Dict + +import numpy as np +import torch +import json + +from .base_dataset import BaseMotionDataset +from .builder import DATASETS + + +@DATASETS.register_module() +class TextMotionDataset(BaseMotionDataset): + """ + TextMotion dataset for handling motion data paired with text descriptions. + + Args: + data_prefix (str): Path to the base directory containing the dataset. + pipeline (list): List of data transformations to apply. + dataset_name (Optional[str]): Name of the dataset. + fixed_length (Optional[int]): Fixed length of data samples (if applicable). + ann_file (Optional[str]): Path to the annotation file. + motion_dir (Optional[str]): Path to the directory containing motion data. + text_dir (Optional[str]): Path to the directory containing text data. + token_dir (Optional[str]): Path to the directory containing token data. + clip_feat_dir (Optional[str]): Path to the directory containing clip feature data. + meta_dir (Optional[str]): Path to the directory containing metadata. + eval_cfg (Optional[dict]): Configuration for evaluation metrics. + test_mode (Optional[bool]): Whether the dataset is in test mode. Defaults to False. + siamese_mode (Optional[bool]): Whether to use Siamese mode (motion1 vs. motion2 comparison). Defaults to False. + tcomb_mode (Optional[bool]): Mode for specific processing (tcomb). Defaults to False. + fine_mode (Optional[bool]): Whether to use fine-grained text processing. Defaults to False. + balanced_sampling (Optional[int]): Number of categories for balanced sampling. If not None, enables balanced sampling. + """ + + def __init__(self, + data_prefix: str, + pipeline: List[Dict], + dataset_name: Optional[Union[str, None]] = None, + fixed_length: Optional[Union[int, None]] = None, + ann_file: Optional[Union[str, None]] = None, + motion_dir: Optional[Union[str, None]] = None, + text_dir: Optional[Union[str, None]] = None, + token_dir: Optional[Union[str, None]] = None, + clip_feat_dir: Optional[Union[str, None]] = None, + meta_dir: Optional[Union[str, None]] = None, + eval_cfg: Optional[Union[dict, None]] = None, + test_mode: Optional[bool] = False, + siamese_mode: Optional[bool] = False, + tcomb_mode: Optional[bool] = False, + fine_mode: Optional[bool] = False, + balanced_sampling: Optional[Union[int, None]] = None): + self.text_dir = os.path.join(data_prefix, 'datasets', dataset_name, text_dir) + self.token_dir = os.path.join(data_prefix, 'datasets', dataset_name, token_dir) if token_dir else None + self.clip_feat_dir = os.path.join(data_prefix, 'datasets', dataset_name, clip_feat_dir) if clip_feat_dir else None + self.meta_dir = os.path.join(data_prefix, 'datasets', dataset_name, meta_dir) if meta_dir else None + self.siamese_mode = siamese_mode + self.tcomb_mode = tcomb_mode + self.fine_mode = fine_mode + self.balanced_sampling = balanced_sampling is not None + + if self.balanced_sampling: + self.category_list = [[] for _ in range(balanced_sampling)] + + super(TextMotionDataset, self).__init__( + data_prefix=data_prefix, + pipeline=pipeline, + dataset_name=dataset_name, + fixed_length=fixed_length, + ann_file=ann_file, + motion_dir=motion_dir, + eval_cfg=eval_cfg, + test_mode=test_mode + ) + + def load_anno(self, idx: int, name: str) -> Dict: + """ + Load a single annotation based on the given index and name. + + Args: + idx (int): Index of the data sample. + name (str): Name of the data sample (typically used as a file identifier). + + Returns: + dict: A dictionary containing the loaded data and relevant information. + """ + results = {} + if self.siamese_mode: + motion_path = os.path.join(self.motion_dir, name + '.npz') + motion_data = np.load(motion_path) + results['motion1'] = motion_data['motion1'] + results['motion2'] = motion_data['motion2'] + assert results['motion1'].shape == results['motion2'].shape + else: + motion_path = os.path.join(self.motion_dir, name + '.npy') + motion_data = np.load(motion_path) + results['motion'] = motion_data + + if self.fine_mode: + text_path = os.path.join(self.text_dir, name + '.json') + text_data = json.load(open(text_path)) + for entry in text_data: + entry.pop('start_frame', None) + entry.pop('end_frame', None) + entry.pop('num_frames', None) + results['text'] = text_data + else: + text_path = os.path.join(self.text_dir, name + '.txt') + results['text'] = [line.strip() for line in open(text_path, 'r')] + + if self.token_dir: + token_path = os.path.join(self.token_dir, name + '.txt') + results['token'] = [line.strip() for line in open(token_path, 'r')] + + if self.clip_feat_dir: + clip_feat_path = os.path.join(self.clip_feat_dir, name + '.npy') + results['clip_feat_path'] = clip_feat_path + # if self.fine_mode: + # results['clip_feat_path'] = clip_feat_path + # else: + # clip_feat = torch.from_numpy(np.load(clip_feat_path)) + # if len(clip_feat.shape) == 2: + # clip_feat = clip_feat.unsqueeze(0) + # results['clip_feat'] = clip_feat + + if self.meta_dir: + score_path = os.path.join(self.meta_dir, name + '_score.npy') + results['score'] = torch.from_numpy(np.load(score_path)) + + if self.balanced_sampling: + assert self.meta_dir is not None + category_path = os.path.join(self.meta_dir, name + '.json') + category = json.load(open(category_path))['category'] + self.category_list[category].append(idx) + + return results + + def prepare_data(self, idx: int) -> Dict: + """ + Prepare raw data for the given index. + + Args: + idx (int): Index of the data sample. + + Returns: + dict: Processed data after applying the pipeline. + """ + results = copy.deepcopy(self.data_infos[idx]) + text_list = results['text'] + selected_idx = np.random.randint(0, len(text_list)) + results['text'] = text_list[selected_idx] + + if 'clip_feat' in results: + results['clip_feat'] = results['clip_feat'][selected_idx] + + if 'clip_feat_path' in results: + clip_feat = torch.from_numpy(np.load(results['clip_feat_path'])) + if len(clip_feat.shape) == 2: + clip_feat = clip_feat.unsqueeze(0) + results['clip_feat'] = clip_feat[selected_idx] + + if 'token' in results: + results['token'] = results['token'][selected_idx] + + results['dataset_name'] = self.dataset_name + results = self.pipeline(results) + return results diff --git a/mogen/datasets/utils.py b/mogen/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ed85a755cdb61a0613660535f2035a4207941241 --- /dev/null +++ b/mogen/datasets/utils.py @@ -0,0 +1,301 @@ +import os +import json +import numpy as np +import torch +from imagebind import data +from imagebind.models.imagebind_model import ModalityType +from mogen.datasets.human_body_prior.body_model.body_model import BodyModel +from mogen.datasets.quaternion import qrot, qinv +from pytorch3d.transforms import axis_angle_to_matrix + + +def recover_root_rot_pos(data): + rot_vel = data[..., 0] + r_rot_ang = torch.zeros_like(rot_vel).to(data.device) + '''Get Y-axis rotation from rotation velocity''' + r_rot_ang[..., 1:] = rot_vel[..., :-1] + r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) + + r_rot_quat = torch.zeros(data.shape[:-1] + (4, )).to(data.device) + r_rot_quat[..., 0] = torch.cos(r_rot_ang) + r_rot_quat[..., 2] = torch.sin(r_rot_ang) + + r_pos = torch.zeros(data.shape[:-1] + (3, )).to(data.device) + r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3] + '''Add Y-axis rotation to root position''' + r_pos = qrot(qinv(r_rot_quat), r_pos) + + r_pos = torch.cumsum(r_pos, dim=-2) + + r_pos[..., 1] = data[..., 3] + return r_rot_quat, r_pos + + +def recover_from_ric(data, joints_num): + r_rot_quat, r_pos = recover_root_rot_pos(data) + positions = data[..., 4:(joints_num - 1) * 3 + 4] + positions = positions.view(positions.shape[:-1] + (-1, 3)) + '''Add Y-axis rotation to local joints''' + rot = qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4, )) + positions = qrot(rot, positions) + '''Add root XZ to joints''' + positions[..., 0] += r_pos[..., 0:1] + positions[..., 2] += r_pos[..., 2:3] + '''Concate root and joints''' + positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2) + + return positions + + +def create_data_item(meta_data, + root_dir, + basename, + tomato_repr=None, + keypoints3d=None, + expression=None, + smpl_rot=None, + bvh_rot=None): + assert os.path.exists(root_dir) + meta_dir = os.path.join(root_dir, 'metas') + motion_dir = os.path.join(root_dir, 'motions') + os.makedirs(meta_dir, exist_ok=True) + os.makedirs(motion_dir, exist_ok=True) + + motion_data = {} + + if tomato_repr is not None: + motion_data['tomato_repr'] = tomato_repr + if keypoints3d is not None: + motion_data['keypoints3d'] = keypoints3d + num_frames = keypoints3d.shape[0] + keypoints3d = keypoints3d.reshape((num_frames, -1)) + if expression is not None: + motion_data['expression'] = expression + if smpl_rot is not None: + motion_data['smpl_rot'] = smpl_rot + if bvh_rot is not None: + motion_data['bvh_rot'] = bvh_rot + + motion_path = os.path.join(motion_dir, basename + '.npz') + meta_path = os.path.join(meta_dir, basename + '.json') + np.savez_compressed(motion_path, **motion_data) + json.dump(meta_data, open(meta_path, 'w'), indent=4) + + +def extract_text_feature(model, text, device): + text_list = text + inputs = { + ModalityType.TEXT: data.load_and_transform_text(text_list, device), + } + with torch.no_grad(): + text_word_feat, text_seq_feat = model(inputs) + return text_word_feat, text_seq_feat + + +def extract_image_feature(model, image_paths, device): + inputs = { + ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), + } + with torch.no_grad(): + _, embeddings = model(inputs) + return embeddings + + +def extract_audio_feature(model, audio_paths, device): + inputs = { + ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device, clips_per_video=1), + } + with torch.no_grad(): + audio_word_feat, audio_seq_feat = model(inputs) + return audio_word_feat, audio_seq_feat + + +def copy_repr_data(src_data, src_idx, num_src_joints, tgt_data, tgt_idx, num_tgt_joints): + # ric_data + tgt_base1 = 4 + (tgt_idx - 1) * 3 + src_base1 = 4 + (src_idx - 1) * 3 + tgt_data[:, tgt_base1: tgt_base1 + 3] = \ + src_data[:, src_base1: src_base1 + 3] + # rot_data + tgt_base2 = 4 + (num_tgt_joints - 1) * 3 + (tgt_idx - 1) * 6 + src_base2 = 4 + (num_src_joints - 1) * 3 + (src_idx - 1) * 6 + tgt_data[:, tgt_base2: tgt_base2 + 6] = \ + src_data[:, src_base2: src_base2 + 6] + # local velocity + tgt_base3 = 4 + (num_tgt_joints - 1) * 9 + tgt_idx * 3 + src_base3 = 4 + (num_src_joints - 1) * 9 + src_idx * 3 + tgt_data[:, tgt_base3: tgt_base3 + 3] = \ + src_data[:, src_base3: src_base3 + 3] + + +def extract_repr_data(data, idx, num_joints): + assert idx > 0 + base1 = 4 + (idx - 1) * 3 + ric_data = data[:, base1: base1 + 3] + base2 = 4 + (num_joints - 1) * 3 + (idx - 1) * 6 + rot_data = data[:, base2: base2 + 6] + base3 = 4 + (num_joints - 1) * 9 + idx * 3 + local_vel = data[:, base3: base3 + 3] + if isinstance(data, torch.Tensor): + output = torch.cat((ric_data, rot_data, local_vel), dim=-1) + else: + output = np.concatenate((ric_data, rot_data, local_vel), axis=-1) + return output + + +def move_repr_data(data, idx, num_joints, output): + assert idx > 0 + assert data.shape[1] == 12 + base1 = 4 + (idx - 1) * 3 + output[:, base1: base1 + 3] = data[:, :3] + base2 = 4 + (num_joints - 1) * 3 + (idx - 1) * 6 + output[:, base2: base2 + 6] = data[:, 3: 9] + base3 = 4 + (num_joints - 1) * 9 + idx * 3 + output[:, base3: base3 + 3] = data[:, 9:] + + +def estimate_repr_data(data, idx1, idx2, tgt, ratio, num_joints): + # direction: same as idx1 + # position: |idx1 - tgt| / |idx1 - idx2| = ratio + assert 0 <= ratio <= 1, "ratio should be between 0 and 1" + assert 1 <= idx1 <= num_joints, "idx1 out of range" + assert 1 <= idx2 <= num_joints, "idx2 out of range" + assert 1 <= tgt <= num_joints, "tgt out of range" + + # ric data + base1 = 4 + (idx1 - 1) * 3 + base2 = 4 + (idx2 - 1) * 3 + baset = 4 + (tgt - 1) * 3 + pose1 = data[:, base1: base1 + 3] + pose2 = data[:, base2: base2 + 3] + poset = pose1 * (1 - ratio) + pose2 * ratio + data[:, baset: baset + 3] = poset + + # rot_data + base1 = 4 + (num_joints - 1) * 3 + (idx1 - 1) * 6 + baset = 4 + (num_joints - 1) * 3 + (tgt - 1) * 6 + data[:, baset: baset + 6] = data[:, base1: base1 + 6] + + # local velocity + base1 = 4 + (num_joints - 1) * 9 + idx1 * 3 + base2 = 4 + (num_joints - 1) * 9 + idx2 * 3 + baset = 4 + (num_joints - 1) * 9 + tgt * 3 + vel1 = data[:, base1: base1 + 3] + vel2 = data[:, base2: base2 + 3] + velt = vel1 * (1 - ratio) + vel2 * ratio + data[:, baset: baset + 3] = velt + + +class BodyModelWrapper: + + def __init__(self, device): + file_path = os.path.abspath(os.path.dirname(__file__)) + body_model_dir = os.path.join(file_path, '../../data/motionverse/body_models') + male_bm_path = os.path.join(body_model_dir, 'smplh/male/model.npz') + male_dmpl_path = os.path.join(body_model_dir, 'dmpls/male/model.npz') + female_bm_path = os.path.join(body_model_dir, 'smplh/female/model.npz') + female_dmpl_path = os.path.join(body_model_dir, 'dmpls/female/model.npz') + neutral_bm_path = os.path.join(body_model_dir, 'smplh/neutral/model.npz') + neutral_dmpl_path = os.path.join(body_model_dir, 'dmpls/neutral/model.npz') + + self.num_betas = 10 # number of body parameters + self.num_dmpls = 8 # number of DMPL parameters + + self.male_bm = BodyModel( + bm_fname=male_bm_path, + num_betas=self.num_betas, + num_dmpls=self.num_dmpls, + dmpl_fname=male_dmpl_path).to(device) + self.female_bm = BodyModel( + bm_fname=female_bm_path, + num_betas=self.num_betas, + num_dmpls=self.num_dmpls, + dmpl_fname=female_dmpl_path).to(device) + self.neutral_bm = BodyModel( + bm_fname=neutral_bm_path, + num_betas=self.num_betas, + num_dmpls=self.num_dmpls, + dmpl_fname=neutral_dmpl_path).to(device) + self.device = device + + def process_smplh(self, smplh_data, downsample=1): + poses = smplh_data['poses'][::downsample] + trans = smplh_data['trans'][::downsample] + betas = smplh_data['betas'] + if len(betas.shape) == 1: + betas = betas[:self.num_betas][np.newaxis] + betas = np.repeat(betas, repeats=len(trans), axis=0) + else: + betas = betas[:, :self.num_betas] + body_parms = { + 'root_orient': torch.Tensor(poses[:, :3]).to(self.device), + 'pose_body': torch.Tensor(poses[:, 3:66]).to(self.device), + 'pose_hand': torch.Tensor(poses[:, 66:]).to(self.device), + 'trans': torch.Tensor(trans).to(self.device), + 'betas': torch.Tensor(betas).to(self.device), + } + gender = smplh_data.get('gender', 'neutral') + if gender == 'male' or gender == 'm': + bm = self.male_bm + elif gender == 'female' or gender == 'f': + bm = self.female_bm + else: + bm = self.neutral_bm + with torch.no_grad(): + body = bm(**body_parms) + pose_seq_np = body.Jtr.detach().cpu().numpy() + return pose_seq_np + + +def ang2joint(p3d0, pose, + parent={0: -1, 1: 0, 2: 0, 3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7, 11: 8, 12: 9, 13: 9, 14: 9, + 15: 12, 16: 13, 17: 14, 18: 16, 19: 17, 20: 18, 21: 19, 22: 20, 23: 21}): + """ + + :param p3d0:[batch_size, joint_num, 3] + :param pose:[batch_size, joint_num, 3] + :param parent: + :return: + """ + def with_zeros(x): + """ + Append a [0, 0, 0, 1] tensor to a [3, 4] tensor. + + Parameter: + --------- + x: Tensor to be appended. + + Return: + ------ + Tensor after appending of shape [4,4] + + """ + ones = torch.tensor( + [[[0.0, 0.0, 0.0, 1.0]]], dtype=torch.float + ).expand(x.shape[0], -1, -1).to(x.device) + ret = torch.cat((x, ones), dim=1) + return ret + batch_num = p3d0.shape[0] + jnum = len(parent.keys()) + J = p3d0 + R_cube_big = axis_angle_to_matrix(pose.contiguous().view(-1, 1, 3)).reshape(batch_num, -1, 3, 3) + results = [] + results.append( + with_zeros(torch.cat((R_cube_big[:, 0], torch.reshape(J[:, 0, :], (-1, 3, 1))), dim=2)) + ) + for i in range(1, jnum): + results.append( + torch.matmul( + results[parent[i]], + with_zeros( + torch.cat( + (R_cube_big[:, i], torch.reshape(J[:, i, :] - J[:, parent[i], :], (-1, 3, 1))), + dim=2 + ) + ) + ) + ) + + stacked = torch.stack(results, dim=1) + J_transformed = stacked[:, :, :3, 3] + return J_transformed \ No newline at end of file diff --git a/mogen/models/__init__.py b/mogen/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eeefc46ef7d490566e96a3311eff270f347ca098 --- /dev/null +++ b/mogen/models/__init__.py @@ -0,0 +1,7 @@ +from .architectures import * # noqa: F401,F403 +from .attentions import * # noqa: F401,F403 +from .builder import * # noqa: F401,F403 +from .losses import * # noqa: F401,F403 +# from .eval_models import * # noqa: F401,F403 +from .transformers import * # noqa: F401,F403 +from .utils import * # noqa: F401,F403 diff --git a/mogen/models/architectures/__init__.py b/mogen/models/architectures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..17045f93d82d9f3a6417c4942800fc678f7ec4aa --- /dev/null +++ b/mogen/models/architectures/__init__.py @@ -0,0 +1,5 @@ +from .diffusion_architecture import MotionDiffusion, UnifiedMotionDiffusion +from .vae_architecture import MotionVAE + + +__all__ = ['MotionVAE', 'MotionDiffusion', 'UnifiedMotionDiffusion'] diff --git a/mogen/models/architectures/base_architecture.py b/mogen/models/architectures/base_architecture.py new file mode 100644 index 0000000000000000000000000000000000000000..68e06b6a19634fdc1d3c4a084f9b9ec1dbf538bc --- /dev/null +++ b/mogen/models/architectures/base_architecture.py @@ -0,0 +1,155 @@ +from collections import OrderedDict +import torch +import torch.distributed as dist +from mmcv.runner import BaseModule +from typing import Dict, Tuple, List + + +def to_cpu(x: torch.Tensor) -> torch.Tensor: + """Move a tensor to CPU and detach it from the computation graph. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The tensor detached and moved to CPU. + """ + if isinstance(x, torch.Tensor): + return x.detach().cpu() + return x + + +class BaseArchitecture(BaseModule): + """Base class for mogen architecture. + + Args: + init_cfg (dict, optional): Initialization config for the module. + """ + + def __init__(self, init_cfg: dict = None): + super(BaseArchitecture, self).__init__(init_cfg) + + def forward_train(self, **kwargs): + """Forward computation during training.""" + pass + + def forward_test(self, **kwargs): + """Forward computation during testing.""" + pass + + def _parse_losses(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, float]]: + """Parse the raw outputs (losses) of the network. + + Args: + losses (dict): Raw output of the network, which usually contains + losses and other necessary information. + + Returns: + tuple[Tensor, dict]: (loss, log_vars) + - loss is the loss tensor which may be a weighted sum of all losses, + - log_vars contains all the variables to be logged. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError(f'{loss_name} is not a tensor or list of tensors') + + loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) + + log_vars['loss'] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars + + def train_step(self, data: Dict, optimizer: torch.optim.Optimizer) -> Dict: + """The iteration step during training. + + This method defines an iteration step during training, excluding backpropagation + and optimizer updating, which are handled by an optimizer hook. + + Args: + data (dict): The output of the dataloader. + optimizer (torch.optim.Optimizer): The optimizer object (unused). + + Returns: + dict: A dictionary containing the loss, log_vars for logging, and the number of samples. + - ``loss``: A tensor for backpropagation, which may be a weighted sum of multiple losses. + - ``log_vars``: All the variables to be logged. + - ``num_samples``: The number of samples in the batch. + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data['motion'])) + return outputs + + def val_step(self, data: Dict, optimizer: torch.optim.Optimizer = None) -> Dict: + """The iteration step during validation. + + Args: + data (dict): The output of the dataloader. + optimizer (torch.optim.Optimizer, optional): The optimizer object (unused). + + Returns: + dict: A dictionary containing the loss, log_vars for logging, and the number of samples. + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data['motion'])) + return outputs + + def forward(self, **kwargs): + """Forward computation based on the training or testing mode.""" + if self.training: + return self.forward_train(**kwargs) + else: + return self.forward_test(**kwargs) + + def split_results(self, results: Dict[str, torch.Tensor]) -> List[Dict]: + """Split batched results into individual outputs. + + Args: + results (dict): The batched results from the model containing 'motion', 'pred_motion', etc. + + Returns: + list: A list of dictionaries where each dictionary contains results for a single instance. + """ + B = results['motion'].shape[0] + output = [] + for i in range(B): + batch_output = dict() + batch_output['motion'] = to_cpu(results['motion'][i]) + batch_output['pred_motion'] = to_cpu(results['pred_motion'][i]) + batch_output['motion_length'] = to_cpu(results['motion_length'][i]) + batch_output['motion'][batch_output['motion_length']:, :] = 0 + batch_output['motion_mask'] = to_cpu(results['motion_mask'][i]) + if 'pred_motion_length' in results: + batch_output['pred_motion_length'] = to_cpu(results['pred_motion_length'][i]) + else: + batch_output['pred_motion_length'] = to_cpu(results['motion_length'][i]) + batch_output['pred_motion'][batch_output['pred_motion_length']:, :] = 0 + if 'pred_motion_mask' in results: + batch_output['pred_motion_mask'] = to_cpu(results['pred_motion_mask'][i]) + else: + batch_output['pred_motion_mask'] = to_cpu(results['motion_mask'][i]) + if 'motion_metas' in results: + motion_metas = results['motion_metas'][i] + if 'text' in motion_metas: + batch_output['text'] = motion_metas['text'] + if 'token' in motion_metas: + batch_output['token'] = motion_metas['token'] + if 'meta_data' in motion_metas and 'category_id' in motion_metas['meta_data']: + batch_output['category_id'] = motion_metas['meta_data']['category_id'] + batch_output['motion_metas'] = motion_metas + output.append(batch_output) + return output diff --git a/mogen/models/architectures/diffusion_architecture.py b/mogen/models/architectures/diffusion_architecture.py new file mode 100644 index 0000000000000000000000000000000000000000..6bfc0eb04d536f566db4a31d8df130a7a0d0e70f --- /dev/null +++ b/mogen/models/architectures/diffusion_architecture.py @@ -0,0 +1,428 @@ +import torch +import torch.nn.functional as F +import numpy as np +from typing import Optional, List, Dict, Union + +from ..builder import ARCHITECTURES, build_loss, build_submodule +from ..utils.gaussian_diffusion import create_named_schedule_sampler, build_diffusion +from ..utils.mask_helper import expand_mask_to_all +from .base_architecture import BaseArchitecture + + +def set_requires_grad(nets: Union[torch.nn.Module, List[torch.nn.Module]], requires_grad: bool = False): + """Set requires_grad for all the networks. + + Args: + nets (nn.Module | list[nn.Module]): A list of networks or a single network. + requires_grad (bool): Whether the networks require gradients or not. + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + +@ARCHITECTURES.register_module() +class MotionDiffusion(BaseArchitecture): + """ + Motion Diffusion architecture for modeling and generating motion sequences using diffusion models. + + Args: + dataset_name (Optional[str]): Name of the dataset being used (e.g., 'kit_ml', 'human_ml3d'). + model (dict): Configuration for the submodule (e.g., the motion generation model). + loss_recon (dict): Configuration for the reconstruction loss. + loss_reduction (str): Specifies the reduction method for the loss. Defaults to 'frame'. + use_loss_score (bool): Whether to use a scoring mechanism for loss calculation. Defaults to False. + diffusion_train (dict): Configuration for the diffusion model during training. + diffusion_test (dict): Configuration for the diffusion model during testing. + sampler_type (str): The type of sampler to use. Defaults to 'uniform'. + init_cfg (dict): Initialization config for the module. + inference_type (str): Type of inference to use ('ddpm' or 'ddim'). Defaults to 'ddpm'. + """ + + def __init__(self, + dataset_name: Optional[str] = None, + model: dict = None, + loss_recon: dict = None, + loss_reduction: str = "frame", + use_loss_score: bool = False, + diffusion_train: dict = None, + diffusion_test: dict = None, + sampler_type: str = 'uniform', + init_cfg: dict = None, + inference_type: str = 'ddpm', + **kwargs): + super().__init__(init_cfg=init_cfg, **kwargs) + self.model = build_submodule(model) + self.loss_recon = build_loss(loss_recon) + self.diffusion_train = build_diffusion(diffusion_train) + self.diffusion_test = build_diffusion(diffusion_test) + self.sampler = create_named_schedule_sampler(sampler_type, self.diffusion_train) + self.inference_type = inference_type + self.loss_reduction = loss_reduction + self.use_loss_score = use_loss_score + self.dataset_name = dataset_name + + if self.dataset_name == "kit_ml": + self.mean = np.load("data/datasets/kit_ml/mean.npy") + self.std = np.load("data/datasets/kit_ml/std.npy") + elif self.dataset_name == "human_ml3d": + self.mean = np.load("data/datasets/human_ml3d/mean.npy") + self.std = np.load("data/datasets/human_ml3d/std.npy") + elif self.dataset_name is not None: + raise NotImplementedError() + + + def forward(self, **kwargs) -> Union[Dict, List]: + """Forward pass of the model. + + Depending on whether the model is in training mode, this method performs the forward pass + during training or inference, and calculates the relevant losses. + + Args: + **kwargs: Keyword arguments containing the input data for the model. + + Returns: + dict or list: The calculated losses during training or the generated motion during inference. + """ + motion = kwargs['motion'].float() + motion_mask = kwargs['motion_mask'].float() + motion_length = kwargs['motion_length'] + num_intervals = kwargs.get('num_intervals', 1) + sample_idx = kwargs.get('sample_idx', None) + clip_feat = kwargs.get('clip_feat', None) + B, T = motion.shape[:2] + text = [kwargs['motion_metas'][i]['text'] for i in range(B)] + + if self.training: + t, _ = self.sampler.sample(B, motion.device) + output = self.diffusion_train.training_losses( + model=self.model, + x_start=motion, + t=t, + model_kwargs={ + 'motion_mask': motion_mask, + 'motion_length': motion_length, + 'text': text, + 'clip_feat': clip_feat, + 'sample_idx': sample_idx, + 'num_intervals': num_intervals + } + ) + pred, target = output['pred'], output['target'] + recon_loss = self.loss_recon(pred, target, reduction_override='none') + + if self.use_loss_score: + loss_score = kwargs['score'] + recon_loss = recon_loss * loss_score.view(B, 1, -1) + + recon_loss = recon_loss.mean(dim=-1) * motion_mask + recon_loss_batch = recon_loss.sum(dim=1) / motion_mask.sum(dim=1) + recon_loss_frame = recon_loss.sum() / motion_mask.sum() + + if self.loss_reduction == "frame": + recon_loss = recon_loss_frame + else: + recon_loss = recon_loss_batch + + if hasattr(self.sampler, "update_with_local_losses"): + self.sampler.update_with_local_losses(t, recon_loss_batch) + + loss = {'recon_loss': recon_loss.mean()} + if hasattr(self.model, 'aux_loss'): + loss.update(self.model.aux_loss()) + return loss + + else: + dim_pose = kwargs['motion'].shape[-1] + model_kwargs = self.model.get_precompute_condition( + device=motion.device, text=text, **kwargs + ) + model_kwargs.update({ + 'motion_mask': motion_mask, + 'sample_idx': sample_idx, + 'motion_length': motion_length, + 'num_intervals': num_intervals + }) + + inference_kwargs = kwargs.get('inference_kwargs', {}) + if self.inference_type == 'ddpm': + output = self.diffusion_test.p_sample_loop( + self.model, (B, T, dim_pose), clip_denoised=False, progress=False, + model_kwargs=model_kwargs, **inference_kwargs + ) + else: + output = self.diffusion_test.ddim_sample_loop( + self.model, (B, T, dim_pose), clip_denoised=False, progress=False, + model_kwargs=model_kwargs, eta=0, **inference_kwargs + ) + + results = kwargs + if getattr(self.model, "post_process") is not None: + output = self.model.post_process(output) + + results['pred_motion'] = output + results = self.split_results(results) + return results + + +@ARCHITECTURES.register_module() +class UnifiedMotionDiffusion(BaseArchitecture): + """ + Unified Motion Diffusion architecture for generating motion sequences using diffusion models. + + Args: + model (dict): Configuration for the motion generation model. + loss_recon (dict): Configuration for the reconstruction loss. + loss_reduction (str): Specifies the reduction method for the loss. Defaults to 'frame'. + random_mask (float): Probability or scaling factor for applying random masking. Defaults to 0. + diffusion_train (dict): Configuration for the diffusion model during training. + diffusion_test (dict): Configuration for the diffusion model during testing. + sampler_type (str): The type of sampler to use. Defaults to 'uniform'. + init_cfg (dict): Initialization config for the module. + inference_type (str): Type of inference to use ('ddpm' or 'ddim'). Defaults to 'ddpm'. + body_scale (float): Scaling factor for the body motion mask. Defaults to 1.0. + hand_scale (float): Scaling factor for the hand motion mask. Defaults to 1.0. + face_scale (float): Scaling factor for the face motion mask. Defaults to 1.0. + """ + + def __init__(self, + model: dict = None, + loss_recon: dict = None, + loss_reduction: str = "frame", + random_mask: float = 0, + diffusion_train: dict = None, + diffusion_test_dict: dict = None, + sampler_type: str = 'uniform', + init_cfg: dict = None, + inference_type: str = 'ddpm', + body_scale: float = 1.0, + hand_scale: float = 1.0, + face_scale: float = 1.0, + train_repeat: int = 1, + loss_weight: str = None, + **kwargs): + super().__init__(init_cfg=init_cfg, **kwargs) + self.model = build_submodule(model) + self.loss_recon = build_loss(loss_recon) + self.diffusion_train = build_diffusion(diffusion_train) + self.diffusion_test_dict = diffusion_test_dict + self.sampler = create_named_schedule_sampler(sampler_type, self.diffusion_train) + self.inference_type = inference_type + self.loss_reduction = loss_reduction + self.random_mask = random_mask + self.body_scale = body_scale + self.hand_scale = hand_scale + self.face_scale = face_scale + self.train_repeat = train_repeat + self.loss_weight = None + if init_cfg is not None: + self.init_weights() + + def repeat_data(self, **kwargs): + if self.train_repeat == 1: + return kwargs + N = self.train_repeat + motion = kwargs['motion'].float().repeat(N, 1, 1) + B = motion.shape[0] + kwargs['motion'] = motion + + motion_mask = kwargs['motion_mask'].float().repeat(N, 1, 1) + kwargs['motion_mask'] = motion_mask + + motion_length = kwargs['motion_length'].repeat(N, 1) + kwargs['motion_length'] = motion_length + + motion_metas = kwargs['motion_metas'] * N + kwargs['motion_metas'] = motion_metas + + if 'text_seq_feat' in kwargs: + kwargs['text_seq_feat'] = kwargs['text_seq_feat'].repeat(N, 1) + if 'text_word_feat' in kwargs: + kwargs['text_word_feat'] = kwargs['text_word_feat'].repeat(N, 1, 1) + if 'text_cond' in kwargs: + kwargs['text_cond'] = kwargs['text_cond'].repeat(N, 1) + + if 'music_seq_feat' in kwargs: + kwargs['music_seq_feat'] = kwargs['music_seq_feat'].repeat(N, 1) + if 'music_word_feat' in kwargs: + kwargs['music_word_feat'] = kwargs['music_word_feat'].repeat(N, 1, 1) + if 'music_cond' in kwargs: + kwargs['music_cond'] = kwargs['music_cond'].repeat(N, 1) + + if 'speech_seq_feat' in kwargs: + kwargs['speech_seq_feat'] = kwargs['speech_seq_feat'].repeat(N, 1) + if 'speech_word_feat' in kwargs: + kwargs['speech_word_feat'] = kwargs['speech_word_feat'].repeat(N, 1, 1) + if 'speech_cond' in kwargs: + kwargs['speech_cond'] = kwargs['speech_cond'].repeat(N, 1) + + if 'video_seq_feat' in kwargs: + kwargs['video_seq_feat'] = kwargs['video_seq_feat'].repeat(N, 1) + if 'video_word_feat' in kwargs: + kwargs['video_word_feat'] = kwargs['video_word_feat'].repeat(N, 1, 1) + if 'video_cond' in kwargs: + kwargs['video_cond'] = kwargs['video_cond'].repeat(N, 1) + return kwargs + + + def forward(self, **kwargs) -> Dict: + """Forward pass for training or inference in the unified motion diffusion model. + + Args: + **kwargs: Keyword arguments containing the input data for the model. + + Returns: + dict: The calculated losses during training or the generated motion during inference. + """ + if self.training: + kwargs = self.repeat_data(**kwargs) + + motion = kwargs['motion'].float() + B, T = motion.shape[:2] + motion_mask = kwargs['motion_mask'].float() + motion_length = kwargs['motion_length'] + num_intervals = kwargs.get('num_intervals', 1) + sample_idx = kwargs.get('sample_idx', None) + motion_metas = kwargs['motion_metas'] + + # Conditioning features (text, music, speech, video) + text_word_feat = kwargs.get('text_word_feat', None) + text_seq_feat = kwargs.get('text_seq_feat', None) + text_cond = kwargs.get('text_cond', torch.zeros(B).type_as(motion)) + + music_word_feat = kwargs.get('music_word_feat', None) + music_seq_feat = kwargs.get('music_seq_feat', None) + music_cond = kwargs.get('music_cond', torch.zeros(B).type_as(motion)) + + speech_word_feat = kwargs.get('speech_word_feat', None) + speech_seq_feat = kwargs.get('speech_seq_feat', None) + speech_cond = kwargs.get('speech_cond', torch.zeros(B).type_as(motion)) + + video_word_feat = kwargs.get('video_word_feat', None) + video_seq_feat = kwargs.get('video_seq_feat', None) + video_cond = kwargs.get('video_cond', torch.zeros(B).type_as(motion)) + + if self.training: + # Random masking during training + t, _ = self.sampler.sample(B, motion.device) + + # rand_mask = torch.rand_like(motion_mask) + # new_motion_mask = motion_mask.clone() + # threshold = torch.rand(B).type_as(rand_mask) + # threshold = threshold.view(B, 1, 1).repeat(1, T, 10) + # new_motion_mask[rand_mask < threshold] = 0 + # motion_mask = new_motion_mask + + output = self.diffusion_train.training_losses( + model=self.model, + x_start=motion, + t=t, + model_kwargs={ + 'motion_mask': motion_mask, + 'motion_length': motion_length, + 'num_intervals': num_intervals, + 'motion_metas': motion_metas, + 'text_word_feat': text_word_feat, + 'text_seq_feat': text_seq_feat, + 'text_cond': text_cond, + 'music_word_feat': music_word_feat, + 'music_seq_feat': music_seq_feat, + 'music_cond': music_cond, + 'speech_word_feat': speech_word_feat, + 'speech_seq_feat': speech_seq_feat, + 'speech_cond': speech_cond, + 'video_word_feat': video_word_feat, + 'video_seq_feat': video_seq_feat, + 'video_cond': video_cond, + }) + pred, target = output['pred'], output['target'] + recon_loss = self.loss_recon(pred, target, reduction_override='none') + # Apply expanded motion mask + motion_mask = expand_mask_to_all( + motion_mask, self.body_scale, self.hand_scale, self.face_scale + ) + if self.loss_weight is not None: + loss_weight = torch.from_numpy(self.loss_weight).type_as(motion_mask) + dataset_idx = self.model.dataset_idx + loss_weight = loss_weight.index_select(0, dataset_idx).unsqueeze(1) + motion_mask = motion_mask * loss_weight + recon_loss = (recon_loss * motion_mask).sum(dim=-1) + motion_mask = motion_mask.sum(dim=-1) + else: + recon_loss = (recon_loss * motion_mask).mean(dim=-1) + motion_mask = motion_mask.mean(dim=-1) + + recon_loss_batch = recon_loss.sum(dim=1) / motion_mask.sum(dim=1) + recon_loss_frame = recon_loss.sum() / motion_mask.sum() + + # Determine final reconstruction loss + if self.loss_reduction == "frame": + recon_loss = recon_loss_frame + else: + recon_loss = recon_loss_batch + + if hasattr(self.sampler, "update_with_local_losses"): + self.sampler.update_with_local_losses(t, recon_loss_batch) + + loss = {'recon_loss': recon_loss.mean()} + + # Add auxiliary loss if applicable + if hasattr(self.model, 'aux_loss'): + loss.update(self.model.aux_loss()) + + return loss + else: + # Inference (DDPM or DDIM sampling) + dim_pose = 669 # Fixed dimension for the motion output + model_kwargs = self.model.get_precompute_condition( + device=motion.device, **kwargs + ) + model_kwargs.update({ + 'motion_mask': motion_mask, + 'sample_idx': sample_idx, + 'motion_length': motion_length, + 'num_intervals': num_intervals, + 'motion_metas': motion_metas, + 'text_word_feat': text_word_feat, + 'text_seq_feat': text_seq_feat, + 'text_cond': text_cond, + 'music_word_feat': music_word_feat, + 'music_seq_feat': music_seq_feat, + 'music_cond': music_cond, + 'speech_word_feat': speech_word_feat, + 'speech_seq_feat': speech_seq_feat, + 'speech_cond': speech_cond, + 'video_word_feat': video_word_feat, + 'video_seq_feat': video_seq_feat, + 'video_cond': video_cond, + }) + inference_kwargs = kwargs.get('inference_kwargs', {}) + inference_kwargs['gt_motion'] = motion + inference_kwargs['context_mask'] = kwargs.get('context_mask', None) + dataset_name = motion_metas[0]['meta_data']['dataset_name'] + diffusion_test_cfg = self.diffusion_test_dict['base'] + diffusion_test_cfg.update(dict(respace=self.diffusion_test_dict[dataset_name])) + diffusion_test = build_diffusion(diffusion_test_cfg) + if self.inference_type == 'ddpm': + output = diffusion_test.p_sample_loop( + self.model, (B, T, dim_pose), clip_denoised=False, + progress=False, model_kwargs=model_kwargs, **inference_kwargs + ) + else: + output = diffusion_test.ddim_sample_loop( + self.model, (B, T, dim_pose), clip_denoised=False, + progress=False, model_kwargs=model_kwargs, eta=0, + **inference_kwargs + ) + + results = kwargs + if getattr(self.model, "post_process") is not None: + output = self.model.post_process(output) + + results['pred_motion'] = output + results = self.split_results(results) + + return results diff --git a/mogen/models/architectures/vae_architecture.py b/mogen/models/architectures/vae_architecture.py new file mode 100644 index 0000000000000000000000000000000000000000..9a1f4455194a43e9fcb2f686f13bed2c00aa438e --- /dev/null +++ b/mogen/models/architectures/vae_architecture.py @@ -0,0 +1,112 @@ +import torch + +from ..builder import ARCHITECTURES, build_loss, build_submodule +from .base_architecture import BaseArchitecture + + +@ARCHITECTURES.register_module() +class PoseVAE(BaseArchitecture): + + def __init__(self, + encoder=None, + decoder=None, + loss_recon=None, + kl_div_loss_weight=None, + init_cfg=None, + **kwargs): + super().__init__(init_cfg=init_cfg, **kwargs) + self.encoder = build_submodule(encoder) + self.decoder = build_submodule(decoder) + self.loss_recon = build_loss(loss_recon) + self.kl_div_loss_weight = kl_div_loss_weight + + def reparameterize(self, mu, logvar): + std = torch.exp(logvar / 2) + + eps = std.data.new(std.size()).normal_() + latent_code = eps.mul(std).add_(mu) + return latent_code + + def encode(self, pose): + mu, logvar = self.encoder(pose) + return mu + + def forward(self, **kwargs): + motion = kwargs['motion'].float() + B, T = motion.shape[:2] + pose = motion.reshape(B * T, -1) + pose = pose[:, :-4] + + mu, logvar = self.encoder(pose) + z = self.reparameterize(mu, logvar) + pred = self.decoder(z) + + loss = dict() + recon_loss = self.loss_recon(pred, pose, reduction_override='none') + loss['recon_loss'] = recon_loss + if self.kl_div_loss_weight is not None: + loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight) + + return loss + + +@ARCHITECTURES.register_module() +class MotionVAE(BaseArchitecture): + + def __init__(self, + encoder=None, + decoder=None, + loss_recon=None, + kl_div_loss_weight=None, + init_cfg=None, + **kwargs): + super().__init__(init_cfg=init_cfg, **kwargs) + self.encoder = build_submodule(encoder) + self.decoder = build_submodule(decoder) + self.loss_recon = build_loss(loss_recon) + self.kl_div_loss_weight = kl_div_loss_weight + + def sample(self, std=1, latent_code=None): + if latent_code is not None: + z = latent_code + else: + z = torch.randn(1, 7, self.decoder.latent_dim).cuda() * std + output = self.decoder(z) + if self.use_normalization: + output = output * self.motion_std + output = output + self.motion_mean + return output + + def reparameterize(self, mu, logvar): + std = torch.exp(logvar / 2) + + eps = std.data.new(std.size()).normal_() + latent_code = eps.mul(std).add_(mu) + return latent_code + + def encode(self, motion, motion_mask): + mu, logvar = self.encoder(motion, motion_mask) + return self.reparameterize(mu, logvar) + + def decode(self, z, motion_mask): + return self.decoder(z, motion_mask) + + def forward(self, **kwargs): + motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask'] + B, T = motion.shape[:2] + + mu, logvar = self.encoder(motion, motion_mask) + z = self.reparameterize(mu, logvar) + pred = self.decoder(z, motion_mask) + + loss = dict() + recon_loss = self.loss_recon(pred, motion, reduction_override='none') + recon_loss = recon_loss.mean(dim=-1) * motion_mask + recon_loss = recon_loss.sum() / motion_mask.sum() + loss['recon_loss'] = recon_loss + if self.kl_div_loss_weight is not None: + loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight) + + return loss diff --git a/mogen/models/attentions/__init__.py b/mogen/models/attentions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c0d866087894afa068a6a51533d38eb4303d6ac0 --- /dev/null +++ b/mogen/models/attentions/__init__.py @@ -0,0 +1,15 @@ +from .base_attention import BaseMixedAttention +from .efficient_attention import (EfficientCrossAttention, + EfficientMixedAttention, + EfficientSelfAttention) +from .fine_attention import SAMI +from .art_attention import ArtAttention +from .semantics_modulated import (DualSemanticsModulatedAttention, + SemanticsModulatedAttention) + +__all__ = [ + 'EfficientSelfAttention', 'EfficientCrossAttention', + 'EfficientMixedAttention', 'SemanticsModulatedAttention', + 'DualSemanticsModulatedAttention', 'BaseMixedAttention', 'SAMI', + 'ArtAttention' +] diff --git a/mogen/models/attentions/art_attention.py b/mogen/models/attentions/art_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..af8f9b511e03f57e4d239809dba39a3f5a95f3c3 --- /dev/null +++ b/mogen/models/attentions/art_attention.py @@ -0,0 +1,476 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Dict, Any + +from ..builder import ATTENTIONS +from ..utils.stylization_block import StylizationBlock + +from tutel import moe as tutel_moe +from tutel import net + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zero out the parameters of a module and return it. + + Args: + module (nn.Module): The input PyTorch module. + + Returns: + nn.Module: The module with zeroed parameters. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class MOE(nn.Module): + """ + Mixture of Experts (MoE) layer with support for time embeddings and optional framerate conditioning. + + Args: + num_experts (int): Number of experts in the MoE layer. + topk (int): Number of top experts selected per input. + input_dim (int): Dimensionality of the input features. + ffn_dim (int): Dimensionality of the feed-forward network (FFN) used inside each expert. + output_dim (int): Dimensionality of the output features. + num_heads (int): Number of attention heads used in the model. + max_seq_len (int): Maximum sequence length for the input data. + gate_type (str): Type of gating mechanism used for MoE (e.g., "topk"). + gate_noise (float): Noise added to the gating mechanism for improved exploration. + framerate (bool, optional): Whether to use framerate-based embedding. Defaults to False. + embedding (bool, optional): Whether to use positional embeddings. Defaults to True. + + Attributes: + proj (nn.Linear): Linear projection layer applied after MoE processing. + activation (nn.GELU): Activation function used in the feed-forward layers. + model (tutel_moe.moe_layer): The Mixture of Experts layer. + embedding (torch.nn.Parameter): Positional or framerate-based embedding for input data. + aux_loss (torch.Tensor): Auxiliary loss from MoE layer for load balancing across experts. + """ + + def __init__(self, num_experts: int, topk: int, input_dim: int, ffn_dim: int, output_dim: int, + num_heads: int, max_seq_len: int, gate_type: str, gate_noise: float, embedding: bool = True): + super().__init__() + + # Linear projection layer to project from input_dim to output_dim + self.proj = nn.Linear(input_dim, output_dim) + # Activation function (GELU) + self.activation = nn.GELU() + + # Initialize Tutel MoE layer with gating and expert setup + try: + data_group = net.create_groups_from_world(group_count=1).data_group + except: + data_group = None + + self.model = tutel_moe.moe_layer( + gate_type={ + 'type': gate_type, + 'k': topk, + 'fp32_gate': True, + 'gate_noise': gate_noise, + 'capacity_factor': 1.5 # Capacity factor to allow extra room for expert routing + }, + experts={ + 'type': 'ffn', # Feed-forward expert type + 'count_per_node': num_experts, + 'hidden_size_per_expert': ffn_dim, + 'activation_fn': lambda x: F.gelu(x) # Activation inside experts + }, + model_dim=input_dim, + batch_prioritized_routing=True, # Prioritize routing based on batch size + is_gshard_loss=False, # Whether to use GShard loss for load balancing + group=data_group + ) + + # Determine whether to use positional embedding or framerate embedding + self.use_embedding = embedding + if self.use_embedding: + self.embedding = nn.Parameter(torch.randn(1, max_seq_len, num_heads, input_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the MoE layer with optional framerate embedding. + + Args: + x (torch.Tensor): Input tensor of shape (B, T, H, D), where + B is the batch size, + T is the sequence length, + H is the number of attention heads, + D is the dimensionality of each head. + + Returns: + torch.Tensor: Output tensor of shape (B, T, H, output_dim), where output_dim is the projected dimensionality. + """ + B, T, H, D = x.shape + + # Apply positional or framerate-based embedding + if self.use_embedding: + # Default positional embedding + x = x + self.embedding[:, :T, :, :] + + # Flatten the input for MoE processing + x = x.reshape(-1, D) + + # Pass through the Mixture of Experts layer and apply the projection + y = self.proj(self.activation(self.model(x))) + + # Auxiliary loss for expert load balancing + self.aux_loss = self.model.l_aux + + # Reshape the output back to (B, T, H, output_dim) + y = y.reshape(B, T, H, -1) + + return y + + +def get_ffn(latent_dim: int, ffn_dim: int) -> nn.Sequential: + """ + Create a feed-forward network (FFN) block. + + Args: + latent_dim (int): Input/output dimension of the FFN. + ffn_dim (int): Hidden dimension of the FFN. + + Returns: + nn.Sequential: A sequential block consisting of two linear layers and a GELU activation in between. + """ + return nn.Sequential(nn.Linear(latent_dim, ffn_dim), nn.GELU(), nn.Linear(ffn_dim, latent_dim)) + + +@ATTENTIONS.register_module() +class ArtAttention(nn.Module): + """ + ArtAttention module for attending to multi-modal inputs (e.g., text, music, speech, video) + and generating time-dependent motion features using a Mixture of Experts (MoE) mechanism. + + Args: + latent_dim (int): Dimensionality of the latent representation. + num_heads (int): Number of attention heads. + num_experts (int): Number of experts in the Mixture of Experts. + topk (int): Number of top experts selected by the gating mechanism. + gate_type (str): Type of gating mechanism for the MoE layer. + gate_noise (float): Noise level for the gating mechanism. + ffn_dim (int): Dimensionality of the feed-forward network inside the MoE. + time_embed_dim (int): Dimensionality of the time embedding for stylization. + max_seq_len (int): Maximum length of the motion sequence. + dropout (float): Dropout rate applied to the output of the MoE and attention layers. + motion_moe_dropout (float): Dropout rate applied to the motion MoE. + has_text (bool): Whether the input includes text features. + has_music (bool): Whether the input includes music features. + has_speech (bool): Whether the input includes speech features. + has_video (bool): Whether the input includes video features. + norm (str): Type of normalization layer to use ('LayerNorm' or 'RMSNorm'). + + Inputs: + - x (torch.Tensor): Tensor of shape (B, T, D), where B is the batch size, + T is the sequence length, and D is the dimensionality of the input motion data. + - emb (torch.Tensor): Time embedding for stylization, of shape (B, T, time_embed_dim). + - src_mask (torch.Tensor): Mask for the input data, of shape (B, T). + - motion_length (torch.Tensor): Tensor of shape (B,) representing the motion length. + - num_intervals (int): Number of intervals for processing the motion data. + - text_cond (torch.Tensor, optional): Conditioning mask for text data, of shape (B, 1). + - text_word_out (torch.Tensor, optional): Word features for text, of shape (B, M, latent_dim). + - music_cond (torch.Tensor, optional): Conditioning mask for music data, of shape (B, 1). + - music_word_out (torch.Tensor, optional): Word features for music, of shape (B, M, latent_dim). + - speech_cond (torch.Tensor, optional): Conditioning mask for speech data, of shape (B, 1). + - speech_word_out (torch.Tensor, optional): Word features for speech, of shape (B, M, latent_dim). + - video_cond (torch.Tensor, optional): Conditioning mask for video data, of shape (B, 1). + - video_word_out (torch.Tensor, optional): Word features for video, of shape (B, M, latent_dim). + - duration (torch.Tensor, optional): Duration of each motion sequence, of shape (B,). + + Outputs: + - y (torch.Tensor): The final attended output, with the same shape as input x (B, T, D). + """ + def __init__(self, + latent_dim, + num_heads, + num_experts, + topk, + gate_type, + gate_noise, + ffn_dim, + time_embed_dim, + max_seq_len, + dropout, + num_datasets, + has_text=False, + has_music=False, + has_speech=False, + has_video=False, + norm="LayerNorm"): + super().__init__() + self.latent_dim = latent_dim + self.num_heads = num_heads + self.max_seq_len = max_seq_len + + # Choose normalization type + if norm == "LayerNorm": + Norm = nn.LayerNorm + + # Parameters for time-related functions + self.sigma = nn.Parameter(torch.Tensor([100])) # Sigma for softmax-based time weighting + self.time = torch.arange(max_seq_len) + + # Normalization for motion features + self.norm = Norm(latent_dim * 10) + + # MoE for motion data + self.motion_moe = MOE(num_experts, topk, latent_dim, latent_dim * 4, + 5 * latent_dim, num_heads, max_seq_len, + gate_type, gate_noise) + self.motion_moe_dropout = nn.Dropout(p=dropout) # Dropout for motion MoE + self.key_motion_scale = nn.Parameter(torch.Tensor([1.0])) + + # Default keys and values + self.num_datasets = num_datasets + self.key_dataset = nn.Parameter(torch.randn(num_datasets, 48, 10, latent_dim)) + self.key_dataset_scale = nn.Parameter(torch.Tensor([1.0])) + self.value_dataset = nn.Parameter(torch.randn(num_datasets, 48, 10, latent_dim)) + + self.key_rotation = nn.Parameter(torch.randn(3, 16, 10, latent_dim)) + self.value_rotation = nn.Parameter(torch.randn(3, 16, 10, latent_dim)) + self.key_rotation_scale = nn.Parameter(torch.Tensor([1.0])) + + # Conditional MoE layers for each modality (if applicable) + self.has_text = has_text + self.has_music = has_music + self.has_speech = has_speech + self.has_video = has_video + + if has_text or has_music or has_speech or has_video: + self.cond_moe = MOE(num_experts, topk, latent_dim, latent_dim * 4, + 2 * latent_dim, num_heads, max_seq_len, + gate_type, gate_noise, embedding=False) + if has_text: + self.norm_text = Norm(latent_dim * 10) + self.key_text_scale = nn.Parameter(torch.Tensor([1.0])) + if has_music: + self.norm_music = Norm(latent_dim * 10) + self.key_music_scale = nn.Parameter(torch.Tensor([1.0])) + if has_speech: + self.norm_speech = Norm(latent_dim * 10) + self.key_speech_scale = nn.Parameter(torch.Tensor([1.0])) + if has_video: + self.norm_video = Norm(latent_dim * 10) + self.key_video_scale = nn.Parameter(torch.Tensor([1.0])) + + # Template functions for Taylor expansion (state, velocity, acceleration, jerk) + self.template_s = get_ffn(latent_dim, ffn_dim) + self.template_v = get_ffn(latent_dim, ffn_dim) + self.template_a = get_ffn(latent_dim, ffn_dim) + self.template_j = get_ffn(latent_dim, ffn_dim) + self.template_t = nn.Sequential(nn.Linear(latent_dim, ffn_dim), + nn.GELU(), nn.Linear(ffn_dim, 1)) + self.t_sigma = nn.Parameter(torch.Tensor([1])) # Sigma for Taylor expansion + + # Final projection with stylization block + self.proj_out = StylizationBlock(latent_dim * num_heads, + time_embed_dim, dropout) + + def forward(self, + x: torch.Tensor, + emb: torch.Tensor, + src_mask: torch.Tensor, + motion_length: torch.Tensor, + num_intervals: int, + text_cond: Optional[torch.Tensor] = None, + text_word_out: Optional[torch.Tensor] = None, + music_cond: Optional[torch.Tensor] = None, + music_word_out: Optional[torch.Tensor] = None, + speech_cond: Optional[torch.Tensor] = None, + speech_word_out: Optional[torch.Tensor] = None, + video_cond: Optional[torch.Tensor] = None, + video_word_out: Optional[torch.Tensor] = None, + duration: Optional[torch.Tensor] = None, + dataset_idx: Optional[torch.Tensor] = None, + rotation_idx: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """ + Forward pass for the ArtAttention module, handling multi-modal inputs. + + Args: + x (torch.Tensor): Input motion data of shape (B, T, D). + emb (torch.Tensor): Time embedding for stylization. + src_mask (torch.Tensor): Source mask for the input data. + motion_length (torch.Tensor): Length of the motion data. + num_intervals (int): Number of intervals for motion data. + text_cond (torch.Tensor, optional): Conditioning mask for text data. + text_word_out (torch.Tensor, optional): Text word output features. + music_cond (torch.Tensor, optional): Conditioning mask for music data. + music_word_out (torch.Tensor, optional): Music word output features. + speech_cond (torch.Tensor, optional): Conditioning mask for speech data. + speech_word_out (torch.Tensor, optional): Speech word output features. + video_cond (torch.Tensor, optional): Conditioning mask for video data. + video_word_out (torch.Tensor, optional): Video word output features. + duration (torch.Tensor, optional): Duration of each motion sequence. + + Returns: + y (torch.Tensor): The attended multi-modal motion features. + """ + + B, T, D = x.shape # Batch size (B), Time steps (T), Feature dimension (D) + H = self.num_heads + L = self.latent_dim + + # Pass motion data through MoE + motion_feat = self.motion_moe(self.norm(x).reshape(B, T, H, -1)) + motion_feat = self.motion_moe_dropout(motion_feat) + + # Reshape motion data for attention + x = x.reshape(B, T, H, -1) + + # Apply source mask and compute attention over motion features + src_mask = src_mask.view(B, T, H, 1) + body_value = motion_feat[:, :, :, :L] * src_mask + body_key = motion_feat[:, :, :, L: 2 * L] + (1 - src_mask) * -1000000 + body_key = F.softmax(body_key, dim=2) + body_query = F.softmax(motion_feat[:, :, :, 2 * L: 3 * L], dim=-1) + body_attention = torch.einsum('bnhd,bnhl->bndl', body_key, body_value) + body_feat = torch.einsum('bndl,bnhd->bnhl', body_attention, body_query) + body_feat = body_feat.reshape(B, T, D) + + # Key and value attention for motion + key_motion = motion_feat[:, :, :, 3 * L: 4 * L].contiguous() + key_motion = key_motion.view(B, T, H, -1) + key_motion = (key_motion + (1 - src_mask) * -1000000) / self.key_motion_scale + + value_motion = motion_feat[:, :, :, 4 * L:].contiguous() * src_mask + value_motion = value_motion.view(B, T, H, -1) + + # Process multi-modal conditioning (text, music, speech, video) + key_dataset = self.key_dataset.index_select(0, dataset_idx) / self.key_dataset_scale + value_dataset = self.value_dataset.index_select(0, dataset_idx) + key_rotation = self.key_rotation.index_select(0, rotation_idx) / self.key_rotation_scale + value_rotation = self.value_rotation.index_select(0, rotation_idx) + key = torch.cat((key_motion, key_dataset, key_rotation), dim=1) + value = torch.cat((value_motion, value_dataset, value_rotation), dim=1) + N = 64 + if self.has_text and text_word_out is not None and torch.sum(text_cond) > 0: + M = text_word_out.shape[1] + text_feat = self.norm_text(text_word_out).reshape(B, M, H, -1) + text_feat = self.cond_moe(text_feat) + key_text = text_feat[:, :, :, :L].contiguous() + key_text = key_text + (1 - text_cond.view(B, 1, 1, 1)) * -1000000 + key_text = key_text / self.key_text_scale + key = torch.cat((key, key_text), dim=1) + value_text = text_feat[:, :, :, L:].contiguous() + value_text = value_text * text_cond.view(B, 1, 1, 1) + value = torch.cat((value, value_text), dim=1) + N += M + + if self.has_music and music_word_out is not None and torch.sum(music_cond) > 0: + M = music_word_out.shape[1] + music_feat = self.norm_music(music_word_out).reshape(B, M, H, -1) + music_feat = self.cond_moe(music_feat) + key_music = music_feat[:, :, :, :L].contiguous() + key_music = key_music + (1 - music_cond.view(B, 1, 1, 1)) * -1000000 + key_music = key_music / self.key_music_scale + key = torch.cat((key, key_music), dim=1) + value_music = music_feat[:, :, :, L:].contiguous() + value_music = value_music * music_cond.view(B, 1, 1, 1) + value = torch.cat((value, value_music), dim=1) + N += M + + if self.has_speech and speech_word_out is not None and torch.sum(speech_cond) > 0: + M = speech_word_out.shape[1] + speech_feat = self.norm_speech(speech_word_out).reshape(B, M, H, -1) + speech_feat = self.cond_moe(speech_feat) + key_speech = speech_feat[:, :, :, :L].contiguous() + key_speech = key_speech + (1 - speech_cond.view(B, 1, 1, 1)) * -1000000 + key_speech = key_speech / self.key_speech_scale + key = torch.cat((key, key_speech), dim=1) + value_speech = speech_feat[:, :, :, L:].contiguous() + value_speech = value_speech * speech_cond.view(B, 1, 1, 1) + value = torch.cat((value, value_speech), dim=1) + N += M + + if self.has_video and video_word_out is not None and torch.sum(video_cond) > 0: + M = video_word_out.shape[1] + video_feat = self.norm_video(video_word_out).reshape(B, M, H, -1) + video_feat = self.cond_moe(video_feat) + key_video = video_feat[:, :, :, :L].contiguous() + key_video = key_video + (1 - video_cond.view(B, 1, 1, 1)) * -1000000 + key_video = key_video + (1 - src_mask) * -1000000 + key_video = key_video / self.key_video_scale + key = torch.cat((key, key_video), dim=1) + value_video = video_feat[:, :, :, L:].contiguous() + value_video = value_video * video_cond.view(B, 1, 1, 1) * src_mask + value= torch.cat((value, value_video), dim=1) + N += M + + key = F.softmax(key, dim=1) + # B, H, d, l + template = torch.einsum('bnhd,bnhl->bhdl', key, value) + template_t_feat = self.template_t(template) + template_t = torch.sigmoid(template_t_feat / self.t_sigma) + template_t = template_t * motion_length.view(B, 1, 1, 1) + template_t = template_t * duration.view(B, 1, 1, 1) + org_t = self.time[:T].type_as(x) + + # Handle time-based calculations + NI = num_intervals + t = org_t.clone().view(1, 1, -1, 1, 1).repeat(B // NI, NI, 1, 1, 1) + t = t * duration.view(B // NI, NI, 1, 1, 1) + template_t = template_t.view(-1, NI, H, L) + motion_length = motion_length.view(-1, NI) + for b_ix in range(B // NI): + sum_frames = 0 + for i in range(NI): + t[b_ix, i] += sum_frames * float(duration[b_ix]) + template_t[b_ix, i] += sum_frames * float(duration[b_ix]) + sum_frames += motion_length[b_ix, i] + template_t = template_t.permute(0, 2, 1, 3) + template_t = template_t.unsqueeze(1).repeat(1, NI, 1, 1, 1) + template_t = template_t.reshape(B, 1, H, -1) + time_delta = t.view(B, -1, 1, 1) - template_t + time_sqr = time_delta * time_delta + time_coef = F.softmax(-time_sqr, dim=-1) + + template = template.view(-1, NI, H, L, L) + template = template.permute(0, 2, 1, 3, 4).unsqueeze(1) + template = template.repeat(1, NI, 1, 1, 1, 1) + template = template.reshape(B, H, -1, L) + + # Taylor expansion for motion + template_s = template + self.template_s(template) # state + template_v = template + self.template_v(template) # velocity + template_a = template + self.template_a(template) # acceleration + template_j = template + self.template_j(template) # jerk + template_t = template_t.view(B, H, -1, 1) + template_a0 = template_s - template_v * template_t + \ + template_a * template_t * template_t - \ + template_j * template_t * template_t * template_t + template_a1 = template_v - 2 * template_a * template_t + \ + 3 * template_j * template_t * template_t + template_a2 = template_a - 3 * template_j * template_t + template_a3 = template_j + a0 = torch.einsum('bnhd,bhdl->bnhl', time_coef, + template_a0).reshape(B, T, D) + a1 = torch.einsum('bnhd,bhdl->bnhl', time_coef, + template_a1).reshape(B, T, D) + a2 = torch.einsum('bnhd,bhdl->bnhl', time_coef, + template_a2).reshape(B, T, D) + a3 = torch.einsum('bnhd,bhdl->bnhl', time_coef, + template_a3).reshape(B, T, D) + t = t.view(B, -1, 1) + y_t = a0 + a1 * t + a2 * t * t + a3 * t * t * t + y_s = body_feat + y = x.reshape(B, T, D) + self.proj_out(y_s + y_t, emb) + + if self.training: + # Add auxiliary losses during training + self.aux_loss = self.motion_moe.aux_loss + if self.has_text or self.has_music or self.has_speech or self.has_video: + if hasattr(self.cond_moe, 'aux_loss') and self.cond_moe.aux_loss is not None: + self.aux_loss += self.cond_moe.aux_loss + self.cond_moe.aux_loss = None + mu = template_t_feat.squeeze(-1).mean(dim=-1) + logvar = torch.log(template_t_feat.squeeze(-1).std(dim=-1)) + logvar[logvar > 1000000] = 0 + logvar[logvar < -1000000] = 0 + self.kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + + return y diff --git a/mogen/models/attentions/base_attention.py b/mogen/models/attentions/base_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..99bae7fc5c6e70f036d8af6bac1df579dc353841 --- /dev/null +++ b/mogen/models/attentions/base_attention.py @@ -0,0 +1,216 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Dict, Any + +from ..builder import ATTENTIONS +from ..utils.stylization_block import StylizationBlock + + +@ATTENTIONS.register_module() +class BaseMixedAttention(nn.Module): + """ + Base class for Mixed Attention, combining text and motion attention. + + Args: + latent_dim (int): Dimension of the latent space for motion input. + text_latent_dim (int): Dimension of the latent space for text input. + num_heads (int): Number of attention heads. + dropout (float): Dropout probability. + time_embed_dim (int): Dimension of the time embedding. + """ + + def __init__(self, latent_dim: int, text_latent_dim: int, num_heads: int, dropout: float, time_embed_dim: int): + super().__init__() + self.num_heads = num_heads + + self.norm = nn.LayerNorm(latent_dim) + self.text_norm = nn.LayerNorm(text_latent_dim) + + self.query = nn.Linear(latent_dim, latent_dim) + self.key_text = nn.Linear(text_latent_dim, latent_dim) + self.value_text = nn.Linear(text_latent_dim, latent_dim) + self.key_motion = nn.Linear(latent_dim, latent_dim) + self.value_motion = nn.Linear(latent_dim, latent_dim) + + self.dropout = nn.Dropout(dropout) + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, x: torch.Tensor, xf: torch.Tensor, emb: torch.Tensor, src_mask: torch.Tensor, + cond_type: torch.Tensor, **kwargs: Dict[str, Any]) -> torch.Tensor: + """ + Forward pass of Mixed Attention. + + Args: + x (torch.Tensor): Input motion tensor of shape [B, T, D]. + xf (torch.Tensor): Input text tensor of shape [B, N, L]. + emb (torch.Tensor): Time embedding tensor of shape [B, D]. + src_mask (torch.Tensor): Source mask tensor of shape [B, T]. + cond_type (torch.Tensor): Conditioning type tensor of shape [B]. + + Returns: + torch.Tensor: Output of the mixed attention module. + """ + B, T, D = x.shape + N = xf.shape[1] + x.shape[1] + H = self.num_heads + + query = self.query(self.norm(x)).view(B, T, H, -1) + + # Text conditioning type + text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1) + text_cond_type = text_cond_type.repeat(1, xf.shape[1], 1) + + key = torch.cat( + (self.key_text(self.text_norm(xf)), self.key_motion(self.norm(x))), + dim=1 + ).view(B, N, H, -1) + + attention = torch.einsum('bnhl,bmhl->bnmh', query, key) + + motion_mask = src_mask.view(B, 1, T, 1) + text_mask = text_cond_type.view(B, 1, -1, 1) + mask = torch.cat((text_mask, motion_mask), dim=2) + attention = attention + (1 - mask) * -1000000 # Masking for softmax + attention = F.softmax(attention, dim=2) + + value = torch.cat( + (self.value_text(self.text_norm(xf)) * text_cond_type, self.value_motion(self.norm(x)) * src_mask), + dim=1 + ).view(B, N, H, -1) + + y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D) + y = x + self.proj_out(y, emb) + + return y + + +@ATTENTIONS.register_module() +class BaseSelfAttention(nn.Module): + """ + Base class for Self-Attention mechanism. + + Args: + latent_dim (int): Dimension of the latent space. + num_heads (int): Number of attention heads. + dropout (float): Dropout probability. + time_embed_dim (Optional[int]): Dimension of the time embedding (optional). + """ + + def __init__(self, latent_dim: int, num_heads: int, dropout: float, time_embed_dim: Optional[int] = None): + super().__init__() + self.num_heads = num_heads + + self.norm = nn.LayerNorm(latent_dim) + self.query = nn.Linear(latent_dim, latent_dim) + self.key = nn.Linear(latent_dim, latent_dim) + self.value = nn.Linear(latent_dim, latent_dim) + + self.dropout = nn.Dropout(dropout) + self.time_embed_dim = time_embed_dim + if time_embed_dim is not None: + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, x: torch.Tensor, src_mask: Optional[torch.Tensor] = None, emb: Optional[torch.Tensor] = None, **kwargs: Dict[str, Any]) -> torch.Tensor: + """ + Forward pass of Self-Attention. + + Args: + x (torch.Tensor): Input tensor of shape [B, T, D]. + emb (torch.Tensor): Time embedding tensor of shape [B, D]. + src_mask (torch.Tensor): Source mask tensor of shape [B, T]. + + Returns: + torch.Tensor: Output of the self-attention module. + """ + B, T, D = x.shape + H = self.num_heads + + query = self.query(self.norm(x)).view(B, T, H, -1) + key = self.key(self.norm(x)).view(B, T, H, -1) + + attention = torch.einsum('bnhl,bmhl->bnmh', query, key) + if src_mask is not None: + mask = src_mask.view(B, 1, T, 1) + attention = attention + (1 - mask) * -1000000 # Masking for softmax + attention = F.softmax(attention, dim=2) + + if src_mask is not None: + value = (self.value(self.norm(x)) * src_mask).view(B, T, H, -1) + else: + value = self.value(self.norm(x)).view(B, T, H, -1) + y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D) + if self.time_embed_dim is None: + y = x + y + else: + y = x + self.proj_out(y, emb) + return y + + +@ATTENTIONS.register_module() +class BaseCrossAttention(nn.Module): + """ + Base class for Cross-Attention mechanism, attending over text and motion inputs. + + Args: + latent_dim (int): Dimension of the latent space for motion input. + text_latent_dim (int): Dimension of the latent space for text input. + num_heads (int): Number of attention heads. + dropout (float): Dropout probability. + time_embed_dim (int): Dimension of the time embedding. + """ + + def __init__(self, latent_dim: int, text_latent_dim: int, num_heads: int, dropout: float, time_embed_dim: int): + super().__init__() + self.num_heads = num_heads + + self.norm = nn.LayerNorm(latent_dim) + self.text_norm = nn.LayerNorm(text_latent_dim) + + self.query = nn.Linear(latent_dim, latent_dim) + self.key = nn.Linear(text_latent_dim, latent_dim) + self.value = nn.Linear(text_latent_dim, latent_dim) + + self.dropout = nn.Dropout(dropout) + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, x: torch.Tensor, xf: torch.Tensor, emb: torch.Tensor, src_mask: torch.Tensor, + cond_type: Optional[torch.Tensor] = None, **kwargs: Dict[str, Any]) -> torch.Tensor: + """ + Forward pass of Cross-Attention. + + Args: + x (torch.Tensor): Input motion tensor of shape [B, T, D]. + xf (torch.Tensor): Input text tensor of shape [B, N, L]. + emb (torch.Tensor): Time embedding tensor of shape [B, D]. + src_mask (torch.Tensor): Source mask tensor of shape [B, T]. + cond_type (Optional[torch.Tensor]): Conditioning type tensor of shape [B]. Defaults to None. + + Returns: + torch.Tensor: Output of the cross-attention module. + """ + B, T, D = x.shape + N = xf.shape[1] + H = self.num_heads + + query = self.query(self.norm(x)).view(B, T, H, -1) + + if cond_type is None: + text_cond_type = 1 + mask = 1 + else: + text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1) + text_cond_type = text_cond_type.repeat(1, xf.shape[1], 1) + mask = text_cond_type.view(B, 1, -1, 1) + + key = self.key(self.text_norm(xf)).view(B, N, H, -1) + + attention = torch.einsum('bnhl,bmhl->bnmh', query, key) + attention = attention + (1 - mask) * -1000000 # Masking for softmax + attention = F.softmax(attention, dim=2) + + value = (self.value(self.text_norm(xf)) * text_cond_type).view(B, N, H, -1) + y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D) + y = x + self.proj_out(y, emb) + + return y diff --git a/mogen/models/attentions/efficient_attention.py b/mogen/models/attentions/efficient_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f11a501c8164b916800b5f6b7f717a9639207f --- /dev/null +++ b/mogen/models/attentions/efficient_attention.py @@ -0,0 +1,246 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Dict, Any + +from ..builder import ATTENTIONS +from ..utils.stylization_block import StylizationBlock + + +@ATTENTIONS.register_module() +class EfficientSelfAttention(nn.Module): + """ + Efficient Self-Attention mechanism for motion generation tasks. + + Args: + latent_dim (int): Dimension of the latent space. + num_heads (int): Number of attention heads. + dropout (float): Dropout probability. + time_embed_dim (Optional[int]): Dimension of the time embedding (optional). + """ + + def __init__(self, + latent_dim: int, + num_heads: int, + dropout: float, + time_embed_dim: Optional[int] = None): + super().__init__() + self.num_heads = num_heads + self.norm = nn.LayerNorm(latent_dim) + self.query = nn.Linear(latent_dim, latent_dim) + self.key = nn.Linear(latent_dim, latent_dim) + self.value = nn.Linear(latent_dim, latent_dim) + self.dropout = nn.Dropout(dropout) + self.time_embed_dim = time_embed_dim + if time_embed_dim is not None: + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, + x: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + emb: Optional[torch.Tensor] = None, + **kwargs: Dict[str, Any]) -> torch.Tensor: + """ + Forward pass of Efficient Self-Attention. + + Args: + x (torch.Tensor): Input tensor of shape [B, T, D]. + src_mask (Optional[torch.Tensor]): Source mask of shape [B, T] (optional). + emb (Optional[torch.Tensor]): Time embedding tensor of shape [B, D] (optional). + + Returns: + torch.Tensor: Output of the self-attention module. + """ + B, T, D = x.shape + H = self.num_heads + + query = self.query(self.norm(x)) + + if src_mask is None: + key = self.key(self.norm(x)) + else: + key = self.key(self.norm(x)) + (1 - src_mask) * -1000000 + + query = F.softmax(query.view(B, T, H, -1), dim=-1) + key = F.softmax(key.view(B, T, H, -1), dim=1) + + if src_mask is None: + value = self.value(self.norm(x)).view(B, T, H, -1) + else: + value = (self.value(self.norm(x)) * src_mask).view(B, T, H, -1) + + attention = torch.einsum('bnhd,bnhl->bhdl', key, value) + y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) + + if self.time_embed_dim is None: + y = x + y + else: + y = x + self.proj_out(y, emb) + + return y + + +@ATTENTIONS.register_module() +class EfficientCrossAttention(nn.Module): + """ + Efficient Cross-Attention mechanism, attending to text and motion inputs. + + Args: + latent_dim (int): Dimension of the latent space for motion input. + text_latent_dim (int): Dimension of the latent space for text input. + num_heads (int): Number of attention heads. + dropout (float): Dropout probability. + time_embed_dim (int): Dimension of the time embedding. + """ + + def __init__(self, + latent_dim: int, + text_latent_dim: int, + num_heads: int, + dropout: float, + time_embed_dim: Optional[int] = None): + super().__init__() + self.num_heads = num_heads + self.norm = nn.LayerNorm(latent_dim) + self.text_norm = nn.LayerNorm(text_latent_dim) + self.query = nn.Linear(latent_dim, latent_dim) + self.key = nn.Linear(text_latent_dim, latent_dim) + self.value = nn.Linear(text_latent_dim, latent_dim) + self.dropout = nn.Dropout(dropout) + if time_embed_dim is not None: + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + else: + self.proj_out = None + + def forward(self, + x: torch.Tensor, + xf: torch.Tensor, + emb: Optional[torch.Tensor] = None, + cond_type: Optional[torch.Tensor] = None, + **kwargs: Dict[str, Any]) -> torch.Tensor: + """ + Forward pass of Efficient Cross-Attention. + + Args: + x (torch.Tensor): Input motion tensor of shape [B, T, D]. + xf (torch.Tensor): Input text tensor of shape [B, N, L]. + emb (torch.Tensor): Time embedding tensor of shape [B, D]. + cond_type (Optional[torch.Tensor]): Conditioning type tensor (optional). + + Returns: + torch.Tensor: Output of the cross-attention module. + """ + B, T, D = x.shape + N = xf.shape[1] + H = self.num_heads + + query = self.query(self.norm(x)) + + key = self.key(self.text_norm(xf)) + query = F.softmax(query.view(B, T, H, -1), dim=-1) + + if cond_type is None: + key = F.softmax(key.view(B, N, H, -1), dim=1) + value = self.value(self.text_norm(xf)).view(B, N, H, -1) + else: + text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1) + text_cond_type = text_cond_type.repeat(1, xf.shape[1], 1) + key = key + (1 - text_cond_type) * -1000000 + key = F.softmax(key.view(B, N, H, -1), dim=1) + value = self.value(self.text_norm(xf) * text_cond_type).view(B, N, H, -1) + + attention = torch.einsum('bnhd,bnhl->bhdl', key, value) + y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) + if self.proj_out is not None: + y = x + self.proj_out(y, emb) + else: + y = x + y + return y + + +@ATTENTIONS.register_module() +class EfficientMixedAttention(nn.Module): + """ + Efficient Mixed Attention, combining text and motion attention. + + Args: + latent_dim (int): Dimension of the latent space for motion input. + text_latent_dim (int): Dimension of the latent space for text input. + num_heads (int): Number of attention heads. + dropout (float): Dropout probability. + time_embed_dim (int): Dimension of the time embedding. + """ + + def __init__(self, + latent_dim: int, + text_latent_dim: int, + num_heads: int, + dropout: float, + time_embed_dim: Optional[int] = None): + super().__init__() + self.num_heads = num_heads + self.norm = nn.LayerNorm(latent_dim) + self.text_norm = nn.LayerNorm(text_latent_dim) + + self.query = nn.Linear(latent_dim, latent_dim) + self.key_text = nn.Linear(text_latent_dim, latent_dim) + self.value_text = nn.Linear(text_latent_dim, latent_dim) + self.key_motion = nn.Linear(latent_dim, latent_dim) + self.value_motion = nn.Linear(latent_dim, latent_dim) + + self.dropout = nn.Dropout(dropout) + if time_embed_dim is not None: + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + else: + self.proj_out = None + + def forward(self, + x: torch.Tensor, + xf: torch.Tensor, + src_mask: torch.Tensor, + emb: Optional[torch.Tensor] = None, + cond_type: Optional[torch.Tensor] = None, + **kwargs: Dict[str, Any]) -> torch.Tensor: + """ + Forward pass of Efficient Mixed Attention. + + Args: + x (torch.Tensor): Input motion tensor of shape [B, T, D]. + xf (torch.Tensor): Input text tensor of shape [B, N, L]. + emb (torch.Tensor): Time embedding tensor of shape [B, D]. + src_mask (torch.Tensor): Source mask tensor of shape [B, T]. + cond_type (torch.Tensor): Conditioning type tensor. + + Returns: + torch.Tensor: Output of the mixed attention module. + """ + B, T, D = x.shape + N = xf.shape[1] + x.shape[1] + H = self.num_heads + + query = self.query(self.norm(x)).view(B, T, H, -1) + + text_cond_type = (cond_type % 10 > 0).float() + src_mask = src_mask.view(B, T, 1) + + key_text = self.key_text(self.text_norm(xf)) + key_text = key_text + (1 - text_cond_type) * -1000000 + key_motion = self.key_motion(self.norm(x)) + (1 - src_mask) * -1000000 + key = torch.cat((key_text, key_motion), dim=1) + + query = F.softmax(query.view(B, T, H, -1), dim=-1) + key = self.dropout(F.softmax(key.view(B, N, H, -1), dim=1)) + + value = torch.cat( + (self.value_text(self.text_norm(xf)) * text_cond_type, self.value_motion(self.norm(x)) * src_mask), + dim=1 + ).view(B, N, H, -1) + + attention = torch.einsum('bnhd,bnhl->bhdl', key, value) + y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) + + if self.proj_out is not None: + y = x + self.proj_out(y, emb) + else: + y = x + y + return y diff --git a/mogen/models/attentions/fine_attention.py b/mogen/models/attentions/fine_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..536ad26f7f7593a1917425dd522a37f70ec8c339 --- /dev/null +++ b/mogen/models/attentions/fine_attention.py @@ -0,0 +1,328 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Dict, Any + +from ..builder import ATTENTIONS +from ..utils.stylization_block import StylizationBlock + +try: + from tutel import moe as tutel_moe + from tutel import net +except ImportError: + pass + + +class MOE(nn.Module): + """ + Mixture of Experts (MoE) layer implementation using the Tutel MoE library. + + Args: + num_experts (int): Number of experts. + topk (int): Number of top experts to route tokens to. + input_dim (int): Input dimension of the MoE layer. + ffn_dim (int): Feed-forward network dimension for each expert. + output_dim (int): Output dimension of the MoE layer. + num_heads (int): Number of attention heads. + max_seq_len (int): Maximum sequence length. + gate_type (str): Type of gating mechanism (e.g., 'top_k'). + gate_noise (float): Noise factor for the gating mechanism. + """ + + def __init__(self, num_experts: int, topk: int, input_dim: int, ffn_dim: int, output_dim: int, + num_heads: int, max_seq_len: int, gate_type: str, gate_noise: float): + super().__init__() + self.proj = nn.Linear(input_dim, output_dim) + self.activation = nn.GELU() + + try: + data_group = net.create_groups_from_world(group_count=1).data_group + except Exception: + data_group = None + + self.model = tutel_moe.moe_layer( + gate_type={ + 'type': gate_type, + 'k': topk, + 'fp32_gate': True, + 'gate_noise': gate_noise, + 'capacity_factor': 1.5 + }, + experts={ + 'type': 'ffn', + 'count_per_node': num_experts, + 'hidden_size_per_expert': ffn_dim, + 'activation_fn': lambda x: F.gelu(x) + }, + model_dim=input_dim, + batch_prioritized_routing=True, + is_gshard_loss=False, + group=data_group + ) + self.embedding = nn.Parameter(torch.randn(1, max_seq_len, num_heads, input_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the MOE layer. + + Args: + x (torch.Tensor): Input tensor of shape [B, T, H, D]. + + Returns: + torch.Tensor: Output tensor of shape [B, T, H, D]. + """ + B, T, H, D = x.shape + x = x + self.embedding[:, :T, :, :] + x = x.reshape(-1, D) + y = self.proj(self.activation(self.model(x))) + self.aux_loss = self.model.l_aux + y = y.reshape(B, T, H, -1) + return y + + +def get_ffn(latent_dim: int, ffn_dim: int) -> nn.Sequential: + """ + Create a feed-forward network (FFN) block. + + Args: + latent_dim (int): Input/output dimension of the FFN. + ffn_dim (int): Hidden dimension of the FFN. + + Returns: + nn.Sequential: A sequential block consisting of two linear layers and a GELU activation in between. + """ + return nn.Sequential(nn.Linear(latent_dim, ffn_dim), nn.GELU(), nn.Linear(ffn_dim, latent_dim)) + + +@ATTENTIONS.register_module() +class SAMI(nn.Module): + """ + SAMI: Self-Attention-based MoE Integration model for motion generation. + + Args: + latent_dim (int): Dimension of the latent space for motion input. + text_latent_dim (int): Dimension of the latent space for text input. + num_heads (int): Number of motion attention heads. + num_text_heads (int): Number of text attention heads. + num_experts (int): Number of experts for MoE. + topk (int): Number of top experts to route tokens to. + gate_type (str): Type of gating mechanism. + gate_noise (float): Noise factor for the gating mechanism. + ffn_dim (int): Dimension of the feed-forward network. + time_embed_dim (int): Dimension of the time embedding. + max_seq_len (int): Maximum sequence length for motion data. + max_text_seq_len (int): Maximum sequence length for text data. + dropout (float): Dropout probability. + norm (str): Type of normalization ('LayerNorm'). + att_balance (bool): Whether to balance attention weights between motion and text. + fine_mode (bool): Whether to use fine-grained features. + mask_cond (float): Masking condition for fine-tuning. + """ + + def __init__(self, + latent_dim: int, + text_latent_dim: int, + num_heads: int, + num_text_heads: int, + num_experts: int, + topk: int, + gate_type: str, + gate_noise: float, + ffn_dim: int, + time_embed_dim: int, + max_seq_len: int, + max_text_seq_len: int, + dropout: float, + norm: str = "LayerNorm", + att_balance: bool = False, + fine_mode: bool = False, + mask_cond: float = 0): + super().__init__() + self.latent_dim = latent_dim + self.num_heads = num_heads + self.num_text_heads = num_text_heads + self.max_seq_len = max_seq_len + + # Normalization + Norm = nn.LayerNorm + self.norm = Norm(latent_dim) + self.text_norm = Norm(text_latent_dim) + + # MoE Layers for motion and text + self.sigma = nn.Parameter(torch.Tensor([100])) + self.time = torch.arange(max_seq_len) / max_seq_len + self.text_moe = MOE(num_experts, topk, text_latent_dim, text_latent_dim * 4, 2 * latent_dim, + num_text_heads, max_text_seq_len, gate_type, gate_noise) + self.motion_moe = MOE(num_experts, topk, latent_dim, latent_dim * 4, 3 * latent_dim, + num_heads, max_seq_len, gate_type, gate_noise) + + # Key-motion and attention blocks + self.key_motion = nn.Parameter(torch.randn(max_seq_len, latent_dim)) + self.body_weight = nn.Parameter(torch.randn(num_heads, num_heads)) + + # Feedforward networks for state, velocity, acceleration, and jerk + self.template_s = get_ffn(latent_dim, ffn_dim) + self.template_v = get_ffn(latent_dim, ffn_dim) + self.template_a = get_ffn(latent_dim, ffn_dim) + self.template_j = get_ffn(latent_dim, ffn_dim) + + # Time embedding block + self.template_t = nn.Sequential(nn.Linear(latent_dim, ffn_dim), nn.GELU(), nn.Linear(ffn_dim, 1)) + self.t_sigma = nn.Parameter(torch.Tensor([1])) + + # Output projection + self.proj_out = StylizationBlock(latent_dim * num_heads, time_embed_dim, dropout) + self.att_balance = att_balance + if self.att_balance: + self.motion_coef = nn.Parameter(torch.Tensor([0])) + self.text_coef = nn.Parameter(torch.Tensor([0])) + + self.fine_mode = fine_mode + self.mask_cond = mask_cond + + def forward(self, x: torch.Tensor, xf: torch.Tensor, emb: torch.Tensor, src_mask: torch.Tensor, + cond_type: torch.Tensor, motion_length: torch.Tensor, num_intervals: int, **kwargs: Dict[str, Any]) -> torch.Tensor: + """ + Forward pass of SAMI. + + Args: + x (torch.Tensor): Motion input tensor of shape [B, T, D]. + xf (torch.Tensor): Text input tensor of shape [B, N, P]. + emb (torch.Tensor): Time embedding tensor. + src_mask (torch.Tensor): Source mask tensor of shape [B, T]. + cond_type (torch.Tensor): Conditioning type tensor of shape [B, ?]. + motion_length (torch.Tensor): Motion length tensor. + num_intervals (int): Number of intervals for the motion. + + Returns: + torch.Tensor: Output tensor after motion and text MoE integration. + """ + B, T, D = x.shape + N = xf.shape[1] + x.shape[1] + H = self.num_heads + L = self.latent_dim + + x = x.reshape(B, T, H, -1) + if self.fine_mode: + text_feat = xf.reshape(B, self.num_text_heads, xf.shape[1], xf.shape[2]).permute(0, 2, 1, 3) + else: + text_feat = xf.reshape(B, xf.shape[1], self.num_text_heads, -1) + + # MoE Layers for text and motion features + text_feat = self.text_moe(self.text_norm(text_feat)) + motion_feat = self.motion_moe(self.norm(x)) + + # Weighted combination of motion features + body_weight = F.softmax(self.body_weight, dim=1) + body_value = motion_feat[:, :, :, :L] + body_feat = torch.einsum('hl,bnld->bnhd', body_weight, body_value) + body_feat = body_feat.reshape(B, T, D) + + # Apply the source mask and combine key-text and key-motion + src_mask = src_mask.view(B, T, 1, 1) + key_text = text_feat[:, :, :, :L].contiguous() + + # Handle conditional types and masks + if self.fine_mode: + text_cond_type = torch.cat((cond_type[:, :7, :] % 10 > self.mask_cond, cond_type[:, 7:8, :] % 10 > 0), 1).float().unsqueeze(-1) + text_cond_type = text_cond_type.permute(0, 2, 1, 3) + text_cond_type = text_cond_type.repeat(1, key_text.shape[1], 1, 1) + else: + text_cond_type = (cond_type % 10 > 0).float().unsqueeze(-1) + + key_text = key_text + (1 - text_cond_type) * -1000000 + if self.num_text_heads == 1: + key_text = key_text.repeat(1, 1, H, 1) + + key_motion = motion_feat[:, :, :, L:2 * L].contiguous() + key_motion = key_motion + (1 - src_mask) * -1000000 + + # Attention balance between motion and text + if self.att_balance: + motion_coef = torch.sigmoid(self.motion_coef) + text_coef = torch.sigmoid(self.text_coef) + key_motion = F.softmax(key_motion, dim=1) * motion_coef + key_text = F.softmax(key_text, dim=1) * text_coef + sum_coef = motion_coef.repeat(B) + text_coef.repeat(B) * text_cond_type.view(B) + sum_coef = sum_coef.view(B, 1, 1, 1) + key_motion = key_motion / sum_coef + key_text = key_text / sum_coef + key = torch.cat((key_text, key_motion), dim=1) + else: + key = torch.cat((key_text, key_motion), dim=1) + key = F.softmax(key.view(B, N, H, -1), dim=1) + + # Value combination for text and motion + value_text = text_feat[:, :, :, L:].contiguous() * text_cond_type + if self.num_text_heads == 1: + value_text = value_text.repeat(1, 1, H, 1) + value_motion = motion_feat[:, :, :, 2 * L:].contiguous() * src_mask + value = torch.cat((value_text, value_motion), dim=1).view(B, N, H, -1) + + # Calculate the attention-weighted template + template = torch.einsum('bnhd,bnhl->bhdl', key, value) + template_t_feat = self.template_t(template) + template_t = torch.sigmoid(template_t_feat / self.t_sigma) + template_t = template_t * motion_length.view(B, 1, 1, 1) + template_t = template_t / self.max_seq_len + + org_t = self.time[:T].type_as(x) + + # Handle time intervals for the motion + NI = num_intervals + t = org_t.clone().view(1, 1, -1, 1, 1).repeat(B // NI, NI, 1, 1, 1) + template_t = template_t.view(-1, NI, H, L) + motion_length = motion_length.view(-1, NI) + for b_ix in range(B // NI): + sum_frames = 0 + for i in range(NI): + t[b_ix, i] += sum_frames / self.max_seq_len + template_t[b_ix, i] += sum_frames / self.max_seq_len + sum_frames += motion_length[b_ix, i] + template_t = template_t.permute(0, 2, 1, 3).unsqueeze(1).repeat(1, NI, 1, 1, 1) + template_t = template_t.reshape(B, 1, H, -1) + + time_delta = t.view(B, -1, 1, 1) - template_t + time_delta = time_delta * self.max_seq_len + time_sqr = time_delta * time_delta + time_coef = F.softmax(-time_sqr / self.sigma, dim=-1) + + # Reshape and repeat templates for Taylor expansion + template = template.view(-1, NI, H, L, L) + template = template.permute(0, 2, 1, 3, 4).unsqueeze(1) + template = template.repeat(1, NI, 1, 1, 1, 1) + template = template.reshape(B, H, -1, L) + + # Taylor expansion for state (s), velocity (v), acceleration (a), jerk (j) + template_s = self.template_s(template) + template_v = self.template_v(template) + template_a = self.template_a(template) + template_j = self.template_j(template) + + template_t = template_t.view(B, H, -1, 1) + template_a0 = template_s - template_v * template_t + template_a * template_t * template_t - template_j * template_t * template_t * template_t + template_a1 = template_v - 2 * template_a * template_t + 3 * template_j * template_t * template_t + template_a2 = template_a - 3 * template_j * template_t + template_a3 = template_j + + # Calculate the final time-dependent output using the Taylor expansion + a0 = torch.einsum('bnhd,bhdl->bnhl', time_coef, template_a0).reshape(B, T, D) + a1 = torch.einsum('bnhd,bhdl->bnhl', time_coef, template_a1).reshape(B, T, D) + a2 = torch.einsum('bnhd,bhdl->bnhl', time_coef, template_a2).reshape(B, T, D) + a3 = torch.einsum('bnhd,bhdl->bnhl', time_coef, template_a3).reshape(B, T, D) + + t = t.view(B, -1, 1) + y_t = a0 + a1 * t + a2 * t * t + a3 * t * t * t + + # Combine with body features and output the final result + y_s = body_feat + y = x.reshape(B, T, D) + self.proj_out(y_s + y_t, emb) + + if self.training: + self.aux_loss = self.text_moe.aux_loss + self.motion_moe.aux_loss + mu = template_t_feat.squeeze(-1).mean(dim=-1) + logvar = torch.log(template_t_feat.squeeze(-1).std(dim=-1)) + self.kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + + return y + diff --git a/mogen/models/attentions/semantics_modulated.py b/mogen/models/attentions/semantics_modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..85240be7d4476b176c4439c102790b16171ee47f --- /dev/null +++ b/mogen/models/attentions/semantics_modulated.py @@ -0,0 +1,246 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Dict, Any + +from ..builder import ATTENTIONS +from ..utils.stylization_block import StylizationBlock + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zero out the parameters of a module and return it. + + Args: + module (nn.Module): The input PyTorch module. + + Returns: + nn.Module: The module with zeroed parameters. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +@ATTENTIONS.register_module() +class SemanticsModulatedAttention(nn.Module): + """ + Semantics-modulated attention module that integrates motion, text, and retrieval features into attention computation. + + Args: + latent_dim (int): Dimensionality of the latent (motion) features. + text_latent_dim (int): Dimensionality of the text features. + num_heads (int): Number of attention heads. + dropout (float): Dropout rate. + time_embed_dim (int): Dimensionality of time embeddings. + """ + + def __init__(self, latent_dim: int, text_latent_dim: int, num_heads: int, dropout: float, time_embed_dim: int): + super().__init__() + self.num_heads = num_heads + + # Layer Normalization for motion and text features + self.norm = nn.LayerNorm(latent_dim) + self.text_norm = nn.LayerNorm(text_latent_dim) + + # Linear projections for queries, keys, and values + self.query = nn.Linear(latent_dim, latent_dim) + self.key_text = nn.Linear(text_latent_dim, latent_dim) + self.value_text = nn.Linear(text_latent_dim, latent_dim) + self.key_motion = nn.Linear(latent_dim, latent_dim) + self.value_motion = nn.Linear(latent_dim, latent_dim) + + # Retrieval feature processing (motion and text) + self.retr_norm1 = nn.LayerNorm(2 * latent_dim) + self.retr_norm2 = nn.LayerNorm(latent_dim) + self.key_retr = nn.Linear(2 * latent_dim, latent_dim) + self.value_retr = zero_module(nn.Linear(latent_dim, latent_dim)) + + # Dropout and output projection + self.dropout = nn.Dropout(dropout) + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, x: torch.Tensor, xf: torch.Tensor, emb: torch.Tensor, src_mask: torch.Tensor, + cond_type: torch.Tensor, re_dict: dict) -> torch.Tensor: + """ + Forward pass of SemanticsModulatedAttention. + + Args: + x (torch.Tensor): Motion features of shape (B, T, D). + xf (torch.Tensor): Text features of shape (B, N, L). + emb (torch.Tensor): Time embedding. + src_mask (torch.Tensor): Source mask for the input motion features. + cond_type (torch.Tensor): Condition type tensor. + re_dict (dict): Dictionary containing retrieval motion, text, and mask data. + + Returns: + torch.Tensor: Output tensor after attention modulation, shape (B, T, D). + """ + B, T, D = x.shape + re_motion = re_dict['re_motion'] + re_text = re_dict['re_text'] + re_mask = re_dict['re_mask'].reshape(B, -1, 1) + N = xf.shape[1] + x.shape[1] + re_motion.shape[1] * re_motion.shape[2] # Total number of attention keys + + H = self.num_heads + query = self.query(self.norm(x)) # Query from motion features + + # Key and Value from text and retrieval features + text_cond_type = (cond_type % 10 > 0).float() + retr_cond_type = (cond_type // 10 > 0).float() + re_text = re_text.repeat(1, 1, re_motion.shape[2], 1) + re_feat_key = torch.cat((re_motion, re_text), dim=-1).reshape(B, -1, 2 * D) + + # Calculate keys for text, retrieval, and motion + key_text = self.key_text(self.text_norm(xf)) + (1 - text_cond_type) * -1000000 + key_retr = self.key_retr(self.retr_norm1(re_feat_key)) + (1 - retr_cond_type) * -1000000 + (1 - re_mask) * -1000000 + key_motion = self.key_motion(self.norm(x)) + (1 - src_mask) * -1000000 + + key = torch.cat((key_text, key_retr, key_motion), dim=1) # Concatenate all keys + + query = F.softmax(query.view(B, T, H, -1), dim=-1) + key = F.softmax(key.view(B, N, H, -1), dim=1) + + # Value computation from text, retrieval, and motion features + re_feat_value = re_motion.reshape(B, -1, D) + value_text = self.value_text(self.text_norm(xf)) * text_cond_type + value_retr = self.value_retr(self.retr_norm2(re_feat_value)) * retr_cond_type * re_mask + value_motion = self.value_motion(self.norm(x)) * src_mask + value = torch.cat((value_text, value_retr, value_motion), dim=1).view(B, N, H, -1) + + # Attention computation and output projection + attention = torch.einsum('bnhd,bnhl->bhdl', key, value) + y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) + y = x + self.proj_out(y, emb) + return y + + + +@ATTENTIONS.register_module() +class DualSemanticsModulatedAttention(nn.Module): + """ + Dual semantics-modulated attention module that handles two streams of motion features and integrates + them with text and retrieval features. + + Args: + latent_dim (int): Dimensionality of the latent (motion) features. + text_latent_dim (int): Dimensionality of the text features. + num_heads (int): Number of attention heads. + dropout (float): Dropout rate. + time_embed_dim (int): Dimensionality of time embeddings. + """ + + def __init__(self, latent_dim: int, text_latent_dim: int, num_heads: int, dropout: float, time_embed_dim: int): + super().__init__() + self.num_heads = num_heads + self.latent_dim = latent_dim + + # Layer Normalization for motion and text features + self.norm = nn.LayerNorm(latent_dim) + self.text_norm = nn.LayerNorm(text_latent_dim) + + # Linear projections for queries, keys, and values + self.query = nn.Linear(latent_dim, latent_dim) + self.key_text = nn.Linear(text_latent_dim, latent_dim) + self.value_text = nn.Linear(text_latent_dim, latent_dim) + self.key_motion = nn.Linear(latent_dim, latent_dim) + self.value_motion = nn.Linear(latent_dim, latent_dim) + self.key_inter = nn.Linear(latent_dim, latent_dim) + self.value_inter = nn.Linear(latent_dim, latent_dim) + + # Retrieval feature processing (motion and text) + self.retr_norm1 = nn.LayerNorm(2 * latent_dim) + self.retr_norm2 = nn.LayerNorm(latent_dim) + self.key_retr = nn.Linear(2 * latent_dim, latent_dim) + self.value_retr = zero_module(nn.Linear(latent_dim, latent_dim)) + + # Dropout and output projection + self.dropout = nn.Dropout(dropout) + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, x: torch.Tensor, xf: torch.Tensor, emb: torch.Tensor, src_mask: torch.Tensor, + cond_type: torch.Tensor, re_dict: dict) -> torch.Tensor: + """ + Forward pass of DualSemanticsModulatedAttention. + + Args: + x (torch.Tensor): Motion features of shape (B, T, 2*D). + xf (torch.Tensor): Text features of shape (B, N, L). + emb (torch.Tensor): Time embedding. + src_mask (torch.Tensor): Source mask for the input motion features. + cond_type (torch.Tensor): Condition type tensor. + re_dict (dict): Dictionary containing retrieval motion, text, and mask data. + + Returns: + torch.Tensor: Output tensor after dual attention modulation, shape (B, T, 2*D). + """ + x1 = x[:, :, :self.latent_dim].contiguous() + x2 = x[:, :, self.latent_dim:].contiguous() + B, T, D = x1.shape + re_motion = re_dict['re_motion'] + re_text = re_dict['re_text'] + re_mask = re_dict['re_mask'].reshape(B, -1, 1) + N = xf.shape[1] + x.shape[1] * 2 + re_motion.shape[1] * re_motion.shape[2] + + H = self.num_heads + + # Query computation for both streams + query1 = self.query(self.norm(x1)) + query2 = self.query(self.norm(x2)) + + # Retrieval key/value feature preparation + text_cond_type = (cond_type % 10 > 0).float() + retr_cond_type = (cond_type // 10 > 0).float() + re_text = re_text.repeat(1, 1, re_motion.shape[2], 1) + re_feat_key = torch.cat((re_motion, re_text), dim=-1) + re_feat_key = re_feat_key.reshape(B, -1, 2 * D) + + # Keys for text, retrieval, and motion + key_text = self.key_text(self.text_norm(xf)) + (1 - text_cond_type) * -1000000 + key_retr = self.key_retr(self.retr_norm1(re_feat_key)) + (1 - retr_cond_type) * -1000000 + (1 - re_mask) * -1000000 + key_motion1 = self.key_motion(self.norm(x1)) + (1 - src_mask) * -1000000 + key_motion2 = self.key_motion(self.norm(x2)) + (1 - src_mask) * -1000000 + + # Cross-attention keys for inter-stream communication + key_inter1 = self.key_inter(self.norm(x2)) + (1 - src_mask) * -1000000 + key_inter2 = self.key_inter(self.norm(x1)) + (1 - src_mask) * -1000000 + + # Concatenate all keys for the two streams + key1 = torch.cat((key_text, key_retr, key_motion1, key_inter1), dim=1) + key2 = torch.cat((key_text, key_retr, key_motion2, key_inter2), dim=1) + + # Softmax over queries and keys + query1 = F.softmax(query1.view(B, T, H, -1), dim=-1) + query2 = F.softmax(query2.view(B, T, H, -1), dim=-1) + key1 = F.softmax(key1.view(B, N, H, -1), dim=1) + key2 = F.softmax(key2.view(B, N, H, -1), dim=1) + + # Value computation for text, retrieval, and motion + re_feat_value = re_motion.reshape(B, -1, D) + value_text = self.value_text(self.text_norm(xf)) * text_cond_type + value_retr = self.value_retr(self.retr_norm2(re_feat_value)) * retr_cond_type * re_mask + value_motion1 = self.value_motion(self.norm(x1)) * src_mask + value_motion2 = self.value_motion(self.norm(x2)) * src_mask + + # Inter-stream value exchange + value_inter1 = self.value_inter(self.norm(x2)) * src_mask + value_inter2 = self.value_inter(self.norm(x1)) * src_mask + + # Concatenate values for both streams + value1 = torch.cat((value_text, value_retr, value_motion1, value_inter1), dim=1).view(B, N, H, -1) + value2 = torch.cat((value_text, value_retr, value_motion2, value_inter2), dim=1).view(B, N, H, -1) + + # Compute attention outputs for both streams + attention1 = torch.einsum('bnhd,bnhl->bhdl', key1, value1) + attention2 = torch.einsum('bnhd,bnhl->bhdl', key2, value2) + + # Apply attention to queries and compute final output + y1 = torch.einsum('bnhd,bhdl->bnhl', query1, attention1).reshape(B, T, D) + y2 = torch.einsum('bnhd,bhdl->bnhl', query2, attention2).reshape(B, T, D) + + # Combine both streams and apply output projection + y1 = x1 + self.proj_out(y1, emb) + y2 = x2 + self.proj_out(y2, emb) + y = torch.cat((y1, y2), dim=-1) + + return y diff --git a/mogen/models/builder.py b/mogen/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..08756798f3b98a69ba75c8748735f1335e30e34f --- /dev/null +++ b/mogen/models/builder.py @@ -0,0 +1,36 @@ +from mmcv.cnn import MODELS as MMCV_MODELS +from mmcv.utils import Registry + + +def build_from_cfg(cfg, registry, default_args=None): + if cfg is None: + return None + return MMCV_MODELS.build_func(cfg, registry, default_args) + + +MODELS = Registry('models', parent=MMCV_MODELS, build_func=build_from_cfg) + +LOSSES = MODELS +ARCHITECTURES = MODELS +SUBMODULES = MODELS +ATTENTIONS = MODELS + + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + + +def build_architecture(cfg): + """Build framework.""" + return ARCHITECTURES.build(cfg) + + +def build_submodule(cfg): + """Build submodule.""" + return SUBMODULES.build(cfg) + + +def build_attention(cfg): + """Build attention.""" + return ATTENTIONS.build(cfg) diff --git a/mogen/models/losses/__init__.py b/mogen/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f4c183d649149095948f9991c81b6713e2e112ce --- /dev/null +++ b/mogen/models/losses/__init__.py @@ -0,0 +1,8 @@ +from .mse_loss import MSELoss, KinematicLoss +from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss, + weighted_loss) + +__all__ = [ + 'convert_to_one_hot', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', + 'MSELoss', 'KinematicLoss' +] diff --git a/mogen/models/losses/mse_loss.py b/mogen/models/losses/mse_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..742ea3d94e1db99138043eb05af661f3743683c4 --- /dev/null +++ b/mogen/models/losses/mse_loss.py @@ -0,0 +1,175 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import weighted_loss + + +def gmof(x, sigma): + """Geman-McClure error function. + + Args: + x (torch.Tensor): The input tensor. + sigma (float): The sigma value used in the calculation. + + Returns: + torch.Tensor: The computed Geman-McClure error. + """ + x_squared = x**2 + sigma_squared = sigma**2 + return (sigma_squared * x_squared) / (sigma_squared + x_squared) + + +@weighted_loss +def mse_loss(pred, target): + """Wrapper for Mean Squared Error (MSE) loss. + + Args: + pred (torch.Tensor): Predicted values. + target (torch.Tensor): Ground truth values. + + Returns: + torch.Tensor: MSE loss. + """ + return F.mse_loss(pred, target, reduction='none') + + +@weighted_loss +def smooth_l1_loss(pred, target): + """Wrapper for Smooth L1 loss. + + Args: + pred (torch.Tensor): Predicted values. + target (torch.Tensor): Ground truth values. + + Returns: + torch.Tensor: Smooth L1 loss. + """ + return F.smooth_l1_loss(pred, target, reduction='none') + + +@weighted_loss +def l1_loss(pred, target): + """Wrapper for L1 loss. + + Args: + pred (torch.Tensor): Predicted values. + target (torch.Tensor): Ground truth values. + + Returns: + torch.Tensor: L1 loss. + """ + return F.l1_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss_with_gmof(pred, target, sigma): + """Extended MSE Loss with Geman-McClure function applied. + + Args: + pred (torch.Tensor): Predicted values. + target (torch.Tensor): Ground truth values. + sigma (float): The sigma value for the Geman-McClure function. + + Returns: + torch.Tensor: The loss value. + """ + loss = F.mse_loss(pred, target, reduction='none') + loss = gmof(loss, sigma) + return loss + + +@LOSSES.register_module() +class MSELoss(nn.Module): + """Mean Squared Error (MSE) Loss. + + Args: + reduction (str, optional): The method to reduce the loss to a scalar. + Options are 'none', 'mean', and 'sum'. Defaults to 'mean'. + loss_weight (float, optional): Weight of the loss. Defaults to 1.0. + """ + + def __init__(self, reduction='mean', loss_weight=1.0): + super().__init__() + assert reduction in (None, 'none', 'mean', 'sum') + self.reduction = 'none' if reduction is None else reduction + self.loss_weight = loss_weight + + def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None): + """Forward function to compute loss. + + Args: + pred (torch.Tensor): Predictions. + target (torch.Tensor): Ground truth. + weight (torch.Tensor, optional): Optional weight per sample. + avg_factor (int, optional): Factor for averaging the loss. + reduction_override (str, optional): Option to override reduction method. + + Returns: + torch.Tensor: Calculated loss. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = reduction_override if reduction_override else self.reduction + loss = self.loss_weight * mse_loss(pred, target, weight, reduction=reduction, avg_factor=avg_factor) + return loss + + +@LOSSES.register_module() +class KinematicLoss(nn.Module): + """Kinematic Loss for hierarchical motion prediction. + + Args: + reduction (str, optional): Reduction method ('none', 'mean', or 'sum'). + loss_type (str, optional): The type of loss to use ('mse', 'smooth_l1', 'l1'). + loss_weight (list[float], optional): List of weights for each stage of the hierarchy. + """ + + def __init__(self, reduction='mean', loss_type='mse', loss_weight=[1.0]): + super().__init__() + assert reduction in (None, 'none', 'mean', 'sum') + self.reduction = 'none' if reduction is None else reduction + self.loss_weight = loss_weight + self.num_stages = len(loss_weight) + + # Select loss function based on loss_type + if loss_type == 'mse': + self.loss_func = mse_loss + elif loss_type == 'smooth_l1': + self.loss_func = smooth_l1_loss + elif loss_type == 'l1': + self.loss_func = l1_loss + else: + raise ValueError(f"Unknown loss type: {loss_type}") + + def forward(self, pred, target, weight=None, avg_factor=None, reduction_override=None): + """Forward function for hierarchical kinematic loss. + + Args: + pred (torch.Tensor): The prediction tensor. + target (torch.Tensor): The target tensor. + weight (torch.Tensor, optional): Weights for each prediction. Defaults to None. + avg_factor (int, optional): Factor to average the loss. Defaults to None. + reduction_override (str, optional): Override reduction method. Defaults to None. + + Returns: + torch.Tensor: The calculated hierarchical loss. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = reduction_override if reduction_override else self.reduction + + total_loss = 0 + pred_t = pred.clone() + target_t = target.clone() + + # Apply loss function across stages + for i in range(self.num_stages): + stage_loss = self.loss_weight[i] * self.loss_func( + pred_t, target_t, weight, reduction=reduction, avg_factor=avg_factor) + total_loss += stage_loss + + # Compute differences between consecutive frames + pred_t = torch.cat((pred_t[:, :1, :], pred_t[:, 1:] - pred_t[:, :-1]), dim=1) + target_t = torch.cat((target_t[:, :1, :], target_t[:, 1:] - target_t[:, :-1]), dim=1) + + return total_loss diff --git a/mogen/models/losses/utils.py b/mogen/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bedb55051964620c3f5df5b3159c4c44720f3159 --- /dev/null +++ b/mogen/models/losses/utils.py @@ -0,0 +1,109 @@ +import functools + +import torch +import torch.nn.functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): + """Apply element-wise weight and reduce loss. + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Average factor when computing the mean of losses. + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + loss = loss.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + :Example: + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper + + +def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: + """This function converts target class indices to one-hot vectors, given + the number of classes. + Args: + targets (Tensor): The ground truth label of the prediction + with shape (N, 1) + classes (int): the number of classes. + Returns: + Tensor: Processed loss values. + """ + assert (torch.max(targets).item() + < classes), 'Class Index must be less than number of classes' + one_hot_targets = torch.zeros((targets.shape[0], classes), + dtype=torch.long, + device=targets.device) + one_hot_targets.scatter_(1, targets.long(), 1) + return one_hot_targets diff --git a/mogen/models/transformers/__init__.py b/mogen/models/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc1654aedac1ef535ee6c2a3b50b739666b8ba35 --- /dev/null +++ b/mogen/models/transformers/__init__.py @@ -0,0 +1,14 @@ +from .actor import ACTORDecoder, ACTOREncoder +from .finemogen import FineMoGenTransformer +from .intergen import InterCLIP +from .mdm import MDMTransformer +from .momatmogen import MoMatMoGenTransformer +from .motiondiffuse import MotionDiffuseTransformer +from .remodiffuse import ReMoDiffuseTransformer +from .large_motion_model import LargeMotionModel + +__all__ = [ + 'ACTOREncoder', 'ACTORDecoder', 'MotionDiffuseTransformer', + 'ReMoDiffuseTransformer', 'MDMTransformer', 'FineMoGenTransformer', + 'InterCLIP', 'MoMatMoGenTransformer', 'LargeMotionModel' +] diff --git a/mogen/models/transformers/actor.py b/mogen/models/transformers/actor.py new file mode 100644 index 0000000000000000000000000000000000000000..0b9508353788acd347e419ca48b2f032dd86d647 --- /dev/null +++ b/mogen/models/transformers/actor.py @@ -0,0 +1,298 @@ +import torch +from mmcv.runner import BaseModule +from torch import nn +from typing import Optional + +from mogen.models.utils.mlp import build_MLP +from mogen.models.utils.position_encoding import (LearnedPositionalEncoding, + SinusoidalPositionalEncoding) + +from ..builder import SUBMODULES + + +@SUBMODULES.register_module() +class ACTOREncoder(BaseModule): + """ACTOR Encoder module for motion data. + + Args: + max_seq_len (Optional[int]): Maximum sequence length for positional encoding. + njoints (Optional[int]): Number of joints for motion input. Defaults to None. + nfeats (Optional[int]): Number of features for each joint. Defaults to None. + input_feats (Optional[int]): Total input feature dimension. Defaults to None. + latent_dim (Optional[int]): Latent feature dimension. + condition_dim (Optional[int]): Dimension of condition features. Defaults to None. + num_heads (Optional[int]): Number of heads in the Transformer encoder. + ff_size (Optional[int]): Feedforward network size in the Transformer. + num_layers (Optional[int]): Number of layers in the Transformer encoder. + activation (Optional[str]): Activation function for the Transformer. + dropout (Optional[float]): Dropout probability. + use_condition (Optional[bool]): Whether to use conditioning inputs. + num_class (Optional[int]): Number of classes for conditional encoding. Defaults to None. + use_final_proj (Optional[bool]): Whether to apply a final projection layer. + output_var (Optional[bool]): Whether to output a variance along with mean. + pos_embedding (Optional[str]): Type of positional encoding ('sinusoidal' or 'learned'). + init_cfg (Optional[dict]): Initialization configuration. + """ + + def __init__(self, + max_seq_len: Optional[int] = 16, + njoints: Optional[int] = None, + nfeats: Optional[int] = None, + input_feats: Optional[int] = None, + latent_dim: Optional[int] = 256, + condition_dim: Optional[int] = None, + num_heads: Optional[int] = 4, + ff_size: Optional[int] = 1024, + num_layers: Optional[int] = 8, + activation: Optional[str] = 'gelu', + dropout: Optional[float] = 0.1, + use_condition: Optional[bool] = False, + num_class: Optional[int] = None, + use_final_proj: Optional[bool] = False, + output_var: Optional[bool] = False, + pos_embedding: Optional[str] = 'sinusoidal', + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + + # If input_feats is not provided, compute it from njoints and nfeats + self.njoints = njoints + self.nfeats = nfeats + if input_feats is None: + assert self.njoints is not None and self.nfeats is not None + self.input_feats = njoints * nfeats + else: + self.input_feats = input_feats + + # Initialize parameters + self.max_seq_len = max_seq_len + self.latent_dim = latent_dim + self.condition_dim = condition_dim + self.use_condition = use_condition + self.num_class = num_class + self.use_final_proj = use_final_proj + self.output_var = output_var + + # Linear embedding layer for skeleton input features + self.skelEmbedding = nn.Linear(self.input_feats, self.latent_dim) + + # If using conditional inputs, set up layers for conditional processing + if self.use_condition: + if num_class is None: + self.mu_layer = build_MLP(self.condition_dim, self.latent_dim) + if self.output_var: + self.sigma_layer = build_MLP(self.condition_dim, self.latent_dim) + else: + self.mu_layer = nn.Parameter(torch.randn(num_class, self.latent_dim)) + if self.output_var: + self.sigma_layer = nn.Parameter(torch.randn(num_class, self.latent_dim)) + else: + if self.output_var: + self.query = nn.Parameter(torch.randn(2, self.latent_dim)) # Query for mu and sigma + else: + self.query = nn.Parameter(torch.randn(1, self.latent_dim)) # Query for mu only + + # Positional encoding setup + if pos_embedding == 'sinusoidal': + self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout) + else: + self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len + 2) + + # Transformer encoder layers + seqTransEncoderLayer = nn.TransformerEncoderLayer( + d_model=self.latent_dim, nhead=num_heads, dim_feedforward=ff_size, dropout=dropout, activation=activation) + self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=num_layers) + + def forward(self, motion: torch.Tensor, motion_mask: Optional[torch.Tensor] = None, condition: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass for ACTOR Encoder. + + Args: + motion (torch.Tensor): Input motion data of shape (B, T, njoints, nfeats). + motion_mask (Optional[torch.Tensor]): Mask for valid motion data. Defaults to None. + condition (Optional[torch.Tensor]): Conditional input. Defaults to None. + + Returns: + torch.Tensor: Encoded latent representation. + """ + # Get batch size (B) and sequence length (T) + B, T = motion.shape[:2] + + # Flatten motion input into (B, T, input_feats) + motion = motion.view(B, T, -1) + + # Embed the motion input features into latent space + feature = self.skelEmbedding(motion) + + # Handle conditional inputs, concatenating condition queries + if self.use_condition: + if self.output_var: + if self.num_class is None: + sigma_query = self.sigma_layer(condition) + else: + sigma_query = self.sigma_layer[condition.long()] + sigma_query = sigma_query.view(B, 1, -1) + feature = torch.cat((sigma_query, feature), dim=1) + + if self.num_class is None: + mu_query = self.mu_layer(condition).view(B, 1, -1) + else: + mu_query = self.mu_layer[condition.long()].view(B, 1, -1) + feature = torch.cat((mu_query, feature), dim=1) + else: + query = self.query.view(1, -1, self.latent_dim).repeat(B, 1, 1) + feature = torch.cat((query, feature), dim=1) + + # If outputting variance, adjust the mask accordingly + if self.output_var: + motion_mask = torch.cat((torch.zeros(B, 2).to(motion.device), 1 - motion_mask), dim=1).bool() + else: + motion_mask = torch.cat((torch.zeros(B, 1).to(motion.device), 1 - motion_mask), dim=1).bool() + + # Positional encoding and transformer encoder processing + feature = feature.permute(1, 0, 2).contiguous() # Permute for transformer + feature = self.pos_encoder(feature) + feature = self.seqTransEncoder(feature, src_key_padding_mask=motion_mask) + + # Apply final projection if required + if self.use_final_proj: + mu = self.final_mu(feature[0]) + if self.output_var: + sigma = self.final_sigma(feature[1]) + return mu, sigma + return mu + else: + if self.output_var: + return feature[0], feature[1] + else: + return feature[0] + + +@SUBMODULES.register_module() +class ACTORDecoder(BaseModule): + """ACTOR Decoder module for motion generation. + + Args: + max_seq_len (Optional[int]): Maximum sequence length. + njoints (Optional[int]): Number of joints for motion input. Defaults to None. + nfeats (Optional[int]): Number of features for each joint. Defaults to None. + input_feats (Optional[int]): Total input feature dimension. Defaults to None. + input_dim (Optional[int]): Input feature dimension. + latent_dim (Optional[int]): Latent feature dimension. + condition_dim (Optional[int]): Dimension of condition features. Defaults to None. + num_heads (Optional[int]): Number of heads in the Transformer decoder. + ff_size (Optional[int]): Feedforward network size in the Transformer. + num_layers (Optional[int]): Number of layers in the Transformer decoder. + activation (Optional[str]): Activation function for the Transformer. + dropout (Optional[float]): Dropout probability. + use_condition (Optional[bool]): Whether to use conditioning inputs. + num_class (Optional[int]): Number of classes for conditional encoding. Defaults to None. + pos_embedding (Optional[str]): Type of positional encoding ('sinusoidal' or 'learned'). + init_cfg (Optional[dict]): Initialization configuration. + """ + + def __init__(self, + max_seq_len: Optional[int] = 16, + njoints: Optional[int] = None, + nfeats: Optional[int] = None, + input_feats: Optional[int] = None, + input_dim: Optional[int] = 256, + latent_dim: Optional[int] = 256, + condition_dim: Optional[int] = None, + num_heads: Optional[int] = 4, + ff_size: Optional[int] = 1024, + num_layers: Optional[int] = 8, + activation: Optional[str] = 'gelu', + dropout: Optional[float] = 0.1, + use_condition: Optional[bool] = False, + num_class: Optional[int] = None, + pos_embedding: Optional[str] = 'sinusoidal', + init_cfg: Optional[dict] = None): + super().__init__(init_cfg=init_cfg) + + # If input_dim is different from latent_dim, we need a linear transformation + if input_dim != latent_dim: + self.linear = nn.Linear(input_dim, latent_dim) + else: + self.linear = nn.Identity() + + # Setting parameters for the number of joints, features, and sequence length + self.njoints = njoints + self.nfeats = nfeats + if input_feats is None: + assert self.njoints is not None and self.nfeats is not None + self.input_feats = njoints * nfeats + else: + self.input_feats = input_feats + + # Model configuration parameters + self.max_seq_len = max_seq_len + self.input_dim = input_dim + self.latent_dim = latent_dim + self.condition_dim = condition_dim + self.use_condition = use_condition + self.num_class = num_class + + # If using condition input, initialize condition bias + if self.use_condition: + if num_class is None: + self.condition_bias = build_MLP(condition_dim, latent_dim) + else: + self.condition_bias = nn.Parameter(torch.randn(num_class, latent_dim)) + + # Initialize positional encoding method + if pos_embedding == 'sinusoidal': + self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout) + else: + self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len) + + # Transformer Decoder layer definition + seqTransDecoderLayer = nn.TransformerDecoderLayer( + d_model=self.latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation) + + # Define the transformer decoder with multiple layers + self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer, num_layers=num_layers) + + # Final output layer to produce the pose from latent features + self.final = nn.Linear(self.latent_dim, self.input_feats) + + def forward(self, input: torch.Tensor, motion_mask: Optional[torch.Tensor] = None, condition: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass for ACTOR Decoder. + + Args: + input (torch.Tensor): Input tensor from the encoder, shape (B, latent_dim). + motion_mask (Optional[torch.Tensor]): Mask for motion data, shape (B, T). Defaults to None. + condition (Optional[torch.Tensor]): Conditional input, shape (B, condition_dim). Defaults to None. + + Returns: + torch.Tensor: Output pose predictions of shape (B, T, njoints * nfeats). + """ + B = input.shape[0] # Get batch size + T = self.max_seq_len # Max sequence length for decoding + + # Transform input to latent space if needed + input = self.linear(input) + + # Add condition bias to input if using conditional inputs + if self.use_condition: + if self.num_class is None: + condition = self.condition_bias(condition) + else: + condition = self.condition_bias[condition.long()].squeeze(1) + input = input + condition + + # Positional encoding for query + query = self.pos_encoder.pe[:T, :].view(T, 1, -1).repeat(1, B, 1) + + # Prepare input and pass through Transformer Decoder + input = input.view(1, B, -1) # Prepare input shape for decoder + feature = self.seqTransDecoder( + tgt=query, memory=input, tgt_key_padding_mask=(1 - motion_mask).bool()) + + # Final layer to produce pose from latent features + pose = self.final(feature).permute(1, 0, 2).contiguous() + + return pose + diff --git a/mogen/models/transformers/finemogen.py b/mogen/models/transformers/finemogen.py new file mode 100644 index 0000000000000000000000000000000000000000..f9631f1bc5a2246eecab6658d833a91ddb6a50c7 --- /dev/null +++ b/mogen/models/transformers/finemogen.py @@ -0,0 +1,547 @@ +import numpy as np +import torch +from torch import nn + +from typing import Optional, Dict, List + +from mogen.models.utils.misc import zero_module + +from ..builder import SUBMODULES, build_attention +from ..utils.stylization_block import StylizationBlock +from .motion_transformer import MotionTransformer + + +def get_kit_slice(idx: int) -> List[int]: + """ + Get the slice indices for the KIT skeleton. + + Args: + idx (int): The index of the skeleton part. + + Returns: + List[int]: Slice indices for the specified skeleton part. + """ + if idx == 0: + return [0, 1, 2, 3, 184, 185, 186, 247, 248, 249, 250] + return [ + 4 + (idx - 1) * 3, + 4 + (idx - 1) * 3 + 1, + 4 + (idx - 1) * 3 + 2, + 64 + (idx - 1) * 6, + 64 + (idx - 1) * 6 + 1, + 64 + (idx - 1) * 6 + 2, + 64 + (idx - 1) * 6 + 3, + 64 + (idx - 1) * 6 + 4, + 64 + (idx - 1) * 6 + 5, + 184 + idx * 3, + 184 + idx * 3 + 1, + 184 + idx * 3 + 2, + ] + + +def get_t2m_slice(idx: int) -> List[int]: + """ + Get the slice indices for the T2M skeleton. + + Args: + idx (int): The index of the skeleton part. + + Returns: + List[int]: Slice indices for the specified skeleton part. + """ + if idx == 0: + return [0, 1, 2, 3, 193, 194, 195, 259, 260, 261, 262] + return [ + 4 + (idx - 1) * 3, + 4 + (idx - 1) * 3 + 1, + 4 + (idx - 1) * 3 + 2, + 67 + (idx - 1) * 6, + 67 + (idx - 1) * 6 + 1, + 67 + (idx - 1) * 6 + 2, + 67 + (idx - 1) * 6 + 3, + 67 + (idx - 1) * 6 + 4, + 67 + (idx - 1) * 6 + 5, + 193 + idx * 3, + 193 + idx * 3 + 1, + 193 + idx * 3 + 2, + ] + + +def get_part_slice(idx_list: List[int], func) -> List[int]: + """ + Get the slice indices for a list of indices. + + Args: + idx_list (List[int]): List of part indices. + func (Callable): Function to get slice indices for each part. + + Returns: + List[int]: Concatenated list of slice indices for the parts. + """ + result = [] + for idx in idx_list: + result.extend(func(idx)) + return result + + +class PoseEncoder(nn.Module): + """ + Pose Encoder to process motion data and encode body parts into latent representations. + """ + + def __init__(self, + dataset_name: str = "human_ml3d", + latent_dim: int = 64, + input_dim: int = 263): + super().__init__() + self.dataset_name = dataset_name + if dataset_name == "human_ml3d": + func = get_t2m_slice + self.head_slice = get_part_slice([12, 15], func) + self.stem_slice = get_part_slice([3, 6, 9], func) + self.larm_slice = get_part_slice([14, 17, 19, 21], func) + self.rarm_slice = get_part_slice([13, 16, 18, 20], func) + self.lleg_slice = get_part_slice([2, 5, 8, 11], func) + self.rleg_slice = get_part_slice([1, 4, 7, 10], func) + self.root_slice = get_part_slice([0], func) + self.body_slice = get_part_slice([_ for _ in range(22)], func) + elif dataset_name == "kit_ml": + func = get_kit_slice + self.head_slice = get_part_slice([4], func) + self.stem_slice = get_part_slice([1, 2, 3], func) + self.larm_slice = get_part_slice([8, 9, 10], func) + self.rarm_slice = get_part_slice([5, 6, 7], func) + self.lleg_slice = get_part_slice([16, 17, 18, 19, 20], func) + self.rleg_slice = get_part_slice([11, 12, 13, 14, 15], func) + self.root_slice = get_part_slice([0], func) + self.body_slice = get_part_slice([_ for _ in range(21)], func) + else: + raise ValueError() + + self.head_embed = nn.Linear(len(self.head_slice), latent_dim) + self.stem_embed = nn.Linear(len(self.stem_slice), latent_dim) + self.larm_embed = nn.Linear(len(self.larm_slice), latent_dim) + self.rarm_embed = nn.Linear(len(self.rarm_slice), latent_dim) + self.lleg_embed = nn.Linear(len(self.lleg_slice), latent_dim) + self.rleg_embed = nn.Linear(len(self.rleg_slice), latent_dim) + self.root_embed = nn.Linear(len(self.root_slice), latent_dim) + self.body_embed = nn.Linear(len(self.body_slice), latent_dim) + + assert len(set(self.body_slice)) == input_dim + + def forward(self, motion: torch.Tensor) -> torch.Tensor: + """ + Forward pass for encoding the motion into body part embeddings. + + Args: + motion (torch.Tensor): Input motion tensor of shape (B, T, D). + + Returns: + torch.Tensor: Concatenated latent representations of body parts. + """ + head_feat = self.head_embed(motion[:, :, self.head_slice].contiguous()) + stem_feat = self.stem_embed(motion[:, :, self.stem_slice].contiguous()) + larm_feat = self.larm_embed(motion[:, :, self.larm_slice].contiguous()) + rarm_feat = self.rarm_embed(motion[:, :, self.rarm_slice].contiguous()) + lleg_feat = self.lleg_embed(motion[:, :, self.lleg_slice].contiguous()) + rleg_feat = self.rleg_embed(motion[:, :, self.rleg_slice].contiguous()) + root_feat = self.root_embed(motion[:, :, self.root_slice].contiguous()) + body_feat = self.body_embed(motion[:, :, self.body_slice].contiguous()) + feat = torch.cat((head_feat, stem_feat, larm_feat, rarm_feat, + lleg_feat, rleg_feat, root_feat, body_feat), + dim=-1) + return feat + + +class PoseDecoder(nn.Module): + """ + Pose Decoder to decode the latent representations of body parts back into motion. + """ + + def __init__(self, + dataset_name: str = "human_ml3d", + latent_dim: int = 64, + output_dim: int = 263): + super().__init__() + self.dataset_name = dataset_name + self.latent_dim = latent_dim + self.output_dim = output_dim + if dataset_name == "human_ml3d": + func = get_t2m_slice + self.head_slice = get_part_slice([12, 15], func) + self.stem_slice = get_part_slice([3, 6, 9], func) + self.larm_slice = get_part_slice([14, 17, 19, 21], func) + self.rarm_slice = get_part_slice([13, 16, 18, 20], func) + self.lleg_slice = get_part_slice([2, 5, 8, 11], func) + self.rleg_slice = get_part_slice([1, 4, 7, 10], func) + self.root_slice = get_part_slice([0], func) + self.body_slice = get_part_slice([_ for _ in range(22)], func) + elif dataset_name == "kit_ml": + func = get_kit_slice + self.head_slice = get_part_slice([4], func) + self.stem_slice = get_part_slice([1, 2, 3], func) + self.larm_slice = get_part_slice([8, 9, 10], func) + self.rarm_slice = get_part_slice([5, 6, 7], func) + self.lleg_slice = get_part_slice([16, 17, 18, 19, 20], func) + self.rleg_slice = get_part_slice([11, 12, 13, 14, 15], func) + self.root_slice = get_part_slice([0], func) + self.body_slice = get_part_slice([_ for _ in range(21)], func) + else: + raise ValueError() + + self.head_out = nn.Linear(latent_dim, len(self.head_slice)) + self.stem_out = nn.Linear(latent_dim, len(self.stem_slice)) + self.larm_out = nn.Linear(latent_dim, len(self.larm_slice)) + self.rarm_out = nn.Linear(latent_dim, len(self.rarm_slice)) + self.lleg_out = nn.Linear(latent_dim, len(self.lleg_slice)) + self.rleg_out = nn.Linear(latent_dim, len(self.rleg_slice)) + self.root_out = nn.Linear(latent_dim, len(self.root_slice)) + self.body_out = nn.Linear(latent_dim, len(self.body_slice)) + + def forward(self, motion: torch.Tensor) -> torch.Tensor: + """ + Forward pass to decode the latent body part features back to motion. + + Args: + motion (torch.Tensor): Input tensor of shape (B, T, D). + + Returns: + torch.Tensor: Output motion tensor of shape (B, T, output_dim). + """ + B, T = motion.shape[:2] + D = self.latent_dim + head_feat = self.head_out(motion[:, :, :D].contiguous()) + stem_feat = self.stem_out(motion[:, :, D:2 * D].contiguous()) + larm_feat = self.larm_out(motion[:, :, 2 * D:3 * D].contiguous()) + rarm_feat = self.rarm_out(motion[:, :, 3 * D:4 * D].contiguous()) + lleg_feat = self.lleg_out(motion[:, :, 4 * D:5 * D].contiguous()) + rleg_feat = self.rleg_out(motion[:, :, 5 * D:6 * D].contiguous()) + root_feat = self.root_out(motion[:, :, 6 * D:7 * D].contiguous()) + body_feat = self.body_out(motion[:, :, 7 * D:].contiguous()) + output = torch.zeros(B, T, self.output_dim).type_as(motion) + output[:, :, self.head_slice] = head_feat + output[:, :, self.stem_slice] = stem_feat + output[:, :, self.larm_slice] = larm_feat + output[:, :, self.rarm_slice] = rarm_feat + output[:, :, self.lleg_slice] = lleg_feat + output[:, :, self.rleg_slice] = rleg_feat + output[:, :, self.root_slice] = root_feat + output = (output + body_feat) / 2.0 + return output + + +class SFFN(nn.Module): + """ + A Stylized Feed-Forward Network (SFFN) module for transformer layers. + + Args: + latent_dim (int): Dimensionality of the input. + ffn_dim (int): Dimensionality of the feed-forward layer. + dropout (float): Dropout probability. + time_embed_dim (int): Dimensionality of the time embedding. + norm (str): Normalization type ('None'). + activation (str): Activation function ('GELU'). + """ + + def __init__(self, + latent_dim: int, + ffn_dim: int, + dropout: float, + time_embed_dim: int, + norm: str = "None", + activation: str = "GELU", + **kwargs): + super().__init__() + self.linear1_list = nn.ModuleList() + self.linear2_list = nn.ModuleList() + + channel_mul = 1 + if activation == "GELU": + self.activation = nn.GELU() + + for i in range(8): + self.linear1_list.append(nn.Linear(latent_dim, ffn_dim * channel_mul)) + self.linear2_list.append(nn.Linear(ffn_dim, latent_dim)) + + self.dropout = nn.Dropout(dropout) + self.proj_out = StylizationBlock(latent_dim * 8, time_embed_dim, dropout) + + if norm == "None": + self.norm = nn.Identity() + + def forward(self, x: torch.Tensor, emb: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Forward pass of the SFFN layer. + + Args: + x (torch.Tensor): Input tensor of shape (B, T, D). + emb (torch.Tensor): Embedding tensor for time step modulation. + + Returns: + torch.Tensor: Output tensor of shape (B, T, D). + """ + B, T, D = x.shape + x = self.norm(x) + x = x.reshape(B, T, 8, -1) + output = [] + for i in range(8): + feat = x[:, :, i].contiguous() + feat = self.dropout(self.activation(self.linear1_list[i](feat))) + feat = self.linear2_list[i](feat) + output.append(feat) + y = torch.cat(output, dim=-1) + y = x.reshape(B, T, D) + self.proj_out(y, emb) + return y + + +class DecoderLayer(nn.Module): + """ + A transformer decoder layer with cross-attention and feed-forward network (SFFN). + + Args: + ca_block_cfg (Optional[Dict]): Configuration for the cross-attention block. + ffn_cfg (Optional[Dict]): Configuration for the feed-forward network (SFFN). + """ + + def __init__(self, ca_block_cfg: Optional[Dict] = None, ffn_cfg: Optional[Dict] = None): + super().__init__() + self.ca_block = build_attention(ca_block_cfg) + self.ffn = SFFN(**ffn_cfg) + + def forward(self, **kwargs) -> torch.Tensor: + """ + Forward pass of the decoder layer. + + Args: + kwargs: Keyword arguments for attention and feed-forward layers. + + Returns: + torch.Tensor: Output of the decoder layer. + """ + if self.ca_block is not None: + x = self.ca_block(**kwargs) + kwargs.update({'x': x}) + if self.ffn is not None: + x = self.ffn(**kwargs) + return x + + +@SUBMODULES.register_module() +class FineMoGenTransformer(MotionTransformer): + """ + A transformer model for motion generation using fine-grained control with Diffusion. + + Args: + scale_func_cfg (Optional[Dict]): Configuration for scaling function. + pose_encoder_cfg (Optional[Dict]): Configuration for the PoseEncoder. + pose_decoder_cfg (Optional[Dict]): Configuration for the PoseDecoder. + moe_route_loss_weight (float): Weight for the Mixture of Experts (MoE) routing loss. + template_kl_loss_weight (float): Weight for the KL loss in template generation. + fine_mode (bool): Whether to enable fine mode for control over body parts. + """ + + def __init__(self, + scale_func_cfg: Optional[Dict] = None, + pose_encoder_cfg: Optional[Dict] = None, + pose_decoder_cfg: Optional[Dict] = None, + moe_route_loss_weight: float = 1.0, + template_kl_loss_weight: float = 0.0001, + fine_mode: bool = False, + **kwargs): + super().__init__(**kwargs) + self.scale_func_cfg = scale_func_cfg + self.joint_embed = PoseEncoder(**pose_encoder_cfg) + self.out = zero_module(PoseDecoder(**pose_decoder_cfg)) + self.moe_route_loss_weight = moe_route_loss_weight + self.template_kl_loss_weight = template_kl_loss_weight + self.mean = np.load("data/datasets/kit_ml/mean.npy") + self.std = np.load("data/datasets/kit_ml/std.npy") + self.fine_mode = fine_mode + + def build_temporal_blocks(self, sa_block_cfg: Optional[Dict], ca_block_cfg: Optional[Dict], ffn_cfg: Optional[Dict]): + """ + Build temporal decoder blocks for the model. + + Args: + sa_block_cfg (Optional[Dict]): Configuration for self-attention blocks. + ca_block_cfg (Optional[Dict]): Configuration for cross-attention blocks. + ffn_cfg (Optional[Dict]): Configuration for feed-forward networks. + """ + self.temporal_decoder_blocks = nn.ModuleList() + for i in range(self.num_layers): + if isinstance(ffn_cfg, list): + ffn_cfg_block = ffn_cfg[i] + else: + ffn_cfg_block = ffn_cfg + self.temporal_decoder_blocks.append(DecoderLayer(ca_block_cfg=ca_block_cfg, ffn_cfg=ffn_cfg_block)) + + def scale_func(self, timestep: int) -> Dict[str, float]: + """ + Scaling function for text and none coefficient based on timestep. + + Args: + timestep (int): Current diffusion timestep. + + Returns: + Dict[str, float]: Scaling factors for text and non-text conditioning. + """ + scale = self.scale_func_cfg['scale'] + w = (1 - (1000 - timestep) / 1000) * scale + 1 + return {'text_coef': w, 'none_coef': 1 - w} + + def aux_loss(self) -> Dict[str, torch.Tensor]: + """ + Auxiliary loss computation for MoE routing and KL loss. + + Returns: + Dict[str, torch.Tensor]: Computed auxiliary losses. + """ + aux_loss = 0 + kl_loss = 0 + for module in self.temporal_decoder_blocks: + if hasattr(module.ca_block, 'aux_loss'): + aux_loss = aux_loss + module.ca_block.aux_loss + if hasattr(module.ca_block, 'kl_loss'): + kl_loss = kl_loss + module.ca_block.kl_loss + losses = {} + if aux_loss > 0: + losses['moe_route_loss'] = aux_loss * self.moe_route_loss_weight + if kl_loss > 0: + losses['template_kl_loss'] = kl_loss * self.template_kl_loss_weight + return losses + + def get_precompute_condition(self, + text: Optional[str] = None, + motion_length: Optional[torch.Tensor] = None, + xf_out: Optional[torch.Tensor] = None, + re_dict: Optional[Dict] = None, + device: Optional[torch.device] = None, + sample_idx: Optional[int] = None, + clip_feat: Optional[torch.Tensor] = None, + **kwargs) -> Dict[str, torch.Tensor]: + """ + Precompute conditioning features for text or other modalities. + + Args: + text (Optional[str]): Text input for conditioning. + motion_length (Optional[torch.Tensor]): Length of the motion sequence. + xf_out (Optional[torch.Tensor]): Precomputed text features. + re_dict (Optional[Dict]): Additional features dictionary. + device (Optional[torch.device]): Target device for the model. + sample_idx (Optional[int]): Sample index for specific conditioning. + clip_feat (Optional[torch.Tensor]): Precomputed CLIP features. + + Returns: + Dict[str, torch.Tensor]: Precomputed conditioning features. + """ + if xf_out is None: + xf_out = self.encode_text(text, clip_feat, device) + output = {'xf_out': xf_out} + return output + + def post_process(self, motion: torch.Tensor) -> torch.Tensor: + """ + Post-process motion data by unnormalizing if necessary. + + Args: + motion (torch.Tensor): Input motion data. + + Returns: + torch.Tensor: Processed motion data. + """ + if self.post_process_cfg is not None: + if self.post_process_cfg.get("unnormalized_infer", False): + mean = torch.from_numpy(np.load(self.post_process_cfg['mean_path'])).type_as(motion) + std = torch.from_numpy(np.load(self.post_process_cfg['std_path'])).type_as(motion) + motion = motion * std + mean + return motion + + def forward_train(self, + h: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + emb: Optional[torch.Tensor] = None, + xf_out: Optional[torch.Tensor] = None, + motion_length: Optional[torch.Tensor] = None, + num_intervals: int = 1, + **kwargs) -> torch.Tensor: + """ + Forward pass during training. + + Args: + h (torch.Tensor): Input tensor of shape (B, T, D). + src_mask (Optional[torch.Tensor]): Source mask tensor. + emb (Optional[torch.Tensor]): Time embedding tensor. + xf_out (Optional[torch.Tensor]): Precomputed text features. + motion_length (Optional[torch.Tensor]): Lengths of motion sequences. + num_intervals (int): Number of intervals for processing. + + Returns: + torch.Tensor: Output tensor of shape (B, T, D). + """ + B, T = h.shape[0], h.shape[1] + cond_type = torch.randint(0, 100, size=(B, 1, 1)).repeat(1, 8, 1).to(h.device) if self.fine_mode else torch.randint(0, 100, size=(B, 1, 1)).to(h.device) + for module in self.temporal_decoder_blocks: + h = module(x=h, + xf=xf_out, + emb=emb, + src_mask=src_mask, + cond_type=cond_type, + motion_length=motion_length, + num_intervals=num_intervals) + + output = self.out(h).view(B, T, -1).contiguous() + return output + + def forward_test(self, + h: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + emb: Optional[torch.Tensor] = None, + xf_out: Optional[torch.Tensor] = None, + timesteps: Optional[torch.Tensor] = None, + motion_length: Optional[torch.Tensor] = None, + num_intervals: int = 1, + **kwargs) -> torch.Tensor: + """ + Forward pass during inference. + + Args: + h (torch.Tensor): Input tensor of shape (B, T, D). + src_mask (Optional[torch.Tensor]): Source mask tensor. + emb (Optional[torch.Tensor]): Time embedding tensor. + xf_out (Optional[torch.Tensor]): Precomputed text features. + timesteps (Optional[torch.Tensor]): Diffusion timesteps. + motion_length (Optional[torch.Tensor]): Lengths of motion sequences. + num_intervals (int): Number of intervals for processing. + + Returns: + torch.Tensor: Output tensor of shape (B, T, D). + """ + B, T = h.shape[0], h.shape[1] + text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1 + none_cond_type = torch.zeros(B, 1, 1).to(h.device) + + all_cond_type = torch.cat((text_cond_type, none_cond_type), dim=0) + h = h.repeat(2, 1, 1) + xf_out = xf_out.repeat(2, 1, 1) + emb = emb.repeat(2, 1) + src_mask = src_mask.repeat(2, 1, 1) + motion_length = motion_length.repeat(2, 1) + for module in self.temporal_decoder_blocks: + h = module(x=h, + xf=xf_out, + emb=emb, + src_mask=src_mask, + cond_type=all_cond_type, + motion_length=motion_length, + num_intervals=num_intervals) + out = self.out(h).view(2 * B, T, -1).contiguous() + out_text = out[:B].contiguous() + out_none = out[B:].contiguous() + + coef_cfg = self.scale_func(int(timesteps[0])) + text_coef = coef_cfg['text_coef'] + none_coef = coef_cfg['none_coef'] + output = out_text * text_coef + out_none * none_coef + return output + + diff --git a/mogen/models/transformers/intergen.py b/mogen/models/transformers/intergen.py new file mode 100644 index 0000000000000000000000000000000000000000..6ba04d6f7681d3d74cd9f714c5f434f31ac87bb8 --- /dev/null +++ b/mogen/models/transformers/intergen.py @@ -0,0 +1,204 @@ +import clip +import numpy as np +import torch +import torch.nn as nn +from mmcv.runner import BaseModule + +from mogen.models.utils.misc import set_requires_grad + +from ..builder import SUBMODULES + +loss_ce = nn.CrossEntropyLoss() + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.0, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.arange(0, d_model, 2).float() * \ + (-np.log(10000.0) / d_model) + div_term = torch.exp(div_term) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + # pe = pe.unsqueeze(0)#.transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x): + # not used in the final model + x = x + self.pe[:x.shape[1], :].unsqueeze(0) + return self.dropout(x) + + +class MotionEncoder(nn.Module): + + def __init__(self, input_dim, latent_dim, ff_size, num_layers, num_heads, + dropout, activation): + super().__init__() + + self.input_feats = input_dim + self.latent_dim = latent_dim + self.ff_size = ff_size + self.num_layers = num_layers + self.num_heads = num_heads + self.dropout = dropout + self.activation = activation + + self.query_token = nn.Parameter(torch.randn(1, self.latent_dim)) + + self.embed_motion = nn.Linear(self.input_feats * 2, self.latent_dim) + self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, + self.dropout, + max_len=2000) + + seqTransEncoderLayer = nn.TransformerEncoderLayer( + d_model=self.latent_dim, + nhead=self.num_heads, + dim_feedforward=self.ff_size, + dropout=self.dropout, + activation=self.activation) + self.transformer = nn.TransformerEncoder(seqTransEncoderLayer, + num_layers=self.num_layers) + self.out_ln = nn.LayerNorm(self.latent_dim) + self.out = nn.Linear(self.latent_dim, 512) + + def forward(self, motion, motion_mask): + x, mask = motion, motion_mask + B, T = x.shape[:2] + + x = x.reshape(B, T, 2, -1)[..., :-4].reshape(B, T, -1) + + x_emb = self.embed_motion(x) + + idx = torch.zeros(B, dtype=torch.long, device=x.device) + emb = torch.cat([self.query_token[idx][:, None], x_emb], dim=1) + + seq_mask = (mask > 0.5) + token_mask = torch.ones((B, 1), dtype=bool, device=x.device) + valid_mask = torch.cat([token_mask, seq_mask], dim=1) + + h = self.sequence_pos_encoder(emb) + + h = h.permute(1, 0, 2) + h = self.transformer(h, src_key_padding_mask=~valid_mask).permute( + 1, 0, 2) + h = self.out_ln(h) + motion_emb = self.out(h[:, 0]) + + return motion_emb + + +@SUBMODULES.register_module() +class InterCLIP(BaseModule): + + def __init__(self, + input_dim=258, + latent_dim=1024, + ff_size=2048, + num_layers=8, + num_heads=8, + dropout=0.1, + activation="gelu", + init_cfg=None): + super().__init__() + self.latent_dim = latent_dim + self.motion_encoder = MotionEncoder(input_dim=input_dim, + latent_dim=latent_dim, + ff_size=ff_size, + num_layers=num_layers, + num_heads=num_heads, + dropout=dropout, + activation=activation) + + self.latent_dim = self.latent_dim + + clip_model, _ = clip.load("ViT-L/14@336px", device="cpu", jit=False) + + self.token_embedding = clip_model.token_embedding + self.positional_embedding = clip_model.positional_embedding + self.dtype = clip_model.dtype + self.latent_scale = nn.Parameter(torch.Tensor([1])) + + set_requires_grad(self.token_embedding, False) + + textTransEncoderLayer = nn.TransformerEncoderLayer( + d_model=768, + nhead=8, + dim_feedforward=ff_size, + dropout=0.1, + activation="gelu") + self.textTransEncoder = nn.TransformerEncoder(textTransEncoderLayer, + num_layers=8) + self.text_ln = nn.LayerNorm(768) + self.out = nn.Linear(768, 512) + + self.clip_training = "text_" + self.l1_criterion = torch.nn.L1Loss(reduction='mean') + assert init_cfg['type'] == 'Pretrained' + self.load_pretrained(init_cfg['checkpoint']) + + def compute_loss(self, batch): + losses = {} + losses["total"] = 0 + + # compute clip losses + batch = self.encode_text(batch) + batch = self.encode_motion(batch) + + mixed_clip_loss, clip_losses = self.compute_clip_losses(batch) + losses.update(clip_losses) + losses["total"] += mixed_clip_loss + + return losses["total"], losses + + def generate_src_mask(self, T, length): + B = length.shape[0] + src_mask = torch.ones(B, T) + for i in range(B): + for j in range(length[i], T): + src_mask[i, j] = 0 + return src_mask + + def encode_motion(self, + motion, + motion_length=None, + motion_mask=None, + **kwargs): + motion_emb = self.motion_encoder(motion, motion_mask) + motion_emb = motion_emb / motion_emb.norm(dim=-1, keepdim=True) + motion_emb = motion_emb * self.latent_scale + return motion_emb + + def encode_text(self, text, device=None, **kwargs): + raw_text = text + with torch.no_grad(): + text = clip.tokenize(raw_text, truncate=True).to(device) + x = self.token_embedding(text).type(self.dtype) + pe_tokens = x + self.positional_embedding.type(self.dtype) + + pe_tokens = pe_tokens.permute(1, 0, 2) + out = self.textTransEncoder(pe_tokens) + out = out.permute(1, 0, 2) + + out = self.text_ln(out) + + out = out[torch.arange(x.shape[0]), text.argmax(dim=-1)] + out = self.out(out) + + text_emb = out + text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True) + text_emb = text_emb * self.latent_scale + + return text_emb + + def load_pretrained(self, ckpt_path): + checkpoint = torch.load(ckpt_path, map_location="cpu") + state_dict = checkpoint["state_dict"] + for k in list(state_dict.keys()): + if "model" in k: + state_dict[k.replace("model.", "")] = state_dict.pop(k) + self.load_state_dict(state_dict, strict=True) diff --git a/mogen/models/transformers/large_motion_model.py b/mogen/models/transformers/large_motion_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9b613b2f12d9ca599632e9fa6cef09f39cf29308 --- /dev/null +++ b/mogen/models/transformers/large_motion_model.py @@ -0,0 +1,832 @@ +import numpy as np +import torch +from torch import nn +import random +from typing import Optional, List, Dict + +from mogen.models.utils.misc import zero_module +from ..builder import SUBMODULES, build_attention +from ..utils.stylization_block import StylizationBlock +from .motion_transformer import MotionTransformer +from mogen.models.utils.position_encoding import timestep_embedding +from scipy.ndimage import gaussian_filter + + +def get_tomato_slice(idx: int) -> List[int]: + """Return specific slices for the tomato dataset.""" + if idx == 0: + result = [0, 1, 2, 3, 463, 464, 465] + else: + result = [ + 4 + (idx - 1) * 3, + 4 + (idx - 1) * 3 + 1, + 4 + (idx - 1) * 3 + 2, + 157 + (idx - 1) * 6, + 157 + (idx - 1) * 6 + 1, + 157 + (idx - 1) * 6 + 2, + 157 + (idx - 1) * 6 + 3, + 157 + (idx - 1) * 6 + 4, + 157 + (idx - 1) * 6 + 5, + 463 + idx * 3, + 463 + idx * 3 + 1, + 463 + idx * 3 + 2, + ] + return result + + +def get_part_slice(idx_list: List[int], func) -> List[int]: + """Return a list of slices by applying the provided function.""" + result = [] + for idx in idx_list: + result.extend(func(idx)) + return result + + +class SinglePoseEncoder(nn.Module): + """Encoder module for individual pose, separating different body parts.""" + + def __init__(self, latent_dim: int = 64): + super().__init__() + func = get_tomato_slice + self.root_slice = get_part_slice([0], func) + self.head_slice = get_part_slice([12, 15], func) + self.stem_slice = get_part_slice([3, 6, 9], func) + self.larm_slice = get_part_slice([14, 17, 19, 21], func) + self.rarm_slice = get_part_slice([13, 16, 18, 20], func) + self.lleg_slice = get_part_slice([2, 5, 8, 11], func) + self.rleg_slice = get_part_slice([1, 4, 7, 10], func) + self.lhnd_slice = get_part_slice(range(22, 37), func) + self.rhnd_slice = get_part_slice(range(37, 52), func) + self.face_slice = range(619, 669) + + # Initialize linear layers for each body part embedding + self.root_embed = nn.Linear(len(self.root_slice), latent_dim) + self.head_embed = nn.Linear(len(self.head_slice), latent_dim) + self.stem_embed = nn.Linear(len(self.stem_slice), latent_dim) + self.larm_embed = nn.Linear(len(self.larm_slice), latent_dim) + self.rarm_embed = nn.Linear(len(self.rarm_slice), latent_dim) + self.lleg_embed = nn.Linear(len(self.lleg_slice), latent_dim) + self.rleg_embed = nn.Linear(len(self.rleg_slice), latent_dim) + self.lhnd_embed = nn.Linear(len(self.lhnd_slice), latent_dim) + self.rhnd_embed = nn.Linear(len(self.rhnd_slice), latent_dim) + self.face_embed = nn.Linear(len(self.face_slice), latent_dim) + + def forward(self, motion: torch.Tensor) -> torch.Tensor: + """Forward pass to embed different parts of the motion tensor.""" + root_feat = self.root_embed(motion[:, :, self.root_slice].contiguous()) + head_feat = self.head_embed(motion[:, :, self.head_slice].contiguous()) + stem_feat = self.stem_embed(motion[:, :, self.stem_slice].contiguous()) + larm_feat = self.larm_embed(motion[:, :, self.larm_slice].contiguous()) + rarm_feat = self.rarm_embed(motion[:, :, self.rarm_slice].contiguous()) + lleg_feat = self.lleg_embed(motion[:, :, self.lleg_slice].contiguous()) + rleg_feat = self.rleg_embed(motion[:, :, self.rleg_slice].contiguous()) + lhnd_feat = self.lhnd_embed(motion[:, :, self.lhnd_slice].contiguous()) + rhnd_feat = self.rhnd_embed(motion[:, :, self.rhnd_slice].contiguous()) + face_feat = self.face_embed(motion[:, :, self.face_slice].contiguous()) + + # Concatenate all embeddings + feat = torch.cat((root_feat, head_feat, stem_feat, + larm_feat, rarm_feat, lleg_feat, rleg_feat, + lhnd_feat, rhnd_feat, face_feat), dim=-1) + return feat + + +class PoseEncoder(nn.Module): + """Encoder for multi-dataset scenarios, handling different datasets.""" + + def __init__(self, latent_dim: int, num_datasets: int): + super().__init__() + self.models = nn.ModuleList() + self.num_datasets = num_datasets + self.latent_dim = latent_dim + + # Initialize single pose encoders for each dataset + for _ in range(num_datasets): + self.models.append(SinglePoseEncoder(latent_dim=latent_dim)) + + def forward(self, motion: torch.Tensor, dataset_idx: torch.Tensor) -> torch.Tensor: + """Forward pass for multi-dataset encoding.""" + B, T = motion.shape[:2] + output = torch.zeros(B, T, 10 * self.latent_dim).type_as(motion) + num_finish = 0 + + # Process each dataset's motion separately + for i in range(self.num_datasets): + batch_motion = motion[dataset_idx == i] + if len(batch_motion) == 0: + continue + num_finish += len(batch_motion) + batch_motion = self.models[i](batch_motion) + output[dataset_idx == i] = batch_motion + assert num_finish == B + return output + + +class SinglePoseDecoder(nn.Module): + """Decoder module for individual pose, reconstructing body parts.""" + + def __init__(self, latent_dim: int = 64, output_dim: int = 669): + super().__init__() + self.latent_dim = latent_dim + self.output_dim = output_dim + func = get_tomato_slice + self.root_slice = get_part_slice([0], func) + self.head_slice = get_part_slice([12, 15], func) + self.stem_slice = get_part_slice([3, 6, 9], func) + self.larm_slice = get_part_slice([14, 17, 19, 21], func) + self.rarm_slice = get_part_slice([13, 16, 18, 20], func) + self.lleg_slice = get_part_slice([2, 5, 8, 11], func) + self.rleg_slice = get_part_slice([1, 4, 7, 10], func) + self.lhnd_slice = get_part_slice(range(22, 37), func) + self.rhnd_slice = get_part_slice(range(37, 52), func) + self.face_slice = range(619, 669) + + # Initialize linear layers for each body part output + self.root_out = nn.Linear(latent_dim, len(self.root_slice)) + self.head_out = nn.Linear(latent_dim, len(self.head_slice)) + self.stem_out = nn.Linear(latent_dim, len(self.stem_slice)) + self.larm_out = nn.Linear(latent_dim, len(self.larm_slice)) + self.rarm_out = nn.Linear(latent_dim, len(self.rarm_slice)) + self.lleg_out = nn.Linear(latent_dim, len(self.lleg_slice)) + self.rleg_out = nn.Linear(latent_dim, len(self.rleg_slice)) + self.lhnd_out = nn.Linear(latent_dim, len(self.lhnd_slice)) + self.rhnd_out = nn.Linear(latent_dim, len(self.rhnd_slice)) + self.face_out = nn.Linear(latent_dim, len(self.face_slice)) + + + def forward(self, motion: torch.Tensor) -> torch.Tensor: + """Forward pass to decode body parts from latent representation.""" + B, T = motion.shape[:2] + D = self.latent_dim + + # Decode each part using corresponding linear layer + root_feat = self.root_out(motion[:, :, :D].contiguous()) + head_feat = self.head_out(motion[:, :, D: 2 * D].contiguous()) + stem_feat = self.stem_out(motion[:, :, 2 * D: 3 * D].contiguous()) + larm_feat = self.larm_out(motion[:, :, 3 * D: 4 * D].contiguous()) + rarm_feat = self.rarm_out(motion[:, :, 4 * D: 5 * D].contiguous()) + lleg_feat = self.lleg_out(motion[:, :, 5 * D: 6 * D].contiguous()) + rleg_feat = self.rleg_out(motion[:, :, 6 * D: 7 * D].contiguous()) + lhnd_feat = self.lhnd_out(motion[:, :, 7 * D: 8 * D].contiguous()) + rhnd_feat = self.rhnd_out(motion[:, :, 8 * D: 9 * D].contiguous()) + face_feat = self.face_out(motion[:, :, 9 * D:].contiguous()) + + # Combine outputs into final tensor + output = torch.zeros(B, T, self.output_dim).type_as(motion) + output[:, :, self.root_slice] = root_feat + output[:, :, self.head_slice] = head_feat + output[:, :, self.stem_slice] = stem_feat + output[:, :, self.larm_slice] = larm_feat + output[:, :, self.rarm_slice] = rarm_feat + output[:, :, self.lleg_slice] = lleg_feat + output[:, :, self.rleg_slice] = rleg_feat + output[:, :, self.lhnd_slice] = lhnd_feat + output[:, :, self.rhnd_slice] = rhnd_feat + output[:, :, self.face_slice] = face_feat + + return output + + +class PoseDecoder(nn.Module): + """Decoder for multi-dataset scenarios, handling different datasets.""" + + def __init__(self, latent_dim: int, output_dim: int, num_datasets: int): + super().__init__() + self.models = nn.ModuleList() + self.num_datasets = num_datasets + self.latent_dim = latent_dim + self.output_dim = output_dim + + # Initialize single pose decoders for each dataset + for _ in range(num_datasets): + self.models.append( + SinglePoseDecoder(latent_dim=latent_dim, output_dim=output_dim) + ) + + def forward(self, motion: torch.Tensor, dataset_idx: torch.Tensor) -> torch.Tensor: + """Forward pass for multi-dataset decoding.""" + B, T = motion.shape[:2] + output = torch.zeros(B, T, self.output_dim).type_as(motion) + num_finish = 0 + + # Process each dataset's motion separately + for i in range(self.num_datasets): + batch_motion = motion[dataset_idx == i] + if len(batch_motion) == 0: + continue + num_finish += len(batch_motion) + batch_motion = self.models[i](batch_motion) + output[dataset_idx == i] = batch_motion + assert num_finish == B + return output + + +class SFFN(nn.Module): + """SFFN module with multiple linear layers, acting on different parts of the input.""" + + def __init__(self, + latent_dim: int, + ffn_dim: int, + dropout: float, + time_embed_dim: int, + activation: str = "GELU"): + super().__init__() + self.linear1_list = nn.ModuleList() + self.linear2_list = nn.ModuleList() + + if activation == "GELU": + self.activation = nn.GELU() + self.linear1 = nn.Linear(latent_dim * 10, ffn_dim * 10) + self.linear2 = nn.Linear(ffn_dim * 10, latent_dim * 10) + + self.dropout = nn.Dropout(dropout) + self.proj_out = StylizationBlock(latent_dim * 10, time_embed_dim, dropout) + + def forward(self, x: torch.Tensor, emb: torch.Tensor, **kwargs) -> torch.Tensor: + """Forward pass for SFFN, applying stylization block.""" + B, T, D = x.shape + y = self.linear2(self.dropout(self.activation(self.linear1(x)))) + y = x.reshape(B, T, D) + self.proj_out(y, emb) + + return y + + +class FFN(nn.Module): + """Feed-forward network with GELU activation and dropout.""" + + def __init__(self, latent_dim: int, ffn_dim: int, dropout: float): + super().__init__() + self.linear1 = nn.Linear(latent_dim, ffn_dim) + self.linear2 = nn.Linear(ffn_dim, latent_dim) + self.activation = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """Forward pass with normalization and residual connection.""" + y = self.linear2(self.dropout(self.activation(self.linear1(x)))) + y = x + y + return y + + +class DecoderLayer(nn.Module): + """Decoder layer consisting of conditional attention block and SFFN.""" + + def __init__(self, ca_block_cfg: Optional[Dict] = None, ffn_cfg: Optional[Dict] = None): + super().__init__() + self.ca_block = build_attention(ca_block_cfg) if ca_block_cfg else None + self.ffn = SFFN(**ffn_cfg) if ffn_cfg else None + + def forward(self, **kwargs) -> torch.Tensor: + """Forward pass for the decoder layer.""" + if self.ca_block is not None: + x = self.ca_block(**kwargs) + kwargs.update({'x': x}) + if self.ffn is not None: + x = self.ffn(**kwargs) + return x + + +class EncoderLayer(nn.Module): + """Encoder layer consisting of self-attention block and FFN.""" + + def __init__(self, sa_block_cfg: Optional[Dict] = None, ffn_cfg: Optional[Dict] = None): + super().__init__() + self.sa_block = build_attention(sa_block_cfg) if sa_block_cfg else None + self.ffn = FFN(**ffn_cfg) if ffn_cfg else None + + def forward(self, **kwargs) -> torch.Tensor: + """Forward pass for the encoder layer.""" + if self.sa_block is not None: + x = self.sa_block(**kwargs) + kwargs.update({'x': x}) + if self.ffn is not None: + x = self.ffn(**kwargs) + return x + +class Transformer(nn.Module): + """Transformer model with self-attention and feed-forward network layers.""" + + def __init__(self, + input_dim: int = 1024, + latent_dim: int = 1024, + num_heads: int = 10, + num_layers: int = 4, + max_seq_len: int = 300, + stride: int = 1, + dropout: float = 0): + super().__init__() + self.blocks = nn.ModuleList() + self.proj_in = nn.Linear(input_dim, latent_dim) + self.embedding = nn.Parameter(torch.randn(1, max_seq_len, latent_dim)) + self.latent_dim = latent_dim + self.stride = stride + self.num_heads = num_heads + self.dropout = dropout + + sa_block_cfg = dict( + type='EfficientSelfAttention', + latent_dim=latent_dim, + num_heads=num_heads, + dropout=dropout + ) + + ffn_cfg = dict( + latent_dim=latent_dim, + ffn_dim=latent_dim * 4, + dropout=dropout + ) + for _ in range(num_layers): + self.blocks.append( + EncoderLayer(sa_block_cfg=sa_block_cfg, ffn_cfg=ffn_cfg) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through transformer layers.""" + x = x[:, ::self.stride, :] + x = self.proj_in(x) + T = x.shape[1] + x = x + self.embedding[:, :T, :] + # Apply each encoder layer + for block in self.blocks: + x = block(x=x) + + return x + + +@SUBMODULES.register_module() +class LargeMotionModel(MotionTransformer): + """Large motion model with optional multi-modal conditioning (text, music, video, etc.).""" + + def __init__(self, + num_parts: int = 10, + latent_part_dim: int = 64, + num_cond_layers: int = 2, + num_datasets: int = 27, + guidance_cfg: Optional[Dict] = None, + moe_route_loss_weight: float = 1.0, + template_kl_loss_weight: float = 0.0001, + dataset_names: Optional[List[str]] = None, + text_input_dim: Optional[int] = None, + music_input_dim: Optional[int] = None, + speech_input_dim: Optional[int] = None, + video_input_dim: Optional[int] = None, + music_input_stride: Optional[int] = 3, + speech_input_stride: Optional[int] = 3, + cond_drop_rate: float = 0, + random_mask: float = 0, + dropout: float = 0, + **kwargs): + kwargs['latent_dim'] = latent_part_dim * num_parts + self.num_parts = num_parts + self.latent_part_dim = latent_part_dim + self.num_datasets = num_datasets + self.dropout = dropout + + super().__init__(**kwargs) + self.guidance_cfg = guidance_cfg + + self.joint_embed = PoseEncoder( + latent_dim=self.latent_part_dim, + num_datasets=self.num_datasets) + self.out = zero_module(PoseDecoder( + latent_dim=self.latent_part_dim, + output_dim=self.input_feats, + num_datasets=self.num_datasets)) + + self.dataset_proj = {name: i for i, name in enumerate(dataset_names or [])} + self.rotation_proj = {'h3d_rot': 0, 'smpl_rot': 1, 'bvh_rot': 2} + + self.moe_route_loss_weight = moe_route_loss_weight + self.template_kl_loss_weight = template_kl_loss_weight + self.cond_drop_rate = cond_drop_rate + + # Conditional transformers for multi-modal inputs + self.text_cond = text_input_dim is not None + self.music_cond = music_input_dim is not None + self.speech_cond = speech_input_dim is not None + self.video_cond = video_input_dim is not None + + if self.text_cond: + self.text_transformer = Transformer( + input_dim=text_input_dim, + latent_dim=self.latent_dim, + num_heads=self.num_parts, + num_layers=num_cond_layers, + dropout=self.dropout) + if self.music_cond: + self.music_transformer = Transformer( + input_dim=music_input_dim, + latent_dim=self.latent_dim, + num_heads=self.num_parts, + num_layers=num_cond_layers, + dropout=self.dropout, + stride=music_input_stride) + if self.speech_cond: + self.speech_transformer = Transformer( + input_dim=speech_input_dim, + latent_dim=self.latent_dim, + num_heads=self.num_parts, + num_layers=num_cond_layers, + dropout=self.dropout, + stride=speech_input_stride) + if self.video_cond: + self.video_transformer = Transformer( + input_dim=video_input_dim, + latent_dim=self.latent_dim, + num_heads=self.num_parts, + num_layers=num_cond_layers, + dropout=self.dropout) + + self.mask_token = nn.Parameter(torch.randn(self.num_parts, self.latent_part_dim)) + self.clean_token = nn.Parameter(torch.randn(self.num_parts, self.latent_part_dim)) + self.random_mask = random_mask + + def build_temporal_blocks(self, + sa_block_cfg: Optional[Dict] = None, + ca_block_cfg: Optional[Dict] = None, + ffn_cfg: Optional[Dict] = None): + """Build temporal decoder blocks with attention and feed-forward networks.""" + self.temporal_decoder_blocks = nn.ModuleList() + ca_block_cfg['latent_dim'] = self.latent_part_dim + ca_block_cfg['num_heads'] = self.num_parts + ca_block_cfg['ffn_dim'] = self.latent_part_dim * 4 + ca_block_cfg['time_embed_dim'] = self.time_embed_dim + ca_block_cfg['max_seq_len'] = self.max_seq_len + ca_block_cfg['dropout'] = self.dropout + for _ in range(self.num_layers): + ffn_cfg_block = dict( + latent_dim=self.latent_part_dim, + ffn_dim=self.latent_part_dim * 4, + dropout=self.dropout, + time_embed_dim=self.time_embed_dim + ) + self.temporal_decoder_blocks.append( + DecoderLayer(ca_block_cfg=ca_block_cfg, ffn_cfg=ffn_cfg_block) + ) + + def scale_func(self, timestep: torch.Tensor, dataset_name: str) -> torch.Tensor: + """Scale function for diffusion, adjusting weights based on timestep.""" + guidance_cfg = self.guidance_cfg[dataset_name] + if guidance_cfg['type'] == 'constant': + w = torch.ones_like(timestep).float() * guidance_cfg['scale'] + elif guidance_cfg['type'] == 'linear': + scale = guidance_cfg['scale'] + w = (1 - (1000 - timestep) / 1000) * scale + 1 + else: + raise NotImplementedError() + return w + + def aux_loss(self) -> Dict[str, torch.Tensor]: + """Compute auxiliary and KL losses for multi-modal routing.""" + aux_loss = 0 + kl_loss = 0 + for module in self.temporal_decoder_blocks: + if hasattr(module.ca_block, 'aux_loss'): + aux_loss += module.ca_block.aux_loss + if hasattr(module.ca_block, 'kl_loss'): + kl_loss += module.ca_block.kl_loss + losses = {} + if aux_loss > 0: + losses['moe_route_loss'] = aux_loss * self.moe_route_loss_weight + if kl_loss > 0: + losses['template_kl_loss'] = kl_loss * self.template_kl_loss_weight + return losses + + def get_precompute_condition(self, + text_word_feat: Optional[torch.Tensor] = None, + text_word_out: Optional[torch.Tensor] = None, + text_cond: Optional[torch.Tensor] = None, + music_word_feat: Optional[torch.Tensor] = None, + music_word_out: Optional[torch.Tensor] = None, + music_cond: Optional[torch.Tensor] = None, + speech_word_feat: Optional[torch.Tensor] = None, + speech_word_out: Optional[torch.Tensor] = None, + speech_cond: Optional[torch.Tensor] = None, + video_word_feat: Optional[torch.Tensor] = None, + video_word_out: Optional[torch.Tensor] = None, + video_cond: Optional[torch.Tensor] = None, + **kwargs) -> Dict[str, torch.Tensor]: + """Precompute conditions for various modalities (text, music, speech, video).""" + output = {} + if self.text_cond and text_word_feat is not None: + text_word_feat = text_word_feat.float() + if text_word_out is None: + if text_cond is None or torch.sum(text_cond) == 0: + latent_dim = self.text_transformer.latent_dim + B, N = text_word_feat.shape[:2] + text_word_out = torch.zeros(B, N, latent_dim).type_as(text_word_feat) + else: + text_word_out = self.text_transformer(text_word_feat) + output['text_word_out'] = text_word_out + if self.music_cond and music_word_feat is not None: + music_word_feat = music_word_feat.float() + if music_word_out is None: + if music_cond is None or torch.sum(music_cond) == 0: + latent_dim = self.music_transformer.latent_dim + B, N = music_word_feat.shape[:2] + music_word_out = torch.zeros(B, N, latent_dim).type_as(music_word_feat) + else: + music_word_out = self.music_transformer(music_word_feat) + output['music_word_out'] = music_word_out + if self.speech_cond and speech_word_feat is not None: + speech_word_feat = speech_word_feat.float() + if speech_word_out is None: + if speech_cond is None or torch.sum(speech_cond) == 0: + latent_dim = self.speech_transformer.latent_dim + B, N = speech_word_feat.shape[:2] + speech_word_out = torch.zeros(B, N, latent_dim).type_as(speech_word_feat) + else: + speech_word_out = self.speech_transformer(speech_word_feat) + output['speech_word_out'] = speech_word_out + if self.video_cond and video_word_feat is not None: + video_word_feat = video_word_feat.float() + if video_word_out is None: + if video_cond is None or torch.sum(video_cond) == 0: + latent_dim = self.video_transformer.latent_dim + B, N = video_word_feat.shape[:2] + video_word_out = torch.zeros(B, N, latent_dim).type_as(video_word_feat) + else: + video_word_out = self.video_transformer(video_word_feat) + output['video_word_out'] = video_word_out + return output + + def post_process(self, motion: torch.Tensor) -> torch.Tensor: + """Post-process motion data (e.g., unnormalization).""" + if self.post_process_cfg is not None and self.post_process_cfg.get("unnormalized_infer", False): + mean = torch.from_numpy(np.load(self.post_process_cfg['mean_path'])).type_as(motion) + std = torch.from_numpy(np.load(self.post_process_cfg['std_path'])).type_as(motion) + motion = motion * std + mean + return motion + + def forward_train(self, + h: torch.Tensor, + src_mask: torch.Tensor, + emb: torch.Tensor, + timesteps: torch.Tensor, + motion_length: Optional[torch.Tensor] = None, + text_word_out: Optional[torch.Tensor] = None, + text_cond: Optional[torch.Tensor] = None, + music_word_out: Optional[torch.Tensor] = None, + music_cond: Optional[torch.Tensor] = None, + speech_word_out: Optional[torch.Tensor] = None, + speech_cond: Optional[torch.Tensor] = None, + video_word_out: Optional[torch.Tensor] = None, + video_cond: Optional[torch.Tensor] = None, + num_intervals: int = 1, + duration: Optional[torch.Tensor] = None, + dataset_idx: Optional[torch.Tensor] = None, + rotation_idx: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """Forward pass for training, applying multi-modal conditions.""" + B, T = h.shape[:2] + # Apply conditional masking if needed + if self.text_cond and text_cond is not None: + text_cond_mask = torch.rand(B).type_as(h) + text_cond[text_cond_mask < self.cond_drop_rate] = 0 + if self.music_cond and music_cond is not None: + music_cond_mask = torch.rand(B).type_as(h) + music_cond[music_cond_mask < self.cond_drop_rate] = 0 + if self.speech_cond and speech_cond is not None: + speech_cond_mask = torch.rand(B).type_as(h) + speech_cond[speech_cond_mask < self.cond_drop_rate] = 0 + if self.video_cond and video_cond is not None: + video_cond_mask = torch.rand(B).type_as(h) + video_cond[video_cond_mask < self.cond_drop_rate] = 0 + + # Apply each temporal decoder block + for module in self.temporal_decoder_blocks: + h = module(x=h, + emb=emb, + src_mask=src_mask, + motion_length=motion_length, + text_cond=text_cond, + text_word_out=text_word_out, + music_cond=music_cond, + music_word_out=music_word_out, + speech_cond=speech_cond, + speech_word_out=speech_word_out, + video_cond=video_cond, + video_word_out=video_word_out, + num_intervals=num_intervals, + duration=duration, + dataset_idx=dataset_idx, + rotation_idx=rotation_idx) + + # Output layer + output = self.out(h, dataset_idx).view(B, T, -1).contiguous() + return output + + def forward_test(self, + h: torch.Tensor, + src_mask: torch.Tensor, + emb: torch.Tensor, + timesteps: torch.Tensor, + motion_length: torch.Tensor, + text_word_out: Optional[torch.Tensor] = None, + text_cond: Optional[torch.Tensor] = None, + music_word_out: Optional[torch.Tensor] = None, + music_cond: Optional[torch.Tensor] = None, + speech_word_out: Optional[torch.Tensor] = None, + speech_cond: Optional[torch.Tensor] = None, + video_word_out: Optional[torch.Tensor] = None, + video_cond: Optional[torch.Tensor] = None, + num_intervals: int = 1, + duration: Optional[torch.Tensor] = None, + dataset_idx: Optional[torch.Tensor] = None, + rotation_idx: Optional[torch.Tensor] = None, + dataset_name: Optional[str] = 'humanml3d_t2m', + **kwargs) -> torch.Tensor: + """Forward pass for testing, including scaling and conditional fusion.""" + B, T = h.shape[:2] + # Duplicate tensors for conditional and non-conditional cases + h = h.repeat(2, 1, 1) + emb = emb.repeat(2, 1) + src_mask = src_mask.repeat(2, 1, 1, 1) + motion_length = motion_length.repeat(2, 1) + duration = duration.repeat(2) + + # dataset_idx_att = [self.dataset_proj['all'] for i in range(B)] + # dataset_idx_att = torch.tensor(dataset_idx_att, dtype=torch.long).to(h.device) + # dataset_idx_att = torch.cat((dataset_idx, dataset_idx_att)) + dataset_idx = dataset_idx.repeat(2) + rotation_idx = rotation_idx.repeat(2) + + if self.text_cond and text_cond is not None and text_word_out is not None: + text_cond = text_cond.repeat(2, 1) + text_cond[B:] = 0 + text_word_out = text_word_out.repeat(2, 1, 1) + if self.music_cond and music_cond is not None and music_word_out is not None: + music_cond = music_cond.repeat(2, 1) + music_cond[B:] = 0 + music_word_out = music_word_out.repeat(2, 1, 1) + if self.speech_cond and speech_cond is not None and speech_word_out is not None: + speech_cond = speech_cond.repeat(2, 1) + speech_cond[B:] = 0 + speech_word_out = speech_word_out.repeat(2, 1, 1) + if self.video_cond and video_cond is not None and video_word_out is not None: + video_cond = video_cond.repeat(2, 1) + video_cond[B:] = 0 + video_word_out = video_word_out.repeat(2, 1, 1) + + # Apply each temporal decoder block + for module in self.temporal_decoder_blocks: + h = module(x=h, + emb=emb, + src_mask=src_mask, + motion_length=motion_length, + text_cond=text_cond, + text_word_out=text_word_out, + music_cond=music_cond, + music_word_out=music_word_out, + speech_cond=speech_cond, + speech_word_out=speech_word_out, + video_cond=video_cond, + video_word_out=video_word_out, + num_intervals=num_intervals, + duration=duration, + dataset_idx=dataset_idx, + rotation_idx=rotation_idx) + + # Process the output from conditional and non-conditional branches + output = self.out(h, dataset_idx).view(2 * B, T, -1).contiguous() + scale = self.scale_func(timesteps, dataset_name).view(-1, 1, 1) + output_cond = output[:B].contiguous() + output_none = output[B:].contiguous() + + # Fuse conditional and non-conditional outputs + output = output_cond * scale + output_none * (1 - scale) + return output + + def create_mask_from_length(self, T: int, motion_length: torch.Tensor) -> torch.Tensor: + """Create a binary mask based on motion length.""" + B = motion_length.shape[0] + src_mask = torch.zeros(B, T) + for bix in range(B): + src_mask[bix, :int(motion_length[bix])] = 1 + return src_mask + + def forward(self, + motion: torch.Tensor, + timesteps: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + motion_length: Optional[torch.Tensor] = None, + num_intervals: int = 1, + motion_metas: Optional[List[Dict]] = None, + text_seq_feat: Optional[torch.Tensor] = None, + text_word_feat: Optional[torch.Tensor] = None, + text_cond: Optional[torch.Tensor] = None, + music_seq_feat: Optional[torch.Tensor] = None, + music_word_feat: Optional[torch.Tensor] = None, + music_cond: Optional[torch.Tensor] = None, + speech_seq_feat: Optional[torch.Tensor] = None, + speech_word_feat: Optional[torch.Tensor] = None, + speech_cond: Optional[torch.Tensor] = None, + video_seq_feat: Optional[torch.Tensor] = None, + video_word_feat: Optional[torch.Tensor] = None, + video_cond: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """Unified forward pass for both training and testing.""" + B, T = motion.shape[:2] + # Precompute conditioning features + conditions = self.get_precompute_condition( + motion_length=motion_length, + text_seq_feat=text_seq_feat, + text_word_feat=text_word_feat, + text_cond=text_cond, + music_seq_feat=music_seq_feat, + music_word_feat=music_word_feat, + music_cond=music_cond, + speech_seq_feat=speech_seq_feat, + speech_word_feat=speech_word_feat, + speech_cond=speech_cond, + video_seq_feat=video_seq_feat, + video_word_feat=video_word_feat, + video_cond=video_cond, + device=motion.device, + **kwargs + ) + if self.training: + new_motion_mask = motion_mask.clone() + rand_mask = torch.rand_like(motion_mask) + threshold = torch.rand(B).type_as(rand_mask) + threshold = threshold.view(B, 1, 1).repeat(1, T, self.num_parts) + new_motion_mask[rand_mask < threshold] = 0 + motion_mask = new_motion_mask + else: + t = int(timesteps[0]) + + motion_mask = motion_mask.view(B, T, 10, 1) + + # Temporal embedding + emb = self.time_embed(timestep_embedding(timesteps, self.latent_dim)) + + # Prepare duration and framerate embeddings + duration = [] + for meta in motion_metas: + framerate = meta['meta_data']['framerate'] + duration.append(1.0 / framerate) + + duration = torch.tensor(duration, dtype=motion.dtype).to(motion.device) + + # Dataset index embedding + dataset_idx = [] + for i in range(B): + dataset_name = motion_metas[i]['meta_data']['dataset_name'] + if torch.rand(1).item() < 0.1 and self.training: + dataset_name = 'all' + idx = self.dataset_proj[dataset_name] + dataset_idx.append(idx) + dataset_idx = torch.tensor(dataset_idx, dtype=torch.long).to(motion.device) + self.dataset_idx = dataset_idx.clone().detach() + + # Rotation index embedding + rotation_idx = [self.rotation_proj[meta['meta_data']['rotation_type']] for meta in motion_metas] + rotation_idx = torch.tensor(rotation_idx, dtype=torch.long).to(motion.device) + + # Embed motion with pose encoder + h = self.joint_embed(motion, dataset_idx) + h = h.view(B, T, 10, -1) * motion_mask + (1 - motion_mask) * self.mask_token + h = h.view(B, T, -1) + + # Source mask based on motion length + src_mask = self.create_mask_from_length(T, motion_length).to(motion.device) + src_mask = src_mask.view(B, T, 1, 1).repeat(1, 1, 10, 1) + + # Training or testing forward + if self.training: + output = self.forward_train( + h=h, + emb=emb, + src_mask=src_mask, + timesteps=timesteps, + motion_length=motion_length, + text_cond=text_cond, + music_cond=music_cond, + speech_cond=speech_cond, + video_cond=video_cond, + num_intervals=num_intervals, + duration=duration, + dataset_idx=dataset_idx, + rotation_idx=rotation_idx, + **conditions + ) + else: + output = self.forward_test( + h=h, + emb=emb, + src_mask=src_mask, + timesteps=timesteps, + motion_length=motion_length, + text_cond=text_cond, + music_cond=music_cond, + speech_cond=speech_cond, + video_cond=video_cond, + num_intervals=num_intervals, + duration=duration, + dataset_idx=dataset_idx, + rotation_idx=rotation_idx, + dataset_name=dataset_name, + **conditions + ) + + return output diff --git a/mogen/models/transformers/mdm.py b/mogen/models/transformers/mdm.py new file mode 100644 index 0000000000000000000000000000000000000000..0786f2f1d7e76ffea1688c748f164a96c40fe6d7 --- /dev/null +++ b/mogen/models/transformers/mdm.py @@ -0,0 +1,227 @@ +import clip +import numpy as np +import torch +import torch.nn as nn + +from ..builder import SUBMODULES + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp32""" + + def _convert_weights_to_fp32(m): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): + m.weight.data = m.weight.data.float() + if m.bias is not None: + m.bias.data = m.bias.data.float() + + if isinstance(m, nn.MultiheadAttention): + attr_list = [f"{s}_proj_weight" for s in ["in", "q", "k", "v"]] + attr_list += ["in_proj_bias", "bias_k", "bias_v"] + for attr in attr_list: + tensor = getattr(m, attr) + if tensor is not None: + tensor.data = tensor.data.float() + + for name in ["text_projection", "proj"]: + if hasattr(m, name): + attr = getattr(m, name) + if attr is not None: + attr.data = attr.data.float() + + model.apply(_convert_weights_to_fp32) + + +@SUBMODULES.register_module() +class MDMTransformer(nn.Module): + + def __init__(self, + input_feats=263, + latent_dim=256, + ff_size=1024, + num_layers=8, + num_heads=4, + dropout=0.1, + activation="gelu", + clip_dim=512, + clip_version=None, + guide_scale=1.0, + cond_mask_prob=0.1, + use_official_ckpt=False, + **kwargs): + super().__init__() + + self.latent_dim = latent_dim + self.ff_size = ff_size + self.num_layers = num_layers + self.num_heads = num_heads + self.dropout = dropout + self.activation = activation + self.clip_dim = clip_dim + self.input_feats = input_feats + self.guide_scale = guide_scale + self.use_official_ckpt = use_official_ckpt + + self.cond_mask_prob = cond_mask_prob + self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) + self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, + self.dropout) + + seqTransEncoderLayer = nn.TransformerEncoderLayer( + d_model=self.latent_dim, + nhead=self.num_heads, + dim_feedforward=self.ff_size, + dropout=self.dropout, + activation=self.activation) + + self.seqTransEncoder = nn.TransformerEncoder( + seqTransEncoderLayer, num_layers=self.num_layers) + + self.embed_timestep = TimestepEmbedder(self.latent_dim, + self.sequence_pos_encoder) + + self.embed_text = nn.Linear(self.clip_dim, self.latent_dim) + self.clip_version = clip_version + self.clip_model = self.load_and_freeze_clip(clip_version) + + self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) + + def load_and_freeze_clip(self, clip_version): + clip_model, _ = clip.load(clip_version, device='cpu', jit=False) + clip.model.convert_weights(clip_model) + + clip_model.eval() + for p in clip_model.parameters(): + p.requires_grad = False + + return clip_model + + def mask_cond(self, cond, force_mask=False): + bs = cond.shape[0] + if force_mask: + return torch.zeros_like(cond) + elif self.training and self.cond_mask_prob > 0.: + mask = torch.ones(bs, device=cond.device) * self.cond_mask_prob + # 1-> use null_cond, 0-> use real cond + mask = torch.bernoulli(mask).view(bs, 1) + return cond * (1. - mask) + else: + return cond + + def encode_text(self, raw_text): + device = next(self.parameters()).device + max_text_len = 20 + if max_text_len is not None: + default_context_length = 77 + context_length = max_text_len + 2 # start_token + 20 + end_token + assert context_length < default_context_length + texts = clip.tokenize(raw_text, + context_length=context_length, + truncate=True).to(device) + zero_pad = torch.zeros( + [texts.shape[0], default_context_length - context_length], + dtype=texts.dtype, + device=texts.device) + texts = torch.cat([texts, zero_pad], dim=1) + return self.clip_model.encode_text(texts).float() + + def get_precompute_condition(self, text, device=None, **kwargs): + if not self.training and device == torch.device('cpu'): + convert_weights(self.clip_model) + text_feat = self.encode_text(text) + return {'text_feat': text_feat} + + def post_process(self, motion): + assert len(motion.shape) == 3 + if self.use_official_ckpt: + motion[:, :, :4] = motion[:, :, :4] * 25 + return motion + + def forward(self, motion, timesteps, text_feat=None, **kwargs): + """ + motion: B, T, D + timesteps: [batch_size] (int) + """ + B, T, D = motion.shape + if text_feat is None: + enc_text = self.get_precompute_condition(**kwargs)['text_feat'] + else: + enc_text = text_feat + if self.training: + # T, B, D + motion = self.poseEmbedding(motion).permute(1, 0, 2) + + emb = self.embed_timestep(timesteps) # [1, bs, d] + emb += self.embed_text(self.mask_cond(enc_text, force_mask=False)) + + xseq = self.sequence_pos_encoder(torch.cat((emb, motion), axis=0)) + output = self.seqTransEncoder(xseq)[1:] + + # B, T, D + output = self.poseFinal(output).permute(1, 0, 2) + return output + else: + # T, B, D + motion = self.poseEmbedding(motion).permute(1, 0, 2) + + emb = self.embed_timestep(timesteps) # [1, bs, d] + emb_uncond = emb + \ + self.embed_text(self.mask_cond(enc_text, force_mask=True)) + emb_text = emb + \ + self.embed_text(self.mask_cond(enc_text, force_mask=False)) + + xseq = self.sequence_pos_encoder( + torch.cat((emb_uncond, motion), axis=0)) + xseq_text = self.sequence_pos_encoder( + torch.cat((emb_text, motion), axis=0)) + output = self.seqTransEncoder(xseq)[1:] + output_text = self.seqTransEncoder(xseq_text)[1:] + # B, T, D + output = self.poseFinal(output).permute(1, 0, 2) + output_text = self.poseFinal(output_text).permute(1, 0, 2) + scale = self.guide_scale + output = output + scale * (output_text - output) + return output + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.arange(0, d_model, 2).float() * \ + (-np.log(10000.0) / d_model) + div_term = torch.exp(div_term) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x): + # not used in the final model + x = x + self.pe[:x.shape[0], :] + return self.dropout(x) + + +class TimestepEmbedder(nn.Module): + + def __init__(self, latent_dim, sequence_pos_encoder): + super().__init__() + self.latent_dim = latent_dim + self.sequence_pos_encoder = sequence_pos_encoder + + time_embed_dim = self.latent_dim + self.time_embed = nn.Sequential( + nn.Linear(self.latent_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + def forward(self, timesteps): + output = self.time_embed(self.sequence_pos_encoder.pe[timesteps]) + output = output.permute(1, 0, 2) + return output diff --git a/mogen/models/transformers/momatmogen.py b/mogen/models/transformers/momatmogen.py new file mode 100644 index 0000000000000000000000000000000000000000..5d486ca21985c77e622bc6e86019617480afd260 --- /dev/null +++ b/mogen/models/transformers/momatmogen.py @@ -0,0 +1,268 @@ +import torch +from torch import nn +from typing import Optional + +from mogen.models.utils.misc import zero_module +from mogen.models.utils.position_encoding import timestep_embedding +from mogen.models.utils.stylization_block import StylizationBlock + +from ..builder import SUBMODULES, build_attention +from .remodiffuse import ReMoDiffuseTransformer + + +class FFN(nn.Module): + """ + A feed-forward network (FFN) with optional stylization block. + + Args: + latent_dim (int): The dimension of the input and output latent space. + ffn_dim (int): The dimension of the hidden feed-forward network. + dropout (float): The dropout rate to apply after activation. + time_embed_dim (int): The dimension of the time embedding. + """ + def __init__(self, latent_dim: int, ffn_dim: int, dropout: float, time_embed_dim: int): + super().__init__() + self.latent_dim = latent_dim + self.linear1 = nn.Linear(latent_dim, ffn_dim) + self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) + self.activation = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, x: torch.Tensor, emb: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Forward pass of the FFN layer. + + Args: + x (torch.Tensor): Input tensor of shape (B, T, latent_dim*2). + emb (torch.Tensor): Time embedding tensor. + + Returns: + torch.Tensor: Output tensor after FFN and stylization block. + """ + x1 = x[:, :, :self.latent_dim].contiguous() + x2 = x[:, :, self.latent_dim:].contiguous() + y1 = self.linear2(self.dropout(self.activation(self.linear1(x1)))) + y1 = x1 + self.proj_out(y1, emb) + y2 = self.linear2(self.dropout(self.activation(self.linear1(x2)))) + y2 = x2 + self.proj_out(y2, emb) + y = torch.cat((y1, y2), dim=-1) + return y + + +class DecoderLayer(nn.Module): + """ + A single decoder layer consisting of a cross-attention block and a feed-forward network (FFN). + + Args: + ca_block_cfg (Optional[dict]): Configuration for the cross-attention block. + ffn_cfg (Optional[dict]): Configuration for the feed-forward network. + """ + def __init__(self, ca_block_cfg: Optional[dict] = None, ffn_cfg: Optional[dict] = None): + super().__init__() + self.ca_block = build_attention(ca_block_cfg) + self.ffn = FFN(**ffn_cfg) + + def forward(self, **kwargs) -> torch.Tensor: + """ + Forward pass of the decoder layer. + + Args: + **kwargs: Arguments passed to the cross-attention and FFN layers. + + Returns: + torch.Tensor: Output tensor after passing through the layer. + """ + if self.ca_block is not None: + x = self.ca_block(**kwargs) + kwargs.update({'x': x}) + if self.ffn is not None: + x = self.ffn(**kwargs) + return x + + +@SUBMODULES.register_module() +class MoMatMoGenTransformer(ReMoDiffuseTransformer): + """ + MoMatMoGenTransformer is a motion generation transformer model, which uses ReMoDiffuse as a base. + + Args: + ReMoDiffuseTransformer: Base transformer class. + """ + def build_temporal_blocks(self, sa_block_cfg: Optional[dict], ca_block_cfg: Optional[dict], ffn_cfg: Optional[dict]): + """ + Build temporal decoder blocks using the provided configurations. + + Args: + sa_block_cfg (Optional[dict]): Self-attention block configuration. + ca_block_cfg (Optional[dict]): Cross-attention block configuration. + ffn_cfg (Optional[dict]): Feed-forward network configuration. + """ + self.temporal_decoder_blocks = nn.ModuleList() + for i in range(self.num_layers): + self.temporal_decoder_blocks.append( + DecoderLayer(ca_block_cfg=ca_block_cfg, ffn_cfg=ffn_cfg)) + + def forward(self, + motion: torch.Tensor, + timesteps: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """ + Forward pass for motion generation. + + Args: + motion (torch.Tensor): Input motion tensor of shape (B, T, D). + timesteps (torch.Tensor): Timestep embeddings. + motion_mask (Optional[torch.Tensor]): Motion mask, if any. + + Returns: + torch.Tensor: Output tensor after processing the motion data. + """ + T = motion.shape[1] + conditions = self.get_precompute_condition(device=motion.device, + **kwargs) + if len(motion_mask.shape) == 2: + src_mask = motion_mask.clone().unsqueeze(-1) + else: + src_mask = motion_mask.clone() + + if self.time_embedding_type == 'sinusoidal': + emb = self.time_embed( + timestep_embedding(timesteps, self.latent_dim)) + else: + emb = self.time_embed(self.time_tokens(timesteps)) + + if self.use_text_proj: + emb = emb + conditions['xf_proj'] + + motion1 = motion[:, :, :self.input_feats].contiguous() + motion2 = motion[:, :, self.input_feats:].contiguous() + h1 = self.joint_embed(motion1) + h2 = self.joint_embed(motion2) + if self.use_pos_embedding: + h1 = h1 + self.sequence_embedding.unsqueeze(0)[:, :T, :] + h2 = h2 + self.sequence_embedding.unsqueeze(0)[:, :T, :] + h = torch.cat((h1, h2), dim=-1) + + if self.training: + output = self.forward_train(h=h, + src_mask=src_mask, + emb=emb, + timesteps=timesteps, + **conditions) + else: + output = self.forward_test(h=h, + src_mask=src_mask, + emb=emb, + timesteps=timesteps, + **conditions) + if self.use_residual_connection: + output = motion + output + return output + + def forward_train(self, + h: Optional[torch.Tensor] = None, + src_mask: Optional[torch.Tensor] = None, + emb: Optional[torch.Tensor] = None, + xf_out: Optional[torch.Tensor] = None, + re_dict: Optional[dict] = None, + **kwargs) -> torch.Tensor: + """ + Training forward pass for the motion generation transformer. + + Args: + h (Optional[torch.Tensor]): Input tensor. + src_mask (Optional[torch.Tensor]): Source mask. + emb (Optional[torch.Tensor]): Embedding tensor. + xf_out (Optional[torch.Tensor]): Output of the cross-attention block. + re_dict (Optional[dict]): Dictionary for recurrent features. + + Returns: + torch.Tensor: Output tensor after processing. + """ + B, T = h.shape[0], h.shape[1] + cond_type = torch.randint(0, 100, size=(B, 1, 1)).to(h.device) + for module in self.temporal_decoder_blocks: + h = module(x=h, + xf=xf_out, + emb=emb, + src_mask=src_mask, + cond_type=cond_type, + re_dict=re_dict) + + out1 = self.out(h[:, :, :self.latent_dim].contiguous()) + out1 = out1.view(B, T, -1).contiguous() + out2 = self.out(h[:, :, self.latent_dim:].contiguous()) + out2 = out2.view(B, T, -1).contiguous() + output = torch.cat((out1, out2), dim=-1) + return output + + def forward_test(self, + h: Optional[torch.Tensor] = None, + src_mask: Optional[torch.Tensor] = None, + emb: Optional[torch.Tensor] = None, + xf_out: Optional[torch.Tensor] = None, + re_dict: Optional[dict] = None, + timesteps: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """ + Testing forward pass for the motion generation transformer. + + Args: + h (Optional[torch.Tensor]): Input tensor. + src_mask (Optional[torch.Tensor]): Source mask. + emb (Optional[torch.Tensor]): Embedding tensor. + xf_out (Optional[torch.Tensor]): Output of the cross-attention block. + re_dict (Optional[dict]): Dictionary for recurrent features. + timesteps (Optional[torch.Tensor]): Timestep embeddings. + + Returns: + torch.Tensor: Output tensor after processing. + """ + B, T = h.shape[0], h.shape[1] + both_cond_type = torch.zeros(B, 1, 1).to(h.device) + 99 + text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1 + retr_cond_type = torch.zeros(B, 1, 1).to(h.device) + 10 + none_cond_type = torch.zeros(B, 1, 1).to(h.device) + + all_cond_type = torch.cat( + (both_cond_type, text_cond_type, retr_cond_type, none_cond_type), + dim=0) + h = h.repeat(4, 1, 1) + xf_out = xf_out.repeat(4, 1, 1) + emb = emb.repeat(4, 1) + src_mask = src_mask.repeat(4, 1, 1) + if re_dict['re_motion'].shape[0] != h.shape[0]: + re_dict['re_motion'] = re_dict['re_motion'].repeat(4, 1, 1, 1) + re_dict['re_text'] = re_dict['re_text'].repeat(4, 1, 1, 1) + re_dict['re_mask'] = re_dict['re_mask'].repeat(4, 1, 1) + + for module in self.temporal_decoder_blocks: + h = module(x=h, + xf=xf_out, + emb=emb, + src_mask=src_mask, + cond_type=all_cond_type, + re_dict=re_dict) + + out1 = self.out(h[:, :, :self.latent_dim].contiguous()) + out1 = out1.view(4 * B, T, -1).contiguous() + out2 = self.out(h[:, :, self.latent_dim:].contiguous()) + out2 = out2.view(4 * B, T, -1).contiguous() + out = torch.cat((out1, out2), dim=-1) + out_both = out[:B].contiguous() + out_text = out[B:2 * B].contiguous() + out_retr = out[2 * B:3 * B].contiguous() + out_none = out[3 * B:].contiguous() + + coef_cfg = self.scale_func(int(timesteps[0])) + both_coef = coef_cfg['both_coef'] + text_coef = coef_cfg['text_coef'] + retr_coef = coef_cfg['retr_coef'] + none_coef = coef_cfg['none_coef'] + output = out_both * both_coef + output += out_text * text_coef + output += out_retr * retr_coef + output += out_none * none_coef + return output diff --git a/mogen/models/transformers/motion_transformer.py b/mogen/models/transformers/motion_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f142dd608f66a262efe43c5a6187cab6916fb04e --- /dev/null +++ b/mogen/models/transformers/motion_transformer.py @@ -0,0 +1,299 @@ +from abc import ABCMeta, abstractmethod +import clip +import torch +from torch import nn +from mmcv.runner import BaseModule + +from ..builder import build_attention +from mogen.models.utils.position_encoding import ( + timestep_embedding +) +from mogen.models.utils.stylization_block import StylizationBlock +from mogen.models.utils.misc import set_requires_grad, zero_module + + +class CLIPWrapper: + + def __init__(self, clip_model): + self.clip_model = clip_model + self.device = "cpu" + + def __call__(self, **kwargs): + return self.clip_model(**kwargs) + + def encode_text(self, text): + if text.is_cuda and self.device == "cpu": + self.clip_model = self.clip_model.cuda() + self.device = "cuda" + if not text.is_cuda and self.device == "cuda": + self.clip_model = self.clip_model.cpu() + self.device = "cpu" + return self.clip_model.encode_text(text) + + def to(self, device): + self.clip_model = self.clip_model.to(device) + + +class FFN(nn.Module): + + def __init__(self, latent_dim, ffn_dim, dropout, time_embed_dim=None): + super().__init__() + self.linear1 = nn.Linear(latent_dim, ffn_dim) + self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) + self.activation = nn.GELU() + self.dropout = nn.Dropout(dropout) + if time_embed_dim is not None: + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + else: + self.proj_out = None + + def forward(self, x, emb=None, **kwargs): + y = self.linear2(self.dropout(self.activation(self.linear1(x)))) + if self.proj_out is not None: + y = x + self.proj_out(y, emb) + else: + y = x + y + return y + + +class DecoderLayer(nn.Module): + + def __init__(self, + sa_block_cfg=None, + ca_block_cfg=None, + ffn_cfg=None): + super().__init__() + self.sa_block = build_attention(sa_block_cfg) + self.ca_block = build_attention(ca_block_cfg) + self.ffn = FFN(**ffn_cfg) + + def forward(self, **kwargs): + if self.sa_block is not None: + x = self.sa_block(**kwargs) + kwargs.update({'x': x}) + if self.ca_block is not None: + x = self.ca_block(**kwargs) + kwargs.update({'x': x}) + if self.ffn is not None: + x = self.ffn(**kwargs) + return x + + +class MotionTransformer(BaseModule, metaclass=ABCMeta): + def __init__(self, + input_feats, + max_seq_len=240, + latent_dim=512, + time_embed_dim=2048, + num_layers=8, + sa_block_cfg=None, + ca_block_cfg=None, + ffn_cfg=None, + text_encoder=None, + use_pos_embedding=True, + use_residual_connection=False, + time_embedding_type='sinusoidal', + post_process_cfg=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.input_feats = input_feats + self.max_seq_len = max_seq_len + self.latent_dim = latent_dim + self.num_layers = num_layers + self.time_embed_dim = time_embed_dim + self.use_pos_embedding = use_pos_embedding + if self.use_pos_embedding: + self.sequence_embedding = nn.Parameter(torch.randn(max_seq_len, latent_dim)) + self.build_text_encoder(text_encoder) + + # Input Embedding + self.joint_embed = nn.Linear(self.input_feats, self.latent_dim) + + self.time_embedding_type = time_embedding_type + if time_embedding_type != 'none': + if time_embedding_type == 'learnable': + self.time_tokens = nn.Embedding(1000, self.latent_dim) + self.time_embed = nn.Sequential( + nn.Linear(self.latent_dim, self.time_embed_dim), + nn.SiLU(), + nn.Linear(self.time_embed_dim, self.time_embed_dim), + ) + self.build_temporal_blocks(sa_block_cfg, ca_block_cfg, ffn_cfg) + + # Output Module + self.out = zero_module(nn.Linear(self.latent_dim, self.input_feats)) + self.use_residual_connection = use_residual_connection + self.post_process_cfg = post_process_cfg + + def build_temporal_blocks(self, sa_block_cfg, ca_block_cfg, ffn_cfg): + self.temporal_decoder_blocks = nn.ModuleList() + for i in range(self.num_layers): + self.temporal_decoder_blocks.append( + DecoderLayer( + sa_block_cfg=sa_block_cfg, + ca_block_cfg=ca_block_cfg, + ffn_cfg=ffn_cfg + ) + ) + + def build_text_encoder(self, text_encoder): + if text_encoder is None: + self.use_text_proj = False + return + text_latent_dim = text_encoder['latent_dim'] + num_text_layers = text_encoder.get('num_layers', 0) + text_ff_size = text_encoder.get('ff_size', 2048) + pretrained_model = text_encoder['pretrained_model'] + text_num_heads = text_encoder.get('num_heads', 4) + dropout = text_encoder.get('dropout', 0) + activation = text_encoder.get('activation', 'gelu') + self.use_text_proj = text_encoder.get('use_text_proj', False) + + if pretrained_model == 'clip': + clip_model, _ = clip.load('ViT-B/32', "cpu") + set_requires_grad(clip_model, False) + self.clip = CLIPWrapper(clip_model) + if text_latent_dim != 512: + self.text_pre_proj = nn.Linear(512, text_latent_dim) + else: + self.text_pre_proj = nn.Identity() + else: + raise NotImplementedError() + + if num_text_layers > 0: + self.use_text_finetune = True + textTransEncoderLayer = nn.TransformerEncoderLayer( + d_model=text_latent_dim, + nhead=text_num_heads, + dim_feedforward=text_ff_size, + dropout=dropout, + activation=activation) + self.textTransEncoder = nn.TransformerEncoder( + textTransEncoderLayer, + num_layers=num_text_layers) + else: + self.use_text_finetune = False + self.text_ln = nn.LayerNorm(text_latent_dim) + if self.use_text_proj: + self.text_proj = nn.Sequential( + nn.Linear(text_latent_dim, self.time_embed_dim) + ) + + def encode_text(self, text, clip_feat, device): + B = len(text) + if type(text[0]) is dict: + knames = ["head", "stem", "left_arm", "right_arm", "left_leg", "right_leg", "pelvis", "all"] + new_text = [] + for item in text: + for kname in knames: + new_text.append(item[kname]) + text = new_text + text = clip.tokenize(text, truncate=True).to(device) + if clip_feat is None: + with torch.no_grad(): + if isinstance(self.clip, CLIPWrapper): + self.clip.to(device) + dtype = self.clip.clip_model.dtype + # [batch_size, n_ctx, d_model] + x = self.clip.clip_model.token_embedding(text).type(dtype) + x = x + self.clip.clip_model.positional_embedding.type(dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.clip.clip_model.transformer(x) + x = self.clip.clip_model.ln_final(x).type(dtype) + else: + dtype = self.clip.dtype + # [batch_size, n_ctx, d_model] + x = self.clip.token_embedding(text).type(dtype) + x = x + self.clip.positional_embedding.type(dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.clip.transformer(x) + x = self.clip.ln_final(x).type(dtype) + else: + x = clip_feat.float().to(device) + if len(x.shape) == 4: + x = x.permute(1, 0, 2, 3) + x = x.reshape([x.shape[0], x.shape[1] * x.shape[2], x.shape[3]]) + else: + x = x.permute(1, 0, 2) + + # T, B, D + x = self.text_pre_proj(x) + xf_out = self.textTransEncoder(x) + xf_out = self.text_ln(xf_out) + if self.use_text_proj: + xf_proj = self.text_proj(xf_out[text.argmax(dim=-1), torch.arange(xf_out.shape[1])]) + # B, T, D + xf_out = xf_out.permute(1, 0, 2) + return xf_proj, xf_out + else: + xf_out = xf_out.permute(1, 0, 2) + return xf_out + + @abstractmethod + def get_precompute_condition(self, **kwargs): + pass + + @abstractmethod + def forward_train(self, h, src_mask, emb, **kwargs): + pass + + @abstractmethod + def forward_test(self, h, src_mask, emb, **kwargs): + pass + + def forward(self, + motion, + timesteps=None, + motion_mask=None, + motion_length=None, + num_intervals=1, + **kwargs): + """ + motion: B, T, D + """ + B, T = motion.shape[0], motion.shape[1] + conditions = self.get_precompute_condition(device=motion.device, + motion_length=motion_length, + **kwargs) + if len(motion_mask.shape) == 2: + src_mask = motion_mask.clone().unsqueeze(-1) + else: + src_mask = motion_mask.clone() + + if self.time_embedding_type != 'none': + if self.time_embedding_type == 'sinusoidal': + emb = self.time_embed(timestep_embedding(timesteps, self.latent_dim)) + else: + emb = self.time_embed(self.time_tokens(timesteps)) + + if self.use_text_proj: + emb = emb + conditions['xf_proj'] + else: + emb = None + # B, T, latent_dim + h = self.joint_embed(motion) + if self.use_pos_embedding: + h = h + self.sequence_embedding.unsqueeze(0)[:, :T, :] + + if self.training: + output = self.forward_train( + h=h, + src_mask=src_mask, + emb=emb, + timesteps=timesteps, + motion_length=motion_length, + num_intervals=num_intervals, + motion=motion, + **conditions) + else: + output = self.forward_test( + h=h, + src_mask=src_mask, + emb=emb, + timesteps=timesteps, + motion_length=motion_length, + num_intervals=num_intervals, + **conditions) + if self.use_residual_connection: + output = motion + output + return output diff --git a/mogen/models/transformers/motiondiffuse.py b/mogen/models/transformers/motiondiffuse.py new file mode 100644 index 0000000000000000000000000000000000000000..e3aa3b63989384f86ecc250d96c105ce344e1082 --- /dev/null +++ b/mogen/models/transformers/motiondiffuse.py @@ -0,0 +1,173 @@ +import numpy as np +import torch + +from typing import Optional, Dict, List + +from ..builder import SUBMODULES +from .motion_transformer import MotionTransformer + + +@SUBMODULES.register_module() +class MotionDiffuseTransformer(MotionTransformer): + """ + MotionDiffuseTransformer is a subclass of DiffusionTransformer designed for motion generation. + It uses a diffusion-based approach with optional guidance during training and inference. + + Args: + guidance_cfg (dict, optional): Configuration for guidance during inference and training. + 'type' can be 'constant' or dynamically calculated based on timesteps. + kwargs: Additional keyword arguments for the DiffusionTransformer base class. + """ + + def __init__(self, guidance_cfg: Optional[dict] = None, **kwargs): + """ + Initialize the MotionDiffuseTransformer. + + Args: + guidance_cfg (Optional[dict]): Configuration for the guidance. + kwargs: Additional arguments passed to the base class. + """ + super().__init__(**kwargs) + self.guidance_cfg = guidance_cfg + + def scale_func(self, timestep: int) -> dict: + """ + Compute the scaling coefficients for text-based guidance and no-guidance. + + Args: + timestep (int): The current diffusion timestep. + + Returns: + dict: A dictionary containing 'text_coef' and 'none_coef' that control the mix of text-conditioned and + non-text-conditioned outputs. + """ + if self.guidance_cfg['type'] == 'constant': + w = self.guidance_cfg['scale'] + return {'text_coef': w, 'none_coef': 1 - w} + else: + scale = self.guidance_cfg['scale'] + w = (1 - (1000 - timestep) / 1000) * scale + 1 + output = {'text_coef': w, 'none_coef': 1 - w} + return output + + def get_precompute_condition(self, + text: Optional[torch.Tensor] = None, + xf_proj: Optional[torch.Tensor] = None, + xf_out: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + clip_feat: Optional[torch.Tensor] = None, + **kwargs) -> dict: + """ + Precompute the conditions for text-based guidance using a text encoder. + + Args: + text (Optional[torch.Tensor]): The input text data. + xf_proj (Optional[torch.Tensor]): Precomputed text projection. + xf_out (Optional[torch.Tensor]): Precomputed output from the text encoder. + device (Optional[torch.device]): The device on which the model is running. + clip_feat (Optional[torch.Tensor]): CLIP features for text guidance. + kwargs: Additional keyword arguments. + + Returns: + dict: A dictionary containing the text projection and output from the encoder. + """ + if xf_out is None: + if self.use_text_proj: + xf_proj, xf_out = self.encode_text(text, clip_feat, device) + else: + xf_out = self.encode_text(text, clip_feat, device) + return {'xf_proj': xf_proj, 'xf_out': xf_out} + + def post_process(self, motion: torch.Tensor) -> torch.Tensor: + """ + Post-process the generated motion data by re-normalizing it using mean and standard deviation. + + Args: + motion (torch.Tensor): The generated motion data. + + Returns: + torch.Tensor: Post-processed motion data. + """ + if self.post_process_cfg is not None: + if self.post_process_cfg.get("unnormalized_infer", False): + mean = torch.from_numpy(np.load(self.post_process_cfg['mean_path'])) + mean = mean.type_as(motion) + std = torch.from_numpy(np.load(self.post_process_cfg['std_path'])) + std = std.type_as(motion) + motion = motion * std + mean + return motion + + def forward_train(self, + h: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + emb: Optional[torch.Tensor] = None, + xf_out: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """ + Forward pass during training. + + Args: + h (torch.Tensor): Input motion tensor of shape (B, T, D). + src_mask (Optional[torch.Tensor]): Source mask for masking the input. + emb (torch.Tensor): Time-step embeddings. + xf_out (Optional[torch.Tensor]): Precomputed output from the text encoder. + kwargs: Additional keyword arguments. + + Returns: + torch.Tensor: Output motion data after processing by the temporal decoder blocks. + """ + B, T = h.shape[0], h.shape[1] + if self.guidance_cfg is None: + for module in self.temporal_decoder_blocks: + h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask) + else: + cond_type = torch.randint(0, 100, size=(B, 1, 1)).to(h.device) + for module in self.temporal_decoder_blocks: + h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask, cond_type=cond_type) + output = self.out(h).view(B, T, -1).contiguous() + return output + + def forward_test(self, + h: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + emb: Optional[torch.Tensor] = None, + xf_out: Optional[torch.Tensor] = None, + timesteps: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """ + Forward pass during testing/inference. + + Args: + h (torch.Tensor): Input motion tensor of shape (B, T, D). + src_mask (Optional[torch.Tensor]): Source mask for masking the input. + emb (torch.Tensor): Time-step embeddings. + xf_out (Optional[torch.Tensor]): Precomputed output from the text encoder. + timesteps (Optional[torch.Tensor]): Current diffusion timesteps. + kwargs: Additional keyword arguments. + + Returns: + torch.Tensor: Output motion data after processing by the temporal decoder blocks. + """ + B, T = h.shape[0], h.shape[1] + if self.guidance_cfg is None: + for module in self.temporal_decoder_blocks: + h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask) + output = self.out(h).view(B, T, -1).contiguous() + else: + text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1 + none_cond_type = torch.zeros(B, 1, 1).to(h.device) + all_cond_type = torch.cat((text_cond_type, none_cond_type), dim=0) + h = h.repeat(2, 1, 1) + xf_out = xf_out.repeat(2, 1, 1) + emb = emb.repeat(2, 1) + src_mask = src_mask.repeat(2, 1, 1) + for module in self.temporal_decoder_blocks: + h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask, cond_type=all_cond_type) + out = self.out(h).view(2 * B, T, -1).contiguous() + out_text = out[:B].contiguous() + out_none = out[B:].contiguous() + coef_cfg = self.scale_func(int(timesteps[0])) + text_coef = coef_cfg['text_coef'] + none_coef = coef_cfg['none_coef'] + output = out_text * text_coef + out_none * none_coef + return output diff --git a/mogen/models/transformers/remodiffuse.py b/mogen/models/transformers/remodiffuse.py new file mode 100644 index 0000000000000000000000000000000000000000..564e79dd9966f566bf8bcd7660c03653280ed328 --- /dev/null +++ b/mogen/models/transformers/remodiffuse.py @@ -0,0 +1,541 @@ +import random +import clip +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch import Tensor +from typing import List, Dict, Optional, Union + +from mogen.models.utils.misc import zero_module + +from ..builder import SUBMODULES, build_attention +from .motion_transformer import MotionTransformer + + +class FFN(nn.Module): + """ + Feed-forward network (FFN) used in the transformer layers. + It consists of two linear layers with a GELU activation in between. + + Args: + latent_dim (int): Input dimension of the FFN. + ffn_dim (int): Hidden dimension of the FFN. + dropout (float): Dropout rate applied after activation. + """ + + def __init__(self, latent_dim: int, ffn_dim: int, dropout: float): + super().__init__() + self.linear1 = nn.Linear(latent_dim, ffn_dim) + self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) + self.activation = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x: Tensor, **kwargs) -> Tensor: + """ + Forward pass for the FFN. + + Args: + x (Tensor): Input tensor of shape (B, T, D). + + Returns: + Tensor: Output tensor after the FFN, of shape (B, T, D). + """ + y = self.linear2(self.dropout(self.activation(self.linear1(x)))) + y = x + y + return y + + +class EncoderLayer(nn.Module): + """ + Encoder layer consisting of self-attention and feed-forward network. + + Args: + sa_block_cfg (Optional[dict]): Configuration for the self-attention block. + ca_block_cfg (Optional[dict]): Configuration for the cross-attention block (if applicable). + ffn_cfg (dict): Configuration for the feed-forward network. + """ + + def __init__(self, sa_block_cfg: Optional[dict] = None, ca_block_cfg: Optional[dict] = None, ffn_cfg: dict = None): + super().__init__() + self.sa_block = build_attention(sa_block_cfg) + self.ffn = FFN(**ffn_cfg) + + def forward(self, **kwargs) -> Tensor: + """ + Forward pass for the encoder layer. + + Args: + kwargs: Dictionary containing the input tensor (x) and other related parameters. + + Returns: + Tensor: Output tensor after the encoder layer. + """ + if self.sa_block is not None: + x = self.sa_block(**kwargs) + kwargs.update({'x': x}) + if self.ffn is not None: + x = self.ffn(**kwargs) + return x + + +class RetrievalDatabase(nn.Module): + """ + Retrieval database for retrieving motions and text features based on given captions. + + Args: + num_retrieval (int): Number of retrievals for each caption. + topk (int): Number of top results to consider. + retrieval_file (str): Path to the retrieval file containing text, motion, and length data. + latent_dim (Optional[int]): Dimension of the latent space. + output_dim (Optional[int]): Output dimension of the retrieved features. + num_layers (Optional[int]): Number of layers in the text encoder. + num_motion_layers (Optional[int]): Number of layers in the motion encoder. + kinematic_coef (Optional[float]): Coefficient for scaling kinematic similarity. + max_seq_len (Optional[int]): Maximum sequence length. + num_heads (Optional[int]): Number of attention heads. + ff_size (Optional[int]): Feed-forward size for the transformer layers. + stride (Optional[int]): Stride for downsampling motion data. + sa_block_cfg (Optional[dict]): Configuration for the self-attention block. + ffn_cfg (Optional[dict]): Configuration for the feed-forward network. + dropout (Optional[float]): Dropout rate. + """ + + def __init__(self, + num_retrieval: int, + topk: int, + retrieval_file: str, + latent_dim: Optional[int] = 512, + output_dim: Optional[int] = 512, + num_layers: Optional[int] = 2, + num_motion_layers: Optional[int] = 4, + kinematic_coef: Optional[float] = 0.1, + max_seq_len: Optional[int] = 196, + num_heads: Optional[int] = 8, + ff_size: Optional[int] = 1024, + stride: Optional[int] = 4, + sa_block_cfg: Optional[dict] = None, + ffn_cfg: Optional[dict] = None, + dropout: Optional[float] = 0): + super().__init__() + self.num_retrieval = num_retrieval + self.topk = topk + self.latent_dim = latent_dim + self.stride = stride + self.kinematic_coef = kinematic_coef + self.num_layers = num_layers + self.num_motion_layers = num_motion_layers + self.max_seq_len = max_seq_len + + # Load data from the retrieval file + data = np.load(retrieval_file) + self.text_features = torch.Tensor(data['text_features']) + self.captions = data['captions'] + self.motions = data['motions'] + self.m_lengths = data['m_lengths'] + self.clip_seq_features = data['clip_seq_features'] + self.train_indexes = data.get('train_indexes', None) + self.test_indexes = data.get('test_indexes', None) + + self.latent_dim = latent_dim + self.output_dim = output_dim + self.motion_proj = nn.Linear(self.motions.shape[-1], self.latent_dim) + self.motion_pos_embedding = nn.Parameter( + torch.randn(max_seq_len, self.latent_dim)) + self.motion_encoder_blocks = nn.ModuleList() + + # Build motion encoder blocks + for i in range(num_motion_layers): + self.motion_encoder_blocks.append( + EncoderLayer(sa_block_cfg=sa_block_cfg, ffn_cfg=ffn_cfg)) + + # Transformer for encoding text + TransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation="gelu") + self.text_encoder = nn.TransformerEncoder(TransEncoderLayer, + num_layers=num_layers) + self.results = {} + + def extract_text_feature(self, text: str, clip_model: nn.Module, device: torch.device) -> Tensor: + """ + Extract text features from CLIP model. + + Args: + text (str): Input text caption. + clip_model (nn.Module): CLIP model for encoding the text. + device (torch.device): Device for computation. + + Returns: + Tensor: Extracted text features of shape (1, 512). + """ + text = clip.tokenize([text], truncate=True).to(device) + with torch.no_grad(): + text_features = clip_model.encode_text(text) + return text_features + + def encode_text(self, text: List[str], device: torch.device) -> Tensor: + """ + Encode text using the CLIP model's text encoder. + + Args: + text (List[str]): List of input text captions. + device (torch.device): Device for computation. + + Returns: + Tensor: Encoded text features of shape (B, T, D). + """ + with torch.no_grad(): + text = clip.tokenize(text, truncate=True).to(device) + x = self.clip.token_embedding(text).type(self.clip.dtype) + + x = x + self.clip.positional_embedding.type(self.clip.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.clip.transformer(x) + x = self.clip.ln_final(x).type(self.clip.dtype) + + # B, T, D + xf_out = x.permute(1, 0, 2) + return xf_out + + def retrieve(self, caption: str, length: int, clip_model: nn.Module, device: torch.device, idx: Optional[int] = None) -> List[int]: + """ + Retrieve motions and text features based on a given caption. + + Args: + caption (str): Input text caption. + length (int): Length of the corresponding motion sequence. + clip_model (nn.Module): CLIP model for encoding the text. + device (torch.device): Device for computation. + idx (Optional[int]): Index for retrieval (if provided). + + Returns: + List[int]: List of indexes for the retrieved motions. + """ + value = hash(caption) + if value in self.results: + return self.results[value] + text_feature = self.extract_text_feature(caption, clip_model, device) + + rel_length = torch.LongTensor(self.m_lengths).to(device) + rel_length = torch.abs(rel_length - length) + rel_length = rel_length / torch.clamp(rel_length, min=length) + semantic_score = F.cosine_similarity(self.text_features.to(device), + text_feature) + kinematic_score = torch.exp(-rel_length * self.kinematic_coef) + score = semantic_score * kinematic_score + indexes = torch.argsort(score, descending=True) + data = [] + cnt = 0 + for idx in indexes: + caption, m_length = self.captions[idx], self.m_lengths[idx] + if not self.training or m_length != length: + cnt += 1 + data.append(idx.item()) + if cnt == self.num_retrieval: + self.results[value] = data + return data + assert False + + def generate_src_mask(self, T: int, length: List[int]) -> Tensor: + """ + Generate source mask for the motion sequences based on the motion lengths. + + Args: + T (int): Maximum sequence length. + length (List[int]): List of motion lengths for each sample. + + Returns: + Tensor: A binary mask tensor of shape (B, T), where `B` is the batch size, + and `T` is the maximum sequence length. Mask values are 1 for valid positions + and 0 for padded positions. + """ + B = len(length) + src_mask = torch.ones(B, T) + for i in range(B): + for j in range(length[i], T): + src_mask[i, j] = 0 + return src_mask + + def forward(self, captions: List[str], lengths: List[int], clip_model: nn.Module, device: torch.device, idx: Optional[List[int]] = None) -> Dict[str, Tensor]: + """ + Forward pass for retrieving motion sequences and text features. + + Args: + captions (List[str]): List of input text captions. + lengths (List[int]): List of corresponding motion lengths. + clip_model (nn.Module): CLIP model for encoding the text. + device (torch.device): Device for computation. + idx (Optional[List[int]]): Optional list of indices for retrieval. + + Returns: + Dict[str, Tensor]: Dictionary containing retrieved text and motion features. + - re_text: Retrieved text features of shape (B, num_retrieval, T, D). + - re_motion: Retrieved motion features of shape (B, num_retrieval, T, D). + - re_mask: Source mask for the retrieved motion of shape (B, num_retrieval, T). + - raw_motion: Raw motion features of shape (B, T, motion_dim). + - raw_motion_length: Motion sequence lengths (before any stride). + - raw_motion_mask: Raw binary mask for valid motion positions of shape (B, T). + """ + B = len(captions) + all_indexes = [] + for b_ix in range(B): + length = int(lengths[b_ix]) + if idx is None: + batch_indexes = self.retrieve(captions[b_ix], length, clip_model, device) + else: + batch_indexes = self.retrieve(captions[b_ix], length, clip_model, device, idx[b_ix]) + all_indexes.extend(batch_indexes) + + all_indexes = np.array(all_indexes) + all_motions = torch.Tensor(self.motions[all_indexes]).to(device) + all_m_lengths = torch.Tensor(self.m_lengths[all_indexes]).long() + + # Generate masks and positional encodings + T = all_motions.shape[1] + src_mask = self.generate_src_mask(T, all_m_lengths).to(device) + raw_src_mask = src_mask.clone() + re_motion = self.motion_proj(all_motions) + self.motion_pos_embedding.unsqueeze(0) + + for module in self.motion_encoder_blocks: + re_motion = module(x=re_motion, src_mask=src_mask.unsqueeze(-1)) + + re_motion = re_motion.view(B, self.num_retrieval, T, -1).contiguous() + re_motion = re_motion[:, :, ::self.stride, :].contiguous() # Apply stride + src_mask = src_mask[:, ::self.stride].contiguous() + src_mask = src_mask.view(B, self.num_retrieval, -1).contiguous() + + # Process text sequences + T = 77 # CLIP's max token length + all_text_seq_features = torch.Tensor(self.clip_seq_features[all_indexes]).to(device) + all_text_seq_features = all_text_seq_features.permute(1, 0, 2) + re_text = self.text_encoder(all_text_seq_features) + re_text = re_text.permute(1, 0, 2) + re_text = re_text.view(B, self.num_retrieval, T, -1).contiguous() + re_text = re_text[:, :, -1:, :].contiguous() # Use the last token only for each sequence + + re_dict = { + 're_text': re_text, + 're_motion': re_motion, + 're_mask': src_mask, + 'raw_motion': all_motions, + 'raw_motion_length': all_m_lengths, + 'raw_motion_mask': raw_src_mask + } + return re_dict + + +@SUBMODULES.register_module() +class ReMoDiffuseTransformer(MotionTransformer): + """ + Transformer model for motion retrieval and diffusion. + + Args: + retrieval_cfg (dict): Configuration for the retrieval database. + scale_func_cfg (dict): Configuration for scaling functions. + kwargs: Additional arguments for the base DiffusionTransformer. + """ + + def __init__(self, retrieval_cfg: dict, scale_func_cfg: dict, **kwargs): + super().__init__(**kwargs) + self.database = RetrievalDatabase(**retrieval_cfg) + self.scale_func_cfg = scale_func_cfg + + def scale_func(self, timestep: int) -> Dict[str, float]: + """ + Scale function for adjusting the guidance between text and retrieval. + + Args: + timestep (int): Current diffusion timestep. + + Returns: + Dict[str, float]: Scaling coefficients for different guidance types. + - both_coef: Coefficient for both text and retrieval guidance. + - text_coef: Coefficient for text-only guidance. + - retr_coef: Coefficient for retrieval-only guidance. + - none_coef: Coefficient for no guidance. + """ + coarse_scale = self.scale_func_cfg['coarse_scale'] + w = (1 - (1000 - timestep) / 1000) * coarse_scale + 1 + if timestep > 100: + if random.randint(0, 1) == 0: + output = { + 'both_coef': w, + 'text_coef': 0, + 'retr_coef': 1 - w, + 'none_coef': 0 + } + else: + output = { + 'both_coef': 0, + 'text_coef': w, + 'retr_coef': 0, + 'none_coef': 1 - w + } + else: + both_coef = self.scale_func_cfg['both_coef'] + text_coef = self.scale_func_cfg['text_coef'] + retr_coef = self.scale_func_cfg['retr_coef'] + none_coef = 1 - both_coef - text_coef - retr_coef + output = { + 'both_coef': both_coef, + 'text_coef': text_coef, + 'retr_coef': retr_coef, + 'none_coef': none_coef + } + return output + + def get_precompute_condition(self, + text: Optional[str] = None, + motion_length: Optional[Tensor] = None, + xf_out: Optional[Tensor] = None, + re_dict: Optional[Dict] = None, + device: Optional[torch.device] = None, + sample_idx: Optional[Tensor] = None, + clip_feat: Optional[Tensor] = None, + **kwargs) -> Dict[str, Union[Tensor, Dict]]: + """ + Precompute conditions for both text and retrieval-guided diffusion. + + Args: + text (Optional[str]): Input text string for guidance. + motion_length (Optional[Tensor]): Lengths of the motion sequences. + xf_out (Optional[Tensor]): Encoded text feature (if precomputed). + re_dict (Optional[Dict]): Dictionary of retrieval results (if precomputed). + device (Optional[torch.device]): Device to perform computation on. + sample_idx (Optional[Tensor]): Sample indices for retrieval. + clip_feat (Optional[Tensor]): Clip features (if used). + + Returns: + Dict[str, Union[Tensor, Dict]]: Dictionary containing encoded features and retrieval results. + """ + if xf_out is None: + xf_out = self.encode_text(text, clip_feat, device) + output = {'xf_out': xf_out} + if re_dict is None: + re_dict = self.database(text, motion_length, self.clip, device, idx=sample_idx) + output['re_dict'] = re_dict + return output + + def post_process(self, motion: Tensor) -> Tensor: + """ + Post-process the generated motion by normalizing or un-normalizing it. + + Args: + motion (Tensor): Generated motion data. + + Returns: + Tensor: Post-processed motion data. + """ + if self.post_process_cfg is not None: + if self.post_process_cfg.get("unnormalized_infer", False): + mean = torch.from_numpy(np.load(self.post_process_cfg['mean_path'])).type_as(motion) + std = torch.from_numpy(np.load(self.post_process_cfg['std_path'])).type_as(motion) + motion = motion * std + mean + return motion + + def forward_train(self, + h: Tensor, + src_mask: Tensor, + emb: Tensor, + xf_out: Optional[Tensor] = None, + re_dict: Optional[Dict] = None, + **kwargs) -> Tensor: + """ + Forward training pass for motion retrieval and diffusion model. + + Args: + h (Tensor): Input motion features of shape (B, T, D). + src_mask (Tensor): Mask for the motion data of shape (B, T, 1). + emb (Tensor): Embedding tensor for timesteps. + xf_out (Optional[Tensor]): Precomputed text features. + re_dict (Optional[Dict]): Dictionary of retrieval features. + + Returns: + Tensor: Output motion data of shape (B, T, D). + """ + B, T = h.shape[0], h.shape[1] + cond_type = torch.randint(0, 100, size=(B, 1, 1)).to(h.device) + for module in self.temporal_decoder_blocks: + h = module(x=h, + xf=xf_out, + emb=emb, + src_mask=src_mask, + cond_type=cond_type, + re_dict=re_dict) + + output = self.out(h).view(B, T, -1).contiguous() + return output + + def forward_test(self, + h: Tensor, + src_mask: Tensor, + emb: Tensor, + xf_out: Optional[Tensor] = None, + re_dict: Optional[Dict] = None, + timesteps: Optional[Tensor] = None, + **kwargs) -> Tensor: + """ + Forward testing pass for motion retrieval and diffusion model. This method handles + multiple conditional types such as both text and retrieval-based guidance. + + Args: + h (Tensor): Input motion features of shape (B, T, D). + src_mask (Tensor): Mask for the motion data of shape (B, T, 1). + emb (Tensor): Embedding tensor for timesteps. + xf_out (Optional[Tensor]): Precomputed text features. + re_dict (Optional[Dict]): Dictionary of retrieval features. + timesteps (Optional[Tensor]): Tensor containing current timesteps in the diffusion process. + + Returns: + Tensor: Output motion data after applying multiple conditional types, of shape (B, T, D). + """ + B, T = h.shape[0], h.shape[1] + + # Define condition types for different guidance types + both_cond_type = torch.zeros(B, 1, 1).to(h.device) + 99 + text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1 + retr_cond_type = torch.zeros(B, 1, 1).to(h.device) + 10 + none_cond_type = torch.zeros(B, 1, 1).to(h.device) + + # Concatenate all conditional types and repeat inputs for different guidance modes + all_cond_type = torch.cat((both_cond_type, text_cond_type, retr_cond_type, none_cond_type), dim=0) + h = h.repeat(4, 1, 1) + xf_out = xf_out.repeat(4, 1, 1) + emb = emb.repeat(4, 1) + src_mask = src_mask.repeat(4, 1, 1) + + # Repeat retrieval features if necessary + if re_dict['re_motion'].shape[0] != h.shape[0]: + re_dict['re_motion'] = re_dict['re_motion'].repeat(4, 1, 1, 1) + re_dict['re_text'] = re_dict['re_text'].repeat(4, 1, 1, 1) + re_dict['re_mask'] = re_dict['re_mask'].repeat(4, 1, 1) + + # Pass through the temporal decoder blocks + for module in self.temporal_decoder_blocks: + h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask, cond_type=all_cond_type, re_dict=re_dict) + + # Retrieve output features and handle different guidance coefficients + out = self.out(h).view(4 * B, T, -1).contiguous() + out_both = out[:B].contiguous() + out_text = out[B:2 * B].contiguous() + out_retr = out[2 * B:3 * B].contiguous() + out_none = out[3 * B:].contiguous() + + # Apply scaling coefficients based on the timestep + coef_cfg = self.scale_func(int(timesteps[0])) + both_coef = coef_cfg['both_coef'] + text_coef = coef_cfg['text_coef'] + retr_coef = coef_cfg['retr_coef'] + none_coef = coef_cfg['none_coef'] + + # Compute the final output by blending the different guidance outputs + output = out_both * both_coef + output += out_text * text_coef + output += out_retr * retr_coef + output += out_none * none_coef + + return output + diff --git a/mogen/models/utils/__init__.py b/mogen/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mogen/models/utils/gaussian_diffusion.py b/mogen/models/utils/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..90ed8807bacd3c212b0b0fdef2a1434bfd943394 --- /dev/null +++ b/mogen/models/utils/gaussian_diffusion.py @@ -0,0 +1,1369 @@ +# flake8: noqa +""" +This code is borrowed from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/gaussian_diffusion.py +""" + +import enum +import math +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + elif name == "pretrain": + return PretrainSampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size, ), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class PretrainSampler(ScheduleSampler): + + def __init__(self, diffusion): + self.diffusion = diffusion + # self._weights = np.ones([diffusion.num_timesteps]) + t = np.arange(diffusion.num_timesteps) + self._weights = np.cos(t / 2000 * np.pi) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [ + th.zeros(max_bs).to(local_ts) for bs in batch_sizes + ] + loss_batches = [ + th.zeros(max_bs).to(local_losses) for bs in batch_sizes + ] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) + for x in y[:bs] + ] + losses = [ + x.item() for y, bs in zip(loss_batches, batch_sizes) + for x in y[:bs] + ] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + + ((mean1 - mean2)**2) * th.exp(-logvar2)) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * ( + 1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, + th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace(beta_start, + beta_end, + num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2)**2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps, ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - + 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) / + (1.0 - self.alphas_cumprod)) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:])) + self.posterior_mean_coef1 = (betas * + np.sqrt(self.alphas_cumprod_prev) / + (1.0 - self.alphas_cumprod)) + self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) * + np.sqrt(alphas) / + (1.0 - self.alphas_cumprod)) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, + x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, + t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * + x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * + x_t) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, + x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == + posterior_log_variance_clipped.shape[0] == x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B, ) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE + ]: + assert model_output.shape == (B, 2 * C, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log( + np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, + x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, + xprev=model_output)) + model_mean = model_output + elif self.model_mean_type in [ + ModelMeanType.START_X, ModelMeanType.EPSILON + ]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, + eps=model_output)) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t) + else: + raise NotImplementedError(self.model_mean_type) + + assert (model_mean.shape == model_log_variance.shape == + pred_xstart.shape == x.shape) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, + x_t.shape) * x_t - + _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, + x_t.shape) * eps) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) + * xprev - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + x_t.shape) * x_t) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, + x_t.shape) * x_t - + pred_xstart) / _extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = (p_mean_var["mean"].float() + + p_mean_var["variance"] * gradient.float()) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, self._scale_timesteps(t), **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + pre_seq=None, + transl_req=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + # concat seq + if pre_seq is not None: + T = pre_seq.shape[1] + noise = th.randn_like(pre_seq) + x_t = self.q_sample(pre_seq, t, noise=noise) + x[:, :T, :] = x_t + + if transl_req is not None: + for item in transl_req: + noise = th.randn(2).type_as(x) + transl = th.Tensor(item[1:]).type_as(x) + x_t = self.q_sample(transl, t, noise=noise) + x[:, :2, item[0]] = x_t + + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, + out, + x, + t, + model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp( + 0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + pre_seq=None, + transl_req=None, + progress=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + pre_seq=pre_seq, + transl_req=transl_req, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + pre_seq=None, + transl_req=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample(model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + pre_seq=pre_seq, + transl_req=transl_req) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + pre_seq=None, + gt_motion=None, + context_mask=None, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + if pre_seq is not None: + T = pre_seq.shape[1] + noise = th.randn_like(pre_seq) + x_t = self.q_sample(pre_seq, t, noise=noise) + x[:, :T, :] = x_t + if context_mask is not None: + B, T = gt_motion.shape[:2] + noise = th.randn_like(gt_motion) * 0 + x_t = self.q_sample(gt_motion, t, noise=noise) + context_mask = context_mask.view(B, T, 1) + x = x_t * context_mask + (1 - context_mask) * x + x = x.float() + + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, + out, + x, + t, + model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, + x.shape) + sigma = (eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * + th.sqrt(1 - alpha_bar / alpha_bar_prev)) + # Equation 12. + noise = th.randn_like(x) + mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps) + nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + pre_seq=None, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) + * x - out["pred_xstart"]) / _extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, + x.shape) + + # Equation 12. reversed + mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps) + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + pre_seq=None, + gt_motion=None, + context_mask=None, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + pre_seq=pre_seq, + gt_motion=gt_motion, + context_mask=context_mask + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + pre_seq=None, + gt_motion=None, + context_mask=None, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + pre_seq=pre_seq, + gt_motion=gt_motion, + context_mask=context_mask + ) + yield out + img = out["sample"] + + def _vb_terms_bpd(self, + model, + x_start, + x_t, + t, + clip_denoised=True, + model_kwargs=None): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, + x_t, + t, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], + out["log_variance"]) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, + model, + x_start, + t, + model_kwargs=None, + noise=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, + C, + dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], + dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: + self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, + t=t)[0], + ModelMeanType.START_X: + x_start, + ModelMeanType.EPSILON: + noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat( + (target - model_output)**2).view(-1, 1).mean(-1) + # if "vb" in terms: + # terms["loss"] = terms["mse"] + terms["vb"] + # else: + # terms["loss"] = terms["mse"] + terms["target"] = target + terms["pred"] = model_output + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, + device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, + logvar1=qt_log_variance, + mean2=0.0, + logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, + model, + x_start, + clip_denoised=True, + model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start)**2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, + out["pred_xstart"]) + mse.append(mean_flat((eps - noise)**2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim"):]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + elif section_counts == "fast27": + # steps = space_timesteps(num_timesteps, "10,10,3,2,2") + # steps = space_timesteps(num_timesteps, "30,30,16,12,12") + steps = space_timesteps(num_timesteps, "15,15,8,6,6") + + # Help reduce DDIM artifacts from noisiest timesteps. + steps.remove(num_timesteps - 1) + steps.add(num_timesteps - 3) + return steps + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): + return super().p_mean_variance(self._wrap_model(model), *args, + **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, + **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, + **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.original_num_steps) + + +class _WrappedModel: + + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, + device=ts.device, + dtype=ts.dtype) + new_ts = map_tensor[ts] + return self.model(x, new_ts, **kwargs) + + +def build_diffusion(cfg: dict) -> GaussianDiffusion: + """Build diffusion model based on the configuration. + + Args: + cfg (dict): Configuration dictionary containing the diffusion parameters. + + Returns: + GaussianDiffusion: The built diffusion model. + """ + beta_scheduler = cfg['beta_scheduler'] + diffusion_steps = cfg['diffusion_steps'] + betas = get_named_beta_schedule(beta_scheduler, diffusion_steps) + + model_mean_type = { + 'start_x': ModelMeanType.START_X, + 'previous_x': ModelMeanType.PREVIOUS_X, + 'epsilon': ModelMeanType.EPSILON + }[cfg['model_mean_type']] + + model_var_type = { + 'learned': ModelVarType.LEARNED, + 'fixed_small': ModelVarType.FIXED_SMALL, + 'fixed_large': ModelVarType.FIXED_LARGE, + 'learned_range': ModelVarType.LEARNED_RANGE + }[cfg['model_var_type']] + + if cfg.get('respace', None) is not None: + diffusion = SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, cfg['respace']), + betas=betas, + model_mean_type=model_mean_type, + model_var_type=model_var_type, + loss_type=LossType.MSE) + else: + diffusion = GaussianDiffusion( + betas=betas, + model_mean_type=model_mean_type, + model_var_type=model_var_type, + loss_type=LossType.MSE) + + return diffusion \ No newline at end of file diff --git a/mogen/models/utils/imagebind_wrapper.py b/mogen/models/utils/imagebind_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..80671a46a3ef528d77ce251f8bffb504959245e3 --- /dev/null +++ b/mogen/models/utils/imagebind_wrapper.py @@ -0,0 +1,95 @@ +from imagebind import data +import torch +from imagebind.models import imagebind_model +from imagebind.models.imagebind_model import ModalityType +import os +import numpy as np +from tqdm import tqdm +import json +import pickle + + + + +class FeatureExtractor(imagebind_model.ImageBindModel): + + def forward(self, inputs): + outputs = {} + for modality_key, modality_value in inputs.items(): + reduce_list = ( + modality_value.ndim >= 5 + ) # Audio and Video inputs consist of multiple clips + if reduce_list: + B, S = modality_value.shape[:2] + modality_value = modality_value.reshape( + B * S, *modality_value.shape[2:] + ) + + if modality_value is not None: + modality_value = self.modality_preprocessors[modality_key]( + **{modality_key: modality_value} + ) + trunk_inputs = modality_value["trunk"] + head_inputs = modality_value["head"] + modality_value = self.modality_trunks[modality_key](**trunk_inputs) + word_feat = modality_value + seq_feat = self.modality_heads[modality_key]( + word_feat, **head_inputs + ) + seq_feat = self.modality_postprocessors[modality_key]( + seq_feat + ) + return word_feat, seq_feat + + +def imagebind_huge(pretrained=False): + model = FeatureExtractor( + vision_embed_dim=1280, + vision_num_blocks=32, + vision_num_heads=16, + text_embed_dim=1024, + text_num_blocks=24, + text_num_heads=16, + out_embed_dim=1024, + audio_drop_path=0.1, + imu_drop_path=0.7, + ) + + if pretrained: + file_path = os.path.abspath(os.path.dirname(__file__)) + ckpt_dir = os.path.join(file_path, '../../../data/motionverse/pretrained') + ckpt_path = os.path.join(ckpt_dir, 'imagebind_huge.pth') + if not os.path.exists(ckpt_path): + print( + "Downloading imagebind weights to motionverse/pretrained/imagebind_huge.pth ..." + ) + os.makedirs(ckpt_dir, exist_ok=True) + torch.hub.download_url_to_file( + "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth", + ckpt_path, + progress=True, + ) + + model.load_state_dict(torch.load(ckpt_path)) + return model + + +def extract_text_feature(text, model, device): + text_list = text + inputs = { + ModalityType.TEXT: data.load_and_transform_text(text_list, device), + } + with torch.no_grad(): + text_word_feat, text_seq_feat = model(inputs) + return text_word_feat, text_seq_feat + + +def extract_audio_feature(audio_paths, model, device): + inputs = { + ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device) + } + inputs['audio'] = inputs['audio'][:, :1] + with torch.no_grad(): + audio_word_feat, audio_seq_feat = model(inputs) + return audio_word_feat, audio_seq_feat + diff --git a/mogen/models/utils/mae.py b/mogen/models/utils/mae.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe6f9007315ac602f258147f2de1537eba37e39 --- /dev/null +++ b/mogen/models/utils/mae.py @@ -0,0 +1,24 @@ +import torch + + +def create_mask_sequence(mask_cfg, seq_len): + type_name = mask_cfg['type'] + if type_name == 'raster order': + num_tokens = mask_cfg['num_tokens'] + idx_list = [] + all_idx = torch.arange(seq_len) + for i in range(0, seq_len, num_tokens): + idx_list.append(all_idx[i: i + num_tokens]) + return idx_list + elif type_name == 'random order': + num_tokens = mask_cfg['num_tokens'] + idx_list = [] + all_idx = torch.randperm(seq_len) + for i in range(0, seq_len, num_tokens): + idx_list.append(all_idx[i: i + num_tokens]) + return idx_list + elif type_name == 'single': + idx_list = [torch.arange(seq_len)] + return idx_list + else: + raise NotImplementedError() diff --git a/mogen/models/utils/mask_helper.py b/mogen/models/utils/mask_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..1e96b5465a01bedd9b730936c2eed5fd0fbf69ac --- /dev/null +++ b/mogen/models/utils/mask_helper.py @@ -0,0 +1,66 @@ +import torch + +def get_tomato_slice(idx): + if idx == 0: + result = [0, 1, 2, 3, 463, 464, 465] + else: + result = [ + 4 + (idx - 1) * 3, + 4 + (idx - 1) * 3 + 1, + 4 + (idx - 1) * 3 + 2, + 157 + (idx - 1) * 6, + 157 + (idx - 1) * 6 + 1, + 157 + (idx - 1) * 6 + 2, + 157 + (idx - 1) * 6 + 3, + 157 + (idx - 1) * 6 + 4, + 157 + (idx - 1) * 6 + 5, + 463 + idx * 3, + 463 + idx * 3 + 1, + 463 + idx * 3 + 2, + ] + return result + + +def get_part_slice(idx_list, func): + result = [] + for idx in idx_list: + result.extend(func(idx)) + return result + + +def expand_mask_to_all(mask, body_scale, hand_scale, face_scale): + func = get_tomato_slice + root_slice = get_part_slice([0], func) + head_slice = get_part_slice([12, 15], func) + stem_slice = get_part_slice([3, 6, 9], func) + larm_slice = get_part_slice([14, 17, 19, 21], func) + rarm_slice = get_part_slice([13, 16, 18, 20], func) + lleg_slice = get_part_slice([2, 5, 8, 11], func) + rleg_slice = get_part_slice([1, 4, 7, 10], func) + lhnd_slice = get_part_slice(range(22, 37), func) + rhnd_slice = get_part_slice(range(37, 52), func) + face_slice = range(619, 669) + B, T = mask.shape[0], mask.shape[1] + mask = mask.view(B, T, -1) + all_mask = torch.zeros(B, T, 669).type_as(mask) + all_mask[:, :, root_slice] = mask[:, :, 0].unsqueeze(-1).repeat(1, 1, len(root_slice)) + all_mask[:, :, head_slice] = mask[:, :, 1].unsqueeze(-1).repeat(1, 1, len(head_slice)) + all_mask[:, :, stem_slice] = mask[:, :, 2].unsqueeze(-1).repeat(1, 1, len(stem_slice)) + all_mask[:, :, larm_slice] = mask[:, :, 3].unsqueeze(-1).repeat(1, 1, len(larm_slice)) + all_mask[:, :, rarm_slice] = mask[:, :, 4].unsqueeze(-1).repeat(1, 1, len(rarm_slice)) + all_mask[:, :, lleg_slice] = mask[:, :, 5].unsqueeze(-1).repeat(1, 1, len(lleg_slice)) + all_mask[:, :, rleg_slice] = mask[:, :, 6].unsqueeze(-1).repeat(1, 1, len(rleg_slice)) + all_mask[:, :, lhnd_slice] = mask[:, :, 7].unsqueeze(-1).repeat(1, 1, len(lhnd_slice)) + all_mask[:, :, rhnd_slice] = mask[:, :, 8].unsqueeze(-1).repeat(1, 1, len(rhnd_slice)) + all_mask[:, :, face_slice] = mask[:, :, 9].unsqueeze(-1).repeat(1, 1, len(face_slice)) + all_mask[:, :, root_slice] *= body_scale + all_mask[:, :, head_slice] *= body_scale + all_mask[:, :, stem_slice] *= body_scale + all_mask[:, :, larm_slice] *= body_scale + all_mask[:, :, rarm_slice] *= body_scale + all_mask[:, :, lleg_slice] *= body_scale + all_mask[:, :, rleg_slice] *= body_scale + all_mask[:, :, lhnd_slice] *= hand_scale + all_mask[:, :, rhnd_slice] *= hand_scale + all_mask[:, :, face_slice] *= face_scale + return all_mask diff --git a/mogen/models/utils/misc.py b/mogen/models/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..d97faeea8c760b8bc15002cb26b514865a86a7e0 --- /dev/null +++ b/mogen/models/utils/misc.py @@ -0,0 +1,23 @@ +def set_requires_grad(nets, requires_grad=False): + """Set requies_grad for all the networks. + + Args: + nets (nn.Module | list[nn.Module]): A list of networks or a single + network. + requires_grad (bool): Whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module diff --git a/mogen/models/utils/mlp.py b/mogen/models/utils/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..731117e481e987eb2c397332b88a60088591e7ec --- /dev/null +++ b/mogen/models/utils/mlp.py @@ -0,0 +1,13 @@ +import torch.nn as nn + + +def build_MLP(dim_list, latent_dim): + model_list = [] + prev = dim_list[0] + for cur in dim_list[1:]: + model_list.append(nn.Linear(prev, cur)) + model_list.append(nn.GELU()) + prev = cur + model_list.append(nn.Linear(prev, latent_dim)) + model = nn.Sequential(*model_list) + return model diff --git a/mogen/models/utils/position_encoding.py b/mogen/models/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..16c44e70050efb4b3f590f845e6fb5f9f2d6e69a --- /dev/null +++ b/mogen/models/utils/position_encoding.py @@ -0,0 +1,60 @@ +import math + +import numpy as np +import torch +import torch.nn as nn + + +class SinusoidalPositionalEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(SinusoidalPositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.arange(0, d_model, 2).float() + div_term = div_term * (-np.log(10000.0) / d_model) + div_term = torch.exp(div_term) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + # T, 1, D + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:x.shape[0]] + return self.dropout(x) + + +class LearnedPositionalEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(LearnedPositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + self.pe = nn.Parameter(torch.randn(max_len, 1, d_model)) + + def forward(self, x): + x = x + self.pe[:x.shape[0]] + return self.dropout(x) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + idx = torch.arange(start=0, end=half, dtype=torch.float32) + freqs = torch.exp(-math.log(max_period) * idx / + half).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding diff --git a/mogen/models/utils/stylization_block.py b/mogen/models/utils/stylization_block.py new file mode 100644 index 0000000000000000000000000000000000000000..a221a5d8122cb6e145377b705dbd32410dc6bc94 --- /dev/null +++ b/mogen/models/utils/stylization_block.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class StylizationBlock(nn.Module): + + def __init__(self, latent_dim, time_embed_dim, dropout): + super().__init__() + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear(time_embed_dim, 2 * latent_dim), + ) + self.norm = nn.LayerNorm(latent_dim) + self.out_layers = nn.Sequential( + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Linear(latent_dim, latent_dim)), + ) + + def forward(self, h, emb): + """ + h: B, T, D + emb: B, D + """ + # B, 1, 2D + emb_out = self.emb_layers(emb).unsqueeze(1) + # scale: B, 1, D / shift: B, 1, D + scale, shift = torch.chunk(emb_out, 2, dim=2) + h = self.norm(h) * (1 + scale) + shift + h = self.out_layers(h) + return h diff --git a/mogen/models/utils/word_vectorizer.py b/mogen/models/utils/word_vectorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e3c8192275d08a309133fd2b1ea91b3f6df8b1 --- /dev/null +++ b/mogen/models/utils/word_vectorizer.py @@ -0,0 +1,88 @@ +import pickle +from os.path import join as pjoin + +import numpy as np + +POS_enumerator = { + 'VERB': 0, + 'NOUN': 1, + 'DET': 2, + 'ADP': 3, + 'NUM': 4, + 'AUX': 5, + 'PRON': 6, + 'ADJ': 7, + 'ADV': 8, + 'Loc_VIP': 9, + 'Body_VIP': 10, + 'Obj_VIP': 11, + 'Act_VIP': 12, + 'Desc_VIP': 13, + 'OTHER': 14, +} + +Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', + 'forward', 'back', 'backward', 'up', 'down', 'straight', 'curve') + +Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', + 'waist', 'eye', 'knee', 'shoulder', 'thigh') + +Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', + 'handrail', 'baseball', 'basketball') + +Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', + 'throw', 'hop', 'dance', 'jump', 'turn', 'stumble', 'dance', + 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', + 'stroll', 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', + 'lean', 'rotate', 'spin', 'spread', 'climb') + +Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', + 'happy', 'angry', 'sad', 'happily', 'angrily', 'sadly') + +VIP_dict = { + 'Loc_VIP': Loc_list, + 'Body_VIP': Body_list, + 'Obj_VIP': Obj_List, + 'Act_VIP': Act_list, + 'Desc_VIP': Desc_list, +} + + +class WordVectorizer(object): + + def __init__(self, meta_root, prefix): + vectors = np.load(pjoin(meta_root, '%s_data.npy' % prefix)) + words = pickle.load( + open(pjoin(meta_root, '%s_words.pkl' % prefix), 'rb')) + word2idx = pickle.load( + open(pjoin(meta_root, '%s_idx.pkl' % prefix), 'rb')) + self.word2vec = {w: vectors[word2idx[w]] for w in words} + + def _get_pos_ohot(self, pos): + pos_vec = np.zeros(len(POS_enumerator)) + if pos in POS_enumerator: + pos_vec[POS_enumerator[pos]] = 1 + else: + pos_vec[POS_enumerator['OTHER']] = 1 + return pos_vec + + def __len__(self): + return len(self.word2vec) + + def __getitem__(self, item): + word, pos = item.split('/') + if word in self.word2vec: + word_vec = self.word2vec[word] + vip_pos = None + for key, values in VIP_dict.items(): + if word in values: + vip_pos = key + break + if vip_pos is not None: + pos_vec = self._get_pos_ohot(vip_pos) + else: + pos_vec = self._get_pos_ohot(pos) + else: + word_vec = self.word2vec['unk'] + pos_vec = self._get_pos_ohot('OTHER') + return word_vec, pos_vec diff --git a/mogen/utils/__init__.py b/mogen/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84de592ef57caa2bf002d7e730759712b646ce38 --- /dev/null +++ b/mogen/utils/__init__.py @@ -0,0 +1,13 @@ +from mogen.utils.collect_env import collect_env +from mogen.utils.dist_utils import DistOptimizerHook, allreduce_grads +from mogen.utils.logger import get_root_logger +from mogen.utils.misc import multi_apply, torch_to_numpy +from mogen.utils.path_utils import (Existence, check_input_path, + check_path_existence, check_path_suffix, + prepare_output_path) + +__all__ = [ + 'collect_env', 'DistOptimizerHook', 'allreduce_grads', 'get_root_logger', + 'multi_apply', 'torch_to_numpy', 'Existence', 'check_input_path', + 'check_path_existence', 'check_path_suffix', 'prepare_output_path' +] diff --git a/mogen/utils/collect_env.py b/mogen/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..b2eb1b5dc861605b0d38be74cb668a9fa110877d --- /dev/null +++ b/mogen/utils/collect_env.py @@ -0,0 +1,16 @@ +from mmcv.utils import collect_env as collect_base_env +from mmcv.utils import get_git_hash + +import mogen + + +def collect_env(): + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['mogen'] = mogen.__version__ + '+' + get_git_hash()[:7] + return env_info + + +if __name__ == '__main__': + for name, val in collect_env().items(): + print(f'{name}: {val}') diff --git a/mogen/utils/dist_utils.py b/mogen/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c914b22a9be62b3bde00f5ec352fdcf9d4145e1 --- /dev/null +++ b/mogen/utils/dist_utils.py @@ -0,0 +1,56 @@ +from collections import OrderedDict + +import torch.distributed as dist +from mmcv.runner import OptimizerHook +from torch._utils import (_flatten_dense_tensors, _take_tensors, + _unflatten_dense_tensors) + + +def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): + if bucket_size_mb > 0: + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + buckets = _take_tensors(tensors, bucket_size_bytes) + else: + buckets = OrderedDict() + for tensor in tensors: + tp = tensor.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(tensor) + buckets = buckets.values() + + for bucket in buckets: + flat_tensors = _flatten_dense_tensors(bucket) + dist.all_reduce(flat_tensors) + flat_tensors.div_(world_size) + for tensor, synced in zip( + bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + tensor.copy_(synced) + + +def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): + grads = [ + param.grad.data for param in params + if param.requires_grad and param.grad is not None + ] + world_size = dist.get_world_size() + if coalesce: + _allreduce_coalesced(grads, world_size, bucket_size_mb) + else: + for tensor in grads: + dist.all_reduce(tensor.div_(world_size)) + + +class DistOptimizerHook(OptimizerHook): + + def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + + def after_train_iter(self, runner): + runner.optimizer.zero_grad() + runner.outputs['loss'].backward() + if self.grad_clip is not None: + self.clip_grads(runner.model.parameters()) + runner.optimizer.step() diff --git a/mogen/utils/logger.py b/mogen/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ce105bfb3896e67bbf8c7d1c82a655840dda52 --- /dev/null +++ b/mogen/utils/logger.py @@ -0,0 +1,7 @@ +import logging + +from mmcv.utils import get_logger + + +def get_root_logger(log_file=None, log_level=logging.INFO): + return get_logger('mogen', log_file, log_level) diff --git a/mogen/utils/misc.py b/mogen/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..ae2a036e5fdfa8c880e8bf4e6ccd350f757a2b8a --- /dev/null +++ b/mogen/utils/misc.py @@ -0,0 +1,14 @@ +from functools import partial + +import torch + + +def multi_apply(func, *args, **kwargs): + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def torch_to_numpy(x): + assert isinstance(x, torch.Tensor) + return x.detach().cpu().numpy() diff --git a/mogen/utils/path_utils.py b/mogen/utils/path_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..36fb7f69e79d691a81eb5b6faad592e961878bea --- /dev/null +++ b/mogen/utils/path_utils.py @@ -0,0 +1,232 @@ +import os +import warnings +from enum import Enum +from pathlib import Path +from typing import List, Union + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + + +def check_path_suffix(path_str: str, + allowed_suffix: Union[str, List[str]] = '') -> bool: + """Check whether the suffix of the path is allowed. + + Args: + path_str (str): + Path to check. + allowed_suffix (List[str], optional): + What extension names are allowed. + Offer a list like ['.jpg', ',jpeg']. + When it's [], all will be received. + Use [''] then directory is allowed. + Defaults to []. + + Returns: + bool: + True: suffix test passed + False: suffix test failed + """ + if isinstance(allowed_suffix, str): + allowed_suffix = [allowed_suffix] + pathinfo = Path(path_str) + suffix = pathinfo.suffix.lower() + if len(allowed_suffix) == 0: + return True + if pathinfo.is_dir(): + if '' in allowed_suffix: + return True + else: + return False + else: + for index, tmp_suffix in enumerate(allowed_suffix): + if not tmp_suffix.startswith('.'): + tmp_suffix = '.' + tmp_suffix + allowed_suffix[index] = tmp_suffix.lower() + if suffix in allowed_suffix: + return True + else: + return False + + +class Existence(Enum): + """State of file existence.""" + FileExist = 0 + DirectoryExistEmpty = 1 + DirectoryExistNotEmpty = 2 + MissingParent = 3 + DirectoryNotExist = 4 + FileNotExist = 5 + + +def check_path_existence( + path_str: str, + path_type: Literal['file', 'dir', 'auto'] = 'auto', +) -> Existence: + """Check whether a file or a directory exists at the expected path. + + Args: + path_str (str): + Path to check. + path_type (Literal[, optional): + What kind of file do we expect at the path. + Choose among `file`, `dir`, `auto`. + Defaults to 'auto'. path_type = path_type.lower() + + Raises: + KeyError: if `path_type` conflicts with `path_str` + + Returns: + Existence: + 0. FileExist: file at path_str exists. + 1. DirectoryExistEmpty: folder at path exists and. + 2. DirectoryExistNotEmpty: folder at path_str exists and not empty. + 3. MissingParent: its parent doesn't exist. + 4. DirectoryNotExist: expect a folder at path_str, but not found. + 5. FileNotExist: expect a file at path_str, but not found. + """ + path_type = path_type.lower() + assert path_type in {'file', 'dir', 'auto'} + pathinfo = Path(path_str) + if not pathinfo.parent.is_dir(): + return Existence.MissingParent + suffix = pathinfo.suffix.lower() + if path_type == 'dir' or\ + path_type == 'auto' and suffix == '': + if pathinfo.is_dir(): + if len(os.listdir(path_str)) == 0: + return Existence.DirectoryExistEmpty + else: + return Existence.DirectoryExistNotEmpty + else: + return Existence.DirectoryNotExist + elif path_type == 'file' or\ + path_type == 'auto' and suffix != '': + if pathinfo.is_file(): + return Existence.FileExist + elif pathinfo.is_dir(): + if len(os.listdir(path_str)) == 0: + return Existence.DirectoryExistEmpty + else: + return Existence.DirectoryExistNotEmpty + if path_str.endswith('/'): + return Existence.DirectoryNotExist + else: + return Existence.FileNotExist + + +def prepare_output_path(output_path: str, + allowed_suffix: List[str] = [], + tag: str = 'output file', + path_type: Literal['file', 'dir', 'auto'] = 'auto', + overwrite: bool = True) -> None: + """Check output folder or file. + + Args: + output_path (str): could be folder or file. + allowed_suffix (List[str], optional): + Check the suffix of `output_path`. If folder, should be [] or ['']. + If could both be folder or file, should be [suffixs..., '']. + Defaults to []. + tag (str, optional): The `string` tag to specify the output type. + Defaults to 'output file'. + path_type (Literal[, optional): + Choose `file` for file and `dir` for folder. + Choose `auto` if allowed to be both. + Defaults to 'auto'. + overwrite (bool, optional): + Whether overwrite the existing file or folder. + Defaults to True. + + Raises: + FileNotFoundError: suffix does not match. + FileExistsError: file or folder already exists and `overwrite` is + False. + + Returns: + None + """ + if path_type.lower() == 'dir': + allowed_suffix = [] + exist_result = check_path_existence(output_path, path_type=path_type) + if exist_result == Existence.MissingParent: + warnings.warn( + f'The parent folder of {tag} does not exist: {output_path},' + + f' will make dir {Path(output_path).parent.absolute().__str__()}') + os.makedirs(Path(output_path).parent.absolute().__str__(), + exist_ok=True) + + elif exist_result == Existence.DirectoryNotExist: + os.mkdir(output_path) + print(f'Making directory {output_path} for saving results.') + elif exist_result == Existence.FileNotExist: + suffix_matched = \ + check_path_suffix(output_path, allowed_suffix=allowed_suffix) + if not suffix_matched: + raise FileNotFoundError( + f'The {tag} should be {", ".join(allowed_suffix)}: ' + f'{output_path}.') + elif exist_result == Existence.FileExist: + if not overwrite: + raise FileExistsError( + f'{output_path} exists (set overwrite = True to overwrite).') + else: + print(f'Overwriting {output_path}.') + elif exist_result == Existence.DirectoryExistEmpty: + pass + elif exist_result == Existence.DirectoryExistNotEmpty: + if not overwrite: + raise FileExistsError( + f'{output_path} is not empty (set overwrite = ' + 'True to overwrite the files).') + else: + print(f'Overwriting {output_path} and its files.') + else: + raise FileNotFoundError(f'No Existence type for {output_path}.') + + +def check_input_path( + input_path: str, + allowed_suffix: List[str] = [], + tag: str = 'input file', + path_type: Literal['file', 'dir', 'auto'] = 'auto', +): + """Check input folder or file. + + Args: + input_path (str): input folder or file path. + allowed_suffix (List[str], optional): + Check the suffix of `input_path`. If folder, should be [] or ['']. + If could both be folder or file, should be [suffixs..., '']. + Defaults to []. + tag (str, optional): The `string` tag to specify the output type. + Defaults to 'output file'. + path_type (Literal[, optional): + Choose `file` for file and `directory` for folder. + Choose `auto` if allowed to be both. + Defaults to 'auto'. + + Raises: + FileNotFoundError: file does not exists or suffix does not match. + + Returns: + None + """ + if path_type.lower() == 'dir': + allowed_suffix = [] + exist_result = check_path_existence(input_path, path_type=path_type) + + if exist_result in [ + Existence.FileExist, Existence.DirectoryExistEmpty, + Existence.DirectoryExistNotEmpty + ]: + suffix_matched = \ + check_path_suffix(input_path, allowed_suffix=allowed_suffix) + if not suffix_matched: + raise FileNotFoundError( + f'The {tag} should be {", ".join(allowed_suffix)}:' + + f'{input_path}.') + else: + raise FileNotFoundError(f'The {tag} does not exist: {input_path}.') diff --git a/mogen/utils/plot_utils.py b/mogen/utils/plot_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d7d6e3bb155db5bc80bce2e1e1290eefeca81660 --- /dev/null +++ b/mogen/utils/plot_utils.py @@ -0,0 +1,130 @@ +import matplotlib +import matplotlib.pyplot as plt +import mpl_toolkits.mplot3d.axes3d as p3 +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +import numpy as np +import io +import imageio +from textwrap import wrap +import torch +import moviepy.editor as mpe +from scipy.io import wavfile +import os + +def plot_3d_motion(out_path, joints, kinematic_chain, title=None, ground=True, figsize=(10, 10), fps=120): + matplotlib.use('Agg') + + data = joints.copy().reshape(len(joints), -1, 3) + frame_number = data.shape[0] + + MINS = data.min(axis=0).min(axis=0) + MAXS = data.max(axis=0).max(axis=0) + height_offset = MINS[1] + data[:, :, 1] -= height_offset + trajec = data[:, 0, [0, 2]] + + data[..., 0] -= data[:, 0:1, 0] + data[..., 2] -= data[:, 0:1, 2] + + colors = ['red', 'blue', 'black', 'red', 'blue', + 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue', + 'darkred', 'darkred', 'darkred', 'darkred', 'darkred'] + + def update(index): + + def init(): + ax.set_xlim3d([-0.8, 0.8]) + ax.set_ylim3d([0, 1.6]) + ax.set_zlim3d([0, 1.6]) + ax.grid(False) + + def plot_xzPlane(minx, maxx, miny, minz, maxz): + ## Plot a plane XZ + verts = [ + [minx, miny, minz], + [minx, miny, maxz], + [maxx, miny, maxz], + [maxx, miny, minz] + ] + xz_plane = Poly3DCollection([verts]) + xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) + ax.add_collection3d(xz_plane) + + fig = plt.figure(figsize=figsize, dpi=96) + ax = fig.add_subplot(111, projection='3d') + + if title is not None : + wraped_title = '\n'.join(wrap(title, 40)) + fig.suptitle(wraped_title, fontsize=16) + + ax.view_init(elev=130, azim=-90) + init() + + # ax.cla() + + if ground: + plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1], + MAXS[2] - trajec[index, 1]) + if index > 1: + ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]), + trajec[:index, 1] - trajec[index, 1], linewidth=1.0, + color='blue') + + + for i, (chain, color) in enumerate(zip(kinematic_chain, colors)): + if i < 5: + linewidth = 4.0 + else: + linewidth = 2.0 + ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, + color=color) + + # ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=0, + # color=color, marker="o", markersize=linewidth*1.5, markerfacecolor="g", markeredgecolor="g") + + # for i in range(data[index].shape[0]): + # ax.text(data[index][i][0], data[index][i][1], data[index][i][2], str(i)) + + plt.axis('off') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + + io_buf = io.BytesIO() + fig.savefig(io_buf, format='raw', dpi=96) + io_buf.seek(0) + arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8), + newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1)) + io_buf.close() + plt.close() + return arr + + out = [] + for i in range(frame_number) : + out.append(update(i)) + out = np.stack(out, axis=0) + out = np.array(torch.from_numpy(out)) + imageio.mimsave(out_path, out, fps=fps) + + +def add_audio(out_path, audio_paths): + filename, ext = os.path.splitext(out_path) + in_path = filename + "_tmp" + ext + os.system(f"cp {out_path} {in_path}") + my_clip = mpe.VideoFileClip(in_path) + if len(audio_paths) > 1: + audio_clips = [] + for path in audio_paths: + audio_clips.append(mpe.AudioFileClip(path)) + final_audio = mpe.concatenate_audio(audio_clips) + else: + final_audio = mpe.AudioFileClip(audio_paths[0]) + final_clip = my_clip.set_audio(final_audio) + final_clip.write_videofile(out_path) + # os.system(f'rm -f {in_path}') + +def get_audio_length(audio_path): + sample_rate, data = wavfile.read(audio_path) + len_data = len(data) + t = len_data / sample_rate # duration in floats + return t \ No newline at end of file diff --git a/mogen/version.py b/mogen/version.py new file mode 100644 index 0000000000000000000000000000000000000000..145b162bd4d8f254f6cc8fe33e789f9e4ba3b0bb --- /dev/null +++ b/mogen/version.py @@ -0,0 +1,25 @@ +__version__ = '0.0.1' + + +def parse_version_info(version_str): + """Parse a version string into a tuple. + Args: + version_str (str): The version string. + Returns: + tuple[int | str]: The version info, e.g., "1.3.0" is parsed into + (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). + """ + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) + +__all__ = ['__version__', 'version_info', 'parse_version_info'] diff --git a/requirements/docs.txt b/requirements/docs.txt new file mode 100644 index 0000000000000000000000000000000000000000..a30452a83fe42d922a2e035d7361c0a0f9c5fe44 --- /dev/null +++ b/requirements/docs.txt @@ -0,0 +1,10 @@ +docutils +myst-parser +git+https://github.com/pytorch/pytorch_sphinx_theme.git +sphinx +sphinx-copybutton +sphinx_markdown_tables +sphinx_rtd_theme +mmcv +torch +torchvision \ No newline at end of file diff --git a/requirements/mogen.txt b/requirements/mogen.txt new file mode 100644 index 0000000000000000000000000000000000000000..41d5c28aea026402f2e60f78a11eec2dfc4e8209 --- /dev/null +++ b/requirements/mogen.txt @@ -0,0 +1,24 @@ +numpy==1.23.1 +ftfy +regex +tqdm +scipy +matplotlib +pandas +imageio +imageio-ffmpeg==0.4.9 +git+https://github.com/openai/CLIP.git +pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d +timm==0.6.7 +einops +fvcore +eva-decord==0.6.1 +iopath +types-regex +mayavi +cartopy +gdown +git+https://github.com/omimo/PyMO.git +ipython +librosa +imageio-ffmpeg==0.4.9 \ No newline at end of file diff --git a/tools/test.py b/tools/test.py new file mode 100644 index 0000000000000000000000000000000000000000..f58844a4536cf7a58a1cfdc4c724f2063dcbb860 --- /dev/null +++ b/tools/test.py @@ -0,0 +1,114 @@ +import warnings +warnings.filterwarnings("ignore", category=UserWarning) + +import argparse +import os +import os.path as osp + +import mmcv +import torch +from mmcv import DictAction +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, + wrap_fp16_model) + +from mogen.apis import multi_gpu_test, single_gpu_test +from mogen.datasets import build_dataloader, build_dataset +from mogen.models import build_architecture + + +def parse_args(): + parser = argparse.ArgumentParser(description='mogen evaluation') + parser.add_argument('config', help='test config file path') + parser.add_argument('--work-dir', + help='the dir to save evaluation results') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument('--out', help='output result file') + parser.add_argument('--gpu_collect', + action='store_true', + help='whether to use gpu to collect results') + parser.add_argument('--tmpdir', help='tmp dir for writing some results') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file.') + parser.add_argument('--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--device', + choices=['cpu', 'cuda'], + default='cuda', + help='device used for testing') + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + return args + + +def main(): + args = parse_args() + + cfg = mmcv.Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + cfg.data.test.test_mode = True + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + + # build the dataloader + dataset = build_dataset(cfg.data.test) + # the extra round_up data will be removed during gpu/cpu collect + data_loader = build_dataloader(dataset, + samples_per_gpu=cfg.data.samples_per_gpu, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False, + round_up=False) + + # build the model and load checkpoint + model = build_architecture(cfg.model) + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(model) + load_checkpoint(model, args.checkpoint, map_location='cpu') + + if not distributed: + if args.device == 'cpu': + model = model.cpu() + else: + model = MMDataParallel(model, device_ids=[0]) + outputs = single_gpu_test(model, data_loader) + else: + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False) + outputs = multi_gpu_test(model, data_loader, args.tmpdir, + args.gpu_collect) + + rank, _ = get_dist_info() + if rank == 0: + mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) + results = dataset.evaluate(outputs, args.work_dir) + for k, v in results.items(): + print(f'\n{k} : {v:.4f}') + + if args.out and rank == 0: + print(f'\nwriting results to {args.out}') + mmcv.dump(results, args.out) + + +if __name__ == '__main__': + main() diff --git a/tools/train.py b/tools/train.py new file mode 100644 index 0000000000000000000000000000000000000000..426b5715b49ce260735bc91b483e9db6a1444673 --- /dev/null +++ b/tools/train.py @@ -0,0 +1,149 @@ +import warnings +warnings.filterwarnings("ignore", category=UserWarning) + +import argparse +import copy +import os +import os.path as osp +import time + +import mmcv +import torch +from mmcv import Config, DictAction +from mmcv.runner import get_dist_info, init_dist + +from mogen.apis import set_random_seed, train_model +from mogen.datasets import build_dataset +from mogen.models import build_architecture +from mogen.utils import collect_env, get_root_logger + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a model') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument('--resume-from', + help='the checkpoint file to resume from') + parser.add_argument( + '--no-validate', + action='store_true', + help='whether not to evaluate the checkpoint during training') + group_gpus = parser.add_mutually_exclusive_group() + group_gpus.add_argument('--device', help='device used for training') + group_gpus.add_argument('--gpus', + type=int, + help='number of gpus to use ' + '(only applicable to non-distributed training)') + group_gpus.add_argument('--gpu-ids', + type=int, + nargs='+', + help='ids of gpus to use ' + '(only applicable to non-distributed training)') + parser.add_argument('--seed', type=int, default=None, help='random seed') + parser.add_argument( + '--deterministic', + action='store_true', + help='whether to set deterministic options for CUDNN backend.') + parser.add_argument('--options', + nargs='+', + action=DictAction, + help='arguments in dict') + parser.add_argument('--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.options is not None: + cfg.merge_from_dict(args.options) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + if args.resume_from is not None: + cfg.resume_from = args.resume_from + if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids + else: + cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + + # create work_dir + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + # dump config + cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(cfg.work_dir, f'{timestamp}.log') + logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) + + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) + dash_line = '-' * 60 + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' + + dash_line) + meta['env_info'] = env_info + + # log some basic info + logger.info(f'Distributed training: {distributed}') + logger.info(f'Config:\n{cfg.pretty_text}') + + # set random seeds + if args.seed is not None: + logger.info(f'Set random seed to {args.seed}, ' + f'deterministic: {args.deterministic}') + set_random_seed(args.seed, deterministic=args.deterministic) + cfg.seed = args.seed + meta['seed'] = args.seed + + model = build_architecture(cfg.model) + model.init_weights() + + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + val_dataset.pipeline = cfg.data.train.pipeline + datasets.append(build_dataset(val_dataset)) + # add an attribute for visualization convenience + train_model(model, + datasets, + cfg, + distributed=distributed, + validate=(not args.no_validate), + timestamp=timestamp, + device='cpu' if args.device == 'cpu' else 'cuda', + meta=meta) + + +if __name__ == '__main__': + main()