diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e62cd38287d8869177618d966460a6dc90a479df --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +**__pycache__** \ No newline at end of file diff --git a/README.md b/README.md index 4c9ca0495e9e6384e0ac2bbc66655419b0ac26f9..9fc0514e54007172e0ebba8a73f390d422ac6b97 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,9 @@ ---- -title: ReMoDiffuse -emoji: 📚 -colorFrom: red -colorTo: gray +title: MotionDiffuse +emoji: 🏢 +colorFrom: blue +colorTo: red sdk: gradio -sdk_version: 3.43.2 +sdk_version: 3.44.1 app_file: app.py pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +license: mit \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..0a56da44c94fc8c07d45815dc32b96267dcbf7bb --- /dev/null +++ b/app.py @@ -0,0 +1,123 @@ +import os +import sys +import gradio as gr + +os.makedirs("outputs", exist_ok=True) +sys.path.insert(0, '.') + +import argparse +import os.path as osp +import mmcv +import numpy as np +import torch +from mogen.models import build_architecture +from mmcv.runner import load_checkpoint +from mmcv.parallel import MMDataParallel +from mogen.utils.plot_utils import ( + recover_from_ric, + plot_3d_motion, + t2m_kinematic_chain +) +from scipy.ndimage import gaussian_filter +from IPython.display import Image + + +def motion_temporal_filter(motion, sigma=1): + motion = motion.reshape(motion.shape[0], -1) + for i in range(motion.shape[1]): + motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest") + return motion.reshape(motion.shape[0], -1, 3) + + +def plot_t2m(data, result_path, npy_path, caption): + joint = recover_from_ric(torch.from_numpy(data).float(), 22).numpy() + joint = motion_temporal_filter(joint, sigma=2.5) + plot_3d_motion(result_path, t2m_kinematic_chain, joint, title=caption, fps=20) + if npy_path is not None: + np.save(npy_path, joint) + +def create_remodiffuse(): + config_path = "configs/remodiffuse/remodiffuse_t2m.py" + ckpt_path = "logs/remodiffuse/remodiffuse_t2m/latest.pth" + cfg = mmcv.Config.fromfile(config_path) + model = build_architecture(cfg.model) + load_checkpoint(model, ckpt_path, map_location='cpu') + model.cpu() + model.eval() + return model + +def create_motiondiffuse(): + config_path = "configs/motiondiffuse/motiondiffuse_t2m.py" + ckpt_path = "logs/motiondiffuse/motiondiffuse_t2m/latest.pth" + cfg = mmcv.Config.fromfile(config_path) + model = build_architecture(cfg.model) + load_checkpoint(model, ckpt_path, map_location='cpu') + model.cpu() + model.eval() + return model + +def create_mdm(): + config_path = "configs/mdm/mdm_t2m_official.py" + ckpt_path = "logs/mdm/mdm_t2m/latest.pth" + cfg = mmcv.Config.fromfile(config_path) + model = build_architecture(cfg.model) + load_checkpoint(model, ckpt_path, map_location='cpu') + model.cpu() + model.eval() + return model + +model_remodiffuse = create_remodiffuse() +# model_motiondiffuse = create_motiondiffuse() +# model_mdm = create_mdm() + +mean_path = "data/datasets/human_ml3d/mean.npy" +std_path = "data/datasets/human_ml3d/std.npy" +mean = np.load(mean_path) +std = np.load(std_path) + + +def show_generation_result(model, text, motion_length, result_path): + device = 'cpu' + motion = torch.zeros(1, motion_length, 263).to(device) + motion_mask = torch.ones(1, motion_length).to(device) + motion_length = torch.Tensor([motion_length]).long().to(device) + model = model.to(device) + input = { + 'motion': motion, + 'motion_mask': motion_mask, + 'motion_length': motion_length, + 'motion_metas': [{'text': text}], + } + + all_pred_motion = [] + with torch.no_grad(): + input['inference_kwargs'] = {} + output_list = [] + output = model(**input)[0]['pred_motion'] + pred_motion = output.cpu().detach().numpy() + pred_motion = pred_motion * std + mean + + plot_t2m(pred_motion, result_path, None, text) + +def generate(prompt, length): + if not os.path.exists("outputs"): + os.mkdir("outputs") + result_path = "outputs/" + str(hash(prompt)) + ".mp4" + show_generation_result(model_remodiffuse, prompt, length, result_path) + return result_path + +demo = gr.Interface( + fn=generate, + inputs=["text", gr.Slider(20, 196, value=60)], + examples=[ + ["the man throws a punch with each hand.", 58], + ["a person spins quickly and takes off running.", 29], + ["a person quickly waves with their right hand", 46], + ["a person performing a slight bow", 89], + ], + outputs="video", + title="ReMoDiffuse: Retrieval-Augmented Motion Diffusion Model", + description="This is an interactive demo for ReMoDiffuse. For more information, feel free to visit our project page(https://mingyuan-zhang.github.io/projects/ReMoDiffuse.html).") + +demo.queue() +demo.launch() \ No newline at end of file diff --git a/configs/_base_/datasets/human_ml3d_bs128.py b/configs/_base_/datasets/human_ml3d_bs128.py new file mode 100644 index 0000000000000000000000000000000000000000..1f0653bd1f188717d7b44c810f916506d2c38d91 --- /dev/null +++ b/configs/_base_/datasets/human_ml3d_bs128.py @@ -0,0 +1,60 @@ +# dataset settings +data_keys = ['motion', 'motion_mask', 'motion_length', 'clip_feat'] +meta_keys = ['text', 'token'] +train_pipeline = [ + dict( + type='Normalize', + mean_path='data/datasets/human_ml3d/mean.npy', + std_path='data/datasets/human_ml3d/std.npy'), + dict(type='Crop', crop_size=196), + dict(type='ToTensor', keys=data_keys), + dict(type='Collect', keys=data_keys, meta_keys=meta_keys) +] + +data = dict( + samples_per_gpu=128, + workers_per_gpu=1, + train=dict( + type='RepeatDataset', + dataset=dict( + type='TextMotionDataset', + dataset_name='human_ml3d', + data_prefix='data', + pipeline=train_pipeline, + ann_file='train.txt', + motion_dir='motions', + text_dir='texts', + token_dir='tokens', + clip_feat_dir='clip_feats', + ), + times=200 + ), + test=dict( + type='TextMotionDataset', + dataset_name='human_ml3d', + data_prefix='data', + pipeline=train_pipeline, + ann_file='test.txt', + motion_dir='motions', + text_dir='texts', + token_dir='tokens', + clip_feat_dir='clip_feats', + eval_cfg=dict( + shuffle_indexes=True, + replication_times=20, + replication_reduction='statistics', + text_encoder_name='human_ml3d', + text_encoder_path='data/evaluators/human_ml3d/finest.tar', + motion_encoder_name='human_ml3d', + motion_encoder_path='data/evaluators/human_ml3d/finest.tar', + metrics=[ + dict(type='R Precision', batch_size=32, top_k=3), + dict(type='Matching Score', batch_size=32), + dict(type='FID'), + dict(type='Diversity', num_samples=300), + dict(type='MultiModality', num_samples=100, num_repeats=30, num_picks=10) + ] + ), + test_mode=True + ) +) \ No newline at end of file diff --git a/configs/_base_/datasets/kit_ml_bs128.py b/configs/_base_/datasets/kit_ml_bs128.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e6872851b7f64191459d432e7019bf98ea215e --- /dev/null +++ b/configs/_base_/datasets/kit_ml_bs128.py @@ -0,0 +1,60 @@ +# dataset settings +data_keys = ['motion', 'motion_mask', 'motion_length', 'clip_feat'] +meta_keys = ['text', 'token'] +train_pipeline = [ + dict(type='Crop', crop_size=196), + dict( + type='Normalize', + mean_path='data/datasets/kit_ml/mean.npy', + std_path='data/datasets/kit_ml/std.npy'), + dict(type='ToTensor', keys=data_keys), + dict(type='Collect', keys=data_keys, meta_keys=meta_keys) +] + +data = dict( + samples_per_gpu=128, + workers_per_gpu=1, + train=dict( + type='RepeatDataset', + dataset=dict( + type='TextMotionDataset', + dataset_name='kit_ml', + data_prefix='data', + pipeline=train_pipeline, + ann_file='train.txt', + motion_dir='motions', + text_dir='texts', + token_dir='tokens', + clip_feat_dir='clip_feats', + ), + times=100 + ), + test=dict( + type='TextMotionDataset', + dataset_name='kit_ml', + data_prefix='data', + pipeline=train_pipeline, + ann_file='test.txt', + motion_dir='motions', + text_dir='texts', + token_dir='tokens', + clip_feat_dir='clip_feats', + eval_cfg=dict( + shuffle_indexes=True, + replication_times=20, + replication_reduction='statistics', + text_encoder_name='kit_ml', + text_encoder_path='data/evaluators/kit_ml/finest.tar', + motion_encoder_name='kit_ml', + motion_encoder_path='data/evaluators/kit_ml/finest.tar', + metrics=[ + dict(type='R Precision', batch_size=32, top_k=3), + dict(type='Matching Score', batch_size=32), + dict(type='FID'), + dict(type='Diversity', num_samples=300), + dict(type='MultiModality', num_samples=50, num_repeats=30, num_picks=10) + ] + ), + test_mode=True + ) +) \ No newline at end of file diff --git a/configs/mdm/mdm_t2m_official.py b/configs/mdm/mdm_t2m_official.py new file mode 100644 index 0000000000000000000000000000000000000000..2e93e0df235e07f249caad221d43b86530f604f9 --- /dev/null +++ b/configs/mdm/mdm_t2m_official.py @@ -0,0 +1,67 @@ +_base_ = ['../_base_/datasets/human_ml3d_bs128.py'] + +# checkpoint saving +checkpoint_config = dict(interval=1) + +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] + +# optimizer +optimizer = dict(type='Adam', lr=1e-4) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[]) +runner = dict(type='EpochBasedRunner', max_epochs=50) + +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +input_feats = 263 +max_seq_len = 196 +latent_dim = 512 +time_embed_dim = 2048 +text_latent_dim = 256 +ff_size = 1024 +num_layers = 8 +num_heads = 4 +dropout = 0.1 +cond_mask_prob = 0.1 +# model settings +model = dict( + type='MotionDiffusion', + model=dict( + type='MDMTransformer', + input_feats=input_feats, + latent_dim=latent_dim, + ff_size=ff_size, + num_layers=num_layers, + num_heads=num_heads, + dropout=dropout, + time_embed_dim=time_embed_dim, + cond_mask_prob=cond_mask_prob, + guide_scale=2.5, + clip_version='ViT-B/32', + use_official_ckpt=True + ), + loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), + diffusion_train=dict( + beta_scheduler='cosine', + diffusion_steps=1000, + model_mean_type='start_x', + model_var_type='fixed_small', + ), + diffusion_test=dict( + beta_scheduler='cosine', + diffusion_steps=1000, + model_mean_type='start_x', + model_var_type='fixed_small', + ), + inference_type='ddpm' +) \ No newline at end of file diff --git a/configs/motiondiffuse/motiondiffuse_kit.py b/configs/motiondiffuse/motiondiffuse_kit.py new file mode 100644 index 0000000000000000000000000000000000000000..5925be3298ea617ac9166736aa50608925b6e209 --- /dev/null +++ b/configs/motiondiffuse/motiondiffuse_kit.py @@ -0,0 +1,89 @@ +_base_ = ['../_base_/datasets/kit_ml_bs128.py'] + +# checkpoint saving +checkpoint_config = dict(interval=1) + +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] + +# optimizer +optimizer = dict(type='Adam', lr=2e-4) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[]) +runner = dict(type='EpochBasedRunner', max_epochs=50) + +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +input_feats = 251 +max_seq_len = 196 +latent_dim = 512 +time_embed_dim = 2048 +text_latent_dim = 256 +ff_size = 1024 +num_heads = 8 +dropout = 0 +# model settings +model = dict( + type='MotionDiffusion', + model=dict( + type='MotionDiffuseTransformer', + input_feats=input_feats, + max_seq_len=max_seq_len, + latent_dim=latent_dim, + time_embed_dim=time_embed_dim, + num_layers=8, + sa_block_cfg=dict( + type='EfficientSelfAttention', + latent_dim=latent_dim, + num_heads=num_heads, + dropout=dropout, + time_embed_dim=time_embed_dim + ), + ca_block_cfg=dict( + type='EfficientCrossAttention', + latent_dim=latent_dim, + text_latent_dim=text_latent_dim, + num_heads=num_heads, + dropout=dropout, + time_embed_dim=time_embed_dim + ), + ffn_cfg=dict( + latent_dim=latent_dim, + ffn_dim=ff_size, + dropout=dropout, + time_embed_dim=time_embed_dim + ), + text_encoder=dict( + pretrained_model='clip', + latent_dim=text_latent_dim, + num_layers=4, + num_heads=4, + ff_size=2048, + dropout=dropout, + use_text_proj=True + ) + ), + loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), + diffusion_train=dict( + beta_scheduler='linear', + diffusion_steps=1000, + model_mean_type='epsilon', + model_var_type='fixed_small', + ), + diffusion_test=dict( + beta_scheduler='linear', + diffusion_steps=1000, + model_mean_type='epsilon', + model_var_type='fixed_small', + ), + inference_type='ddpm' +) diff --git a/configs/motiondiffuse/motiondiffuse_t2m.py b/configs/motiondiffuse/motiondiffuse_t2m.py new file mode 100644 index 0000000000000000000000000000000000000000..96caef20ee51774c151a1f6003ec643216849e20 --- /dev/null +++ b/configs/motiondiffuse/motiondiffuse_t2m.py @@ -0,0 +1,90 @@ +_base_ = ['../_base_/datasets/human_ml3d_bs128.py'] + +# checkpoint saving +checkpoint_config = dict(interval=1) + +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] + +# optimizer +optimizer = dict(type='Adam', lr=2e-4) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='step', step=[]) +runner = dict(type='EpochBasedRunner', max_epochs=50) + +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +input_feats = 263 +max_seq_len = 196 +latent_dim = 512 +time_embed_dim = 2048 +text_latent_dim = 256 +ff_size = 1024 +num_heads = 8 +dropout = 0 +# model settings +model = dict( + type='MotionDiffusion', + model=dict( + type='MotionDiffuseTransformer', + input_feats=input_feats, + max_seq_len=max_seq_len, + latent_dim=latent_dim, + time_embed_dim=time_embed_dim, + num_layers=8, + sa_block_cfg=dict( + type='EfficientSelfAttention', + latent_dim=latent_dim, + num_heads=num_heads, + dropout=dropout, + time_embed_dim=time_embed_dim + ), + ca_block_cfg=dict( + type='EfficientCrossAttention', + latent_dim=latent_dim, + text_latent_dim=text_latent_dim, + num_heads=num_heads, + dropout=dropout, + time_embed_dim=time_embed_dim + ), + ffn_cfg=dict( + latent_dim=latent_dim, + ffn_dim=ff_size, + dropout=dropout, + time_embed_dim=time_embed_dim + ), + text_encoder=dict( + pretrained_model='clip', + latent_dim=text_latent_dim, + num_layers=4, + num_heads=4, + ff_size=2048, + dropout=dropout, + use_text_proj=True + ) + ), + loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), + diffusion_train=dict( + beta_scheduler='linear', + diffusion_steps=1000, + model_mean_type='epsilon', + model_var_type='fixed_small', + ), + diffusion_test=dict( + beta_scheduler='linear', + diffusion_steps=1000, + model_mean_type='epsilon', + model_var_type='fixed_small', + ), + inference_type='ddpm' +) +data = dict(samples_per_gpu=128) \ No newline at end of file diff --git a/configs/remodiffuse/remodiffuse_kit.py b/configs/remodiffuse/remodiffuse_kit.py new file mode 100644 index 0000000000000000000000000000000000000000..3943b755ec611b56af16fa83a46d6d6f59ae7355 --- /dev/null +++ b/configs/remodiffuse/remodiffuse_kit.py @@ -0,0 +1,141 @@ +_base_ = ['../_base_/datasets/kit_ml_bs128.py'] + +# checkpoint saving +checkpoint_config = dict(interval=1) + +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] + +# optimizer +optimizer = dict(type='Adam', lr=2e-4) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='CosineAnnealing', min_lr_ratio=2e-5, by_epoch=False) +runner = dict(type='EpochBasedRunner', max_epochs=20) + +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +input_feats = 251 +max_seq_len = 196 +latent_dim = 512 +time_embed_dim = 2048 +text_latent_dim = 256 +ff_size = 1024 +num_heads = 8 +dropout = 0 + +def scale_func(timestep): + import random + w = (1 - (1000 - timestep) / 1000) * 4.0 + 1 + if timestep > 100: + if random.randint(0, 1) == 0: + output = { + 'both_coef': w, + 'text_coef': 0, + 'retr_coef': 1 - w, + 'none_coef': 0 + } + else: + output = { + 'both_coef': 0, + 'text_coef': w, + 'retr_coef': 0, + 'none_coef': 1 - w + } + else: + both_coef = 0.78123 + text_coef = 0.39284 + retr_coef = -0.12475 + none_coef = 1 - both_coef - text_coef - retr_coef + output = { + 'both_coef': both_coef, + 'text_coef': text_coef, + 'retr_coef': retr_coef, + 'none_coef': none_coef + } + return output + +# model settings +model = dict( + type='MotionDiffusion', + model=dict( + type='ReMoDiffuseTransformer', + input_feats=input_feats, + max_seq_len=max_seq_len, + latent_dim=latent_dim, + time_embed_dim=time_embed_dim, + num_layers=4, + ca_block_cfg=dict( + type='SemanticsModulatedAttention', + latent_dim=latent_dim, + text_latent_dim=text_latent_dim, + num_heads=num_heads, + dropout=dropout, + time_embed_dim=time_embed_dim + ), + ffn_cfg=dict( + latent_dim=latent_dim, + ffn_dim=ff_size, + dropout=dropout, + time_embed_dim=time_embed_dim + ), + text_encoder=dict( + pretrained_model='clip', + latent_dim=text_latent_dim, + num_layers=2, + ff_size=2048, + dropout=dropout, + use_text_proj=False + ), + retrieval_cfg=dict( + num_retrieval=2, + stride=4, + num_layers=2, + num_motion_layers=2, + kinematic_coef=0.1, + topk=2, + retrieval_file='data/database/kit_text_train.npz', + latent_dim=latent_dim, + output_dim=latent_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + ff_size=ff_size, + dropout=dropout, + ffn_cfg=dict( + latent_dim=latent_dim, + ffn_dim=ff_size, + dropout=dropout, + ), + sa_block_cfg=dict( + type='EfficientSelfAttention', + latent_dim=latent_dim, + num_heads=num_heads, + dropout=dropout + ), + ), + scale_func=scale_func + ), + loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), + diffusion_train=dict( + beta_scheduler='linear', + diffusion_steps=1000, + model_mean_type='start_x', + model_var_type='fixed_large', + ), + diffusion_test=dict( + beta_scheduler='linear', + diffusion_steps=1000, + model_mean_type='start_x', + model_var_type='fixed_large', + respace='15,15,8,6,6', + ), + inference_type='ddim' +) \ No newline at end of file diff --git a/configs/remodiffuse/remodiffuse_t2m.py b/configs/remodiffuse/remodiffuse_t2m.py new file mode 100644 index 0000000000000000000000000000000000000000..6b06d0e5140183074b3859c155c4e86f348c3e8d --- /dev/null +++ b/configs/remodiffuse/remodiffuse_t2m.py @@ -0,0 +1,141 @@ +_base_ = ['../_base_/datasets/human_ml3d_bs128.py'] + +# checkpoint saving +checkpoint_config = dict(interval=1) + +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] + +# optimizer +optimizer = dict(type='Adam', lr=2e-4) +optimizer_config = dict(grad_clip=None) +# learning policy +lr_config = dict(policy='CosineAnnealing', min_lr_ratio=2e-5, by_epoch=False) +runner = dict(type='EpochBasedRunner', max_epochs=40) + +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) + +input_feats = 263 +max_seq_len = 196 +latent_dim = 512 +time_embed_dim = 2048 +text_latent_dim = 256 +ff_size = 1024 +num_heads = 8 +dropout = 0 + +def scale_func(timestep): + import random + w = (1 - (1000 - timestep) / 1000) * 6.5 + 1 + if timestep > 100: + if random.randint(0, 1) == 0: + output = { + 'both_coef': w, + 'text_coef': 0, + 'retr_coef': 1 - w, + 'none_coef': 0 + } + else: + output = { + 'both_coef': 0, + 'text_coef': w, + 'retr_coef': 0, + 'none_coef': 1 - w + } + else: + both_coef = 0.52351 + text_coef = -0.28419 + retr_coef = 2.39872 + none_coef = 1 - both_coef - text_coef - retr_coef + output = { + 'both_coef': both_coef, + 'text_coef': text_coef, + 'retr_coef': retr_coef, + 'none_coef': none_coef + } + return output + +# model settings +model = dict( + type='MotionDiffusion', + model=dict( + type='ReMoDiffuseTransformer', + input_feats=input_feats, + max_seq_len=max_seq_len, + latent_dim=latent_dim, + time_embed_dim=time_embed_dim, + num_layers=4, + ca_block_cfg=dict( + type='SemanticsModulatedAttention', + latent_dim=latent_dim, + text_latent_dim=text_latent_dim, + num_heads=num_heads, + dropout=dropout, + time_embed_dim=time_embed_dim + ), + ffn_cfg=dict( + latent_dim=latent_dim, + ffn_dim=ff_size, + dropout=dropout, + time_embed_dim=time_embed_dim + ), + text_encoder=dict( + pretrained_model='clip', + latent_dim=text_latent_dim, + num_layers=2, + ff_size=2048, + dropout=dropout, + use_text_proj=False + ), + retrieval_cfg=dict( + num_retrieval=2, + stride=4, + num_layers=2, + num_motion_layers=2, + kinematic_coef=0.1, + topk=2, + retrieval_file='data/database/t2m_text_train.npz', + latent_dim=latent_dim, + output_dim=latent_dim, + max_seq_len=max_seq_len, + num_heads=num_heads, + ff_size=ff_size, + dropout=dropout, + ffn_cfg=dict( + latent_dim=latent_dim, + ffn_dim=ff_size, + dropout=dropout, + ), + sa_block_cfg=dict( + type='EfficientSelfAttention', + latent_dim=latent_dim, + num_heads=num_heads, + dropout=dropout + ), + ), + scale_func=scale_func + ), + loss_recon=dict(type='MSELoss', loss_weight=1, reduction='none'), + diffusion_train=dict( + beta_scheduler='linear', + diffusion_steps=1000, + model_mean_type='start_x', + model_var_type='fixed_large', + ), + diffusion_test=dict( + beta_scheduler='linear', + diffusion_steps=1000, + model_mean_type='start_x', + model_var_type='fixed_large', + respace='15,15,8,6,6', + ), + inference_type='ddim' +) \ No newline at end of file diff --git a/data/database/t2m_text_train.npz b/data/database/t2m_text_train.npz new file mode 100644 index 0000000000000000000000000000000000000000..c809ad6b549e2508804fbc1e81e629e49ae8ed85 --- /dev/null +++ b/data/database/t2m_text_train.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae3575b686e29623f9e1715345b052726650f53c5bfcc770d9fb87a827a60249 +size 1462801786 diff --git a/data/datasets/human_ml3d/mean.npy b/data/datasets/human_ml3d/mean.npy new file mode 100644 index 0000000000000000000000000000000000000000..9fb7b482060b4987346d05a383f99339ff17e628 --- /dev/null +++ b/data/datasets/human_ml3d/mean.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d73483a5b53e017b4044fe363164d7c185082a02ae7f69525ea70c5ccfd4a85 +size 1180 diff --git a/data/datasets/human_ml3d/std.npy b/data/datasets/human_ml3d/std.npy new file mode 100644 index 0000000000000000000000000000000000000000..dddadc19ce19f76f7bdf9204e6eee33ab8e060a2 --- /dev/null +++ b/data/datasets/human_ml3d/std.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6a6d720e004b6da18e8033d739de6078cbc7c1c8fad0ff62eee86f173e4430a2 +size 1180 diff --git a/data/datasets/kit_ml/mean.npy b/data/datasets/kit_ml/mean.npy new file mode 100644 index 0000000000000000000000000000000000000000..c1f076c473eaabf4e6c0144d3e6db8b6a3c7e976 --- /dev/null +++ b/data/datasets/kit_ml/mean.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e23fac51db2215ab5666324226be48f27efd6a6e7b22ebd17c28e0f056a7c22 +size 2136 diff --git a/data/datasets/kit_ml/std.npy b/data/datasets/kit_ml/std.npy new file mode 100644 index 0000000000000000000000000000000000000000..02a4c81095a331998ae0c95e3b01dc48c6d37b77 --- /dev/null +++ b/data/datasets/kit_ml/std.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:296a60656cea07e65ee64512d73d47c0412df0698b35194116330661be32fa90 +size 2136 diff --git a/logs/mdm/mdm_t2m/latest.pth b/logs/mdm/mdm_t2m/latest.pth new file mode 100644 index 0000000000000000000000000000000000000000..9d7d40f80d173975aac39ed6b9d55c05cba36842 --- /dev/null +++ b/logs/mdm/mdm_t2m/latest.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8810255fb8df9eed6211537de9826f07ff73862f367cbf91532d84fd4c9a497e +size 81791550 diff --git a/logs/motiondiffuse/motiondiffuse_t2m/latest.pth b/logs/motiondiffuse/motiondiffuse_t2m/latest.pth new file mode 100644 index 0000000000000000000000000000000000000000..f99949e3546a8c693c006cc1548e1d1bca90ee21 --- /dev/null +++ b/logs/motiondiffuse/motiondiffuse_t2m/latest.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:521baa6ba60865710bc75b99f393b133e45dc18083229a2258a16e5dc65f904a +size 348728194 diff --git a/logs/remodiffuse/remodiffuse_t2m/latest.pth b/logs/remodiffuse/remodiffuse_t2m/latest.pth new file mode 100644 index 0000000000000000000000000000000000000000..7482b29b9bd9e18db5c6969839f2bbf0899b7021 --- /dev/null +++ b/logs/remodiffuse/remodiffuse_t2m/latest.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aaa34b3328942769478e96283678424c95c4b817ca6f7162c4cf1fc512d4951b +size 187939375 diff --git a/mogen/__init__.py b/mogen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19f5ad7413f044b75f67d10413fc48c4148f56a4 --- /dev/null +++ b/mogen/__init__.py @@ -0,0 +1,56 @@ +import warnings + +import mmcv +from packaging.version import parse + +from .version import __version__ + + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + Returns: + tuple[int]: The version info in digits (integers). + """ + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) + else: + release.extend([0, 0]) + return tuple(release) + + +mmcv_minimum_version = '1.4.2' +mmcv_maximum_version = '1.9.0' +mmcv_version = digit_version(mmcv.__version__) + + +assert (mmcv_version >= digit_version(mmcv_minimum_version) + and mmcv_version <= digit_version(mmcv_maximum_version)), \ + f'MMCV=={mmcv.__version__} is used but incompatible. ' \ + f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' + +__all__ = ['__version__', 'digit_version'] \ No newline at end of file diff --git a/mogen/apis/__init__.py b/mogen/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a3f8ebf85df85924cc84e47bf3af210b40cd246 --- /dev/null +++ b/mogen/apis/__init__.py @@ -0,0 +1,13 @@ +from mogen.apis import test, train +from mogen.apis.test import ( + collect_results_cpu, + collect_results_gpu, + multi_gpu_test, + single_gpu_test, +) +from mogen.apis.train import set_random_seed, train_model + +__all__ = [ + 'collect_results_cpu', 'collect_results_gpu', 'multi_gpu_test', + 'single_gpu_test', 'set_random_seed', 'train_model' +] \ No newline at end of file diff --git a/mogen/apis/test.py b/mogen/apis/test.py new file mode 100644 index 0000000000000000000000000000000000000000..300dbedb0a319c54846db5c7923853e7eb20bd36 --- /dev/null +++ b/mogen/apis/test.py @@ -0,0 +1,160 @@ +import os.path as osp +import pickle +import shutil +import tempfile +import time + +import mmcv +import torch +import torch.distributed as dist +from mmcv.runner import get_dist_info + + +def single_gpu_test(model, data_loader): + """Test with single gpu.""" + model.eval() + results = [] + dataset = data_loader.dataset + prog_bar = mmcv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, **data) + + batch_size = len(result) + if isinstance(result, list): + results.extend(result) + else: + results.append(result) + + batch_size = data['motion'].size(0) + for _ in range(batch_size): + prog_bar.update() + return results + + +def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): + """Test model with multiple gpus. + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' + it encodes results to gpu tensors and use gpu communication for results + collection. On cpu mode it saves the results on different gpus to 'tmpdir' + and collects them by the rank 0 worker. + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + # Check if tmpdir is valid for cpu_collect + if (not gpu_collect) and (tmpdir is not None and osp.exists(tmpdir)): + raise OSError((f'The tmpdir {tmpdir} already exists.', + ' Since tmpdir will be deleted after testing,', + ' please make sure you specify an empty one.')) + prog_bar = mmcv.ProgressBar(len(dataset)) + time.sleep(2) # This line can prevent deadlock problem in some cases. + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, **data) + if isinstance(result, list): + results.extend(result) + else: + results.append(result) + + if rank == 0: + batch_size = data['motion'].size(0) + for _ in range(batch_size * world_size): + prog_bar.update() + + # collect results from all ranks + if gpu_collect: + results = collect_results_gpu(results, len(dataset)) + else: + results = collect_results_cpu(results, len(dataset), tmpdir) + return results + + +def collect_results_cpu(result_part, size, tmpdir=None): + """Collect results in cpu.""" + rank, world_size = get_dist_info() + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full((MAX_LEN, ), + 32, + dtype=torch.uint8, + device='cuda') + if rank == 0: + mmcv.mkdir_or_exist('.dist_test') + tmpdir = tempfile.mkdtemp(dir='.dist_test') + tmpdir = torch.tensor( + bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') + dir_tensor[:len(tmpdir)] = tmpdir + dist.broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + else: + mmcv.mkdir_or_exist(tmpdir) + # dump the part result to the dir + mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl')) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, f'part_{i}.pkl') + part_result = mmcv.load(part_file) + part_list.append(part_result) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results + + +def collect_results_gpu(result_part, size): + """Collect results in gpu.""" + rank, world_size = get_dist_info() + # dump result part to tensor with pickle + part_tensor = torch.tensor( + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + # gather all result part tensor shape + shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_list = [shape_tensor.clone() for _ in range(world_size)] + dist.all_gather(shape_list, shape_tensor) + # padding result part tensor to max length + shape_max = torch.tensor(shape_list).max() + part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send[:shape_tensor[0]] = part_tensor + part_recv_list = [ + part_tensor.new_zeros(shape_max) for _ in range(world_size) + ] + # gather all result part + dist.all_gather(part_recv_list, part_send) + + if rank == 0: + part_list = [] + for recv, shape in zip(part_recv_list, shape_list): + part_result = pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()) + part_list.append(part_result) + # sort the results + ordered_results = [] + for res in zip(*part_list): + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + return ordered_results \ No newline at end of file diff --git a/mogen/apis/train.py b/mogen/apis/train.py new file mode 100644 index 0000000000000000000000000000000000000000..eea9de794c0ceb1df71370800f6838b807487ad2 --- /dev/null +++ b/mogen/apis/train.py @@ -0,0 +1,165 @@ +import random +import warnings + +import numpy as np +import torch +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import ( + DistSamplerSeedHook, + Fp16OptimizerHook, + OptimizerHook, + build_runner, +) + +from mogen.core.distributed_wrapper import DistributedDataParallelWrapper +from mogen.core.evaluation import DistEvalHook, EvalHook +from mogen.core.optimizer import build_optimizers +from mogen.datasets import build_dataloader, build_dataset +from mogen.utils import get_root_logger + + +def set_random_seed(seed, deterministic=False): + """Set random seed. + Args: + seed (int): Seed to be used. + deterministic (bool): Whether to set the deterministic option for + CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` + to True and `torch.backends.cudnn.benchmark` to False. + Default: False. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if deterministic: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +def train_model(model, + dataset, + cfg, + distributed=False, + validate=False, + timestamp=None, + device='cuda', + meta=None): + """Main api for training model.""" + logger = get_root_logger(cfg.log_level) + + # prepare data loaders + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] + + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + # cfg.gpus will be ignored if distributed + num_gpus=len(cfg.gpu_ids), + dist=distributed, + round_up=True, + seed=cfg.seed) for ds in dataset + ] + + # determine whether use adversarial training precess or not + use_adverserial_train = cfg.get('use_adversarial_train', False) + + # put model on gpus + if distributed: + find_unused_parameters = cfg.get('find_unused_parameters', True) + # Sets the `find_unused_parameters` parameter in + # torch.nn.parallel.DistributedDataParallel + if use_adverserial_train: + # Use DistributedDataParallelWrapper for adversarial training + model = DistributedDataParallelWrapper( + model, + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters) + else: + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters) + else: + if device == 'cuda': + model = MMDataParallel( + model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) + elif device == 'cpu': + model = model.cpu() + else: + raise ValueError(F'unsupported device name {device}.') + + # build runner + optimizer = build_optimizers(model, cfg.optimizer) + + if cfg.get('runner') is None: + cfg.runner = { + 'type': 'EpochBasedRunner', + 'max_epochs': cfg.total_epochs + } + warnings.warn( + 'config is now expected to have a `runner` section, ' + 'please set `runner` in your config.', UserWarning) + + runner = build_runner( + cfg.runner, + default_args=dict( + model=model, + batch_processor=None, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta)) + + # an ugly walkaround to make the .log and .log.json filenames the same + runner.timestamp = timestamp + + if use_adverserial_train: + # The optimizer step process is included in the train_step function + # of the model, so the runner should NOT include optimizer hook. + optimizer_config = None + else: + # fp16 setting + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + optimizer_config = Fp16OptimizerHook( + **cfg.optimizer_config, **fp16_cfg, distributed=distributed) + elif distributed and 'type' not in cfg.optimizer_config: + optimizer_config = OptimizerHook(**cfg.optimizer_config) + else: + optimizer_config = cfg.optimizer_config + + # register hooks + runner.register_training_hooks( + cfg.lr_config, + optimizer_config, + cfg.checkpoint_config, + cfg.log_config, + cfg.get('momentum_config', None), + custom_hooks_config=cfg.get('custom_hooks', None)) + if distributed: + runner.register_hook(DistSamplerSeedHook()) + + # register eval hooks + if validate: + val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) + val_dataloader = build_dataloader( + val_dataset, + samples_per_gpu=cfg.data.samples_per_gpu, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False, + round_up=True) + eval_cfg = cfg.get('evaluation', {}) + eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' + eval_hook = DistEvalHook if distributed else EvalHook + runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) + + if cfg.resume_from: + runner.resume(cfg.resume_from) + elif cfg.load_from: + runner.load_checkpoint(cfg.load_from) + runner.run(data_loaders, cfg.workflow) \ No newline at end of file diff --git a/mogen/core/__init__.py b/mogen/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mogen/core/distributed_wrapper.py b/mogen/core/distributed_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5f4a29fe0f5dbcb53bb5a94f928a3a554e518e --- /dev/null +++ b/mogen/core/distributed_wrapper.py @@ -0,0 +1,136 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel +from mmcv.parallel.scatter_gather import scatter_kwargs +from torch.cuda._utils import _get_device_index + + +@MODULE_WRAPPERS.register_module() +class DistributedDataParallelWrapper(nn.Module): + """A DistributedDataParallel wrapper for models in 3D mesh estimation task. + + In 3D mesh estimation task, there is a need to wrap different modules in + the models with separate DistributedDataParallel. Otherwise, it will cause + errors for GAN training. + More specific, the GAN model, usually has two sub-modules: + generator and discriminator. If we wrap both of them in one + standard DistributedDataParallel, it will cause errors during training, + because when we update the parameters of the generator (or discriminator), + the parameters of the discriminator (or generator) is not updated, which is + not allowed for DistributedDataParallel. + So we design this wrapper to separately wrap DistributedDataParallel + for generator and discriminator. + In this wrapper, we perform two operations: + 1. Wrap the modules in the models with separate MMDistributedDataParallel. + Note that only modules with parameters will be wrapped. + 2. Do scatter operation for 'forward', 'train_step' and 'val_step'. + Note that the arguments of this wrapper is the same as those in + `torch.nn.parallel.distributed.DistributedDataParallel`. + Args: + module (nn.Module): Module that needs to be wrapped. + device_ids (list[int | `torch.device`]): Same as that in + `torch.nn.parallel.distributed.DistributedDataParallel`. + dim (int, optional): Same as that in the official scatter function in + pytorch. Defaults to 0. + broadcast_buffers (bool): Same as that in + `torch.nn.parallel.distributed.DistributedDataParallel`. + Defaults to False. + find_unused_parameters (bool, optional): Same as that in + `torch.nn.parallel.distributed.DistributedDataParallel`. + Traverse the autograd graph of all tensors contained in returned + value of the wrapped module’s forward function. Defaults to False. + kwargs (dict): Other arguments used in + `torch.nn.parallel.distributed.DistributedDataParallel`. + """ + + def __init__(self, + module, + device_ids, + dim=0, + broadcast_buffers=False, + find_unused_parameters=False, + **kwargs): + super().__init__() + assert len(device_ids) == 1, ( + 'Currently, DistributedDataParallelWrapper only supports one' + 'single CUDA device for each process.' + f'The length of device_ids must be 1, but got {len(device_ids)}.') + self.module = module + self.dim = dim + self.to_ddp( + device_ids=device_ids, + dim=dim, + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + **kwargs) + self.output_device = _get_device_index(device_ids[0], True) + + def to_ddp(self, device_ids, dim, broadcast_buffers, + find_unused_parameters, **kwargs): + """Wrap models with separate MMDistributedDataParallel. + + It only wraps the modules with parameters. + """ + for name, module in self.module._modules.items(): + if next(module.parameters(), None) is None: + module = module.cuda() + elif all(not p.requires_grad for p in module.parameters()): + module = module.cuda() + else: + module = MMDistributedDataParallel( + module.cuda(), + device_ids=device_ids, + dim=dim, + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + **kwargs) + self.module._modules[name] = module + + def scatter(self, inputs, kwargs, device_ids): + """Scatter function. + + Args: + inputs (Tensor): Input Tensor. + kwargs (dict): Args for + ``mmcv.parallel.scatter_gather.scatter_kwargs``. + device_ids (int): Device id. + """ + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def forward(self, *inputs, **kwargs): + """Forward function. + + Args: + inputs (tuple): Input data. + kwargs (dict): Args for + ``mmcv.parallel.scatter_gather.scatter_kwargs``. + """ + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + return self.module(*inputs[0], **kwargs[0]) + + def train_step(self, *inputs, **kwargs): + """Train step function. + + Args: + inputs (Tensor): Input Tensor. + kwargs (dict): Args for + ``mmcv.parallel.scatter_gather.scatter_kwargs``. + """ + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + output = self.module.train_step(*inputs[0], **kwargs[0]) + return output + + def val_step(self, *inputs, **kwargs): + """Validation step function. + + Args: + inputs (tuple): Input data. + kwargs (dict): Args for ``scatter_kwargs``. + """ + inputs, kwargs = self.scatter(inputs, kwargs, + [torch.cuda.current_device()]) + output = self.module.val_step(*inputs[0], **kwargs[0]) + return output \ No newline at end of file diff --git a/mogen/core/evaluation/__init__.py b/mogen/core/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dacade8461ba66f20eb7d8440eeda9e56f3007ec --- /dev/null +++ b/mogen/core/evaluation/__init__.py @@ -0,0 +1,4 @@ +from mogen.core.evaluation.eval_hooks import DistEvalHook, EvalHook +from mogen.core.evaluation.builder import build_evaluator + +__all__ = ["DistEvalHook", "EvalHook", "build_evaluator"] \ No newline at end of file diff --git a/mogen/core/evaluation/builder.py b/mogen/core/evaluation/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..76c89d451fe727489e69f6fc10b6fbda9d93b6e6 --- /dev/null +++ b/mogen/core/evaluation/builder.py @@ -0,0 +1,29 @@ +import copy +import numpy as np +from mmcv.utils import Registry +from .evaluators.precision_evaluator import PrecisionEvaluator +from .evaluators.matching_score_evaluator import MatchingScoreEvaluator +from .evaluators.fid_evaluator import FIDEvaluator +from .evaluators.diversity_evaluator import DiversityEvaluator +from .evaluators.multimodality_evaluator import MultiModalityEvaluator + +EVALUATORS = Registry('evaluators') + +EVALUATORS.register_module(name='R Precision', module=PrecisionEvaluator) +EVALUATORS.register_module(name='Matching Score', module=MatchingScoreEvaluator) +EVALUATORS.register_module(name='FID', module=FIDEvaluator) +EVALUATORS.register_module(name='Diversity', module=DiversityEvaluator) +EVALUATORS.register_module(name='MultiModality', module=MultiModalityEvaluator) + + +def build_evaluator(metric, eval_cfg, data_len, eval_indexes): + cfg = copy.deepcopy(eval_cfg) + cfg.update(metric) + cfg.pop('metrics') + cfg['data_len'] = data_len + cfg['eval_indexes'] = eval_indexes + evaluator = EVALUATORS.build(cfg) + if evaluator.append_indexes is not None: + for i in range(eval_cfg['replication_times']): + eval_indexes[i] = np.concatenate((eval_indexes[i], evaluator.append_indexes[i]), axis=0) + return evaluator, eval_indexes diff --git a/mogen/core/evaluation/eval_hooks.py b/mogen/core/evaluation/eval_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..b46f5a57fb629add10450f897cf25e49dab2f002 --- /dev/null +++ b/mogen/core/evaluation/eval_hooks.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import tempfile +import warnings + +from mmcv.runner import DistEvalHook as BaseDistEvalHook +from mmcv.runner import EvalHook as BaseEvalHook + +mogen_GREATER_KEYS = [] +mogen_LESS_KEYS = [] + + +class EvalHook(BaseEvalHook): + + def __init__(self, + dataloader, + start=None, + interval=1, + by_epoch=True, + save_best=None, + rule=None, + test_fn=None, + greater_keys=mogen_GREATER_KEYS, + less_keys=mogen_LESS_KEYS, + **eval_kwargs): + if test_fn is None: + from mogen.apis import single_gpu_test + test_fn = single_gpu_test + + # remove "gpu_collect" from eval_kwargs + if 'gpu_collect' in eval_kwargs: + warnings.warn( + '"gpu_collect" will be deprecated in EvalHook.' + 'Please remove it from the config.', DeprecationWarning) + _ = eval_kwargs.pop('gpu_collect') + + # update "save_best" according to "key_indicator" and remove the + # latter from eval_kwargs + if 'key_indicator' in eval_kwargs or isinstance(save_best, bool): + warnings.warn( + '"key_indicator" will be deprecated in EvalHook.' + 'Please use "save_best" to specify the metric key,' + 'e.g., save_best="pa-mpjpe".', DeprecationWarning) + + key_indicator = eval_kwargs.pop('key_indicator', None) + if save_best is True and key_indicator is None: + raise ValueError('key_indicator should not be None, when ' + 'save_best is set to True.') + save_best = key_indicator + + super().__init__(dataloader, start, interval, by_epoch, save_best, + rule, test_fn, greater_keys, less_keys, **eval_kwargs) + + def evaluate(self, runner, results): + + with tempfile.TemporaryDirectory() as tmp_dir: + eval_res = self.dataloader.dataset.evaluate( + results, + work_dir=tmp_dir, + logger=runner.logger, + **self.eval_kwargs) + + for name, val in eval_res.items(): + runner.log_buffer.output[name] = val + runner.log_buffer.ready = True + + if self.save_best is not None: + if self.key_indicator == 'auto': + self._init_rule(self.rule, list(eval_res.keys())[0]) + + return eval_res[self.key_indicator] + + return None + + +class DistEvalHook(BaseDistEvalHook): + + def __init__(self, + dataloader, + start=None, + interval=1, + by_epoch=True, + save_best=None, + rule=None, + test_fn=None, + greater_keys=mogen_GREATER_KEYS, + less_keys=mogen_LESS_KEYS, + broadcast_bn_buffer=True, + tmpdir=None, + gpu_collect=False, + **eval_kwargs): + + if test_fn is None: + from mogen.apis import multi_gpu_test + test_fn = multi_gpu_test + + # update "save_best" according to "key_indicator" and remove the + # latter from eval_kwargs + if 'key_indicator' in eval_kwargs or isinstance(save_best, bool): + warnings.warn( + '"key_indicator" will be deprecated in EvalHook.' + 'Please use "save_best" to specify the metric key,' + 'e.g., save_best="pa-mpjpe".', DeprecationWarning) + + key_indicator = eval_kwargs.pop('key_indicator', None) + if save_best is True and key_indicator is None: + raise ValueError('key_indicator should not be None, when ' + 'save_best is set to True.') + save_best = key_indicator + + super().__init__(dataloader, start, interval, by_epoch, save_best, + rule, test_fn, greater_keys, less_keys, + broadcast_bn_buffer, tmpdir, gpu_collect, + **eval_kwargs) + + def evaluate(self, runner, results): + """Evaluate the results. + Args: + runner (:obj:`mmcv.Runner`): The underlined training runner. + results (list): Output results. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + eval_res = self.dataloader.dataset.evaluate( + results, + work_dir=tmp_dir, + logger=runner.logger, + **self.eval_kwargs) + + for name, val in eval_res.items(): + runner.log_buffer.output[name] = val + runner.log_buffer.ready = True + + if self.save_best is not None: + if self.key_indicator == 'auto': + # infer from eval_results + self._init_rule(self.rule, list(eval_res.keys())[0]) + return eval_res[self.key_indicator] + + return None \ No newline at end of file diff --git a/mogen/core/evaluation/evaluators/__init__.py b/mogen/core/evaluation/evaluators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mogen/core/evaluation/evaluators/base_evaluator.py b/mogen/core/evaluation/evaluators/base_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..e63ed45a550e879a57bef9cb9fff251d73181169 --- /dev/null +++ b/mogen/core/evaluation/evaluators/base_evaluator.py @@ -0,0 +1,144 @@ +import torch +import numpy as np +from ..utils import get_metric_statistics + + +class BaseEvaluator(object): + + def __init__(self, + batch_size=None, + drop_last=False, + replication_times=1, + replication_reduction='statistics', + eval_begin_idx=None, + eval_end_idx=None): + self.batch_size = batch_size + self.drop_last = drop_last + self.replication_times = replication_times + self.replication_reduction = replication_reduction + assert replication_reduction in ['statistics', 'mean', 'concat'] + self.eval_begin_idx = eval_begin_idx + self.eval_end_idx = eval_end_idx + + def evaluate(self, results): + total_len = len(results) + partial_len = total_len // self.replication_times + all_metrics = [] + for replication_idx in range(self.replication_times): + partial_results = results[ + replication_idx * partial_len: (replication_idx + 1) * partial_len] + if self.batch_size is not None: + batch_metrics = [] + for batch_start in range(self.eval_begin_idx, self.eval_end_idx, self.batch_size): + batch_results = partial_results[batch_start: batch_start + self.batch_size] + if len(batch_results) < self.batch_size and self.drop_last: + continue + batch_metrics.append(self.single_evaluate(batch_results)) + all_metrics.append(self.concat_batch_metrics(batch_metrics)) + else: + batch_results = partial_results[self.eval_begin_idx: self.eval_end_idx] + all_metrics.append(self.single_evaluate(batch_results)) + all_metrics = np.stack(all_metrics, axis=0) + if self.replication_reduction == 'statistics': + values = get_metric_statistics(all_metrics, self.replication_times) + elif self.replication_reduction == 'mean': + values = np.mean(all_metrics, axis=0) + elif self.replication_reduction == 'concat': + values = all_metrics + return self.parse_values(values) + + def prepare_results(self, results): + text = [] + pred_motion = [] + pred_motion_length = [] + pred_motion_mask = [] + motion = [] + motion_length = [] + motion_mask = [] + token = [] + # count the maximum motion length + T = max([result['motion'].shape[0] for result in results]) + for result in results: + cur_motion = result['motion'] + if cur_motion.shape[0] < T: + padding_values = torch.zeros((T - cur_motion.shape[0], cur_motion.shape[1])) + padding_values = padding_values.type_as(pred_motion) + cur_motion = torch.cat([cur_motion, padding_values], dim=0) + motion.append(cur_motion) + cur_pred_motion = result['pred_motion'] + if cur_pred_motion.shape[0] < T: + padding_values = torch.zeros((T - cur_pred_motion.shape[0], cur_pred_motion.shape[1])) + padding_values = padding_values.type_as(cur_pred_motion) + cur_pred_motion = torch.cat([cur_pred_motion, padding_values], dim=0) + pred_motion.append(cur_pred_motion) + cur_motion_mask = result['motion_mask'] + if cur_motion_mask.shape[0] < T: + padding_values = torch.zeros((T - cur_motion_mask.shape[0])) + padding_values = padding_values.type_as(cur_motion_mask) + cur_motion_mask= torch.cat([cur_motion_mask, padding_values], dim=0) + motion_mask.append(cur_motion_mask) + cur_pred_motion_mask = result['pred_motion_mask'] + if cur_pred_motion_mask.shape[0] < T: + padding_values = torch.zeros((T - cur_pred_motion_mask.shape[0])) + padding_values = padding_values.type_as(cur_pred_motion_mask) + cur_pred_motion_mask= torch.cat([cur_pred_motion_mask, padding_values], dim=0) + pred_motion_mask.append(cur_pred_motion_mask) + motion_length.append(result['motion_length'].item()) + pred_motion_length.append(result['pred_motion_length'].item()) + if 'text' in result.keys(): + text.append(result['text']) + if 'token' in result.keys(): + token.append(result['token']) + + motion = torch.stack(motion, dim=0) + pred_motion = torch.stack(pred_motion, dim=0) + motion_mask = torch.stack(motion_mask, dim=0) + pred_motion_mask = torch.stack(pred_motion_mask, dim=0) + motion_length = torch.Tensor(motion_length).to(motion.device).long() + pred_motion_length = torch.Tensor(pred_motion_length).to(motion.device).long() + output = { + 'pred_motion': pred_motion, + 'pred_motion_mask': pred_motion_mask, + 'pred_motion_length': pred_motion_length, + 'motion': motion, + 'motion_mask': motion_mask, + 'motion_length': motion_length, + 'text': text, + 'token': token + } + return output + + def to_device(self, device): + for model in self.model_list: + model.to(device) + + def motion_encode(self, motion, motion_length, motion_mask, device): + N = motion.shape[0] + motion_emb = [] + batch_size = 32 + cur_idx = 0 + with torch.no_grad(): + while cur_idx < N: + cur_motion = motion[cur_idx: cur_idx + batch_size].to(device) + cur_motion_length = motion_length[cur_idx: cur_idx + batch_size].to(device) + cur_motion_mask = motion_mask[cur_idx: cur_idx + batch_size].to(device) + cur_motion_emb = self.motion_encoder(cur_motion, cur_motion_length, cur_motion_mask) + motion_emb.append(cur_motion_emb) + cur_idx += batch_size + motion_emb = torch.cat(motion_emb, dim=0) + return motion_emb + + def text_encode(self, text, token, device): + N = len(text) + text_emb = [] + batch_size = 32 + cur_idx = 0 + with torch.no_grad(): + while cur_idx < N: + cur_text = text[cur_idx: cur_idx + batch_size] + cur_token = token[cur_idx: cur_idx + batch_size] + cur_text_emb = self.text_encoder(cur_text, cur_token, device) + text_emb.append(cur_text_emb) + cur_idx += batch_size + text_emb = torch.cat(text_emb, dim=0) + return text_emb \ No newline at end of file diff --git a/mogen/core/evaluation/evaluators/diversity_evaluator.py b/mogen/core/evaluation/evaluators/diversity_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..3bce77edf856a77ec7d5d1b0ec7f9fb0fc1a2f82 --- /dev/null +++ b/mogen/core/evaluation/evaluators/diversity_evaluator.py @@ -0,0 +1,52 @@ +import numpy as np +import torch + +from ..get_model import get_motion_model +from .base_evaluator import BaseEvaluator +from ..utils import calculate_diversity + + +class DiversityEvaluator(BaseEvaluator): + + def __init__(self, + data_len=0, + motion_encoder_name=None, + motion_encoder_path=None, + num_samples=300, + batch_size=None, + drop_last=False, + replication_times=1, + replication_reduction='statistics', + **kwargs): + super().__init__( + replication_times=replication_times, + replication_reduction=replication_reduction, + batch_size=batch_size, + drop_last=drop_last, + eval_begin_idx=0, + eval_end_idx=data_len + ) + self.num_samples = num_samples + self.append_indexes = None + self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path) + self.model_list = [self.motion_encoder] + + def single_evaluate(self, results): + results = self.prepare_results(results) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + motion = results['motion'] + pred_motion = results['pred_motion'] + pred_motion_length = results['pred_motion_length'] + pred_motion_mask = results['pred_motion_mask'] + self.motion_encoder.to(device) + self.motion_encoder.eval() + with torch.no_grad(): + pred_motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy() + diversity = calculate_diversity(pred_motion_emb, self.num_samples) + return diversity + + def parse_values(self, values): + metrics = {} + metrics['Diversity (mean)'] = values[0] + metrics['Diversity (conf)'] = values[1] + return metrics diff --git a/mogen/core/evaluation/evaluators/fid_evaluator.py b/mogen/core/evaluation/evaluators/fid_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..627910d0ddc0f81f55c436120ce383837927e100 --- /dev/null +++ b/mogen/core/evaluation/evaluators/fid_evaluator.py @@ -0,0 +1,58 @@ +import numpy as np +import torch + +from ..get_model import get_motion_model +from .base_evaluator import BaseEvaluator +from ..utils import ( + calculate_activation_statistics, + calculate_frechet_distance) + + +class FIDEvaluator(BaseEvaluator): + + def __init__(self, + data_len=0, + motion_encoder_name=None, + motion_encoder_path=None, + batch_size=None, + drop_last=False, + replication_times=1, + replication_reduction='statistics', + **kwargs): + super().__init__( + replication_times=replication_times, + replication_reduction=replication_reduction, + batch_size=batch_size, + drop_last=drop_last, + eval_begin_idx=0, + eval_end_idx=data_len + ) + self.append_indexes = None + self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path) + self.model_list = [self.motion_encoder] + + def single_evaluate(self, results): + results = self.prepare_results(results) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + pred_motion = results['pred_motion'] + + pred_motion_length = results['pred_motion_length'] + pred_motion_mask = results['pred_motion_mask'] + motion = results['motion'] + motion_length = results['motion_length'] + motion_mask = results['motion_mask'] + self.motion_encoder.to(device) + self.motion_encoder.eval() + with torch.no_grad(): + pred_motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy() + gt_motion_emb = self.motion_encode(motion, motion_length, motion_mask, device).cpu().detach().numpy() + gt_mu, gt_cov = calculate_activation_statistics(gt_motion_emb) + pred_mu, pred_cov = calculate_activation_statistics(pred_motion_emb) + fid = calculate_frechet_distance(gt_mu, gt_cov, pred_mu, pred_cov) + return fid + + def parse_values(self, values): + metrics = {} + metrics['FID (mean)'] = values[0] + metrics['FID (conf)'] = values[1] + return metrics diff --git a/mogen/core/evaluation/evaluators/matching_score_evaluator.py b/mogen/core/evaluation/evaluators/matching_score_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d34414e4e7395a5fc2eba488f31a3fb9c5914c --- /dev/null +++ b/mogen/core/evaluation/evaluators/matching_score_evaluator.py @@ -0,0 +1,71 @@ +import numpy as np +import torch + +from ..get_model import get_motion_model, get_text_model +from .base_evaluator import BaseEvaluator +from ..utils import calculate_top_k, euclidean_distance_matrix + + +class MatchingScoreEvaluator(BaseEvaluator): + + def __init__(self, + data_len=0, + text_encoder_name=None, + text_encoder_path=None, + motion_encoder_name=None, + motion_encoder_path=None, + top_k=3, + batch_size=32, + drop_last=False, + replication_times=1, + replication_reduction='statistics', + **kwargs): + super().__init__( + replication_times=replication_times, + replication_reduction=replication_reduction, + batch_size=batch_size, + drop_last=drop_last, + eval_begin_idx=0, + eval_end_idx=data_len + ) + self.append_indexes = None + self.text_encoder = get_text_model(text_encoder_name, text_encoder_path) + self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path) + self.top_k = top_k + self.model_list = [self.text_encoder, self.motion_encoder] + + def single_evaluate(self, results): + results = self.prepare_results(results) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + motion = results['motion'] + pred_motion = results['pred_motion'] + pred_motion_length = results['pred_motion_length'] + pred_motion_mask = results['pred_motion_mask'] + text = results['text'] + token = results['token'] + self.text_encoder.to(device) + self.motion_encoder.to(device) + self.text_encoder.eval() + self.motion_encoder.eval() + with torch.no_grad(): + word_emb = self.text_encode(text, token, device=device).cpu().detach().numpy() + motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy() + dist_mat = euclidean_distance_matrix(word_emb, motion_emb) + matching_score = dist_mat.trace() + all_size = word_emb.shape[0] + return matching_score, all_size + + def concat_batch_metrics(self, batch_metrics): + matching_score_sum = 0 + all_size = 0 + for batch_matching_score, batch_all_size in batch_metrics: + matching_score_sum += batch_matching_score + all_size += batch_all_size + matching_score = matching_score_sum / all_size + return matching_score + + def parse_values(self, values): + metrics = {} + metrics['Matching Score (mean)'] = values[0] + metrics['Matching Score (conf)'] = values[1] + return metrics diff --git a/mogen/core/evaluation/evaluators/multimodality_evaluator.py b/mogen/core/evaluation/evaluators/multimodality_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..16f0b18bd7f60f2991d77e93f72f376c2f877998 --- /dev/null +++ b/mogen/core/evaluation/evaluators/multimodality_evaluator.py @@ -0,0 +1,63 @@ +import numpy as np +import torch + +from ..get_model import get_motion_model +from .base_evaluator import BaseEvaluator +from ..utils import calculate_multimodality + + +class MultiModalityEvaluator(BaseEvaluator): + + def __init__(self, + data_len=0, + motion_encoder_name=None, + motion_encoder_path=None, + num_samples=100, + num_repeats=30, + num_picks=10, + batch_size=None, + drop_last=False, + replication_times=1, + replication_reduction='statistics', + **kwargs): + super().__init__( + replication_times=replication_times, + replication_reduction=replication_reduction, + batch_size=batch_size, + drop_last=drop_last, + eval_begin_idx=data_len, + eval_end_idx=data_len + num_samples * num_repeats + ) + self.num_samples = num_samples + self.num_repeats = num_repeats + self.num_picks = num_picks + self.append_indexes = [] + for i in range(replication_times): + append_indexes = [] + selected_indexs = np.random.choice(data_len, self.num_samples) + for index in selected_indexs: + append_indexes = append_indexes + [index] * self.num_repeats + self.append_indexes.append(np.array(append_indexes)) + self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path) + self.model_list = [self.motion_encoder] + + def single_evaluate(self, results): + results = self.prepare_results(results) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + motion = results['motion'] + pred_motion = results['pred_motion'] + pred_motion_length = results['pred_motion_length'] + pred_motion_mask = results['pred_motion_mask'] + self.motion_encoder.to(device) + self.motion_encoder.eval() + with torch.no_grad(): + pred_motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy() + pred_motion_emb = pred_motion_emb.reshape((self.num_samples, self.num_repeats, -1)) + multimodality = calculate_multimodality(pred_motion_emb, self.num_picks) + return multimodality + + def parse_values(self, values): + metrics = {} + metrics['MultiModality (mean)'] = values[0] + metrics['MultiModality (conf)'] = values[1] + return metrics diff --git a/mogen/core/evaluation/evaluators/precision_evaluator.py b/mogen/core/evaluation/evaluators/precision_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..09063a7ad778e559216a3746564007ee30a6516d --- /dev/null +++ b/mogen/core/evaluation/evaluators/precision_evaluator.py @@ -0,0 +1,74 @@ +import numpy as np +import torch + +from ..get_model import get_motion_model, get_text_model +from .base_evaluator import BaseEvaluator +from ..utils import calculate_top_k, euclidean_distance_matrix + + +class PrecisionEvaluator(BaseEvaluator): + + def __init__(self, + data_len=0, + text_encoder_name=None, + text_encoder_path=None, + motion_encoder_name=None, + motion_encoder_path=None, + top_k=3, + batch_size=32, + drop_last=False, + replication_times=1, + replication_reduction='statistics', + **kwargs): + super().__init__( + replication_times=replication_times, + replication_reduction=replication_reduction, + batch_size=batch_size, + drop_last=drop_last, + eval_begin_idx=0, + eval_end_idx=data_len + ) + self.append_indexes = None + self.text_encoder = get_text_model(text_encoder_name, text_encoder_path) + self.motion_encoder = get_motion_model(motion_encoder_name, motion_encoder_path) + self.top_k = top_k + self.model_list = [self.text_encoder, self.motion_encoder] + + def single_evaluate(self, results): + results = self.prepare_results(results) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + motion = results['motion'] + pred_motion = results['pred_motion'] + pred_motion_length = results['pred_motion_length'] + pred_motion_mask = results['pred_motion_mask'] + text = results['text'] + token = results['token'] + self.text_encoder.to(device) + self.motion_encoder.to(device) + self.text_encoder.eval() + self.motion_encoder.eval() + with torch.no_grad(): + word_emb = self.text_encode(text, token, device=device).cpu().detach().numpy() + motion_emb = self.motion_encode(pred_motion, pred_motion_length, pred_motion_mask, device).cpu().detach().numpy() + dist_mat = euclidean_distance_matrix(word_emb, motion_emb) + argsmax = np.argsort(dist_mat, axis=1) + top_k_mat = calculate_top_k(argsmax, top_k=self.top_k) + top_k_count = top_k_mat.sum(axis=0) + all_size = word_emb.shape[0] + return top_k_count, all_size + + def concat_batch_metrics(self, batch_metrics): + top_k_count = 0 + all_size = 0 + for batch_top_k_count, batch_all_size in batch_metrics: + top_k_count += batch_top_k_count + all_size += batch_all_size + R_precision = top_k_count / all_size + return R_precision + + def parse_values(self, values): + metrics = {} + for top_k in range(self.top_k): + metrics['R_precision Top %d (mean)' % (top_k + 1)] = values[0][top_k] + metrics['R_precision Top %d (conf)' % (top_k + 1)] = values[1][top_k] + return metrics diff --git a/mogen/core/evaluation/get_model.py b/mogen/core/evaluation/get_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ef7f444a35f617ae156f7aa04ac3844b6dacb4ef --- /dev/null +++ b/mogen/core/evaluation/get_model.py @@ -0,0 +1,46 @@ +from mogen.models import build_submodule + + +def get_motion_model(name, ckpt_path): + if name == 'kit_ml': + model = build_submodule(dict( + type='T2MMotionEncoder', + input_size=251, + movement_hidden_size=512, + movement_latent_size=512, + motion_hidden_size=1024, + motion_latent_size=512, + )) + else: + model = build_submodule(dict( + type='T2MMotionEncoder', + input_size=263, + movement_hidden_size=512, + movement_latent_size=512, + motion_hidden_size=1024, + motion_latent_size=512, + )) + model.load_pretrained(ckpt_path) + return model + +def get_text_model(name, ckpt_path): + if name == 'kit_ml': + model = build_submodule(dict( + type='T2MTextEncoder', + word_size=300, + pos_size=15, + hidden_size=512, + output_size=512, + max_text_len=20 + )) + else: + model = build_submodule(dict( + type='T2MTextEncoder', + word_size=300, + pos_size=15, + hidden_size=512, + output_size=512, + max_text_len=20 + )) + model.load_pretrained(ckpt_path) + return model diff --git a/mogen/core/evaluation/utils.py b/mogen/core/evaluation/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e91d4acc5692fa134ad2cec2664c6af56ed351b1 --- /dev/null +++ b/mogen/core/evaluation/utils.py @@ -0,0 +1,130 @@ +import numpy as np +from scipy import linalg + + +def get_metric_statistics(values, replication_times): + mean = np.mean(values, axis=0) + std = np.std(values, axis=0) + conf_interval = 1.96 * std / np.sqrt(replication_times) + return mean, conf_interval + + +# (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train +def euclidean_distance_matrix(matrix1, matrix2): + """ + Params: + -- matrix1: N1 x D + -- matrix2: N2 x D + Returns: + -- dist: N1 x N2 + dist[i, j] == distance(matrix1[i], matrix2[j]) + """ + assert matrix1.shape[1] == matrix2.shape[1] + d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train) + d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1) + d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, ) + dists = np.sqrt(d1 + d2 + d3) # broadcasting + return dists + + +def calculate_top_k(mat, top_k): + size = mat.shape[0] + gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1) + bool_mat = (mat == gt_mat) + correct_vec = False + top_k_list = [] + for i in range(top_k): +# print(correct_vec, bool_mat[:, i]) + correct_vec = (correct_vec | bool_mat[:, i]) + # print(correct_vec) + top_k_list.append(correct_vec[:, None]) + top_k_mat = np.concatenate(top_k_list, axis=1) + return top_k_mat + + +def calculate_activation_statistics(activations): + """ + Params: + -- activation: num_samples x dim_feat + Returns: + -- mu: dim_feat + -- sigma: dim_feat x dim_feat + """ + mu = np.mean(activations, axis=0) + cov = np.cov(activations, rowvar=False) + return mu, cov + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + + np.trace(sigma2) - 2 * tr_covmean) + + +def calculate_diversity(activation, diversity_times): + assert len(activation.shape) == 2 + assert activation.shape[0] > diversity_times + num_samples = activation.shape[0] + + first_indices = np.random.choice(num_samples, diversity_times, replace=False) + second_indices = np.random.choice(num_samples, diversity_times, replace=False) + dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1) + return dist.mean() + + +def calculate_multimodality(activation, multimodality_times): + assert len(activation.shape) == 3 + assert activation.shape[1] > multimodality_times + num_per_sent = activation.shape[1] + + first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) + second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False) + dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2) + return dist.mean() diff --git a/mogen/core/optimizer/__init__.py b/mogen/core/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..193050bbf7a02fe1ad06e783ea0c010bed2b7484 --- /dev/null +++ b/mogen/core/optimizer/__init__.py @@ -0,0 +1,3 @@ +from .builder import OPTIMIZERS, build_optimizers + +__all__ = ['build_optimizers', 'OPTIMIZERS'] \ No newline at end of file diff --git a/mogen/core/optimizer/builder.py b/mogen/core/optimizer/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c75616b6efc4a82ea725cd2d9f871eade213a9 --- /dev/null +++ b/mogen/core/optimizer/builder.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.runner import build_optimizer +from mmcv.utils import Registry + +OPTIMIZERS = Registry('optimizers') + + +def build_optimizers(model, cfgs): + """Build multiple optimizers from configs. If `cfgs` contains several dicts + for optimizers, then a dict for each constructed optimizers will be + returned. If `cfgs` only contains one optimizer config, the constructed + optimizer itself will be returned. For example, + + 1) Multiple optimizer configs: + + .. code-block:: python + + optimizer_cfg = dict( + model1=dict(type='SGD', lr=lr), + model2=dict(type='SGD', lr=lr)) + + The return dict is + ``dict('model1': torch.optim.Optimizer, 'model2': torch.optim.Optimizer)`` + + 2) Single optimizer config: + + .. code-block:: python + + optimizer_cfg = dict(type='SGD', lr=lr) + + The return is ``torch.optim.Optimizer``. + + Args: + model (:obj:`nn.Module`): The model with parameters to be optimized. + cfgs (dict): The config dict of the optimizer. + + Returns: + dict[:obj:`torch.optim.Optimizer`] | :obj:`torch.optim.Optimizer`: + The initialized optimizers. + """ + optimizers = {} + if hasattr(model, 'module'): + model = model.module + # determine whether 'cfgs' has several dicts for optimizers + if all(isinstance(v, dict) for v in cfgs.values()): + for key, cfg in cfgs.items(): + cfg_ = cfg.copy() + module = getattr(model, key) + optimizers[key] = build_optimizer(module, cfg_) + return optimizers + + return build_optimizer(model, cfgs) \ No newline at end of file diff --git a/mogen/datasets/__init__.py b/mogen/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..403118ccaaf5c746530aa2240f81d794d38c13d3 --- /dev/null +++ b/mogen/datasets/__init__.py @@ -0,0 +1,11 @@ +from .base_dataset import BaseMotionDataset +from .text_motion_dataset import TextMotionDataset +from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset +from .pipelines import Compose +from .samplers import DistributedSampler + + +__all__ = [ + 'BaseMotionDataset', 'TextMotionDataset', 'DATASETS', 'PIPELINES', 'build_dataloader', + 'build_dataset', 'Compose', 'DistributedSampler' +] \ No newline at end of file diff --git a/mogen/datasets/base_dataset.py b/mogen/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d99fe2eb63bcd874171e3931c21eb7126e9faed5 --- /dev/null +++ b/mogen/datasets/base_dataset.py @@ -0,0 +1,117 @@ +import os +import copy +from typing import Optional, Union + +import numpy as np +from torch.utils.data import Dataset + +from .pipelines import Compose +from .builder import DATASETS +from mogen.core.evaluation import build_evaluator + + +@DATASETS.register_module() +class BaseMotionDataset(Dataset): + """Base motion dataset. + Args: + data_prefix (str): the prefix of data path. + pipeline (list): a list of dict, where each element represents + a operation defined in `mogen.datasets.pipelines`. + ann_file (str | None, optional): the annotation file. When ann_file is + str, the subclass is expected to read from the ann_file. When + ann_file is None, the subclass is expected to read according + to data_prefix. + test_mode (bool): in train mode or test mode. Default: None. + dataset_name (str | None, optional): the name of dataset. It is used + to identify the type of evaluation metric. Default: None. + """ + + def __init__(self, + data_prefix: str, + pipeline: list, + dataset_name: Optional[Union[str, None]] = None, + fixed_length: Optional[Union[int, None]] = None, + ann_file: Optional[Union[str, None]] = None, + motion_dir: Optional[Union[str, None]] = None, + eval_cfg: Optional[Union[dict, None]] = None, + test_mode: Optional[bool] = False): + super(BaseMotionDataset, self).__init__() + + self.data_prefix = data_prefix + self.pipeline = Compose(pipeline) + self.dataset_name = dataset_name + self.fixed_length = fixed_length + self.ann_file = os.path.join(data_prefix, 'datasets', dataset_name, ann_file) + self.motion_dir = os.path.join(data_prefix, 'datasets', dataset_name, motion_dir) + self.eval_cfg = copy.deepcopy(eval_cfg) + self.test_mode = test_mode + + self.load_annotations() + if self.test_mode: + self.prepare_evaluation() + + def load_anno(self, name): + motion_path = os.path.join(self.motion_dir, name + '.npy') + motion_data = np.load(motion_path) + return {'motion': motion_data} + + + def load_annotations(self): + """Load annotations from ``ann_file`` to ``data_infos``""" + self.data_infos = [] + for line in open(self.ann_file, 'r').readlines(): + line = line.strip() + self.data_infos.append(self.load_anno(line)) + + + def prepare_data(self, idx: int): + """"Prepare raw data for the f'{idx'}-th data.""" + results = copy.deepcopy(self.data_infos[idx]) + results['dataset_name'] = self.dataset_name + results['sample_idx'] = idx + return self.pipeline(results) + + def __len__(self): + """Return the length of current dataset.""" + if self.test_mode: + return len(self.eval_indexes) + elif self.fixed_length is not None: + return self.fixed_length + return len(self.data_infos) + + def __getitem__(self, idx: int): + """Prepare data for the ``idx``-th data. + As for video dataset, we can first parse raw data for each frame. Then + we combine annotations from all frames. This interface is used to + simplify the logic of video dataset and other special datasets. + """ + if self.test_mode: + idx = self.eval_indexes[idx] + elif self.fixed_length is not None: + idx = idx % len(self.data_infos) + return self.prepare_data(idx) + + def prepare_evaluation(self): + self.evaluators = [] + self.eval_indexes = [] + for _ in range(self.eval_cfg['replication_times']): + eval_indexes = np.arange(len(self.data_infos)) + if self.eval_cfg.get('shuffle_indexes', False): + np.random.shuffle(eval_indexes) + self.eval_indexes.append(eval_indexes) + for metric in self.eval_cfg['metrics']: + evaluator, self.eval_indexes = build_evaluator( + metric, self.eval_cfg, len(self.data_infos), self.eval_indexes) + self.evaluators.append(evaluator) + + self.eval_indexes = np.concatenate(self.eval_indexes) + + def evaluate(self, results, work_dir, logger=None): + metrics = {} + device = results[0]['motion'].device + for evaluator in self.evaluators: + evaluator.to_device(device) + metrics.update(evaluator.evaluate(results)) + if logger is not None: + logger.info(metrics) + return metrics diff --git a/mogen/datasets/builder.py b/mogen/datasets/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..f4f77d18503d7e432aa699a4613637deb86c2e6d --- /dev/null +++ b/mogen/datasets/builder.py @@ -0,0 +1,113 @@ +import platform +import random +from functools import partial +from typing import Optional, Union + +import numpy as np +from mmcv.parallel import collate +from mmcv.runner import get_dist_info +from mmcv.utils import Registry, build_from_cfg +from torch.utils.data import DataLoader +from torch.utils.data.dataset import Dataset + +from .samplers import DistributedSampler + +if platform.system() != 'Windows': + # https://github.com/pytorch/pytorch/issues/973 + import resource + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + base_soft_limit = rlimit[0] + hard_limit = rlimit[1] + soft_limit = min(max(4096, base_soft_limit), hard_limit) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) + +DATASETS = Registry('dataset') +PIPELINES = Registry('pipeline') + + +def build_dataset(cfg: Union[dict, list, tuple], + default_args: Optional[Union[dict, None]] = None): + """"Build dataset by the given config.""" + from .dataset_wrappers import ( + ConcatDataset, + RepeatDataset, + ) + if isinstance(cfg, (list, tuple)): + dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) + elif cfg['type'] == 'RepeatDataset': + dataset = RepeatDataset( + build_dataset(cfg['dataset'], default_args), cfg['times']) + else: + dataset = build_from_cfg(cfg, DATASETS, default_args) + + return dataset + + +def build_dataloader(dataset: Dataset, + samples_per_gpu: int, + workers_per_gpu: int, + num_gpus: Optional[int] = 1, + dist: Optional[bool] = True, + shuffle: Optional[bool] = True, + round_up: Optional[bool] = True, + seed: Optional[Union[int, None]] = None, + persistent_workers: Optional[bool] = True, + **kwargs): + """Build PyTorch DataLoader. + In distributed training, each GPU/process has a dataloader. + In non-distributed training, there is only one dataloader for all GPUs. + Args: + dataset (:obj:`Dataset`): A PyTorch dataset. + samples_per_gpu (int): Number of training samples on each GPU, i.e., + batch size of each GPU. + workers_per_gpu (int): How many subprocesses to use for data loading + for each GPU. + num_gpus (int, optional): Number of GPUs. Only used in non-distributed + training. + dist (bool, optional): Distributed training/test or not. Default: True. + shuffle (bool, optional): Whether to shuffle the data at every epoch. + Default: True. + round_up (bool, optional): Whether to round up the length of dataset by + adding extra samples to make it evenly divisible. Default: True. + kwargs: any keyword argument to be used to initialize DataLoader + Returns: + DataLoader: A PyTorch dataloader. + """ + rank, world_size = get_dist_info() + if dist: + sampler = DistributedSampler( + dataset, world_size, rank, shuffle=shuffle, round_up=round_up) + shuffle = False + batch_size = samples_per_gpu + num_workers = workers_per_gpu + else: + sampler = None + batch_size = num_gpus * samples_per_gpu + num_workers = num_gpus * workers_per_gpu + + init_fn = partial( + worker_init_fn, num_workers=num_workers, rank=rank, + seed=seed) if seed is not None else None + + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), + pin_memory=False, + shuffle=shuffle, + worker_init_fn=init_fn, + persistent_workers=persistent_workers, + **kwargs) + + return data_loader + + +def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): + """Init random seed for each worker.""" + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) \ No newline at end of file diff --git a/mogen/datasets/dataset_wrappers.py b/mogen/datasets/dataset_wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..a86d0923d698c176188a85ef7e81266b7e899ddc --- /dev/null +++ b/mogen/datasets/dataset_wrappers.py @@ -0,0 +1,42 @@ +from torch.utils.data.dataset import ConcatDataset as _ConcatDataset +from torch.utils.data.dataset import Dataset + +from .builder import DATASETS + + +@DATASETS.register_module() +class ConcatDataset(_ConcatDataset): + """A wrapper of concatenated dataset. + Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but + add `get_cat_ids` function. + Args: + datasets (list[:obj:`Dataset`]): A list of datasets. + """ + + def __init__(self, datasets: list): + super(ConcatDataset, self).__init__(datasets) + + +@DATASETS.register_module() +class RepeatDataset(object): + """A wrapper of repeated dataset. + The length of repeated dataset will be `times` larger than the original + dataset. This is useful when the data loading time is long but the dataset + is small. Using RepeatDataset can reduce the data loading time between + epochs. + Args: + dataset (:obj:`Dataset`): The dataset to be repeated. + times (int): Repeat times. + """ + + def __init__(self, dataset: Dataset, times: int): + self.dataset = dataset + self.times = times + + self._ori_len = len(self.dataset) + + def __getitem__(self, idx: int): + return self.dataset[idx % self._ori_len] + + def __len__(self): + return self.times * self._ori_len \ No newline at end of file diff --git a/mogen/datasets/pipelines/__init__.py b/mogen/datasets/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2700cf741142781074d53366d1f606fcddbf7933 --- /dev/null +++ b/mogen/datasets/pipelines/__init__.py @@ -0,0 +1,18 @@ +from .compose import Compose +from .formatting import ( + to_tensor, + ToTensor, + Transpose, + Collect, + WrapFieldsToLists +) +from .transforms import ( + Crop, + RandomCrop, + Normalize +) + +__all__ = [ + 'Compose', 'to_tensor', 'Transpose', 'Collect', 'WrapFieldsToLists', 'ToTensor', + 'Crop', 'RandomCrop', 'Normalize' +] \ No newline at end of file diff --git a/mogen/datasets/pipelines/compose.py b/mogen/datasets/pipelines/compose.py new file mode 100644 index 0000000000000000000000000000000000000000..60a4dfc361b48789b6eba656a2208a98e10bcf39 --- /dev/null +++ b/mogen/datasets/pipelines/compose.py @@ -0,0 +1,42 @@ +from collections.abc import Sequence + +from mmcv.utils import build_from_cfg + +from ..builder import PIPELINES + + +@PIPELINES.register_module() +class Compose(object): + """Compose a data pipeline with a sequence of transforms. + + Args: + transforms (list[dict | callable]): + Either config dicts of transforms or transform objects. + """ + + def __init__(self, transforms): + assert isinstance(transforms, Sequence) + self.transforms = [] + for transform in transforms: + if isinstance(transform, dict): + transform = build_from_cfg(transform, PIPELINES) + self.transforms.append(transform) + elif callable(transform): + self.transforms.append(transform) + else: + raise TypeError('transform must be callable or a dict, but got' + f' {type(transform)}') + + def __call__(self, data): + for t in self.transforms: + data = t(data) + if data is None: + return None + return data + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += f'\n {t}' + format_string += '\n)' + return format_string \ No newline at end of file diff --git a/mogen/datasets/pipelines/formatting.py b/mogen/datasets/pipelines/formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..893ea99feac0226356c0f783b1abb8641851ed13 --- /dev/null +++ b/mogen/datasets/pipelines/formatting.py @@ -0,0 +1,134 @@ +from collections.abc import Sequence + +import mmcv +import numpy as np +import torch +from mmcv.parallel import DataContainer as DC +from PIL import Image + +from ..builder import PIPELINES + + +def to_tensor(data): + """Convert objects of various python types to :obj:`torch.Tensor`. + + Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, + :class:`Sequence`, :class:`int` and :class:`float`. + """ + if isinstance(data, torch.Tensor): + return data + elif isinstance(data, np.ndarray): + return torch.from_numpy(data) + elif isinstance(data, Sequence) and not mmcv.is_str(data): + return torch.tensor(data) + elif isinstance(data, int): + return torch.LongTensor([data]) + elif isinstance(data, float): + return torch.FloatTensor([data]) + else: + raise TypeError( + f'Type {type(data)} cannot be converted to tensor.' + 'Supported types are: `numpy.ndarray`, `torch.Tensor`, ' + '`Sequence`, `int` and `float`') + + +@PIPELINES.register_module() +class ToTensor(object): + + def __init__(self, keys): + self.keys = keys + + def __call__(self, results): + for key in self.keys: + results[key] = to_tensor(results[key]) + return results + + def __repr__(self): + return self.__class__.__name__ + f'(keys={self.keys})' + + +@PIPELINES.register_module() +class Transpose(object): + + def __init__(self, keys, order): + self.keys = keys + self.order = order + + def __call__(self, results): + for key in self.keys: + results[key] = results[key].transpose(self.order) + return results + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, order={self.order})' + + +@PIPELINES.register_module() +class Collect(object): + """Collect data from the loader relevant to the specific task. + + This is usually the last stage of the data loader pipeline. + + Args: + keys (Sequence[str]): Keys of results to be collected in ``data``. + meta_keys (Sequence[str], optional): Meta keys to be converted to + ``mmcv.DataContainer`` and collected in ``data[motion_metas]``. + Default: ``('filename', 'ori_filename', 'ori_shape', 'motion_shape', 'motion_mask')`` + + Returns: + dict: The result dict contains the following keys + - keys in``self.keys`` + - ``motion_metas`` if available + """ + + def __init__(self, + keys, + meta_keys=('filename', 'ori_filename', 'ori_shape', 'motion_shape', 'motion_mask')): + self.keys = keys + self.meta_keys = meta_keys + + def __call__(self, results): + data = {} + motion_meta = {} + for key in self.meta_keys: + if key in results: + motion_meta[key] = results[key] + data['motion_metas'] = DC(motion_meta, cpu_only=True) + for key in self.keys: + data[key] = results[key] + return data + + def __repr__(self): + return self.__class__.__name__ + \ + f'(keys={self.keys}, meta_keys={self.meta_keys})' + + +@PIPELINES.register_module() +class WrapFieldsToLists(object): + """Wrap fields of the data dictionary into lists for evaluation. + + This class can be used as a last step of a test or validation + pipeline for single image evaluation or inference. + + Example: + >>> test_pipeline = [ + >>> dict(type='LoadImageFromFile'), + >>> dict(type='Normalize', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + to_rgb=True), + >>> dict(type='ImageToTensor', keys=['img']), + >>> dict(type='Collect', keys=['img']), + >>> dict(type='WrapIntoLists') + >>> ] + """ + + def __call__(self, results): + # Wrap dict fields into lists + for key, val in results.items(): + results[key] = [val] + return results + + def __repr__(self): + return f'{self.__class__.__name__}()' \ No newline at end of file diff --git a/mogen/datasets/pipelines/transforms.py b/mogen/datasets/pipelines/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..72a127c10dd6f888c8bee0d4bfb44fdba9da3e87 --- /dev/null +++ b/mogen/datasets/pipelines/transforms.py @@ -0,0 +1,120 @@ +import math +import random + +import mmcv +import numpy as np + +from ..builder import PIPELINES +import torch +from typing import Optional, Tuple, Union + + +@PIPELINES.register_module() +class Crop(object): + r"""Crop motion sequences. + + Args: + crop_size (int): The size of the cropped motion sequence. + """ + def __init__(self, + crop_size: Optional[Union[int, None]] = None): + self.crop_size = crop_size + assert self.crop_size is not None + + def __call__(self, results): + motion = results['motion'] + length = len(motion) + if length >= self.crop_size: + idx = random.randint(0, length - self.crop_size) + motion = motion[idx: idx + self.crop_size] + results['motion_length'] = self.crop_size + else: + padding_length = self.crop_size - length + D = motion.shape[1:] + padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) + motion = np.concatenate([motion, padding_zeros], axis=0) + results['motion_length'] = length + assert len(motion) == self.crop_size + results['motion'] = motion + results['motion_shape'] = motion.shape + if length >= self.crop_size: + results['motion_mask'] = torch.ones(self.crop_size).numpy() + else: + results['motion_mask'] = torch.cat( + (torch.ones(length), torch.zeros(self.crop_size - length))).numpy() + return results + + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(crop_size={self.crop_size})' + return repr_str + +@PIPELINES.register_module() +class RandomCrop(object): + r"""Random crop motion sequences. Each sequence will be padded with zeros to the maximum length. + + Args: + min_size (int or None): The minimum size of the cropped motion sequence (inclusive). + max_size (int or None): The maximum size of the cropped motion sequence (inclusive). + """ + def __init__(self, + min_size: Optional[Union[int, None]] = None, + max_size: Optional[Union[int, None]] = None): + self.min_size = min_size + self.max_size = max_size + assert self.min_size is not None + assert self.max_size is not None + + def __call__(self, results): + motion = results['motion'] + length = len(motion) + crop_size = random.randint(self.min_size, self.max_size) + if length > crop_size: + idx = random.randint(0, length - crop_size) + motion = motion[idx: idx + crop_size] + results['motion_length'] = crop_size + else: + results['motion_length'] = length + padding_length = self.max_size - min(crop_size, length) + if padding_length > 0: + D = motion.shape[1:] + padding_zeros = np.zeros((padding_length, *D), dtype=np.float32) + motion = np.concatenate([motion, padding_zeros], axis=0) + results['motion'] = motion + results['motion_shape'] = motion.shape + if length >= self.max_size and crop_size == self.max_size: + results['motion_mask'] = torch.ones(self.max_size).numpy() + else: + results['motion_mask'] = torch.cat(( + torch.ones(min(length, crop_size)), + torch.zeros(self.max_size - min(length, crop_size))), dim=0).numpy() + assert len(motion) == self.max_size + return results + + + def __repr__(self): + repr_str = self.__class__.__name__ + f'(min_size={self.min_size}' + repr_str += f', max_size={self.max_size})' + return repr_str + +@PIPELINES.register_module() +class Normalize(object): + """Normalize motion sequences. + + Args: + mean_path (str): Path of mean file. + std_path (str): Path of std file. + """ + + def __init__(self, mean_path, std_path, eps=1e-9): + self.mean = np.load(mean_path) + self.std = np.load(std_path) + self.eps = eps + + def __call__(self, results): + motion = results['motion'] + motion = (motion - self.mean) / (self.std + self.eps) + results['motion'] = motion + results['motion_norm_mean'] = self.mean + results['motion_norm_std'] = self.std + return results diff --git a/mogen/datasets/samplers/__init__.py b/mogen/datasets/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bb671281d781258dae235ac8a4b7371fa55cac9f --- /dev/null +++ b/mogen/datasets/samplers/__init__.py @@ -0,0 +1,3 @@ +from .distributed_sampler import DistributedSampler + +__all__ = ['DistributedSampler'] \ No newline at end of file diff --git a/mogen/datasets/samplers/distributed_sampler.py b/mogen/datasets/samplers/distributed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a46a086e7b2e99afbb02004bf9ff03c84e13a1c8 --- /dev/null +++ b/mogen/datasets/samplers/distributed_sampler.py @@ -0,0 +1,42 @@ +import torch +from torch.utils.data import DistributedSampler as _DistributedSampler + + +class DistributedSampler(_DistributedSampler): + + def __init__(self, + dataset, + num_replicas=None, + rank=None, + shuffle=True, + round_up=True): + super().__init__(dataset, num_replicas=num_replicas, rank=rank) + self.shuffle = shuffle + self.round_up = round_up + if self.round_up: + self.total_size = self.num_samples * self.num_replicas + else: + self.total_size = len(self.dataset) + + def __iter__(self): + # deterministically shuffle based on epoch + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + if self.round_up: + indices = ( + indices * + int(self.total_size / len(indices) + 1))[:self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + if self.round_up: + assert len(indices) == self.num_samples + + return iter(indices) diff --git a/mogen/datasets/text_motion_dataset.py b/mogen/datasets/text_motion_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d0920264f3e21ed0e0158b0f159886091d8f36be --- /dev/null +++ b/mogen/datasets/text_motion_dataset.py @@ -0,0 +1,93 @@ +import json +import os +import os.path +from abc import ABCMeta +from collections import OrderedDict +from typing import Any, List, Optional, Union + +import mmcv +import copy +import numpy as np +import torch +import torch.distributed as dist +from mmcv.runner import get_dist_info + +from .base_dataset import BaseMotionDataset +from .builder import DATASETS + + +@DATASETS.register_module() +class TextMotionDataset(BaseMotionDataset): + """TextMotion dataset. + + Args: + text_dir (str): Path to the directory containing the text files. + """ + def __init__(self, + data_prefix: str, + pipeline: list, + dataset_name: Optional[Union[str, None]] = None, + fixed_length: Optional[Union[int, None]] = None, + ann_file: Optional[Union[str, None]] = None, + motion_dir: Optional[Union[str, None]] = None, + text_dir: Optional[Union[str, None]] = None, + token_dir: Optional[Union[str, None]] = None, + clip_feat_dir: Optional[Union[str, None]] = None, + eval_cfg: Optional[Union[dict, None]] = None, + fine_mode: Optional[bool] = False, + test_mode: Optional[bool] = False): + self.text_dir = os.path.join(data_prefix, 'datasets', dataset_name, text_dir) + if token_dir is not None: + self.token_dir = os.path.join(data_prefix, 'datasets', dataset_name, token_dir) + else: + self.token_dir = None + if clip_feat_dir is not None: + self.clip_feat_dir = os.path.join(data_prefix, 'datasets', dataset_name, clip_feat_dir) + else: + self.clip_feat_dir = None + self.fine_mode = fine_mode + super(TextMotionDataset, self).__init__( + data_prefix=data_prefix, + pipeline=pipeline, + dataset_name=dataset_name, + fixed_length=fixed_length, + ann_file=ann_file, + motion_dir=motion_dir, + eval_cfg=eval_cfg, + test_mode=test_mode) + + def load_anno(self, name): + results = super().load_anno(name) + text_path = os.path.join(self.text_dir, name + '.txt') + text_data = [] + for line in open(text_path, 'r'): + text_data.append(line.strip()) + results['text'] = text_data + if self.token_dir is not None: + token_path = os.path.join(self.token_dir, name + '.txt') + token_data = [] + for line in open(token_path, 'r'): + token_data.append(line.strip()) + results['token'] = token_data + if self.clip_feat_dir is not None: + clip_feat_path = os.path.join(self.clip_feat_dir, name + '.npy') + clip_feat = torch.from_numpy(np.load(clip_feat_path)) + results['clip_feat'] = clip_feat + return results + + def prepare_data(self, idx: int): + """"Prepare raw data for the f'{idx'}-th data.""" + results = copy.deepcopy(self.data_infos[idx]) + text_list = results['text'] + idx = np.random.randint(0, len(text_list)) + if self.fine_mode: + results['text'] = json.loads(text_list[idx]) + else: + results['text'] = text_list[idx] + if 'clip_feat' in results.keys(): + results['clip_feat'] = results['clip_feat'][idx] + if 'token' in results.keys(): + results['token'] = results['token'][idx] + results['dataset_name'] = self.dataset_name + results['sample_idx'] = idx + return self.pipeline(results) diff --git a/mogen/models/__init__.py b/mogen/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7f46d5b1b6da726ea259d4fbdd15a012a5713338 --- /dev/null +++ b/mogen/models/__init__.py @@ -0,0 +1,7 @@ +from .architectures import * +from .losses import * +from .rnns import * +from .transformers import * +from .attentions import * +from .builder import * +from .utils import * \ No newline at end of file diff --git a/mogen/models/architectures/__init__.py b/mogen/models/architectures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0e7b46c9e9a48422bbd89b86519c1e06f2636935 --- /dev/null +++ b/mogen/models/architectures/__init__.py @@ -0,0 +1,6 @@ +from .vae_architecture import MotionVAE +from .diffusion_architecture import MotionDiffusion + +__all__ = [ + 'MotionVAE', 'MotionDiffusion' +] \ No newline at end of file diff --git a/mogen/models/architectures/base_architecture.py b/mogen/models/architectures/base_architecture.py new file mode 100644 index 0000000000000000000000000000000000000000..a2e9e4cfa63ecdaf116f4dc7028a983ba004f466 --- /dev/null +++ b/mogen/models/architectures/base_architecture.py @@ -0,0 +1,135 @@ +from abc import ABCMeta, abstractmethod +from collections import OrderedDict + +import torch +import torch.distributed as dist +from mmcv.runner import BaseModule + + +def to_cpu(x): + if isinstance(x, torch.Tensor): + return x.detach().cpu() + return x + + +class BaseArchitecture(BaseModule): + """Base class for mogen architecture.""" + + def __init__(self, init_cfg=None): + super(BaseArchitecture, self).__init__(init_cfg) + + def forward_train(self, **kwargs): + pass + + def forward_test(self, **kwargs): + pass + + def _parse_losses(self, losses): + """Parse the raw outputs (losses) of the network. + Args: + losses (dict): Raw output of the network, which usually contain + losses and other necessary information. + Returns: + tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \ + which may be a weighted sum of all losses, log_vars contains \ + all the variables to be sent to the logger. + """ + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError( + f'{loss_name} is not a tensor or list of tensors') + + loss = sum(_value for _key, _value in log_vars.items() + if 'loss' in _key) + + log_vars['loss'] = loss + for loss_name, loss_value in log_vars.items(): + # reduce loss when distributed training + if dist.is_available() and dist.is_initialized(): + loss_value = loss_value.data.clone() + dist.all_reduce(loss_value.div_(dist.get_world_size())) + log_vars[loss_name] = loss_value.item() + + return loss, log_vars + + def train_step(self, data, optimizer): + """The iteration step during training. + This method defines an iteration step during training, except for the + back propagation and optimizer updating, which are done in an optimizer + hook. Note that in some complicated cases or models, the whole process + including back propagation and optimizer updating is also defined in + this method, such as GAN. + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of + runner is passed to ``train_step()``. This argument is unused + and reserved. + Returns: + dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \ + ``num_samples``. + - ``loss`` is a tensor for back propagation, which can be a + weighted sum of multiple losses. + - ``log_vars`` contains all the variables to be sent to the + logger. + - ``num_samples`` indicates the batch size (when the model is + DDP, it means the batch size on each GPU), which is used for + averaging the logs. + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['motion'])) + + return outputs + + def val_step(self, data, optimizer=None): + """The iteration step during validation. + This method shares the same signature as :func:`train_step`, but used + during val epochs. Note that the evaluation after training epochs is + not implemented with this method, but an evaluation hook. + """ + losses = self(**data) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, log_vars=log_vars, num_samples=len(data['motion'])) + + return outputs + + def forward(self, **kwargs): + if self.training: + return self.forward_train(**kwargs) + else: + return self.forward_test(**kwargs) + + def split_results(self, results): + B = results['motion'].shape[0] + output = [] + for i in range(B): + batch_output = dict() + batch_output['motion'] = to_cpu(results['motion'][i]) + batch_output['pred_motion'] = to_cpu(results['pred_motion'][i]) + batch_output['motion_length'] = to_cpu(results['motion_length'][i]) + batch_output['motion_mask'] = to_cpu(results['motion_mask'][i]) + if 'pred_motion_length' in results.keys(): + batch_output['pred_motion_length'] = to_cpu(results['pred_motion_length'][i]) + else: + batch_output['pred_motion_length'] = to_cpu(results['motion_length'][i]) + if 'pred_motion_mask' in results: + batch_output['pred_motion_mask'] = to_cpu(results['pred_motion_mask'][i]) + else: + batch_output['pred_motion_mask'] = to_cpu(results['motion_mask'][i]) + if 'motion_metas' in results.keys(): + motion_metas = results['motion_metas'][i] + if 'text' in motion_metas.keys(): + batch_output['text'] = motion_metas['text'] + if 'token' in motion_metas.keys(): + batch_output['token'] = motion_metas['token'] + output.append(batch_output) + return output diff --git a/mogen/models/architectures/diffusion_architecture.py b/mogen/models/architectures/diffusion_architecture.py new file mode 100644 index 0000000000000000000000000000000000000000..124dee34ebeb578d35c46dc072ed28fd6e7cbb95 --- /dev/null +++ b/mogen/models/architectures/diffusion_architecture.py @@ -0,0 +1,127 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_architecture import BaseArchitecture +from ..builder import ( + ARCHITECTURES, + build_architecture, + build_submodule, + build_loss +) +from ..utils.gaussian_diffusion import ( + GaussianDiffusion, get_named_beta_schedule, create_named_schedule_sampler, + ModelMeanType, ModelVarType, LossType, space_timesteps, SpacedDiffusion +) + +def build_diffusion(cfg): + beta_scheduler = cfg['beta_scheduler'] + diffusion_steps = cfg['diffusion_steps'] + + betas = get_named_beta_schedule(beta_scheduler, diffusion_steps) + model_mean_type = { + 'start_x': ModelMeanType.START_X, + 'previous_x': ModelMeanType.PREVIOUS_X, + 'epsilon': ModelMeanType.EPSILON + }[cfg['model_mean_type']] + model_var_type = { + 'learned': ModelVarType.LEARNED, + 'fixed_small': ModelVarType.FIXED_SMALL, + 'fixed_large': ModelVarType.FIXED_LARGE, + 'learned_range': ModelVarType.LEARNED_RANGE + }[cfg['model_var_type']] + if cfg.get('respace', None) is not None: + diffusion = SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, cfg['respace']), + betas=betas, + model_mean_type=model_mean_type, + model_var_type=model_var_type, + loss_type=LossType.MSE + ) + else: + diffusion = GaussianDiffusion( + betas=betas, + model_mean_type=model_mean_type, + model_var_type=model_var_type, + loss_type=LossType.MSE) + return diffusion + + +@ARCHITECTURES.register_module() +class MotionDiffusion(BaseArchitecture): + + def __init__(self, + model=None, + loss_recon=None, + diffusion_train=None, + diffusion_test=None, + init_cfg=None, + inference_type='ddpm', + **kwargs): + super().__init__(init_cfg=init_cfg, **kwargs) + self.model = build_submodule(model) + self.loss_recon = build_loss(loss_recon) + self.diffusion_train = build_diffusion(diffusion_train) + self.diffusion_test = build_diffusion(diffusion_test) + self.sampler = create_named_schedule_sampler('uniform', self.diffusion_train) + self.inference_type = inference_type + + def forward(self, **kwargs): + motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask'].float() + sample_idx = kwargs.get('sample_idx', None) + clip_feat = kwargs.get('clip_feat', None) + B, T = motion.shape[:2] + text = [] + for i in range(B): + text.append(kwargs['motion_metas'][i]['text']) + + if self.training: + t, _ = self.sampler.sample(B, motion.device) + output = self.diffusion_train.training_losses( + model=self.model, + x_start=motion, + t=t, + model_kwargs={ + 'motion_mask': motion_mask, + 'motion_length': kwargs['motion_length'], + 'text': text, + 'clip_feat': clip_feat, + 'sample_idx': sample_idx} + ) + pred, target = output['pred'], output['target'] + recon_loss = self.loss_recon(pred, target, reduction_override='none') + recon_loss = (recon_loss.mean(dim=-1) * motion_mask).sum() / motion_mask.sum() + loss = {'recon_loss': recon_loss} + return loss + else: + dim_pose = kwargs['motion'].shape[-1] + model_kwargs = self.model.get_precompute_condition(device=motion.device, text=text, **kwargs) + model_kwargs['motion_mask'] = motion_mask + model_kwargs['sample_idx'] = sample_idx + inference_kwargs = kwargs.get('inference_kwargs', {}) + if self.inference_type == 'ddpm': + output = self.diffusion_test.p_sample_loop( + self.model, + (B, T, dim_pose), + clip_denoised=False, + progress=False, + model_kwargs=model_kwargs, + **inference_kwargs + ) + else: + output = self.diffusion_test.ddim_sample_loop( + self.model, + (B, T, dim_pose), + clip_denoised=False, + progress=False, + model_kwargs=model_kwargs, + eta=0, + **inference_kwargs + ) + if getattr(self.model, "post_process") is not None: + output = self.model.post_process(output) + results = kwargs + results['pred_motion'] = output + results = self.split_results(results) + return results + diff --git a/mogen/models/architectures/vae_architecture.py b/mogen/models/architectures/vae_architecture.py new file mode 100644 index 0000000000000000000000000000000000000000..83248f894caf511a3d75cbb16303a75566b017f0 --- /dev/null +++ b/mogen/models/architectures/vae_architecture.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_architecture import BaseArchitecture +from ..builder import ( + ARCHITECTURES, + build_architecture, + build_submodule, + build_loss +) + + +@ARCHITECTURES.register_module() +class PoseVAE(BaseArchitecture): + + def __init__(self, + encoder=None, + decoder=None, + loss_recon=None, + kl_div_loss_weight=None, + init_cfg=None, + **kwargs): + super().__init__(init_cfg=init_cfg, **kwargs) + self.encoder = build_submodule(encoder) + self.decoder = build_submodule(decoder) + self.loss_recon = build_loss(loss_recon) + self.kl_div_loss_weight = kl_div_loss_weight + + def reparameterize(self, mu, logvar): + std = torch.exp(logvar / 2) + + eps = std.data.new(std.size()).normal_() + latent_code = eps.mul(std).add_(mu) + return latent_code + + def encode(self, pose): + mu, logvar = self.encoder(pose) + return mu + + def forward(self, **kwargs): + motion = kwargs['motion'].float() + B, T = motion.shape[:2] + pose = motion.reshape(B * T, -1) + pose = pose[:, :-4] + + mu, logvar = self.encoder(pose) + z = self.reparameterize(mu, logvar) + pred = self.decoder(z) + + loss = dict() + recon_loss = self.loss_recon(pred, pose, reduction_override='none') + loss['recon_loss'] = recon_loss + if self.kl_div_loss_weight is not None: + loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight) + + return loss + + +@ARCHITECTURES.register_module() +class MotionVAE(BaseArchitecture): + + def __init__(self, + encoder=None, + decoder=None, + loss_recon=None, + kl_div_loss_weight=None, + init_cfg=None, + **kwargs): + super().__init__(init_cfg=init_cfg, **kwargs) + self.encoder = build_submodule(encoder) + self.decoder = build_submodule(decoder) + self.loss_recon = build_loss(loss_recon) + self.kl_div_loss_weight = kl_div_loss_weight + + def sample(self, std=1, latent_code=None): + if latent_code is not None: + z = latent_code + else: + z = torch.randn(1, 7, self.decoder.latent_dim).cuda() * std + output = self.decoder(z) + if self.use_normalization: + output = output * self.motion_std + output = output + self.motion_mean + return output + + def reparameterize(self, mu, logvar): + std = torch.exp(logvar / 2) + + eps = std.data.new(std.size()).normal_() + latent_code = eps.mul(std).add_(mu) + return latent_code + + def encode(self, motion, motion_mask): + mu, logvar = self.encoder(motion, motion_mask) + return self.reparameterize(mu, logvar) + + def decode(self, z, motion_mask): + return self.decoder(z, motion_mask) + + def forward(self, **kwargs): + motion, motion_mask = kwargs['motion'].float(), kwargs['motion_mask'] + B, T = motion.shape[:2] + + mu, logvar = self.encoder(motion, motion_mask) + z = self.reparameterize(mu, logvar) + pred = self.decoder(z, motion_mask) + + loss = dict() + recon_loss = self.loss_recon(pred, motion, reduction_override='none') + recon_loss = (recon_loss.mean(dim=-1) * motion_mask).sum() / motion_mask.sum() + loss['recon_loss'] = recon_loss + if self.kl_div_loss_weight is not None: + loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + loss['kl_div_loss'] = (loss_kl * self.kl_div_loss_weight) + + return loss \ No newline at end of file diff --git a/mogen/models/attentions/__init__.py b/mogen/models/attentions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..50a70f1054cb8c3acd1e1a5d8b7a4cbc05d1aad6 --- /dev/null +++ b/mogen/models/attentions/__init__.py @@ -0,0 +1,6 @@ +from .efficient_attention import ( + EfficientSelfAttention, + EfficientCrossAttention +) +from .semantics_modulated import SemanticsModulatedAttention +from .base_attention import BaseMixedAttention \ No newline at end of file diff --git a/mogen/models/attentions/base_attention.py b/mogen/models/attentions/base_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..a78dccb1fd53836946d5bb0bdd1d2e8d70b0fbb3 --- /dev/null +++ b/mogen/models/attentions/base_attention.py @@ -0,0 +1,146 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..utils.stylization_block import StylizationBlock +from ..builder import ATTENTIONS + + +@ATTENTIONS.register_module() +class BaseMixedAttention(nn.Module): + + def __init__(self, latent_dim, + text_latent_dim, + num_heads, + dropout, + time_embed_dim): + super().__init__() + self.num_heads = num_heads + + self.norm = nn.LayerNorm(latent_dim) + self.text_norm = nn.LayerNorm(text_latent_dim) + + self.query = nn.Linear(latent_dim, latent_dim) + self.key_text = nn.Linear(text_latent_dim, latent_dim) + self.value_text = nn.Linear(text_latent_dim, latent_dim) + self.key_motion = nn.Linear(latent_dim, latent_dim) + self.value_motion = nn.Linear(latent_dim, latent_dim) + + self.dropout = nn.Dropout(dropout) + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, x, xf, emb, src_mask, cond_type, **kwargs): + """ + x: B, T, D + xf: B, N, L + """ + B, T, D = x.shape + N = xf.shape[1] + x.shape[1] + H = self.num_heads + # B, T, D + query = self.query(self.norm(x)).view(B, T, H, -1) + # B, N, D + text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1).repeat(1, xf.shape[1], 1) + key = torch.cat(( + self.key_text(self.text_norm(xf)), + self.key_motion(self.norm(x)) + ), dim=1).view(B, N, H, -1) + + attention = torch.einsum('bnhl,bmhl->bnmh', query, key) + motion_mask = src_mask.view(B, 1, T, 1) + text_mask = text_cond_type.view(B, 1, -1, 1) + mask = torch.cat((text_mask, motion_mask), dim=2) + attention = attention + (1 - mask) * -1000000 + attention = F.softmax(attention, dim=2) + + value = torch.cat(( + self.value_text(self.text_norm(xf)) * text_cond_type, + self.value_motion(self.norm(x)) * src_mask, + ), dim=1).view(B, N, H, -1) + + y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D) + y = x + self.proj_out(y, emb) + return y + + +@ATTENTIONS.register_module() +class BaseSelfAttention(nn.Module): + + def __init__(self, latent_dim, + num_heads, + dropout, + time_embed_dim): + super().__init__() + self.num_heads = num_heads + + self.norm = nn.LayerNorm(latent_dim) + self.query = nn.Linear(latent_dim, latent_dim) + self.key = nn.Linear(latent_dim, latent_dim) + self.value = nn.Linear(latent_dim, latent_dim) + + self.dropout = nn.Dropout(dropout) + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, x, emb, src_mask, **kwargs): + """ + x: B, T, D + """ + B, T, D = x.shape + H = self.num_heads + # B, T, D + query = self.query(self.norm(x)).view(B, T, H, -1) + # B, N, D + key = self.key(self.norm(x)).view(B, T, H, -1) + + attention = torch.einsum('bnhl,bmhl->bnmh', query, key) + mask = src_mask.view(B, 1, T, 1) + attention = attention + (1 - mask) * -1000000 + attention = F.softmax(attention, dim=2) + value = (self.value(self.norm(x)) * src_mask).view(B, T, H, -1) + y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D) + y = x + self.proj_out(y, emb) + return y + + +@ATTENTIONS.register_module() +class BaseCrossAttention(nn.Module): + + def __init__(self, latent_dim, + text_latent_dim, + num_heads, + dropout, + time_embed_dim): + super().__init__() + self.num_heads = num_heads + + self.norm = nn.LayerNorm(latent_dim) + self.text_norm = nn.LayerNorm(text_latent_dim) + + self.query = nn.Linear(latent_dim, latent_dim) + self.key = nn.Linear(text_latent_dim, latent_dim) + self.value = nn.Linear(text_latent_dim, latent_dim) + + self.dropout = nn.Dropout(dropout) + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, x, xf, emb, src_mask, cond_type, **kwargs): + """ + x: B, T, D + xf: B, N, L + """ + B, T, D = x.shape + N = xf.shape[1] + H = self.num_heads + # B, T, D + query = self.query(self.norm(x)).view(B, T, H, -1) + # B, N, D + text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1).repeat(1, xf.shape[1], 1) + key = self.key(self.text_norm(xf)).view(B, N, H, -1) + attention = torch.einsum('bnhl,bmhl->bnmh', query, key) + mask = text_cond_type.view(B, 1, -1, 1) + attention = attention + (1 - mask) * -1000000 + attention = F.softmax(attention, dim=2) + + value = (self.value(self.text_norm(xf)) * text_cond_type).view(B, N, H, -1) + y = torch.einsum('bnmh,bmhl->bnhl', attention, value).reshape(B, T, D) + y = x + self.proj_out(y, emb) + return y diff --git a/mogen/models/attentions/efficient_attention.py b/mogen/models/attentions/efficient_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..77a07cf9598b74deaf23d0c1d24b1dbf0d8f52f2 --- /dev/null +++ b/mogen/models/attentions/efficient_attention.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..utils.stylization_block import StylizationBlock +from ..builder import ATTENTIONS + + +@ATTENTIONS.register_module() +class EfficientSelfAttention(nn.Module): + + def __init__(self, latent_dim, num_heads, dropout, time_embed_dim=None): + super().__init__() + self.num_heads = num_heads + self.norm = nn.LayerNorm(latent_dim) + self.query = nn.Linear(latent_dim, latent_dim) + self.key = nn.Linear(latent_dim, latent_dim) + self.value = nn.Linear(latent_dim, latent_dim) + self.dropout = nn.Dropout(dropout) + self.time_embed_dim = time_embed_dim + if time_embed_dim is not None: + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, x, src_mask, emb=None, **kwargs): + """ + x: B, T, D + """ + B, T, D = x.shape + H = self.num_heads + # B, T, D + query = self.query(self.norm(x)) + # B, T, D + key = (self.key(self.norm(x)) + (1 - src_mask) * -1000000) + query = F.softmax(query.view(B, T, H, -1), dim=-1) + key = F.softmax(key.view(B, T, H, -1), dim=1) + # B, T, H, HD + value = (self.value(self.norm(x)) * src_mask).view(B, T, H, -1) + # B, H, HD, HD + attention = torch.einsum('bnhd,bnhl->bhdl', key, value) + y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) + if self.time_embed_dim is None: + y = x + y + else: + y = x + self.proj_out(y, emb) + return y + + +@ATTENTIONS.register_module() +class EfficientCrossAttention(nn.Module): + + def __init__(self, latent_dim, text_latent_dim, num_heads, dropout, time_embed_dim): + super().__init__() + self.num_heads = num_heads + self.norm = nn.LayerNorm(latent_dim) + self.text_norm = nn.LayerNorm(text_latent_dim) + self.query = nn.Linear(latent_dim, latent_dim) + self.key = nn.Linear(text_latent_dim, latent_dim) + self.value = nn.Linear(text_latent_dim, latent_dim) + self.dropout = nn.Dropout(dropout) + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, x, xf, emb, cond_type=None, **kwargs): + """ + x: B, T, D + xf: B, N, L + """ + B, T, D = x.shape + N = xf.shape[1] + H = self.num_heads + # B, T, D + query = self.query(self.norm(x)) + # B, N, D + key = self.key(self.text_norm(xf)) + query = F.softmax(query.view(B, T, H, -1), dim=-1) + if cond_type is None: + key = F.softmax(key.view(B, N, H, -1), dim=1) + # B, N, H, HD + value = self.value(self.text_norm(xf)).view(B, N, H, -1) + else: + text_cond_type = ((cond_type % 10) > 0).float().view(B, 1, 1).repeat(1, xf.shape[1], 1) + key = key + (1 - text_cond_type) * -1000000 + key = F.softmax(key.view(B, N, H, -1), dim=1) + value = self.value(self.text_norm(xf) * text_cond_type).view(B, N, H, -1) + # B, H, HD, HD + attention = torch.einsum('bnhd,bnhl->bhdl', key, value) + y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) + y = x + self.proj_out(y, emb) + return y diff --git a/mogen/models/attentions/semantics_modulated.py b/mogen/models/attentions/semantics_modulated.py new file mode 100644 index 0000000000000000000000000000000000000000..e883a706466d21191ebc207dba0d619a6fe0858d --- /dev/null +++ b/mogen/models/attentions/semantics_modulated.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..utils.stylization_block import StylizationBlock +from ..builder import ATTENTIONS + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +@ATTENTIONS.register_module() +class SemanticsModulatedAttention(nn.Module): + + def __init__(self, latent_dim, + text_latent_dim, + num_heads, + dropout, + time_embed_dim): + super().__init__() + self.num_heads = num_heads + + self.norm = nn.LayerNorm(latent_dim) + self.text_norm = nn.LayerNorm(text_latent_dim) + + self.query = nn.Linear(latent_dim, latent_dim) + self.key_text = nn.Linear(text_latent_dim, latent_dim) + self.value_text = nn.Linear(text_latent_dim, latent_dim) + self.key_motion = nn.Linear(latent_dim, latent_dim) + self.value_motion = nn.Linear(latent_dim, latent_dim) + + self.retr_norm1 = nn.LayerNorm(2 * latent_dim) + self.retr_norm2 = nn.LayerNorm(latent_dim) + self.key_retr = nn.Linear(2 * latent_dim, latent_dim) + self.value_retr = zero_module(nn.Linear(latent_dim, latent_dim)) + + self.dropout = nn.Dropout(dropout) + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, x, xf, emb, src_mask, cond_type, re_dict=None): + """ + x: B, T, D + xf: B, N, L + """ + B, T, D = x.shape + re_motion = re_dict['re_motion'] + re_text = re_dict['re_text'] + re_mask = re_dict['re_mask'] + re_mask = re_mask.reshape(B, -1, 1) + N = xf.shape[1] + x.shape[1] + re_motion.shape[1] * re_motion.shape[2] + H = self.num_heads + # B, T, D + query = self.query(self.norm(x)) + # B, N, D + text_cond_type = (cond_type % 10 > 0).float() + retr_cond_type = (cond_type // 10 > 0).float() + re_text = re_text.repeat(1, 1, re_motion.shape[2], 1) + re_feat_key = torch.cat((re_motion, re_text), dim=-1).reshape(B, -1, 2 * D) + key = torch.cat(( + self.key_text(self.text_norm(xf)) + (1 - text_cond_type) * -1000000, + self.key_retr(self.retr_norm1(re_feat_key)) + (1 - retr_cond_type) * -1000000 + (1 - re_mask) * -1000000, + self.key_motion(self.norm(x)) + (1 - src_mask) * -1000000 + ), dim=1) + query = F.softmax(query.view(B, T, H, -1), dim=-1) + key = F.softmax(key.view(B, N, H, -1), dim=1) + # B, N, H, HD + re_feat_value = re_motion.reshape(B, -1, D) + value = torch.cat(( + self.value_text(self.text_norm(xf)) * text_cond_type, + self.value_retr(self.retr_norm2(re_feat_value)) * retr_cond_type * re_mask, + self.value_motion(self.norm(x)) * src_mask, + ), dim=1).view(B, N, H, -1) + # B, H, HD, HD + attention = torch.einsum('bnhd,bnhl->bhdl', key, value) + y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D) + y = x + self.proj_out(y, emb) + return y \ No newline at end of file diff --git a/mogen/models/builder.py b/mogen/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..bff54963371ed5fba8c7a52f29233c2838d9e45d --- /dev/null +++ b/mogen/models/builder.py @@ -0,0 +1,32 @@ +from mmcv.cnn import MODELS as MMCV_MODELS +from mmcv.utils import Registry + + +def build_from_cfg(cfg, registry, default_args=None): + if cfg is None: + return None + return MMCV_MODELS.build_func(cfg, registry, default_args) + + +MODELS = Registry('models', parent=MMCV_MODELS, build_func=build_from_cfg) + +LOSSES = MODELS +ARCHITECTURES = MODELS +SUBMODULES = MODELS +ATTENTIONS = MODELS + +def build_loss(cfg): + """Build loss.""" + return LOSSES.build(cfg) + +def build_architecture(cfg): + """Build framework.""" + return ARCHITECTURES.build(cfg) + +def build_submodule(cfg): + """Build submodule.""" + return SUBMODULES.build(cfg) + +def build_attention(cfg): + """Build attention.""" + return ATTENTIONS.build(cfg) diff --git a/mogen/models/losses/__init__.py b/mogen/models/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c893bf204bf8500b85a583ba82b693236cc6a8e8 --- /dev/null +++ b/mogen/models/losses/__init__.py @@ -0,0 +1,13 @@ +from .mse_loss import MSELoss +from .utils import ( + convert_to_one_hot, + reduce_loss, + weight_reduce_loss, + weighted_loss, +) + + +__all__ = [ + 'convert_to_one_hot', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', + 'MSELoss' +] \ No newline at end of file diff --git a/mogen/models/losses/mse_loss.py b/mogen/models/losses/mse_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..585538fb7167d1c876642cff987ba1180617eb28 --- /dev/null +++ b/mogen/models/losses/mse_loss.py @@ -0,0 +1,70 @@ +import torch.nn as nn +import torch.nn.functional as F + +from ..builder import LOSSES +from .utils import weighted_loss + + +def gmof(x, sigma): + """Geman-McClure error function.""" + x_squared = x**2 + sigma_squared = sigma**2 + return (sigma_squared * x_squared) / (sigma_squared + x_squared) + + +@weighted_loss +def mse_loss(pred, target): + """Warpper of mse loss.""" + return F.mse_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss_with_gmof(pred, target, sigma): + """Extended MSE Loss with GMOF.""" + loss = F.mse_loss(pred, target, reduction='none') + loss = gmof(loss, sigma) + return loss + + +@LOSSES.register_module() +class MSELoss(nn.Module): + """MSELoss. + Args: + reduction (str, optional): The method that reduces the loss to a + scalar. Options are "none", "mean" and "sum". + loss_weight (float, optional): The weight of the loss. Defaults to 1.0 + """ + + def __init__(self, reduction='mean', loss_weight=1.0): + super().__init__() + assert reduction in (None, 'none', 'mean', 'sum') + reduction = 'none' if reduction is None else reduction + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function of loss. + Args: + pred (torch.Tensor): The prediction. + target (torch.Tensor): The learning target of the prediction. + weight (torch.Tensor, optional): Weight of the loss for each + prediction. Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + Returns: + torch.Tensor: The calculated loss + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + loss = self.loss_weight * mse_loss( + pred, target, weight, reduction=reduction, avg_factor=avg_factor) + return loss \ No newline at end of file diff --git a/mogen/models/losses/utils.py b/mogen/models/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..417b6c54d5bfc48290d6c51183c47c18ba9b54b7 --- /dev/null +++ b/mogen/models/losses/utils.py @@ -0,0 +1,109 @@ +import functools + +import torch +import torch.nn.functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are "none", "mean" and "sum". + Return: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + elif reduction_enum == 2: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): + """Apply element-wise weight and reduce loss. + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. + reduction (str): Same as built-in losses of PyTorch. + avg_factor (float): Average factor when computing the mean of losses. + Returns: + Tensor: Processed loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + loss = loss * weight + + # if avg_factor is not specified, just reduce the loss + if avg_factor is None: + loss = reduce_loss(loss, reduction) + else: + # if reduction is mean, then average the loss by avg_factor + if reduction == 'mean': + loss = loss.sum() / avg_factor + # if reduction is 'none', then do nothing, otherwise raise an error + elif reduction != 'none': + raise ValueError('avg_factor can not be used with reduction="sum"') + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + avg_factor=None, **kwargs)`. + :Example: + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, avg_factor=2) + tensor(1.5000) + """ + + @functools.wraps(loss_func) + def wrapper(pred, + target, + weight=None, + reduction='mean', + avg_factor=None, + **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + return wrapper + + +def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor: + """This function converts target class indices to one-hot vectors, given + the number of classes. + Args: + targets (Tensor): The ground truth label of the prediction + with shape (N, 1) + classes (int): the number of classes. + Returns: + Tensor: Processed loss values. + """ + assert (torch.max(targets).item() < + classes), 'Class Index must be less than number of classes' + one_hot_targets = torch.zeros((targets.shape[0], classes), + dtype=torch.long, + device=targets.device) + one_hot_targets.scatter_(1, targets.long(), 1) + return one_hot_targets diff --git a/mogen/models/rnns/__init__.py b/mogen/models/rnns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b730b23d6538b262966d3ba3134a94964b588c33 --- /dev/null +++ b/mogen/models/rnns/__init__.py @@ -0,0 +1 @@ +from .t2m_bigru import T2MMotionEncoder, T2MTextEncoder \ No newline at end of file diff --git a/mogen/models/rnns/t2m_bigru.py b/mogen/models/rnns/t2m_bigru.py new file mode 100644 index 0000000000000000000000000000000000000000..d1df3be54e83a0a1b3ddcfbc6d923169648fb565 --- /dev/null +++ b/mogen/models/rnns/t2m_bigru.py @@ -0,0 +1,260 @@ +import torch +import torch.nn as nn +import numpy as np +import time +import math +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence +import torch.nn.functional as F +from ..builder import SUBMODULES + +from mogen.models.utils.word_vectorizer import WordVectorizer + + +def init_weight(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): + nn.init.xavier_normal_(m.weight) + # m.bias.data.fill_(0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +def reparameterize(mu, logvar): + s_var = logvar.mul(0.5).exp_() + eps = s_var.data.new(s_var.size()).normal_() + return eps.mul(s_var).add_(mu) + + +# batch_size, dimension and position +# output: (batch_size, dim) +def positional_encoding(batch_size, dim, pos): + assert batch_size == pos.shape[0] + positions_enc = np.array([ + [pos[j] / np.power(10000, (i-i%2)/dim) for i in range(dim)] + for j in range(batch_size) + ], dtype=np.float32) + positions_enc[:, 0::2] = np.sin(positions_enc[:, 0::2]) + positions_enc[:, 1::2] = np.cos(positions_enc[:, 1::2]) + return torch.from_numpy(positions_enc).float() + + +def get_padding_mask(batch_size, seq_len, cap_lens): + cap_lens = cap_lens.data.tolist() + mask_2d = torch.ones((batch_size, seq_len, seq_len), dtype=torch.float32) + for i, cap_len in enumerate(cap_lens): + mask_2d[i, :, :cap_len] = 0 + return mask_2d.bool(), 1 - mask_2d[:, :, 0].clone() + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, max_len=300): + super(PositionalEncoding, self).__init__() + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + # pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, pos): + return self.pe[pos] + + +@SUBMODULES.register_module() +class T2MMotionEncoder(nn.Module): + + def __init__(self, + input_size, + movement_hidden_size, + movement_latent_size, + motion_hidden_size, + motion_latent_size): + super().__init__() + self.movement_encoder = MovementConvEncoder( + input_size=input_size-4, + hidden_size=movement_hidden_size, + output_size=movement_latent_size) + self.motion_encoder = MotionEncoderBiGRUCo( + input_size=movement_latent_size, + hidden_size=motion_hidden_size, + output_size=motion_latent_size + ) + + def load_pretrained(self, ckpt_path): + checkpoint = torch.load(ckpt_path, map_location='cpu') + self.movement_encoder.load_state_dict(checkpoint['movement_encoder']) + self.motion_encoder.load_state_dict(checkpoint['motion_encoder']) + + def forward(self, motion, motion_length, motion_mask): + motion = motion.detach().float() + sort_idx = np.argsort(motion_length.data.tolist())[::-1].copy() + rank_idx = np.empty_like(sort_idx) + rank_idx[sort_idx] = np.arange(len(motion_length)) + motion = motion[sort_idx] + motion_length = motion_length[sort_idx] + + movements = self.movement_encoder(motion[..., :-4]).detach() + m_lens = motion_length // 4 + motion_embedding = self.motion_encoder(movements, m_lens) + motion_embedding_ordered = motion_embedding[rank_idx] + return motion_embedding_ordered + + +@SUBMODULES.register_module() +class T2MTextEncoder(nn.Module): + + def __init__(self, + word_size, + pos_size, + hidden_size, + output_size, + max_text_len): + super().__init__() + self.text_encoder = TextEncoderBiGRUCo( + word_size=word_size, + pos_size=pos_size, + hidden_size=hidden_size, + output_size=output_size, + ) + self.w_vectorizer = WordVectorizer('./data/glove', 'our_vab') + self.max_text_len = max_text_len + + def load_pretrained(self, ckpt_path): + checkpoint = torch.load(ckpt_path, map_location='cpu') + self.text_encoder.load_state_dict(checkpoint['text_encoder']) + + def forward(self, text, token, device): + B = len(text) + pos_one_hot = [] + word_emb = [] + sent_len = [] + for i in range(B): + tokens = token[i].split(" ") + if len(tokens) < self.max_text_len: + tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] + batch_sent_len = len(tokens) + tokens = tokens + ['unk/OTHER'] * (self.max_text_len + 2 - batch_sent_len) + else: + tokens = tokens[: self.max_text_len] + tokens = ['sos/OTHER'] + tokens + ['eos/OTHER'] + batch_sent_len = len(tokens) + sent_len.append(batch_sent_len) + batch_word_emb = [] + batch_pos_one_hot = [] + for cur_token in tokens: + cur_word_emb, cur_pos_one_hot = self.w_vectorizer[cur_token] + cur_word_emb = torch.from_numpy(cur_word_emb).float() + cur_pos_one_hot = torch.from_numpy(cur_pos_one_hot).float() + batch_word_emb.append(cur_word_emb) + batch_pos_one_hot.append(cur_pos_one_hot) + + batch_word_emb = torch.stack(batch_word_emb, dim=0) + batch_pos_one_hot = torch.stack(batch_pos_one_hot, dim=0) + word_emb.append(batch_word_emb) + pos_one_hot.append(batch_pos_one_hot) + word_emb = torch.stack(word_emb, dim=0).to(device) + pos_one_hot = torch.stack(pos_one_hot, dim=0).to(device) + sent_len = torch.tensor(sent_len, dtype=torch.long).to(device) + text_embedding = self.text_encoder(word_emb, pos_one_hot, sent_len) + return text_embedding + + +class TextEncoderBiGRUCo(nn.Module): + def __init__(self, word_size, pos_size, hidden_size, output_size): + super(TextEncoderBiGRUCo, self).__init__() + + self.pos_emb = nn.Linear(pos_size, word_size) + self.input_emb = nn.Linear(word_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + self.output_net = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size) + ) + + self.input_emb.apply(init_weight) + self.pos_emb.apply(init_weight) + self.output_net.apply(init_weight) + # self.linear2.apply(init_weight) + # self.batch_size = batch_size + self.hidden_size = hidden_size + self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + # input(batch_size, seq_len, dim) + def forward(self, word_embs, pos_onehot, cap_lens): + num_samples = word_embs.shape[0] + + pos_embs = self.pos_emb(pos_onehot) + inputs = word_embs + pos_embs + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = cap_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True, enforce_sorted=False) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) + + +class MovementConvEncoder(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(MovementConvEncoder, self).__init__() + self.main = nn.Sequential( + nn.Conv1d(input_size, hidden_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(hidden_size, output_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ) + self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + self.out_net.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + # print(outputs.shape) + return self.out_net(outputs) + + +class MotionEncoderBiGRUCo(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(MotionEncoderBiGRUCo, self).__init__() + + self.input_emb = nn.Linear(input_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + self.output_net = nn.Sequential( + nn.Linear(hidden_size*2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size) + ) + + self.input_emb.apply(init_weight) + self.output_net.apply(init_weight) + self.hidden_size = hidden_size + self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + # input(batch_size, seq_len, dim) + def forward(self, inputs, m_lens): + num_samples = inputs.shape[0] + + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = m_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) + diff --git a/mogen/models/transformers/__init__.py b/mogen/models/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..17e7fc62dcfe533738a1c652137731d1c5eb7f82 --- /dev/null +++ b/mogen/models/transformers/__init__.py @@ -0,0 +1,5 @@ +from .actor import ACTOREncoder, ACTORDecoder +from .motiondiffuse import MotionDiffuseTransformer +from .remodiffuse import ReMoDiffuseTransformer +from .mdm import MDMTransformer +from .position_encoding import SinusoidalPositionalEncoding, LearnedPositionalEncoding \ No newline at end of file diff --git a/mogen/models/transformers/actor.py b/mogen/models/transformers/actor.py new file mode 100644 index 0000000000000000000000000000000000000000..dee5a05cf5d460901ed0cddca04a9b9701cd89c4 --- /dev/null +++ b/mogen/models/transformers/actor.py @@ -0,0 +1,189 @@ +from cv2 import norm +import torch +from torch import layer_norm, nn +from mmcv.runner import BaseModule +import numpy as np + +from ..builder import SUBMODULES +from .position_encoding import SinusoidalPositionalEncoding, LearnedPositionalEncoding +import math + + +@SUBMODULES.register_module() +class ACTOREncoder(BaseModule): + def __init__(self, + max_seq_len=16, + njoints=None, + nfeats=None, + input_feats=None, + latent_dim=256, + output_dim=256, + condition_dim=None, + num_heads=4, + ff_size=1024, + num_layers=8, + activation='gelu', + dropout=0.1, + use_condition=False, + num_class=None, + use_final_proj=False, + output_var=False, + pos_embedding='sinusoidal', + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.njoints = njoints + self.nfeats = nfeats + if input_feats is None: + assert self.njoints is not None and self.nfeats is not None + self.input_feats = njoints * nfeats + else: + self.input_feats = input_feats + self.max_seq_len = max_seq_len + self.latent_dim = latent_dim + self.condition_dim = condition_dim + self.use_condition = use_condition + self.num_class = num_class + self.use_final_proj = use_final_proj + self.output_var = output_var + self.skelEmbedding = nn.Linear(self.input_feats, self.latent_dim) + if self.use_condition: + if num_class is None: + self.mu_layer = build_MLP(self.condition_dim, self.latent_dim) + if self.output_var: + self.sigma_layer = build_MLP(self.condition_dim, self.latent_dim) + else: + self.mu_layer = nn.Parameter(torch.randn(num_class, self.latent_dim)) + if self.output_var: + self.sigma_layer = nn.Parameter(torch.randn(num_class, self.latent_dim)) + else: + if self.output_var: + self.query = nn.Parameter(torch.randn(2, self.latent_dim)) + else: + self.query = nn.Parameter(torch.randn(1, self.latent_dim)) + if pos_embedding == 'sinusoidal': + self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout) + else: + self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len + 2) + seqTransEncoderLayer = nn.TransformerEncoderLayer( + d_model=self.latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation) + self.seqTransEncoder = nn.TransformerEncoder( + seqTransEncoderLayer, + num_layers=num_layers) + + def forward(self, motion, motion_mask=None, condition=None): + B, T = motion.shape[:2] + motion = motion.view(B, T, -1) + feature = self.skelEmbedding(motion) + if self.use_condition: + if self.output_var: + if self.num_class is None: + sigma_query = self.sigma_layer(condition).view(B, 1, -1) + else: + sigma_query = self.sigma_layer[condition.long()].view(B, 1, -1) + feature = torch.cat((sigma_query, feature), dim=1) + if self.num_class is None: + mu_query = self.mu_layer(condition).view(B, 1, -1) + else: + mu_query = self.mu_layer[condition.long()].view(B, 1, -1) + feature = torch.cat((mu_query, feature), dim=1) + else: + query = self.query.view(1, -1, self.latent_dim).repeat(B, 1, 1) + feature = torch.cat((query, feature), dim=1) + if self.output_var: + motion_mask = torch.cat((torch.zeros(B, 2).to(motion.device), 1 - motion_mask), dim=1).bool() + else: + motion_mask = torch.cat((torch.zeros(B, 1).to(motion.device), 1 - motion_mask), dim=1).bool() + feature = feature.permute(1, 0, 2).contiguous() + feature = self.pos_encoder(feature) + feature = self.seqTransEncoder(feature, src_key_padding_mask=motion_mask) + if self.use_final_proj: + mu = self.final_mu(feature[0]) + if self.output_var: + sigma = self.final_sigma(feature[1]) + return mu, sigma + return mu + else: + if self.output_var: + return feature[0], feature[1] + else: + return feature[0] + + +@SUBMODULES.register_module() +class ACTORDecoder(BaseModule): + + def __init__(self, + max_seq_len=16, + njoints=None, + nfeats=None, + input_feats=None, + input_dim=256, + latent_dim=256, + condition_dim=None, + num_heads=4, + ff_size=1024, + num_layers=8, + activation='gelu', + dropout=0.1, + use_condition=False, + num_class=None, + pos_embedding='sinusoidal', + init_cfg=None): + super().__init__(init_cfg=init_cfg) + if input_dim != latent_dim: + self.linear = nn.Linear(input_dim, latent_dim) + else: + self.linear = nn.Identity() + self.njoints = njoints + self.nfeats = nfeats + if input_feats is None: + assert self.njoints is not None and self.nfeats is not None + self.input_feats = njoints * nfeats + else: + self.input_feats = input_feats + self.max_seq_len = max_seq_len + self.input_dim = input_dim + self.latent_dim = latent_dim + self.condition_dim = condition_dim + self.use_condition = use_condition + self.num_class = num_class + if self.use_condition: + if num_class is None: + self.condition_bias = build_MLP(condition_dim, latent_dim) + else: + self.condition_bias = nn.Parameter(torch.randn(num_class, latent_dim)) + if pos_embedding == 'sinusoidal': + self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout) + else: + self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len) + seqTransDecoderLayer = nn.TransformerDecoderLayer( + d_model=self.latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation) + self.seqTransDecoder = nn.TransformerDecoder( + seqTransDecoderLayer, + num_layers=num_layers) + + self.final = nn.Linear(self.latent_dim, self.input_feats) + + def forward(self, input, motion_mask=None, condition=None): + B = input.shape[0] + T = self.max_seq_len + input = self.linear(input) + if self.use_condition: + if self.num_class is None: + condition = self.condition_bias(condition) + else: + condition = self.condition_bias[condition.long()].squeeze(1) + input = input + condition + query = self.pos_encoder.pe[:T, :].view(T, 1, -1).repeat(1, B, 1) + input = input.view(1, B, -1) + feature = self.seqTransDecoder(tgt=query, memory=input, tgt_key_padding_mask=(1 - motion_mask).bool()) + pose = self.final(feature).permute(1, 0, 2).contiguous() + return pose diff --git a/mogen/models/transformers/diffusion_transformer.py b/mogen/models/transformers/diffusion_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9a78a960065add060c9543d6ff8f51d56de20c3e --- /dev/null +++ b/mogen/models/transformers/diffusion_transformer.py @@ -0,0 +1,251 @@ +from abc import ABCMeta, abstractmethod +from cv2 import norm +import torch +from torch import layer_norm, nn +import torch.nn.functional as F +from mmcv.runner import BaseModule +import numpy as np + +from ..builder import SUBMODULES, build_attention +from .position_encoding import SinusoidalPositionalEncoding, LearnedPositionalEncoding +from ..utils.stylization_block import StylizationBlock +import math +import clip + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def set_requires_grad(nets, requires_grad=False): + """Set requies_grad for all the networks. + + Args: + nets (nn.Module | list[nn.Module]): A list of networks or a single + network. + requires_grad (bool): Whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class FFN(nn.Module): + + def __init__(self, latent_dim, ffn_dim, dropout, time_embed_dim): + super().__init__() + self.linear1 = nn.Linear(latent_dim, ffn_dim) + self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) + self.activation = nn.GELU() + self.dropout = nn.Dropout(dropout) + self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout) + + def forward(self, x, emb, **kwargs): + y = self.linear2(self.dropout(self.activation(self.linear1(x)))) + y = x + self.proj_out(y, emb) + return y + + +class DecoderLayer(nn.Module): + + def __init__(self, + sa_block_cfg=None, + ca_block_cfg=None, + ffn_cfg=None): + super().__init__() + self.sa_block = build_attention(sa_block_cfg) + self.ca_block = build_attention(ca_block_cfg) + self.ffn = FFN(**ffn_cfg) + + def forward(self, **kwargs): + if self.sa_block is not None: + x = self.sa_block(**kwargs) + kwargs.update({'x': x}) + if self.ca_block is not None: + x = self.ca_block(**kwargs) + kwargs.update({'x': x}) + if self.ffn is not None: + x = self.ffn(**kwargs) + return x + + +class DiffusionTransformer(BaseModule, metaclass=ABCMeta): + def __init__(self, + input_feats, + max_seq_len=240, + latent_dim=512, + time_embed_dim=2048, + num_layers=8, + sa_block_cfg=None, + ca_block_cfg=None, + ffn_cfg=None, + text_encoder=None, + use_cache_for_text=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.input_feats = input_feats + self.max_seq_len = max_seq_len + self.latent_dim = latent_dim + self.num_layers = num_layers + self.time_embed_dim = time_embed_dim + self.sequence_embedding = nn.Parameter(torch.randn(max_seq_len, latent_dim)) + + self.use_cache_for_text = use_cache_for_text + if use_cache_for_text: + self.text_cache = {} + self.build_text_encoder(text_encoder) + + # Input Embedding + self.joint_embed = nn.Linear(self.input_feats, self.latent_dim) + + self.time_embed = nn.Sequential( + nn.Linear(self.latent_dim, self.time_embed_dim), + nn.SiLU(), + nn.Linear(self.time_embed_dim, self.time_embed_dim), + ) + self.build_temporal_blocks(sa_block_cfg, ca_block_cfg, ffn_cfg) + + # Output Module + self.out = zero_module(nn.Linear(self.latent_dim, self.input_feats)) + + def build_temporal_blocks(self, sa_block_cfg, ca_block_cfg, ffn_cfg): + self.temporal_decoder_blocks = nn.ModuleList() + for i in range(self.num_layers): + self.temporal_decoder_blocks.append( + DecoderLayer( + sa_block_cfg=sa_block_cfg, + ca_block_cfg=ca_block_cfg, + ffn_cfg=ffn_cfg + ) + ) + + def build_text_encoder(self, text_encoder): + + text_latent_dim = text_encoder['latent_dim'] + num_text_layers = text_encoder.get('num_layers', 0) + text_ff_size = text_encoder.get('ff_size', 2048) + pretrained_model = text_encoder['pretrained_model'] + text_num_heads = text_encoder.get('num_heads', 4) + dropout = text_encoder.get('dropout', 0) + activation = text_encoder.get('activation', 'gelu') + self.use_text_proj = text_encoder.get('use_text_proj', False) + + if pretrained_model == 'clip': + self.clip, _ = clip.load('ViT-B/32', "cpu") + set_requires_grad(self.clip, False) + if text_latent_dim != 512: + self.text_pre_proj = nn.Linear(512, text_latent_dim) + else: + self.text_pre_proj = nn.Identity() + else: + raise NotImplementedError() + + if num_text_layers > 0: + self.use_text_finetune = True + textTransEncoderLayer = nn.TransformerEncoderLayer( + d_model=text_latent_dim, + nhead=text_num_heads, + dim_feedforward=text_ff_size, + dropout=dropout, + activation=activation) + self.textTransEncoder = nn.TransformerEncoder( + textTransEncoderLayer, + num_layers=num_text_layers) + else: + self.use_text_finetune = False + self.text_ln = nn.LayerNorm(text_latent_dim) + if self.use_text_proj: + self.text_proj = nn.Sequential( + nn.Linear(text_latent_dim, self.time_embed_dim) + ) + + def encode_text(self, text, clip_feat, device): + B = len(text) + text = clip.tokenize(text, truncate=True).to(device) + if clip_feat is None: + with torch.no_grad(): + x = self.clip.token_embedding(text).type(self.clip.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.clip.positional_embedding.type(self.clip.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.clip.transformer(x) + x = self.clip.ln_final(x).type(self.clip.dtype) + else: + x = clip_feat.type(self.clip.dtype).to(device).permute(1, 0, 2) + + # T, B, D + x = self.text_pre_proj(x) + xf_out = self.textTransEncoder(x) + xf_out = self.text_ln(xf_out) + if self.use_text_proj: + xf_proj = self.text_proj(xf_out[text.argmax(dim=-1), torch.arange(xf_out.shape[1])]) + # B, T, D + xf_out = xf_out.permute(1, 0, 2) + return xf_proj, xf_out + else: + xf_out = xf_out.permute(1, 0, 2) + return xf_out + + @abstractmethod + def get_precompute_condition(self, **kwargs): + pass + + @abstractmethod + def forward_train(self, h, src_mask, emb, **kwargs): + pass + + @abstractmethod + def forward_test(self, h, src_mask, emb, **kwargs): + pass + + def forward(self, motion, timesteps, motion_mask=None, **kwargs): + """ + motion: B, T, D + """ + B, T = motion.shape[0], motion.shape[1] + conditions = self.get_precompute_condition(device=motion.device, **kwargs) + if len(motion_mask.shape) == 2: + src_mask = motion_mask.clone().unsqueeze(-1) + else: + src_mask = motion_mask.clone() + + if self.use_text_proj: + emb = self.time_embed(timestep_embedding(timesteps, self.latent_dim)) + conditions['xf_proj'] + else: + emb = self.time_embed(timestep_embedding(timesteps, self.latent_dim)) + # B, T, latent_dim + h = self.joint_embed(motion) + h = h + self.sequence_embedding.unsqueeze(0)[:, :T, :] + + if self.training: + return self.forward_train(h=h, src_mask=src_mask, emb=emb, timesteps=timesteps, **conditions) + else: + return self.forward_test(h=h, src_mask=src_mask, emb=emb, timesteps=timesteps, **conditions) diff --git a/mogen/models/transformers/mdm.py b/mogen/models/transformers/mdm.py new file mode 100644 index 0000000000000000000000000000000000000000..4c2d64db7dccf4a5cf3ad1f9b3da4738a121e919 --- /dev/null +++ b/mogen/models/transformers/mdm.py @@ -0,0 +1,212 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import clip + +from ..builder import SUBMODULES + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp32""" + + def _convert_weights_to_fp32(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.float() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.float() + + model.apply(_convert_weights_to_fp32) + + +@SUBMODULES.register_module() +class MDMTransformer(nn.Module): + def __init__(self, + input_feats=263, + latent_dim=256, + ff_size=1024, + num_layers=8, + num_heads=4, + dropout=0.1, + activation="gelu", + clip_dim=512, + clip_version=None, + guide_scale=1.0, + cond_mask_prob=0.1, + use_official_ckpt=False, + **kwargs): + super().__init__() + + self.latent_dim = latent_dim + self.ff_size = ff_size + self.num_layers = num_layers + self.num_heads = num_heads + self.dropout = dropout + self.activation = activation + self.clip_dim = clip_dim + self.input_feats = input_feats + self.guide_scale = guide_scale + self.use_official_ckpt = use_official_ckpt + + self.cond_mask_prob = cond_mask_prob + self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) + self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) + + seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim, + nhead=self.num_heads, + dim_feedforward=self.ff_size, + dropout=self.dropout, + activation=self.activation) + + self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, + num_layers=self.num_layers) + + self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) + + + self.embed_text = nn.Linear(self.clip_dim, self.latent_dim) + self.clip_version = clip_version + self.clip_model = self.load_and_freeze_clip(clip_version) + + self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) + + + def load_and_freeze_clip(self, clip_version): + clip_model, clip_preprocess = clip.load(clip_version, device='cpu', + jit=False) # Must set jit=False for training + clip.model.convert_weights( + clip_model) # Actually this line is unnecessary since clip by default already on float16 + + clip_model.eval() + for p in clip_model.parameters(): + p.requires_grad = False + + return clip_model + + + def mask_cond(self, cond, force_mask=False): + bs, d = cond.shape + if force_mask: + return torch.zeros_like(cond) + elif self.training and self.cond_mask_prob > 0.: + mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond + return cond * (1. - mask) + else: + return cond + + def encode_text(self, raw_text): + # raw_text - list (batch_size length) of strings with input text prompts + device = next(self.parameters()).device + max_text_len = 20 + if max_text_len is not None: + default_context_length = 77 + context_length = max_text_len + 2 # start_token + 20 + end_token + assert context_length < default_context_length + texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) + zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device) + texts = torch.cat([texts, zero_pad], dim=1) + return self.clip_model.encode_text(texts).float() + + def get_precompute_condition(self, text, device=None, **kwargs): + if not self.training and device == torch.device('cpu'): + convert_weights(self.clip_model) + text_feat = self.encode_text(text) + return {'text_feat': text_feat} + + def post_process(self, motion): + assert len(motion.shape) == 3 + if self.use_official_ckpt: + motion[:, :, :4] = motion[:, :, :4] * 25 + return motion + + def forward(self, motion, timesteps, text_feat=None, **kwargs): + """ + motion: B, T, D + timesteps: [batch_size] (int) + """ + B, T, D = motion.shape + device = motion.device + if text_feat is None: + enc_text = get_precompute_condition(**kwargs)['text_feat'] + else: + enc_text = text_feat + if self.training: + # T, B, D + motion = self.poseEmbedding(motion).permute(1, 0, 2) + + emb = self.embed_timestep(timesteps) # [1, bs, d] + emb += self.embed_text(self.mask_cond(enc_text, force_mask=False)) + + xseq = self.sequence_pos_encoder(torch.cat((emb, motion), axis=0)) + output = self.seqTransEncoder(xseq)[1:] + + # B, T, D + output = self.poseFinal(output).permute(1, 0, 2) + return output + else: + # T, B, D + motion = self.poseEmbedding(motion).permute(1, 0, 2) + + emb = self.embed_timestep(timesteps) # [1, bs, d] + emb_uncond = emb + self.embed_text(self.mask_cond(enc_text, force_mask=True)) + emb_text = emb + self.embed_text(self.mask_cond(enc_text, force_mask=False)) + + xseq = self.sequence_pos_encoder(torch.cat((emb_uncond, motion), axis=0)) + xseq_text = self.sequence_pos_encoder(torch.cat((emb_text, motion), axis=0)) + output = self.seqTransEncoder(xseq)[1:] + output_text = self.seqTransEncoder(xseq_text)[1:] + # B, T, D + output = self.poseFinal(output).permute(1, 0, 2) + output_text = self.poseFinal(output_text).permute(1, 0, 2) + scale = self.guide_scale + output = output + scale * (output_text - output) + return output + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + + self.register_buffer('pe', pe) + + def forward(self, x): + # not used in the final model + x = x + self.pe[:x.shape[0], :] + return self.dropout(x) + + +class TimestepEmbedder(nn.Module): + def __init__(self, latent_dim, sequence_pos_encoder): + super().__init__() + self.latent_dim = latent_dim + self.sequence_pos_encoder = sequence_pos_encoder + + time_embed_dim = self.latent_dim + self.time_embed = nn.Sequential( + nn.Linear(self.latent_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + def forward(self, timesteps): + return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) diff --git a/mogen/models/transformers/motiondiffuse.py b/mogen/models/transformers/motiondiffuse.py new file mode 100644 index 0000000000000000000000000000000000000000..3c0dad47f166d7a2f0f57337d3ef596e149ee92a --- /dev/null +++ b/mogen/models/transformers/motiondiffuse.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from ..builder import SUBMODULES +from .diffusion_transformer import DiffusionTransformer + + +@SUBMODULES.register_module() +class MotionDiffuseTransformer(DiffusionTransformer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_precompute_condition(self, + text=None, + xf_proj=None, + xf_out=None, + device=None, + clip_feat=None, + **kwargs): + if xf_proj is None or xf_out is None: + xf_proj, xf_out = self.encode_text(text, clip_feat, device) + return {'xf_proj': xf_proj, 'xf_out': xf_out} + + def forward_train(self, h=None, src_mask=None, emb=None, xf_out=None, **kwargs): + B, T = h.shape[0], h.shape[1] + for module in self.temporal_decoder_blocks: + h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask) + output = self.out(h).view(B, T, -1).contiguous() + return output + + def forward_test(self, h=None, src_mask=None, emb=None, xf_out=None, **kwargs): + B, T = h.shape[0], h.shape[1] + for module in self.temporal_decoder_blocks: + h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask) + output = self.out(h).view(B, T, -1).contiguous() + return output diff --git a/mogen/models/transformers/position_encoding.py b/mogen/models/transformers/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..efdd6e25e6e3289a67d47ba288282d06413d8a4a --- /dev/null +++ b/mogen/models/transformers/position_encoding.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +import numpy as np + + +class SinusoidalPositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(SinusoidalPositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.arange(0, d_model, 2).float() + div_term = div_term * (-np.log(10000.0) / d_model) + div_term = torch.exp(div_term) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + # T, 1, D + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:x.shape[0]] + return self.dropout(x) + + +class LearnedPositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(LearnedPositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + self.pe = nn.Parameter(torch.randn(max_len, 1, d_model)) + + def forward(self, x): + x = x + self.pe[:x.shape[0]] + return self.dropout(x) diff --git a/mogen/models/transformers/remodiffuse.py b/mogen/models/transformers/remodiffuse.py new file mode 100644 index 0000000000000000000000000000000000000000..c3865b4aef05731b5705913da9a0dffcbdc27c09 --- /dev/null +++ b/mogen/models/transformers/remodiffuse.py @@ -0,0 +1,361 @@ +from cv2 import norm +import torch +import torch.nn.functional as F +from torch import layer_norm, nn +import numpy as np +import clip +import random +import math + +from ..builder import SUBMODULES, build_attention +from .diffusion_transformer import DiffusionTransformer + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def set_requires_grad(nets, requires_grad=False): + """Set requies_grad for all the networks. + + Args: + nets (nn.Module | list[nn.Module]): A list of networks or a single + network. + requires_grad (bool): Whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + +class FFN(nn.Module): + + def __init__(self, latent_dim, ffn_dim, dropout): + super().__init__() + self.linear1 = nn.Linear(latent_dim, ffn_dim) + self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) + self.activation = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x, **kwargs): + y = self.linear2(self.dropout(self.activation(self.linear1(x)))) + y = x + y + return y + + +class EncoderLayer(nn.Module): + + def __init__(self, + sa_block_cfg=None, + ca_block_cfg=None, + ffn_cfg=None): + super().__init__() + self.sa_block = build_attention(sa_block_cfg) + self.ffn = FFN(**ffn_cfg) + + def forward(self, **kwargs): + if self.sa_block is not None: + x = self.sa_block(**kwargs) + kwargs.update({'x': x}) + if self.ffn is not None: + x = self.ffn(**kwargs) + return x + + +class RetrievalDatabase(nn.Module): + + def __init__(self, + num_retrieval=None, + topk=None, + retrieval_file=None, + latent_dim=512, + output_dim=512, + num_layers=2, + num_motion_layers=4, + kinematic_coef=0.1, + max_seq_len=196, + num_heads=8, + ff_size=1024, + stride=4, + sa_block_cfg=None, + ffn_cfg=None, + dropout=0): + super().__init__() + self.num_retrieval = num_retrieval + self.topk = topk + self.latent_dim = latent_dim + self.stride = stride + self.kinematic_coef = kinematic_coef + self.num_layers = num_layers + self.num_motion_layers = num_motion_layers + self.max_seq_len = max_seq_len + data = np.load(retrieval_file) + self.text_features = torch.Tensor(data['text_features']) + self.captions = data['captions'] + self.motions = data['motions'] + self.m_lengths = data['m_lengths'] + self.clip_seq_features = data['clip_seq_features'] + self.train_indexes = data.get('train_indexes', None) + self.test_indexes = data.get('test_indexes', None) + + self.latent_dim = latent_dim + self.output_dim = output_dim + self.motion_proj = nn.Linear(self.motions.shape[-1], self.latent_dim) + self.motion_pos_embedding = nn.Parameter(torch.randn(max_seq_len, self.latent_dim)) + self.motion_encoder_blocks = nn.ModuleList() + for i in range(num_motion_layers): + self.motion_encoder_blocks.append( + EncoderLayer( + sa_block_cfg=sa_block_cfg, + ffn_cfg=ffn_cfg + ) + ) + TransEncoderLayer = nn.TransformerEncoderLayer( + d_model=self.latent_dim, + nhead=num_heads, + dim_feedforward=ff_size, + dropout=dropout, + activation="gelu") + self.text_encoder = nn.TransformerEncoder( + TransEncoderLayer, + num_layers=num_layers) + self.results = {} + + def extract_text_feature(self, text, clip_model, device): + text = clip.tokenize([text], truncate=True).to(device) + with torch.no_grad(): + text_features = clip_model.encode_text(text) + return text_features + + def encode_text(self, text, device): + with torch.no_grad(): + text = clip.tokenize(text, truncate=True).to(device) + x = self.clip.token_embedding(text).type(self.clip.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.clip.positional_embedding.type(self.clip.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.clip.transformer(x) + x = self.clip.ln_final(x).type(self.clip.dtype) + + # B, T, D + xf_out = x.permute(1, 0, 2) + return xf_out + + def retrieve(self, caption, length, clip_model, device, idx=None): + if self.training and self.train_indexes is not None and idx is not None: + idx = idx.item() + indexes = self.train_indexes[idx] + data = [] + cnt = 0 + for retr_idx in indexes: + if retr_idx != idx: + data.append(retr_idx) + cnt += 1 + if cnt == self.topk: + break + random.shuffle(data) + return data[:self.num_retrieval] + + elif not self.training and self.test_indexes is not None and idx is not None: + idx = idx.item() + indexes = self.test_indexes[idx] + data = [] + cnt = 0 + for retr_idx in indexes: + data.append(retr_idx) + cnt += 1 + if cnt == self.topk: + break + # random.shuffle(data) + return data[:self.num_retrieval] + else: + value = hash(caption) + if value in self.results: + return self.results[value] + text_feature = self.extract_text_feature(caption, clip_model, device) + + rel_length = torch.LongTensor(self.m_lengths).to(device) + rel_length = torch.abs(rel_length - length) / torch.clamp(rel_length, min=length) + semantic_score = F.cosine_similarity(self.text_features.to(device), text_feature) + kinematic_score = torch.exp(-rel_length * self.kinematic_coef) + score = semantic_score * kinematic_score + indexes = torch.argsort(score, descending=True) + data = [] + cnt = 0 + for idx in indexes: + caption, motion, m_length = self.captions[idx], self.motions[idx], self.m_lengths[idx] + if not self.training or m_length != length: + cnt += 1 + data.append(idx.item()) + if cnt == self.num_retrieval: + self.results[value] = data + return data + assert False + + def generate_src_mask(self, T, length): + B = len(length) + src_mask = torch.ones(B, T) + for i in range(B): + for j in range(length[i], T): + src_mask[i, j] = 0 + return src_mask + + def forward(self, captions, lengths, clip_model, device, idx=None): + B = len(captions) + all_indexes = [] + for b_ix in range(B): + length = int(lengths[b_ix]) + if idx is None: + batch_indexes = self.retrieve(captions[b_ix], length, clip_model, device) + else: + batch_indexes = self.retrieve(captions[b_ix], length, clip_model, device, idx[b_ix]) + all_indexes.extend(batch_indexes) + all_indexes = np.array(all_indexes) + N = all_indexes.shape[0] + all_motions = torch.Tensor(self.motions[all_indexes]).to(device) + all_m_lengths = torch.Tensor(self.m_lengths[all_indexes]).long() + all_captions = self.captions[all_indexes].tolist() + + T = all_motions.shape[1] + src_mask = self.generate_src_mask(T, all_m_lengths).to(device) + raw_src_mask = src_mask.clone() + re_motion = self.motion_proj(all_motions) + self.motion_pos_embedding.unsqueeze(0) + for module in self.motion_encoder_blocks: + re_motion = module(x=re_motion, src_mask=src_mask.unsqueeze(-1)) + re_motion = re_motion.view(B, self.num_retrieval, T, -1).contiguous() + # stride + re_motion = re_motion[:, :, ::self.stride, :].contiguous() + + src_mask = src_mask[:, ::self.stride].contiguous() + src_mask = src_mask.view(B, self.num_retrieval, -1).contiguous() + + T = 77 + all_text_seq_features = torch.Tensor(self.clip_seq_features[all_indexes]).to(device) + all_text_seq_features = all_text_seq_features.permute(1, 0, 2) + re_text = self.text_encoder(all_text_seq_features) + re_text = re_text.permute(1, 0, 2).view(B, self.num_retrieval, T, -1).contiguous() + re_text = re_text[:, :, -1:, :].contiguous() + + # T = re_motion.shape[2] + # re_feat = re_feat.view(B, self.num_retrieval * T, -1).contiguous() + re_dict = dict( + re_text=re_text, + re_motion=re_motion, + re_mask=src_mask, + raw_motion=all_motions, + raw_motion_length=all_m_lengths, + raw_motion_mask=raw_src_mask) + return re_dict + + +@SUBMODULES.register_module() +class ReMoDiffuseTransformer(DiffusionTransformer): + def __init__(self, + retrieval_cfg=None, + scale_func_cfg=None, + **kwargs): + super().__init__(**kwargs) + self.database = RetrievalDatabase(**retrieval_cfg) + self.scale_func_cfg = scale_func_cfg + + def scale_func(self, timestep): + coarse_scale = self.scale_func_cfg['coarse_scale'] + w = (1 - (1000 - timestep) / 1000) * coarse_scale + 1 + if timestep > 100: + if random.randint(0, 1) == 0: + output = { + 'both_coef': w, + 'text_coef': 0, + 'retr_coef': 1 - w, + 'none_coef': 0 + } + else: + output = { + 'both_coef': 0, + 'text_coef': w, + 'retr_coef': 0, + 'none_coef': 1 - w + } + else: + both_coef = self.scale_func_cfg['both_coef'] + text_coef = self.scale_func_cfg['text_coef'] + retr_coef = self.scale_func_cfg['retr_coef'] + none_coef = 1 - both_coef - text_coef - retr_coef + output = { + 'both_coef': both_coef, + 'text_coef': text_coef, + 'retr_coef': retr_coef, + 'none_coef': none_coef + } + return output + + def get_precompute_condition(self, + text=None, + motion_length=None, + xf_out=None, + re_dict=None, + device=None, + sample_idx=None, + clip_feat=None, + **kwargs): + if xf_out is None: + xf_out = self.encode_text(text, clip_feat, device) + output = {'xf_out': xf_out} + if re_dict is None: + re_dict = self.database(text, motion_length, self.clip, device, idx=sample_idx) + output['re_dict'] = re_dict + return output + + def post_process(self, motion): + return motion + + def forward_train(self, h=None, src_mask=None, emb=None, xf_out=None, re_dict=None, **kwargs): + B, T = h.shape[0], h.shape[1] + cond_type = torch.randint(0, 100, size=(B, 1, 1)).to(h.device) + for module in self.temporal_decoder_blocks: + h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask, cond_type=cond_type, re_dict=re_dict) + + output = self.out(h).view(B, T, -1).contiguous() + return output + + def forward_test(self, h=None, src_mask=None, emb=None, xf_out=None, re_dict=None, timesteps=None, **kwargs): + B, T = h.shape[0], h.shape[1] + both_cond_type = torch.zeros(B, 1, 1).to(h.device) + 99 + text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1 + retr_cond_type = torch.zeros(B, 1, 1).to(h.device) + 10 + none_cond_type = torch.zeros(B, 1, 1).to(h.device) + + all_cond_type = torch.cat(( + both_cond_type, text_cond_type, retr_cond_type, none_cond_type + ), dim=0) + h = h.repeat(4, 1, 1) + xf_out = xf_out.repeat(4, 1, 1) + emb = emb.repeat(4, 1) + src_mask = src_mask.repeat(4, 1, 1) + if re_dict['re_motion'].shape[0] != h.shape[0]: + re_dict['re_motion'] = re_dict['re_motion'].repeat(4, 1, 1, 1) + re_dict['re_text'] = re_dict['re_text'].repeat(4, 1, 1, 1) + re_dict['re_mask'] = re_dict['re_mask'].repeat(4, 1, 1) + for module in self.temporal_decoder_blocks: + h = module(x=h, xf=xf_out, emb=emb, src_mask=src_mask, cond_type=all_cond_type, re_dict=re_dict) + out = self.out(h).view(4 * B, T, -1).contiguous() + out_both = out[:B].contiguous() + out_text = out[B: 2 * B].contiguous() + out_retr = out[2 * B: 3 * B].contiguous() + out_none = out[3 * B:].contiguous() + + coef_cfg = self.scale_func(int(timesteps[0])) + both_coef = coef_cfg['both_coef'] + text_coef = coef_cfg['text_coef'] + retr_coef = coef_cfg['retr_coef'] + none_coef = coef_cfg['none_coef'] + output = out_both * both_coef + out_text * text_coef + out_retr * retr_coef + out_none * none_coef + return output \ No newline at end of file diff --git a/mogen/models/utils/__init__.py b/mogen/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mogen/models/utils/gaussian_diffusion.py b/mogen/models/utils/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..69fdb85c5bb179e2ef8c06b87b777709abc7f108 --- /dev/null +++ b/mogen/models/utils/gaussian_diffusion.py @@ -0,0 +1,1266 @@ +""" +This code is borrowed from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/gaussian_diffusion.py +""" + +import enum +import math + +import numpy as np +import torch as th + + +from abc import ABC, abstractmethod +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, self._scale_timesteps(t), **model_kwargs) + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, 2 * C, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape + ) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + else: + raise NotImplementedError(self.model_mean_type) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape + ) + * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, self._scale_timesteps(t), **model_kwargs + ) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + pre_seq=None, + transl_req=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + # concat seq + if pre_seq is not None: + T = pre_seq.shape[1] + noise = th.randn_like(pre_seq) + x_t = self.q_sample(pre_seq, t, noise=noise) + x[:, :T, :] = x_t + + if transl_req is not None: + for item in transl_req: + noise = th.randn(2).type_as(x) + transl = th.Tensor(item[1:]).type_as(x) + x_t = self.q_sample(transl, t, noise=noise) + x[:, :2, item[0]] = x_t + + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + pre_seq=None, + transl_req=None, + progress=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + pre_seq=pre_seq, + transl_req=transl_req, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + pre_seq=None, + transl_req=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + pre_seq=pre_seq, + transl_req=transl_req + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + pre_seq=None, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + if pre_seq is not None: + T = pre_seq.shape[1] + noise = th.randn_like(pre_seq) + x_t = self.q_sample(pre_seq, t, noise=noise) + x[:, :T, :] = x_t + + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + pre_seq=None, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps + ) + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + pre_seq=None, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + pre_seq=pre_seq, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + pre_seq=None, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + pre_seq=pre_seq, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C * 2, *x_t.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2).view(-1, 1).mean(-1) + # if "vb" in terms: + # terms["loss"] = terms["mse"] + terms["vb"] + # else: + # terms["loss"] = terms["mse"] + terms["target"] = target + terms["pred"] = model_output + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") + elif section_counts == "fast27": + # steps = space_timesteps(num_timesteps, "10,10,3,2,2") + # steps = space_timesteps(num_timesteps, "30,30,16,12,12") + steps = space_timesteps(num_timesteps, "15,15,8,6,6") + + # Help reduce DDIM artifacts from noisiest timesteps. + steps.remove(num_timesteps - 1) + steps.add(num_timesteps - 3) + return steps + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError(f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance(self, model, *args, **kwargs): + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.original_num_steps) + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + return self.model(x, new_ts, **kwargs) diff --git a/mogen/models/utils/mlp.py b/mogen/models/utils/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..731117e481e987eb2c397332b88a60088591e7ec --- /dev/null +++ b/mogen/models/utils/mlp.py @@ -0,0 +1,13 @@ +import torch.nn as nn + + +def build_MLP(dim_list, latent_dim): + model_list = [] + prev = dim_list[0] + for cur in dim_list[1:]: + model_list.append(nn.Linear(prev, cur)) + model_list.append(nn.GELU()) + prev = cur + model_list.append(nn.Linear(prev, latent_dim)) + model = nn.Sequential(*model_list) + return model diff --git a/mogen/models/utils/stylization_block.py b/mogen/models/utils/stylization_block.py new file mode 100644 index 0000000000000000000000000000000000000000..a221a5d8122cb6e145377b705dbd32410dc6bc94 --- /dev/null +++ b/mogen/models/utils/stylization_block.py @@ -0,0 +1,40 @@ +import torch +import torch.nn as nn + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class StylizationBlock(nn.Module): + + def __init__(self, latent_dim, time_embed_dim, dropout): + super().__init__() + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear(time_embed_dim, 2 * latent_dim), + ) + self.norm = nn.LayerNorm(latent_dim) + self.out_layers = nn.Sequential( + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Linear(latent_dim, latent_dim)), + ) + + def forward(self, h, emb): + """ + h: B, T, D + emb: B, D + """ + # B, 1, 2D + emb_out = self.emb_layers(emb).unsqueeze(1) + # scale: B, 1, D / shift: B, 1, D + scale, shift = torch.chunk(emb_out, 2, dim=2) + h = self.norm(h) * (1 + scale) + shift + h = self.out_layers(h) + return h diff --git a/mogen/models/utils/word_vectorizer.py b/mogen/models/utils/word_vectorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..68c5956ff39f840d03c9a352e65291d26e2dfbd4 --- /dev/null +++ b/mogen/models/utils/word_vectorizer.py @@ -0,0 +1,80 @@ +import numpy as np +import pickle +from os.path import join as pjoin + +POS_enumerator = { + 'VERB': 0, + 'NOUN': 1, + 'DET': 2, + 'ADP': 3, + 'NUM': 4, + 'AUX': 5, + 'PRON': 6, + 'ADJ': 7, + 'ADV': 8, + 'Loc_VIP': 9, + 'Body_VIP': 10, + 'Obj_VIP': 11, + 'Act_VIP': 12, + 'Desc_VIP': 13, + 'OTHER': 14, +} + +Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward', + 'up', 'down', 'straight', 'curve') + +Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh') + +Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball') + +Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn', + 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll', + 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb') + +Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily', + 'angrily', 'sadly') + +VIP_dict = { + 'Loc_VIP': Loc_list, + 'Body_VIP': Body_list, + 'Obj_VIP': Obj_List, + 'Act_VIP': Act_list, + 'Desc_VIP': Desc_list, +} + + +class WordVectorizer(object): + def __init__(self, meta_root, prefix): + vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix)) + words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb')) + word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb')) + self.word2vec = {w: vectors[word2idx[w]] for w in words} + + def _get_pos_ohot(self, pos): + pos_vec = np.zeros(len(POS_enumerator)) + if pos in POS_enumerator: + pos_vec[POS_enumerator[pos]] = 1 + else: + pos_vec[POS_enumerator['OTHER']] = 1 + return pos_vec + + def __len__(self): + return len(self.word2vec) + + def __getitem__(self, item): + word, pos = item.split('/') + if word in self.word2vec: + word_vec = self.word2vec[word] + vip_pos = None + for key, values in VIP_dict.items(): + if word in values: + vip_pos = key + break + if vip_pos is not None: + pos_vec = self._get_pos_ohot(vip_pos) + else: + pos_vec = self._get_pos_ohot(pos) + else: + word_vec = self.word2vec['unk'] + pos_vec = self._get_pos_ohot('OTHER') + return word_vec, pos_vec \ No newline at end of file diff --git a/mogen/utils/__init__.py b/mogen/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..468f74e2713bdf539806f0f6e6ce33adaf753536 --- /dev/null +++ b/mogen/utils/__init__.py @@ -0,0 +1,18 @@ +from mogen.utils.collect_env import collect_env +from mogen.utils.dist_utils import DistOptimizerHook, allreduce_grads +from mogen.utils.logger import get_root_logger +from mogen.utils.misc import multi_apply, torch_to_numpy +from mogen.utils.path_utils import ( + Existence, + check_input_path, + check_path_existence, + check_path_suffix, + prepare_output_path, +) + + +__all__ = [ + 'collect_env', 'DistOptimizerHook', 'allreduce_grads', 'get_root_logger', + 'multi_apply', 'torch_to_numpy', 'Existence', 'check_input_path', + 'check_path_existence', 'check_path_suffix', 'prepare_output_path' +] \ No newline at end of file diff --git a/mogen/utils/collect_env.py b/mogen/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..5c27d6966a30b526ec32cf8ad70d33cdb0ed1b72 --- /dev/null +++ b/mogen/utils/collect_env.py @@ -0,0 +1,16 @@ +from mmcv.utils import collect_env as collect_base_env +from mmcv.utils import get_git_hash + +import mogen + + +def collect_env(): + """Collect the information of the running environments.""" + env_info = collect_base_env() + env_info['mogen'] = mogen.__version__ + '+' + get_git_hash()[:7] + return env_info + + +if __name__ == '__main__': + for name, val in collect_env().items(): + print(f'{name}: {val}') \ No newline at end of file diff --git a/mogen/utils/dist_utils.py b/mogen/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe098600d64d43f25181d4ff7ecf1d8f98b02a5 --- /dev/null +++ b/mogen/utils/dist_utils.py @@ -0,0 +1,59 @@ +from collections import OrderedDict + +import torch.distributed as dist +from mmcv.runner import OptimizerHook +from torch._utils import ( + _flatten_dense_tensors, + _take_tensors, + _unflatten_dense_tensors, +) + + +def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): + if bucket_size_mb > 0: + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + buckets = _take_tensors(tensors, bucket_size_bytes) + else: + buckets = OrderedDict() + for tensor in tensors: + tp = tensor.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(tensor) + buckets = buckets.values() + + for bucket in buckets: + flat_tensors = _flatten_dense_tensors(bucket) + dist.all_reduce(flat_tensors) + flat_tensors.div_(world_size) + for tensor, synced in zip( + bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + tensor.copy_(synced) + + +def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): + grads = [ + param.grad.data for param in params + if param.requires_grad and param.grad is not None + ] + world_size = dist.get_world_size() + if coalesce: + _allreduce_coalesced(grads, world_size, bucket_size_mb) + else: + for tensor in grads: + dist.all_reduce(tensor.div_(world_size)) + + +class DistOptimizerHook(OptimizerHook): + + def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + + def after_train_iter(self, runner): + runner.optimizer.zero_grad() + runner.outputs['loss'].backward() + if self.grad_clip is not None: + self.clip_grads(runner.model.parameters()) + runner.optimizer.step() \ No newline at end of file diff --git a/mogen/utils/logger.py b/mogen/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..a7bf049be988836c393969a43f5d86ab1560efd6 --- /dev/null +++ b/mogen/utils/logger.py @@ -0,0 +1,7 @@ +import logging + +from mmcv.utils import get_logger + + +def get_root_logger(log_file=None, log_level=logging.INFO): + return get_logger('mogen', log_file, log_level) \ No newline at end of file diff --git a/mogen/utils/misc.py b/mogen/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf11972abb6e64e9d5ca93bb2cf1518a6a11b08 --- /dev/null +++ b/mogen/utils/misc.py @@ -0,0 +1,14 @@ +from functools import partial + +import torch + + +def multi_apply(func, *args, **kwargs): + pfunc = partial(func, **kwargs) if kwargs else func + map_results = map(pfunc, *args) + return tuple(map(list, zip(*map_results))) + + +def torch_to_numpy(x): + assert isinstance(x, torch.Tensor) + return x.detach().cpu().numpy() \ No newline at end of file diff --git a/mogen/utils/path_utils.py b/mogen/utils/path_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..903425a2ae931c5cc684e76ac3c65759120a28aa --- /dev/null +++ b/mogen/utils/path_utils.py @@ -0,0 +1,232 @@ +import os +import warnings +from enum import Enum +from pathlib import Path +from typing import List, Union + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + + +def check_path_suffix(path_str: str, + allowed_suffix: Union[str, List[str]] = '') -> bool: + """Check whether the suffix of the path is allowed. + + Args: + path_str (str): + Path to check. + allowed_suffix (List[str], optional): + What extension names are allowed. + Offer a list like ['.jpg', ',jpeg']. + When it's [], all will be received. + Use [''] then directory is allowed. + Defaults to []. + + Returns: + bool: + True: suffix test passed + False: suffix test failed + """ + if isinstance(allowed_suffix, str): + allowed_suffix = [allowed_suffix] + pathinfo = Path(path_str) + suffix = pathinfo.suffix.lower() + if len(allowed_suffix) == 0: + return True + if pathinfo.is_dir(): + if '' in allowed_suffix: + return True + else: + return False + else: + for index, tmp_suffix in enumerate(allowed_suffix): + if not tmp_suffix.startswith('.'): + tmp_suffix = '.' + tmp_suffix + allowed_suffix[index] = tmp_suffix.lower() + if suffix in allowed_suffix: + return True + else: + return False + + +class Existence(Enum): + """State of file existence.""" + FileExist = 0 + DirectoryExistEmpty = 1 + DirectoryExistNotEmpty = 2 + MissingParent = 3 + DirectoryNotExist = 4 + FileNotExist = 5 + + +def check_path_existence( + path_str: str, + path_type: Literal['file', 'dir', 'auto'] = 'auto', +) -> Existence: + """Check whether a file or a directory exists at the expected path. + + Args: + path_str (str): + Path to check. + path_type (Literal[, optional): + What kind of file do we expect at the path. + Choose among `file`, `dir`, `auto`. + Defaults to 'auto'. path_type = path_type.lower() + + Raises: + KeyError: if `path_type` conflicts with `path_str` + + Returns: + Existence: + 0. FileExist: file at path_str exists. + 1. DirectoryExistEmpty: folder at path exists and. + 2. DirectoryExistNotEmpty: folder at path_str exists and not empty. + 3. MissingParent: its parent doesn't exist. + 4. DirectoryNotExist: expect a folder at path_str, but not found. + 5. FileNotExist: expect a file at path_str, but not found. + """ + path_type = path_type.lower() + assert path_type in {'file', 'dir', 'auto'} + pathinfo = Path(path_str) + if not pathinfo.parent.is_dir(): + return Existence.MissingParent + suffix = pathinfo.suffix.lower() + if path_type == 'dir' or\ + path_type == 'auto' and suffix == '': + if pathinfo.is_dir(): + if len(os.listdir(path_str)) == 0: + return Existence.DirectoryExistEmpty + else: + return Existence.DirectoryExistNotEmpty + else: + return Existence.DirectoryNotExist + elif path_type == 'file' or\ + path_type == 'auto' and suffix != '': + if pathinfo.is_file(): + return Existence.FileExist + elif pathinfo.is_dir(): + if len(os.listdir(path_str)) == 0: + return Existence.DirectoryExistEmpty + else: + return Existence.DirectoryExistNotEmpty + if path_str.endswith('/'): + return Existence.DirectoryNotExist + else: + return Existence.FileNotExist + + +def prepare_output_path(output_path: str, + allowed_suffix: List[str] = [], + tag: str = 'output file', + path_type: Literal['file', 'dir', 'auto'] = 'auto', + overwrite: bool = True) -> None: + """Check output folder or file. + + Args: + output_path (str): could be folder or file. + allowed_suffix (List[str], optional): + Check the suffix of `output_path`. If folder, should be [] or ['']. + If could both be folder or file, should be [suffixs..., '']. + Defaults to []. + tag (str, optional): The `string` tag to specify the output type. + Defaults to 'output file'. + path_type (Literal[, optional): + Choose `file` for file and `dir` for folder. + Choose `auto` if allowed to be both. + Defaults to 'auto'. + overwrite (bool, optional): + Whether overwrite the existing file or folder. + Defaults to True. + + Raises: + FileNotFoundError: suffix does not match. + FileExistsError: file or folder already exists and `overwrite` is + False. + + Returns: + None + """ + if path_type.lower() == 'dir': + allowed_suffix = [] + exist_result = check_path_existence(output_path, path_type=path_type) + if exist_result == Existence.MissingParent: + warnings.warn( + f'The parent folder of {tag} does not exist: {output_path},' + + f' will make dir {Path(output_path).parent.absolute().__str__()}') + os.makedirs( + Path(output_path).parent.absolute().__str__(), exist_ok=True) + + elif exist_result == Existence.DirectoryNotExist: + os.mkdir(output_path) + print(f'Making directory {output_path} for saving results.') + elif exist_result == Existence.FileNotExist: + suffix_matched = \ + check_path_suffix(output_path, allowed_suffix=allowed_suffix) + if not suffix_matched: + raise FileNotFoundError( + f'The {tag} should be {", ".join(allowed_suffix)}: ' + f'{output_path}.') + elif exist_result == Existence.FileExist: + if not overwrite: + raise FileExistsError( + f'{output_path} exists (set overwrite = True to overwrite).') + else: + print(f'Overwriting {output_path}.') + elif exist_result == Existence.DirectoryExistEmpty: + pass + elif exist_result == Existence.DirectoryExistNotEmpty: + if not overwrite: + raise FileExistsError( + f'{output_path} is not empty (set overwrite = ' + 'True to overwrite the files).') + else: + print(f'Overwriting {output_path} and its files.') + else: + raise FileNotFoundError(f'No Existence type for {output_path}.') + + +def check_input_path( + input_path: str, + allowed_suffix: List[str] = [], + tag: str = 'input file', + path_type: Literal['file', 'dir', 'auto'] = 'auto', +): + """Check input folder or file. + + Args: + input_path (str): input folder or file path. + allowed_suffix (List[str], optional): + Check the suffix of `input_path`. If folder, should be [] or ['']. + If could both be folder or file, should be [suffixs..., '']. + Defaults to []. + tag (str, optional): The `string` tag to specify the output type. + Defaults to 'output file'. + path_type (Literal[, optional): + Choose `file` for file and `directory` for folder. + Choose `auto` if allowed to be both. + Defaults to 'auto'. + + Raises: + FileNotFoundError: file does not exists or suffix does not match. + + Returns: + None + """ + if path_type.lower() == 'dir': + allowed_suffix = [] + exist_result = check_path_existence(input_path, path_type=path_type) + + if exist_result in [ + Existence.FileExist, Existence.DirectoryExistEmpty, + Existence.DirectoryExistNotEmpty + ]: + suffix_matched = \ + check_path_suffix(input_path, allowed_suffix=allowed_suffix) + if not suffix_matched: + raise FileNotFoundError( + f'The {tag} should be {", ".join(allowed_suffix)}:' + + f'{input_path}.') + else: + raise FileNotFoundError(f'The {tag} does not exist: {input_path}.') \ No newline at end of file diff --git a/mogen/utils/plot_utils.py b/mogen/utils/plot_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..580c442f752a8a0ff85b5014c43cf4ad3fbd53d2 --- /dev/null +++ b/mogen/utils/plot_utils.py @@ -0,0 +1,216 @@ +""" +This code is borrowed from https://github.com/EricGuo5513/text-to-motion +""" + +import torch +import numpy as np + +import math +import matplotlib +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +from matplotlib.animation import FuncAnimation, FFMpegFileWriter +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +import mpl_toolkits.mplot3d.axes3d as p3 + +# Define a kinematic tree for the skeletal struture +kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]] + +kit_raw_offsets = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1], + [-1, 0, 0], + [0, -1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, 1] + ] +) + +t2m_raw_offsets = np.array([[0,0,0], + [1,0,0], + [-1,0,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,-1,0], + [0,-1,0], + [0,1,0], + [0,0,1], + [0,0,1], + [0,1,0], + [1,0,0], + [-1,0,0], + [0,0,1], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0], + [0,-1,0]]) + +t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]] +t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]] +t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]] + + +def qinv(q): + assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)' + mask = torch.ones_like(q) + mask[..., 1:] = -mask[..., 1:] + return q * mask + + +def qrot(q, v): + """ + Rotate vector(s) v about the rotation described by quaternion(s) q. + Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, + where * denotes any number of dimensions. + Returns a tensor of shape (*, 3). + """ + assert q.shape[-1] == 4 + assert v.shape[-1] == 3 + assert q.shape[:-1] == v.shape[:-1] + + original_shape = list(v.shape) + # print(q.shape) + q = q.contiguous().view(-1, 4) + v = v.contiguous().view(-1, 3) + + qvec = q[:, 1:] + uv = torch.cross(qvec, v, dim=1) + uuv = torch.cross(qvec, uv, dim=1) + return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) + + +def recover_root_rot_pos(data): + rot_vel = data[..., 0] + r_rot_ang = torch.zeros_like(rot_vel).to(data.device) + '''Get Y-axis rotation from rotation velocity''' + r_rot_ang[..., 1:] = rot_vel[..., :-1] + r_rot_ang = torch.cumsum(r_rot_ang, dim=-1) + + r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device) + r_rot_quat[..., 0] = torch.cos(r_rot_ang) + r_rot_quat[..., 2] = torch.sin(r_rot_ang) + + r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device) + r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3] + '''Add Y-axis rotation to root position''' + r_pos = qrot(qinv(r_rot_quat), r_pos) + + r_pos = torch.cumsum(r_pos, dim=-2) + + r_pos[..., 1] = data[..., 3] + return r_rot_quat, r_pos + + +def recover_from_ric(data, joints_num): + r_rot_quat, r_pos = recover_root_rot_pos(data) + positions = data[..., 4:(joints_num - 1) * 3 + 4] + positions = positions.view(positions.shape[:-1] + (-1, 3)) + + '''Add Y-axis rotation to local joints''' + positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions) + + '''Add root XZ to joints''' + positions[..., 0] += r_pos[..., 0:1] + positions[..., 2] += r_pos[..., 2:3] + + '''Concate root and joints''' + positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2) + + return positions + + +def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4): + matplotlib.use('Agg') + + title_sp = title.split(' ') + if len(title_sp) > 20: + title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:20]), ' '.join(title_sp[20:])]) + elif len(title_sp) > 10: + title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])]) + + def init(): + ax.set_xlim3d([-radius / 4, radius / 4]) + ax.set_ylim3d([0, radius / 2]) + ax.set_zlim3d([0, radius / 2]) + fig.suptitle(title, fontsize=20) + ax.grid(b=False) + + def plot_xzPlane(minx, maxx, miny, minz, maxz): + verts = [ + [minx, miny, minz], + [minx, miny, maxz], + [maxx, miny, maxz], + [maxx, miny, minz] + ] + xz_plane = Poly3DCollection([verts]) + xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5)) + ax.add_collection3d(xz_plane) + + # (seq_len, joints_num, 3) + data = joints.copy().reshape(len(joints), -1, 3) + fig = plt.figure(figsize=figsize) + ax = p3.Axes3D(fig) + init() + MINS = data.min(axis=0).min(axis=0) + MAXS = data.max(axis=0).max(axis=0) + colors = ['red', 'blue', 'black', 'red', 'blue', + 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue', + 'darkred', 'darkred', 'darkred', 'darkred', 'darkred'] + frame_number = data.shape[0] + + height_offset = MINS[1] + data[:, :, 1] -= height_offset + trajec = data[:, 0, [0, 2]] + + data[..., 0] -= data[:, 0:1, 0] + data[..., 2] -= data[:, 0:1, 2] + + def update(index): + ax.lines = [] + ax.collections = [] + ax.view_init(elev=120, azim=-90) + ax.dist = 7.5 + plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1], + MAXS[2] - trajec[index, 1]) + + if index > 1: + ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]), + trajec[:index, 1] - trajec[index, 1], linewidth=1.0, + color='blue') + + for i, (chain, color) in enumerate(zip(kinematic_tree, colors)): + if i < 5: + linewidth = 4.0 + else: + linewidth = 2.0 + ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth, + color=color) + + plt.axis('off') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_zticklabels([]) + + ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False) + ani.save(save_path, fps=fps) + plt.close() diff --git a/mogen/version.py b/mogen/version.py new file mode 100644 index 0000000000000000000000000000000000000000..12cc21f96506948cb046d25424cf0a3a9a8ef40f --- /dev/null +++ b/mogen/version.py @@ -0,0 +1,25 @@ +__version__ = '0.0.1' + + +def parse_version_info(version_str): + """Parse a version string into a tuple. + Args: + version_str (str): The version string. + Returns: + tuple[int | str]: The version info, e.g., "1.3.0" is parsed into + (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). + """ + version_info = [] + for x in version_str.split('.'): + if x.isdigit(): + version_info.append(int(x)) + elif x.find('rc') != -1: + patch_version = x.split('rc') + version_info.append(int(patch_version[0])) + version_info.append(f'rc{patch_version[1]}') + return tuple(version_info) + + +version_info = parse_version_info(__version__) + +__all__ = ['__version__', 'version_info', 'parse_version_info'] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..caea879a2f3849b676a76a73534479832c026233 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +torch==1.12.1 +mmcv==1.9.0 +gradio==3.44.1 +matplotlib==3.3.1 +ftfy +regex +tqdm +scipy +git+https://github.com/openai/CLIP.git \ No newline at end of file diff --git a/tools/slurm_test.sh b/tools/slurm_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..acf03ae0690a6d0a31ac8c01ffc5959e1df97dd2 --- /dev/null +++ b/tools/slurm_test.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash +# Copyright (c) OpenMMLab. All rights reserved. + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +WORK_DIR=$4 +CHECKPOINT=$5 +GPUS=1 +GPUS_PER_NODE=$((${GPUS}<8?${GPUS}:8)) +CPUS_PER_TASK=${CPUS_PER_TASK:-2} +SRUN_ARGS=${SRUN_ARGS:-""} +PY_ARGS=${@:6} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + ${SRUN_ARGS} \ + python -u tools/test.py ${CONFIG} --work-dir=${WORK_DIR} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} \ No newline at end of file diff --git a/tools/slurm_train.sh b/tools/slurm_train.sh new file mode 100644 index 0000000000000000000000000000000000000000..1ba4d37c15828d01f1e4090b79ba976c2647ac55 --- /dev/null +++ b/tools/slurm_train.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash +# Copyright (c) OpenMMLab. All rights reserved. + +set -x + +PARTITION=$1 +JOB_NAME=$2 +CONFIG=$3 +WORK_DIR=$4 +GPUS=$5 +GPUS_PER_NODE=$((${GPUS}<8?${GPUS}:8)) +CPUS_PER_TASK=${CPUS_PER_TASK:-2} +SRUN_ARGS=${SRUN_ARGS:-""} +PY_ARGS=${@:6} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +srun -p ${PARTITION} \ + --job-name=${JOB_NAME} \ + --gres=gpu:${GPUS_PER_NODE} \ + --ntasks=${GPUS} \ + --ntasks-per-node=${GPUS_PER_NODE} \ + --cpus-per-task=${CPUS_PER_TASK} \ + --kill-on-bad-exit=1 \ + -w SG-IDC2-10-51-5-49 \ + ${SRUN_ARGS} \ + python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} \ No newline at end of file diff --git a/tools/test.py b/tools/test.py new file mode 100644 index 0000000000000000000000000000000000000000..8baf3d5cf77d376fcc955def019ca6f80725668e --- /dev/null +++ b/tools/test.py @@ -0,0 +1,119 @@ +import argparse +import os +import os.path as osp + +import mmcv +import torch +from mmcv import DictAction +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import ( + get_dist_info, + init_dist, + load_checkpoint, + wrap_fp16_model, +) + +from mogen.apis import multi_gpu_test, single_gpu_test +from mogen.datasets import build_dataloader, build_dataset +from mogen.models import build_architecture + + +def parse_args(): + parser = argparse.ArgumentParser(description='mogen evaluation') + parser.add_argument('config', help='test config file path') + parser.add_argument( + '--work-dir', help='the dir to save evaluation results') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument('--out', help='output result file') + parser.add_argument( + '--gpu_collect', + action='store_true', + help='whether to use gpu to collect results') + parser.add_argument('--tmpdir', help='tmp dir for writing some results') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--device', + choices=['cpu', 'cuda'], + default='cuda', + help='device used for testing') + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + return args + + +def main(): + args = parse_args() + + cfg = mmcv.Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + cfg.data.test.test_mode = True + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + + # build the dataloader + dataset = build_dataset(cfg.data.test) + # the extra round_up data will be removed during gpu/cpu collect + data_loader = build_dataloader( + dataset, + samples_per_gpu=cfg.data.samples_per_gpu, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False, + round_up=False) + + # build the model and load checkpoint + model = build_architecture(cfg.model) + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(model) + load_checkpoint(model, args.checkpoint, map_location='cpu') + + if not distributed: + if args.device == 'cpu': + model = model.cpu() + else: + model = MMDataParallel(model, device_ids=[0]) + outputs = single_gpu_test(model, data_loader) + else: + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False) + outputs = multi_gpu_test(model, data_loader, args.tmpdir, + args.gpu_collect) + + rank, _ = get_dist_info() + if rank == 0: + mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) + results = dataset.evaluate(outputs, args.work_dir) + for k, v in results.items(): + print(f'\n{k} : {v:.4f}') + + if args.out and rank == 0: + print(f'\nwriting results to {args.out}') + mmcv.dump(results, args.out) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tools/train.py b/tools/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7036d3e018ed1a968afcbb7b623cf2d69d6db7f4 --- /dev/null +++ b/tools/train.py @@ -0,0 +1,149 @@ +import argparse +import copy +import os +import os.path as osp +import time + +import mmcv +import torch +from mmcv import Config, DictAction +from mmcv.runner import get_dist_info, init_dist + +from mogen import __version__ +from mogen.apis import set_random_seed, train_model +from mogen.datasets import build_dataset +from mogen.models import build_architecture +from mogen.utils import collect_env, get_root_logger + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a model') + parser.add_argument('config', help='train config file path') + parser.add_argument('--work-dir', help='the dir to save logs and models') + parser.add_argument( + '--resume-from', help='the checkpoint file to resume from') + parser.add_argument( + '--no-validate', + action='store_true', + help='whether not to evaluate the checkpoint during training') + group_gpus = parser.add_mutually_exclusive_group() + group_gpus.add_argument('--device', help='device used for training') + group_gpus.add_argument( + '--gpus', + type=int, + help='number of gpus to use ' + '(only applicable to non-distributed training)') + group_gpus.add_argument( + '--gpu-ids', + type=int, + nargs='+', + help='ids of gpus to use ' + '(only applicable to non-distributed training)') + parser.add_argument('--seed', type=int, default=None, help='random seed') + parser.add_argument( + '--deterministic', + action='store_true', + help='whether to set deterministic options for CUDNN backend.') + parser.add_argument( + '--options', nargs='+', action=DictAction, help='arguments in dict') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.options is not None: + cfg.merge_from_dict(args.options) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + if args.resume_from is not None: + cfg.resume_from = args.resume_from + if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids + else: + cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + + # create work_dir + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + # dump config + cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(cfg.work_dir, f'{timestamp}.log') + logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) + + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) + dash_line = '-' * 60 + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' + + dash_line) + meta['env_info'] = env_info + + # log some basic info + logger.info(f'Distributed training: {distributed}') + logger.info(f'Config:\n{cfg.pretty_text}') + + # set random seeds + if args.seed is not None: + logger.info(f'Set random seed to {args.seed}, ' + f'deterministic: {args.deterministic}') + set_random_seed(args.seed, deterministic=args.deterministic) + cfg.seed = args.seed + meta['seed'] = args.seed + + model = build_architecture(cfg.model) + model.init_weights() + + datasets = [build_dataset(cfg.data.train)] + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + val_dataset.pipeline = cfg.data.train.pipeline + datasets.append(build_dataset(val_dataset)) + # add an attribute for visualization convenience + train_model( + model, + datasets, + cfg, + distributed=distributed, + validate=(not args.no_validate), + timestamp=timestamp, + device='cpu' if args.device == 'cpu' else 'cuda', + meta=meta) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/tools/visualize.py b/tools/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..bf2a77b6398cf8cbdd48f9bf250bfb3a40444c27 --- /dev/null +++ b/tools/visualize.py @@ -0,0 +1,123 @@ +import argparse +import os +import os.path as osp +import mmcv +import numpy as np +import torch +from mogen.models import build_architecture +from mmcv.runner import load_checkpoint +from mmcv.parallel import MMDataParallel +from mogen.utils.plot_utils import ( + recover_from_ric, + plot_3d_motion, + t2m_kinematic_chain +) +from scipy.ndimage import gaussian_filter + + +def motion_temporal_filter(motion, sigma=1): + motion = motion.reshape(motion.shape[0], -1) + for i in range(motion.shape[1]): + motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest") + return motion.reshape(motion.shape[0], -1, 3) + + +def plot_t2m(data, result_path, npy_path, caption): + joint = recover_from_ric(torch.from_numpy(data).float(), 22).numpy() + joint = motion_temporal_filter(joint, sigma=2.5) + plot_3d_motion(result_path, t2m_kinematic_chain, joint, title=caption, fps=20) + if npy_path is not None: + np.save(npy_path, joint) + + +def parse_args(): + parser = argparse.ArgumentParser(description='mogen evaluation') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument('--text', help='motion description') + parser.add_argument('--motion_length', type=int, help='expected motion length') + parser.add_argument('--out', help='output animation file') + parser.add_argument('--pose_npy', help='output pose sequence file', default=None) + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--device', + choices=['cpu', 'cuda'], + default='cuda', + help='device used for testing') + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + return args + + +def main(): + args = parse_args() + + cfg = mmcv.Config.fromfile(args.config) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + cfg.data.test.test_mode = True + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + + assert args.motion_length >= 16 and args.motion_length <= 196 + + # build the model and load checkpoint + model = build_architecture(cfg.model) + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(model) + load_checkpoint(model, args.checkpoint, map_location='cpu') + + if args.device == 'cpu': + model = model.cpu() + else: + model = MMDataParallel(model, device_ids=[0]) + model.eval() + + dataset_name = cfg.data.test.dataset_name + assert dataset_name == "human_ml3d" + mean_path = "data/datasets/human_ml3d/mean.npy" + std_path = "data/datasets/human_ml3d/std.npy" + mean = np.load(mean_path) + std = np.load(std_path) + + device = args.device + text = args.text + motion_length = args.motion_length + motion = torch.zeros(1, motion_length, 263).to(device) + motion_mask = torch.ones(1, motion_length).to(device) + motion_length = torch.Tensor([motion_length]).long().to(device) + model = model.to(device) + + input = { + 'motion': motion, + 'motion_mask': motion_mask, + 'motion_length': motion_length, + 'motion_metas': [{'text': text}], + } + + all_pred_motion = [] + with torch.no_grad(): + input['inference_kwargs'] = {} + output_list = [] + output = model(**input)[0]['pred_motion'] + pred_motion = output.cpu().detach().numpy() + pred_motion = pred_motion * std + mean + + plot_t2m(pred_motion, args.out, args.pose_npy, text) + + +if __name__ == '__main__': + main() \ No newline at end of file