wxDai commited on
Commit
6b1e9f7
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +10 -0
  3. LICENSE +25 -0
  4. README.md +14 -0
  5. app.py +258 -0
  6. configs/mld_control.yaml +105 -0
  7. configs/mld_t2m_infer.yaml +72 -0
  8. configs/modules/denoiser.yaml +16 -0
  9. configs/modules/motion_vae.yaml +13 -0
  10. configs/modules/scheduler.yaml +20 -0
  11. configs/modules/text_encoder.yaml +5 -0
  12. configs/modules/traj_encoder.yaml +12 -0
  13. configs/modules_mld/denoiser.yaml +16 -0
  14. configs/modules_mld/motion_vae.yaml +13 -0
  15. configs/modules_mld/scheduler.yaml +23 -0
  16. configs/modules_mld/text_encoder.yaml +5 -0
  17. configs/modules_mld/traj_encoder.yaml +12 -0
  18. configs/motionlcm_control.yaml +105 -0
  19. configs/motionlcm_t2m.yaml +100 -0
  20. demo.py +154 -0
  21. fit.py +134 -0
  22. mld/__init__.py +0 -0
  23. mld/config.py +47 -0
  24. mld/data/HumanML3D.py +79 -0
  25. mld/data/Kit.py +79 -0
  26. mld/data/__init__.py +0 -0
  27. mld/data/base.py +65 -0
  28. mld/data/get_data.py +93 -0
  29. mld/data/humanml/__init__.py +0 -0
  30. mld/data/humanml/common/quaternion.py +29 -0
  31. mld/data/humanml/dataset.py +290 -0
  32. mld/data/humanml/scripts/motion_process.py +51 -0
  33. mld/data/humanml/utils/__init__.py +0 -0
  34. mld/data/humanml/utils/paramUtil.py +62 -0
  35. mld/data/humanml/utils/plot_script.py +98 -0
  36. mld/data/humanml/utils/word_vectorizer.py +82 -0
  37. mld/data/utils.py +38 -0
  38. mld/launch/__init__.py +0 -0
  39. mld/launch/blender.py +23 -0
  40. mld/models/__init__.py +0 -0
  41. mld/models/architectures/__init__.py +0 -0
  42. mld/models/architectures/mld_clip.py +72 -0
  43. mld/models/architectures/mld_denoiser.py +172 -0
  44. mld/models/architectures/mld_traj_encoder.py +78 -0
  45. mld/models/architectures/mld_vae.py +154 -0
  46. mld/models/architectures/t2m_motionenc.py +58 -0
  47. mld/models/architectures/t2m_textenc.py +43 -0
  48. mld/models/architectures/tools/embeddings.py +89 -0
  49. mld/models/metrics/__init__.py +3 -0
  50. mld/models/metrics/cm.py +55 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ **/*.pyc
2
+ .idea/
3
+ __pycache__/
4
+
5
+ deps/
6
+ datasets/
7
+ experiments_t2m/
8
+ experiments_t2m_test/
9
+ experiments_control/
10
+ experiments_control_test/
LICENSE ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright Tsinghua University and Shanghai AI Laboratory. All Rights Reserved.
2
+
3
+ License for Non-commercial Scientific Research Purposes.
4
+
5
+ For more information see <https://github.com/Dai-Wenxun/MotionLCM>.
6
+ If you use this software, please cite the corresponding publications
7
+ listed on the above website.
8
+
9
+ Permission to use, copy, modify, and distribute this software and its
10
+ documentation for educational, research, and non-profit purposes only.
11
+ Any modification based on this work must be open-source and prohibited
12
+ for commercial, pornographic, military, or surveillance use.
13
+
14
+ The authors grant you a non-exclusive, worldwide, non-transferable,
15
+ non-sublicensable, revocable, royalty-free, and limited license under
16
+ our copyright interests to reproduce, distribute, and create derivative
17
+ works of the text, videos, and codes solely for your non-commercial
18
+ research purposes.
19
+
20
+ You must retain, in the source form of any derivative works that you
21
+ distribute, all copyright, patent, trademark, and attribution notices
22
+ from the source form of this work.
23
+
24
+ For commercial uses of this software, please send email to all people
25
+ in the author list.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MotionLCM
3
+ emoji: 🏃
4
+ colorFrom: yellow
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.24.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: other
11
+ python_version: 3.10.12
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ import datetime
5
+ import os.path as osp
6
+ from functools import partial
7
+
8
+ import torch
9
+ import gradio as gr
10
+ from omegaconf import OmegaConf
11
+
12
+ from mld.config import get_module_config
13
+ from mld.data.get_data import get_datasets
14
+ from mld.models.modeltype.mld import MLD
15
+ from mld.utils.utils import set_seed
16
+ from mld.data.humanml.utils.plot_script import plot_3d_motion
17
+
18
+ WEBSITE = """
19
+ <div class="embed_hidden">
20
+ <h1 style='text-align: center'> MotionLCM: Real-time Controllable Motion Generation via Latent Consistency Model </h1>
21
+
22
+ <h2 style='text-align: center'>
23
+ <a href="https://github.com/Dai-Wenxun/" target="_blank"><nobr>Wenxun Dai</nobr><sup>1</sup></a> &emsp;
24
+ <a href="https://lhchen.top/" target="_blank"><nobr>Ling-Hao Chen</nobr></a><sup>1</sup> &emsp;
25
+ <a href="https://wangjingbo1219.github.io/" target="_blank"><nobr>Jingbo Wang</nobr></a><sup>2</sup> &emsp;
26
+ <a href="https://moonsliu.github.io/" target="_blank"><nobr>Jinpeng Liu</nobr></a><sup>1</sup> &emsp;
27
+ <a href="https://daibo.info/" target="_blank"><nobr>Bo Dai</nobr></a><sup>2</sup> &emsp;
28
+ <a href="https://andytang15.github.io/" target="_blank"><nobr>Yansong Tang</nobr></a><sup>1</sup>
29
+ </h2>
30
+
31
+ <h2 style='text-align: center'>
32
+ <nobr><sup>1</sup>Tsinghua University</nobr> &emsp;
33
+ <nobr><sup>2</sup>Shanghai AI Laboratory</nobr>
34
+ </h2>
35
+
36
+ </div>
37
+ """
38
+
39
+ WEBSITE_bottom = """
40
+ <div class="embed_hidden">
41
+ <p>
42
+ Space adapted from <a href="https://huggingface.co/spaces/Mathux/TMR" target="_blank">TMR</a>
43
+ and <a href="https://huggingface.co/spaces/MeYourHint/MoMask" target="_blank">MoMask</a>.
44
+ </p>
45
+ </div>
46
+ """
47
+
48
+ EXAMPLES = [
49
+ "a person does a jump",
50
+ "a person waves both arms in the air.",
51
+ "The person takes 4 steps backwards.",
52
+ "this person bends forward as if to bow.",
53
+ "The person was pushed but did not fall.",
54
+ "a man walks forward in a snake like pattern.",
55
+ "a man paces back and forth along the same line.",
56
+ "with arms out to the sides a person walks forward",
57
+ "A man bends down and picks something up with his right hand.",
58
+ "The man walked forward, spun right on one foot and walked back to his original position.",
59
+ "a person slightly bent over with right hand pressing against the air walks forward slowly"
60
+ ]
61
+
62
+ CSS = """
63
+ .contour_video {
64
+ display: flex;
65
+ flex-direction: column;
66
+ justify-content: center;
67
+ align-items: center;
68
+ z-index: var(--layer-5);
69
+ border-radius: var(--block-radius);
70
+ background: var(--background-fill-primary);
71
+ padding: 0 var(--size-6);
72
+ max-height: var(--size-screen-h);
73
+ overflow: hidden;
74
+ }
75
+ """
76
+
77
+ if not os.path.exists("./experiments_t2m/"):
78
+ os.system("bash prepare/download_pretrained_models.sh")
79
+ if not os.path.exists('./deps/glove/'):
80
+ os.system("bash prepare/download_glove.sh")
81
+ if not os.path.exists('./deps/sentence-t5-large/'):
82
+ os.system("bash prepare/prepare_t5.sh")
83
+ if not os.path.exists('./deps/t2m/'):
84
+ os.system("bash prepare/download_t2m_evaluators.sh")
85
+ if not os.path.exists('./datasets/humanml3d/'):
86
+ os.system("bash prepare/prepare_tiny_humanml3d.sh")
87
+
88
+ DEFAULT_TEXT = "A person is "
89
+ MAX_VIDEOS = 12
90
+ T2M_CFG = "./configs/motionlcm_t2m.yaml"
91
+
92
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
93
+
94
+ cfg = OmegaConf.load(T2M_CFG)
95
+ cfg_model = get_module_config(cfg.model, cfg.model.target)
96
+ cfg = OmegaConf.merge(cfg, cfg_model)
97
+ set_seed(1949)
98
+
99
+ name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
100
+ output_dir = osp.join(cfg.TEST_FOLDER, name_time_str)
101
+ vis_dir = osp.join(output_dir, 'samples')
102
+ os.makedirs(output_dir, exist_ok=False)
103
+ os.makedirs(vis_dir, exist_ok=False)
104
+
105
+ state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"]
106
+ print("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS))
107
+
108
+ lcm_key = 'denoiser.time_embedding.cond_proj.weight'
109
+ is_lcm = False
110
+ if lcm_key in state_dict:
111
+ is_lcm = True
112
+ time_cond_proj_dim = state_dict[lcm_key].shape[1]
113
+ cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim
114
+ print(f'Is LCM: {is_lcm}')
115
+
116
+ cfg.model.is_controlnet = False
117
+
118
+ datasets = get_datasets(cfg, phase="test")[0]
119
+ model = MLD(cfg, datasets)
120
+ model.to(device)
121
+ model.eval()
122
+ model.load_state_dict(state_dict)
123
+
124
+
125
+ @torch.no_grad()
126
+ def generate(text, motion_len, num_videos):
127
+ batch = {"text": [text] * num_videos, "length": [motion_len] * num_videos}
128
+
129
+ s = time.time()
130
+ joints, _ = model(batch)
131
+ runtime = round(time.time() - s, 3)
132
+ runtime_info = f'Inference {len(joints)} motions, runtime: {runtime}s, device: {device}'
133
+ path = []
134
+ for i in range(num_videos):
135
+ uid = random.randrange(999999999)
136
+ video_path = osp.join(vis_dir, f"sample_{uid}.mp4")
137
+ plot_3d_motion(video_path, joints[i].detach().cpu().numpy(), '', fps=20)
138
+ path.append(video_path)
139
+ return path, runtime_info
140
+
141
+
142
+ # HTML component
143
+ def get_video_html(path, video_id, width=700, height=700):
144
+ video_html = f"""
145
+ <video class="contour_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
146
+ autoplay loop disablepictureinpicture id="{video_id}">
147
+ <source src="file/{path}" type="video/mp4">
148
+ Your browser does not support the video tag.
149
+ </video>
150
+ """
151
+ return video_html
152
+
153
+
154
+ def generate_component(generate_function, text, motion_len, num_videos):
155
+ if text == DEFAULT_TEXT or text == "" or text is None:
156
+ return [None for _ in range(MAX_VIDEOS)] + [None]
157
+
158
+ motion_len = max(36, min(int(float(motion_len) * 20), 196))
159
+ paths, info = generate_function(text, motion_len, num_videos)
160
+ htmls = [get_video_html(path, idx) for idx, path in enumerate(paths)]
161
+ htmls = htmls + [None for _ in range(max(0, MAX_VIDEOS - num_videos))]
162
+ return htmls + [info]
163
+
164
+
165
+ theme = gr.themes.Default(primary_hue="purple", secondary_hue="gray")
166
+ generate_and_show = partial(generate_component, generate)
167
+
168
+ with gr.Blocks(css=CSS, theme=theme) as demo:
169
+ gr.HTML(WEBSITE)
170
+ videos = []
171
+
172
+ with gr.Row():
173
+ with gr.Column(scale=3):
174
+ text = gr.Textbox(
175
+ show_label=True,
176
+ label="Text prompt",
177
+ value=DEFAULT_TEXT,
178
+ )
179
+
180
+ with gr.Row():
181
+ with gr.Column(scale=1):
182
+ motion_len = gr.Textbox(
183
+ show_label=True,
184
+ label="Motion length (in seconds, <=9.8s)",
185
+ value=5,
186
+ info="Any length exceeding 9.8s will be restricted to 9.8s.",
187
+ )
188
+ with gr.Column(scale=1):
189
+ num_videos = gr.Radio(
190
+ [1, 4, 8, 12],
191
+ label="Videos",
192
+ value=8,
193
+ info="Number of videos to generate.",
194
+ )
195
+
196
+ gen_btn = gr.Button("Generate", variant="primary")
197
+ clear = gr.Button("Clear", variant="secondary")
198
+
199
+ results = gr.Textbox(show_label=True,
200
+ label='Inference info (runtime and device)',
201
+ info='Real-time inference cannot be achieved using the free CPU. Local GPU deployment is recommended.',
202
+ interactive=False)
203
+
204
+ with gr.Column(scale=2):
205
+ def generate_example(text, motion_len, num_videos):
206
+ return generate_and_show(text, motion_len, num_videos)
207
+
208
+ examples = gr.Examples(
209
+ examples=[[x, None, None] for x in EXAMPLES],
210
+ inputs=[text, motion_len, num_videos],
211
+ examples_per_page=12,
212
+ run_on_click=False,
213
+ cache_examples=False,
214
+ fn=generate_example,
215
+ outputs=[],
216
+ )
217
+
218
+ for _ in range(3):
219
+ with gr.Row():
220
+ for _ in range(4):
221
+ video = gr.HTML()
222
+ videos.append(video)
223
+
224
+ # gr.HTML(WEBSITE_bottom)
225
+ # connect the examples to the output
226
+ # a bit hacky
227
+ examples.outputs = videos
228
+
229
+ def load_example(example_id):
230
+ processed_example = examples.non_none_processed_examples[example_id]
231
+ return gr.utils.resolve_singleton(processed_example)
232
+
233
+ examples.dataset.click(
234
+ load_example,
235
+ inputs=[examples.dataset],
236
+ outputs=examples.inputs_with_examples, # type: ignore
237
+ show_progress=False,
238
+ postprocess=False,
239
+ queue=False,
240
+ ).then(fn=generate_example, inputs=examples.inputs, outputs=videos + [results])
241
+
242
+ gen_btn.click(
243
+ fn=generate_and_show,
244
+ inputs=[text, motion_len, num_videos],
245
+ outputs=videos + [results],
246
+ )
247
+ text.submit(
248
+ fn=generate_and_show,
249
+ inputs=[text, motion_len, num_videos],
250
+ outputs=videos + [results],
251
+ )
252
+
253
+ def clear_videos():
254
+ return [None for _ in range(MAX_VIDEOS)] + [DEFAULT_TEXT] + [None]
255
+
256
+ clear.click(fn=clear_videos, outputs=videos + [text] + [results])
257
+
258
+ demo.launch()
configs/mld_control.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOLDER: './experiments_control'
2
+ TEST_FOLDER: './experiments_control_test'
3
+
4
+ NAME: 'mld_humanml'
5
+
6
+ TRAIN:
7
+ DATASETS: ['humanml3d']
8
+ BATCH_SIZE: 128
9
+ SPLIT: 'train'
10
+ NUM_WORKERS: 8
11
+ PERSISTENT_WORKERS: true
12
+ SEED_VALUE: 1234
13
+ PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml.ckpt'
14
+
15
+ validation_steps: -1
16
+ validation_epochs: 50
17
+ checkpointing_steps: -1
18
+ checkpointing_epochs: 50
19
+ max_train_steps: -1
20
+ max_train_epochs: 1000
21
+ learning_rate: 1e-4
22
+ learning_rate_spatial: 1e-4
23
+ lr_scheduler: "cosine"
24
+ lr_warmup_steps: 1000
25
+ adam_beta1: 0.9
26
+ adam_beta2: 0.999
27
+ adam_weight_decay: 0.0
28
+ adam_epsilon: 1e-08
29
+ max_grad_norm: 1.0
30
+
31
+ EVAL:
32
+ DATASETS: ['humanml3d']
33
+ BATCH_SIZE: 32
34
+ SPLIT: 'test'
35
+ NUM_WORKERS: 12
36
+
37
+ TEST:
38
+ DATASETS: ['humanml3d']
39
+ BATCH_SIZE: 1
40
+ SPLIT: 'test'
41
+ NUM_WORKERS: 12
42
+
43
+ CHECKPOINTS: 'experiments_control/mld_humanml/mld_humanml.ckpt'
44
+
45
+ # Testing Args
46
+ REPLICATION_TIMES: 1
47
+ MM_NUM_SAMPLES: 100
48
+ MM_NUM_REPEATS: 30
49
+ MM_NUM_TIMES: 10
50
+ DIVERSITY_TIMES: 300
51
+ MAX_NUM_SAMPLES: 1024
52
+
53
+ DATASET:
54
+ SMPL_PATH: './deps/smpl'
55
+ WORD_VERTILIZER_PATH: './deps/glove/'
56
+ HUMANML3D:
57
+ PICK_ONE_TEXT: true
58
+ FRAME_RATE: 20.0
59
+ UNIT_LEN: 4
60
+ ROOT: './datasets/humanml3d'
61
+ SPLIT_ROOT: './datasets/humanml3d'
62
+ SAMPLER:
63
+ MAX_LEN: 196
64
+ MIN_LEN: 40
65
+ MAX_TEXT_LEN: 20
66
+
67
+ METRIC:
68
+ DIST_SYNC_ON_STEP: true
69
+ TYPE: ['TM2TMetrics', 'ControlMetrics']
70
+
71
+ model:
72
+ target: 'modules_mld'
73
+ latent_dim: [1, 256]
74
+ guidance_scale: 7.5
75
+ guidance_uncondp: 0.1
76
+
77
+ # ControlNet Args
78
+ is_controlnet: true
79
+ is_controlnet_temporal: false
80
+ training_control_joint: [0]
81
+ testing_control_joint: [0]
82
+ training_density: 'random'
83
+ testing_density: 100
84
+ control_scale: 1.0
85
+ vaeloss: true
86
+ vaeloss_type: 'sum'
87
+ cond_ratio: 1.0
88
+ rot_ratio: 0.0
89
+
90
+ t2m_textencoder:
91
+ dim_word: 300
92
+ dim_pos_ohot: 15
93
+ dim_text_hidden: 512
94
+ dim_coemb_hidden: 512
95
+
96
+ t2m_motionencoder:
97
+ dim_move_hidden: 512
98
+ dim_move_latent: 512
99
+ dim_motion_hidden: 1024
100
+ dim_motion_latent: 512
101
+
102
+ bert_path: './deps/distilbert-base-uncased'
103
+ clip_path: './deps/clip-vit-large-patch14'
104
+ t5_path: './deps/sentence-t5-large'
105
+ t2m_path: './deps/t2m/'
configs/mld_t2m_infer.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOLDER: './experiments_t2m'
2
+ TEST_FOLDER: './experiments_t2m_test'
3
+
4
+ NAME: 'mld_humanml'
5
+
6
+ TRAIN:
7
+ DATASETS: ['humanml3d']
8
+ BATCH_SIZE: 1
9
+ NUM_WORKERS: 8
10
+ PERSISTENT_WORKERS: true
11
+ SEED_VALUE: 1234
12
+
13
+ EVAL:
14
+ DATASETS: ['humanml3d']
15
+ BATCH_SIZE: 32
16
+ SPLIT: test
17
+ NUM_WORKERS: 12
18
+
19
+ TEST:
20
+ DATASETS: ['humanml3d']
21
+ SPLIT: test
22
+ BATCH_SIZE: 1
23
+ NUM_WORKERS: 12
24
+
25
+ CHECKPOINTS: 'experiments_t2m/mld_humanml/mld_humanml.ckpt'
26
+
27
+ # Testing Args
28
+ REPLICATION_TIMES: 20
29
+ MM_NUM_SAMPLES: 100
30
+ MM_NUM_REPEATS: 30
31
+ MM_NUM_TIMES: 10
32
+ DIVERSITY_TIMES: 300
33
+
34
+ DATASET:
35
+ SMPL_PATH: './deps/smpl'
36
+ WORD_VERTILIZER_PATH: './deps/glove/'
37
+ HUMANML3D:
38
+ PICK_ONE_TEXT: true
39
+ FRAME_RATE: 20.0
40
+ UNIT_LEN: 4
41
+ ROOT: './datasets/humanml3d'
42
+ SPLIT_ROOT: './datasets/humanml3d'
43
+ SAMPLER:
44
+ MAX_LEN: 196
45
+ MIN_LEN: 40
46
+ MAX_TEXT_LEN: 20
47
+
48
+ METRIC:
49
+ DIST_SYNC_ON_STEP: True
50
+ TYPE: ['TM2TMetrics']
51
+
52
+ model:
53
+ target: 'modules_mld'
54
+ latent_dim: [1, 256]
55
+ guidance_scale: 7.5
56
+
57
+ t2m_textencoder:
58
+ dim_word: 300
59
+ dim_pos_ohot: 15
60
+ dim_text_hidden: 512
61
+ dim_coemb_hidden: 512
62
+
63
+ t2m_motionencoder:
64
+ dim_move_hidden: 512
65
+ dim_move_latent: 512
66
+ dim_motion_hidden: 1024
67
+ dim_motion_latent: 512
68
+
69
+ bert_path: './deps/distilbert-base-uncased'
70
+ clip_path: './deps/clip-vit-large-patch14'
71
+ t5_path: './deps/sentence-t5-large'
72
+ t2m_path: './deps/t2m/'
configs/modules/denoiser.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ denoiser:
2
+ target: mld.models.architectures.mld_denoiser.MldDenoiser
3
+ params:
4
+ text_encoded_dim: 768
5
+ ff_size: 1024
6
+ num_layers: 9
7
+ num_heads: 4
8
+ dropout: 0.1
9
+ normalize_before: false
10
+ activation: 'gelu'
11
+ flip_sin_to_cos: true
12
+ return_intermediate_dec: false
13
+ position_embedding: 'learned'
14
+ arch: 'trans_enc'
15
+ freq_shift: 0
16
+ latent_dim: ${model.latent_dim}
configs/modules/motion_vae.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ motion_vae:
2
+ target: mld.models.architectures.mld_vae.MldVae
3
+ params:
4
+ arch: 'encoder_decoder'
5
+ ff_size: 1024
6
+ num_layers: 9
7
+ num_heads: 4
8
+ dropout: 0.1
9
+ normalize_before: false
10
+ activation: 'gelu'
11
+ position_embedding: 'learned'
12
+ latent_dim: ${model.latent_dim}
13
+ nfeats: ${DATASET.NFEATS}
configs/modules/scheduler.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scheduler:
2
+ target: diffusers.LCMScheduler
3
+ num_inference_timesteps: 1
4
+ params:
5
+ num_train_timesteps: 1000
6
+ beta_start: 0.00085
7
+ beta_end: 0.012
8
+ beta_schedule: 'scaled_linear'
9
+ clip_sample: false
10
+ set_alpha_to_one: false
11
+
12
+ noise_scheduler:
13
+ target: diffusers.DDPMScheduler
14
+ params:
15
+ num_train_timesteps: 1000
16
+ beta_start: 0.00085
17
+ beta_end: 0.012
18
+ beta_schedule: 'scaled_linear'
19
+ variance_type: 'fixed_small'
20
+ clip_sample: false
configs/modules/text_encoder.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ text_encoder:
2
+ target: mld.models.architectures.mld_clip.MldTextEncoder
3
+ params:
4
+ last_hidden_state: false # if true, the last hidden state is used as the text embedding
5
+ modelpath: ${model.t5_path}
configs/modules/traj_encoder.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ traj_encoder:
2
+ target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder
3
+ params:
4
+ ff_size: 1024
5
+ num_layers: 9
6
+ num_heads: 4
7
+ dropout: 0.1
8
+ normalize_before: false
9
+ activation: 'gelu'
10
+ position_embedding: 'learned'
11
+ latent_dim: ${model.latent_dim}
12
+ nfeats: ${DATASET.NJOINTS}
configs/modules_mld/denoiser.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ denoiser:
2
+ target: mld.models.architectures.mld_denoiser.MldDenoiser
3
+ params:
4
+ text_encoded_dim: 768
5
+ ff_size: 1024
6
+ num_layers: 9
7
+ num_heads: 4
8
+ dropout: 0.1
9
+ normalize_before: false
10
+ activation: 'gelu'
11
+ flip_sin_to_cos: true
12
+ return_intermediate_dec: false
13
+ position_embedding: 'learned'
14
+ arch: 'trans_enc'
15
+ freq_shift: 0
16
+ latent_dim: ${model.latent_dim}
configs/modules_mld/motion_vae.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ motion_vae:
2
+ target: mld.models.architectures.mld_vae.MldVae
3
+ params:
4
+ arch: 'encoder_decoder'
5
+ ff_size: 1024
6
+ num_layers: 9
7
+ num_heads: 4
8
+ dropout: 0.1
9
+ normalize_before: false
10
+ activation: 'gelu'
11
+ position_embedding: 'learned'
12
+ latent_dim: ${model.latent_dim}
13
+ nfeats: ${DATASET.NFEATS}
configs/modules_mld/scheduler.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scheduler:
2
+ target: diffusers.DDIMScheduler
3
+ num_inference_timesteps: 50
4
+ eta: 0.0
5
+ params:
6
+ num_train_timesteps: 1000
7
+ beta_start: 0.00085
8
+ beta_end: 0.012
9
+ beta_schedule: 'scaled_linear'
10
+ clip_sample: false
11
+ # below are for ddim
12
+ set_alpha_to_one: false
13
+ steps_offset: 1
14
+
15
+ noise_scheduler:
16
+ target: diffusers.DDPMScheduler
17
+ params:
18
+ num_train_timesteps: 1000
19
+ beta_start: 0.00085
20
+ beta_end: 0.012
21
+ beta_schedule: 'scaled_linear'
22
+ variance_type: 'fixed_small'
23
+ clip_sample: false
configs/modules_mld/text_encoder.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ text_encoder:
2
+ target: mld.models.architectures.mld_clip.MldTextEncoder
3
+ params:
4
+ last_hidden_state: false # if true, the last hidden state is used as the text embedding
5
+ modelpath: ${model.t5_path}
configs/modules_mld/traj_encoder.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ traj_encoder:
2
+ target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder
3
+ params:
4
+ ff_size: 1024
5
+ num_layers: 9
6
+ num_heads: 4
7
+ dropout: 0.1
8
+ normalize_before: false
9
+ activation: 'gelu'
10
+ position_embedding: 'learned'
11
+ latent_dim: ${model.latent_dim}
12
+ nfeats: ${DATASET.NJOINTS}
configs/motionlcm_control.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOLDER: './experiments_control'
2
+ TEST_FOLDER: './experiments_control_test'
3
+
4
+ NAME: 'motionlcm_humanml'
5
+
6
+ TRAIN:
7
+ DATASETS: ['humanml3d']
8
+ BATCH_SIZE: 128
9
+ SPLIT: 'train'
10
+ NUM_WORKERS: 8
11
+ PERSISTENT_WORKERS: true
12
+ SEED_VALUE: 1234
13
+ PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
14
+
15
+ validation_steps: -1
16
+ validation_epochs: 50
17
+ checkpointing_steps: -1
18
+ checkpointing_epochs: 50
19
+ max_train_steps: -1
20
+ max_train_epochs: 1000
21
+ learning_rate: 1e-4
22
+ learning_rate_spatial: 1e-4
23
+ lr_scheduler: "cosine"
24
+ lr_warmup_steps: 1000
25
+ adam_beta1: 0.9
26
+ adam_beta2: 0.999
27
+ adam_weight_decay: 0.0
28
+ adam_epsilon: 1e-08
29
+ max_grad_norm: 1.0
30
+
31
+ EVAL:
32
+ DATASETS: ['humanml3d']
33
+ BATCH_SIZE: 32
34
+ SPLIT: 'test'
35
+ NUM_WORKERS: 12
36
+
37
+ TEST:
38
+ DATASETS: ['humanml3d']
39
+ BATCH_SIZE: 1
40
+ SPLIT: 'test'
41
+ NUM_WORKERS: 12
42
+
43
+ CHECKPOINTS: 'experiments_control/motionlcm_humanml/motionlcm_humanml.ckpt'
44
+
45
+ # Testing Args
46
+ REPLICATION_TIMES: 1
47
+ MM_NUM_SAMPLES: 100
48
+ MM_NUM_REPEATS: 30
49
+ MM_NUM_TIMES: 10
50
+ DIVERSITY_TIMES: 300
51
+ MAX_NUM_SAMPLES: 1024
52
+
53
+ DATASET:
54
+ SMPL_PATH: './deps/smpl'
55
+ WORD_VERTILIZER_PATH: './deps/glove/'
56
+ HUMANML3D:
57
+ PICK_ONE_TEXT: true
58
+ FRAME_RATE: 20.0
59
+ UNIT_LEN: 4
60
+ ROOT: './datasets/humanml3d'
61
+ SPLIT_ROOT: './datasets/humanml3d'
62
+ SAMPLER:
63
+ MAX_LEN: 196
64
+ MIN_LEN: 40
65
+ MAX_TEXT_LEN: 20
66
+
67
+ METRIC:
68
+ DIST_SYNC_ON_STEP: true
69
+ TYPE: ['TM2TMetrics', 'ControlMetrics']
70
+
71
+ model:
72
+ target: 'modules'
73
+ latent_dim: [1, 256]
74
+ guidance_scale: 7.5
75
+ guidance_uncondp: 0.0
76
+
77
+ # ControlNet Args
78
+ is_controlnet: true
79
+ is_controlnet_temporal: false
80
+ training_control_joint: [0]
81
+ testing_control_joint: [0]
82
+ training_density: 'random'
83
+ testing_density: 100
84
+ control_scale: 1.0
85
+ vaeloss: true
86
+ vaeloss_type: 'sum'
87
+ cond_ratio: 1.0
88
+ rot_ratio: 0.0
89
+
90
+ t2m_textencoder:
91
+ dim_word: 300
92
+ dim_pos_ohot: 15
93
+ dim_text_hidden: 512
94
+ dim_coemb_hidden: 512
95
+
96
+ t2m_motionencoder:
97
+ dim_move_hidden: 512
98
+ dim_move_latent: 512
99
+ dim_motion_hidden: 1024
100
+ dim_motion_latent: 512
101
+
102
+ bert_path: './deps/distilbert-base-uncased'
103
+ clip_path: './deps/clip-vit-large-patch14'
104
+ t5_path: './deps/sentence-t5-large'
105
+ t2m_path: './deps/t2m/'
configs/motionlcm_t2m.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOLDER: './experiments_t2m'
2
+ TEST_FOLDER: './experiments_t2m_test'
3
+
4
+ NAME: 'motionlcm_humanml'
5
+
6
+ TRAIN:
7
+ DATASETS: ['humanml3d']
8
+ BATCH_SIZE: 256
9
+ SPLIT: 'train'
10
+ NUM_WORKERS: 8
11
+ PERSISTENT_WORKERS: true
12
+ SEED_VALUE: 1234
13
+ PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml.ckpt'
14
+
15
+ validation_steps: -1
16
+ validation_epochs: 50
17
+ checkpointing_steps: -1
18
+ checkpointing_epochs: 50
19
+ max_train_steps: -1
20
+ max_train_epochs: 1000
21
+ learning_rate: 2e-4
22
+ lr_scheduler: "cosine"
23
+ lr_warmup_steps: 1000
24
+ adam_beta1: 0.9
25
+ adam_beta2: 0.999
26
+ adam_weight_decay: 0.0
27
+ adam_epsilon: 1e-08
28
+ max_grad_norm: 1.0
29
+
30
+ # Latent Consistency Distillation Specific Arguments
31
+ w_min: 5.0
32
+ w_max: 15.0
33
+ num_ddim_timesteps: 50
34
+ loss_type: 'huber'
35
+ huber_c: 0.001
36
+ unet_time_cond_proj_dim: 256
37
+ ema_decay: 0.95
38
+
39
+ EVAL:
40
+ DATASETS: ['humanml3d']
41
+ BATCH_SIZE: 32
42
+ SPLIT: 'test'
43
+ NUM_WORKERS: 12
44
+
45
+ TEST:
46
+ DATASETS: ['humanml3d']
47
+ BATCH_SIZE: 1
48
+ SPLIT: 'test'
49
+ NUM_WORKERS: 12
50
+
51
+ CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
52
+
53
+ # Testing Args
54
+ REPLICATION_TIMES: 20
55
+ MM_NUM_SAMPLES: 100
56
+ MM_NUM_REPEATS: 30
57
+ MM_NUM_TIMES: 10
58
+ DIVERSITY_TIMES: 300
59
+
60
+ DATASET:
61
+ SMPL_PATH: './deps/smpl'
62
+ WORD_VERTILIZER_PATH: './deps/glove/'
63
+ HUMANML3D:
64
+ PICK_ONE_TEXT: true
65
+ FRAME_RATE: 20.0
66
+ UNIT_LEN: 4
67
+ ROOT: './datasets/humanml3d'
68
+ SPLIT_ROOT: './datasets/humanml3d'
69
+ SAMPLER:
70
+ MAX_LEN: 196
71
+ MIN_LEN: 40
72
+ MAX_TEXT_LEN: 20
73
+
74
+ METRIC:
75
+ DIST_SYNC_ON_STEP: true
76
+ TYPE: ['TM2TMetrics']
77
+
78
+ model:
79
+ target: 'modules'
80
+ latent_dim: [1, 256]
81
+ guidance_scale: 7.5
82
+ guidance_uncondp: 0.0
83
+ is_controlnet: false
84
+
85
+ t2m_textencoder:
86
+ dim_word: 300
87
+ dim_pos_ohot: 15
88
+ dim_text_hidden: 512
89
+ dim_coemb_hidden: 512
90
+
91
+ t2m_motionencoder:
92
+ dim_move_hidden: 512
93
+ dim_move_latent: 512
94
+ dim_motion_hidden: 1024
95
+ dim_motion_latent: 512
96
+
97
+ bert_path: './deps/distilbert-base-uncased'
98
+ clip_path: './deps/clip-vit-large-patch14'
99
+ t5_path: './deps/sentence-t5-large'
100
+ t2m_path: './deps/t2m/'
demo.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import sys
4
+ import datetime
5
+ import logging
6
+ import os.path as osp
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+ import torch
11
+
12
+ from mld.config import parse_args
13
+ from mld.data.get_data import get_datasets
14
+ from mld.models.modeltype.mld import MLD
15
+ from mld.utils.utils import set_seed, move_batch_to_device
16
+ from mld.data.humanml.utils.plot_script import plot_3d_motion
17
+ from mld.utils.temos_utils import remove_padding
18
+
19
+
20
+ def load_example_input(text_path: str) -> tuple:
21
+ with open(text_path, "r") as f:
22
+ lines = f.readlines()
23
+
24
+ count = 0
25
+ texts, lens = [], []
26
+ # Strips the newline character
27
+ for line in lines:
28
+ count += 1
29
+ s = line.strip()
30
+ s_l = s.split(" ")[0]
31
+ s_t = s[(len(s_l) + 1):]
32
+ lens.append(int(s_l))
33
+ texts.append(s_t)
34
+ return texts, lens
35
+
36
+
37
+ def main():
38
+ cfg = parse_args()
39
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
40
+ set_seed(cfg.TRAIN.SEED_VALUE)
41
+
42
+ name_time_str = osp.join(cfg.NAME, "demo_" + datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
43
+ output_dir = osp.join(cfg.TEST_FOLDER, name_time_str)
44
+ vis_dir = osp.join(output_dir, 'samples')
45
+ os.makedirs(output_dir, exist_ok=False)
46
+ os.makedirs(vis_dir, exist_ok=False)
47
+
48
+ steam_handler = logging.StreamHandler(sys.stdout)
49
+ file_handler = logging.FileHandler(osp.join(output_dir, 'output.log'))
50
+ logging.basicConfig(level=logging.INFO,
51
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
52
+ datefmt="%m/%d/%Y %H:%M:%S",
53
+ handlers=[steam_handler, file_handler])
54
+ logger = logging.getLogger(__name__)
55
+
56
+ OmegaConf.save(cfg, osp.join(output_dir, 'config.yaml'))
57
+
58
+ state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"]
59
+ logger.info("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS))
60
+
61
+ lcm_key = 'denoiser.time_embedding.cond_proj.weight'
62
+ is_lcm = False
63
+ if lcm_key in state_dict:
64
+ is_lcm = True
65
+ time_cond_proj_dim = state_dict[lcm_key].shape[1]
66
+ cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim
67
+ logger.info(f'Is LCM: {is_lcm}')
68
+
69
+ cn_key = "controlnet.controlnet_cond_embedding.0.weight"
70
+ is_controlnet = True if cn_key in state_dict else False
71
+ cfg.model.is_controlnet = is_controlnet
72
+ logger.info(f'Is Controlnet: {is_controlnet}')
73
+
74
+ datasets = get_datasets(cfg, phase="test")[0]
75
+ model = MLD(cfg, datasets)
76
+ model.to(device)
77
+ model.eval()
78
+ model.load_state_dict(state_dict)
79
+
80
+ # example only support text-to-motion
81
+ if cfg.example is not None and not is_controlnet:
82
+ text, length = load_example_input(cfg.example)
83
+ for t, l in zip(text, length):
84
+ logger.info(f"{l}: {t}")
85
+
86
+ batch = {"length": length, "text": text}
87
+
88
+ for rep_i in range(cfg.replication):
89
+ with torch.no_grad():
90
+ joints, _ = model(batch)
91
+
92
+ num_samples = len(joints)
93
+ batch_id = 0
94
+ for i in range(num_samples):
95
+ res = dict()
96
+ pkl_path = osp.join(vis_dir, f"batch_id_{batch_id}_sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl")
97
+ res['joints'] = joints[i].detach().cpu().numpy()
98
+ res['text'] = text[i]
99
+ res['length'] = length[i]
100
+ res['hint'] = None
101
+ with open(pkl_path, 'wb') as f:
102
+ pickle.dump(res, f)
103
+ logger.info(f"Motions are generated here:\n{pkl_path}")
104
+
105
+ if not cfg.no_plot:
106
+ plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(), text[i], fps=20)
107
+
108
+ else:
109
+ test_dataloader = datasets.test_dataloader()
110
+ for rep_i in range(cfg.replication):
111
+ for batch_id, batch in enumerate(test_dataloader):
112
+ batch = move_batch_to_device(batch, device)
113
+ with torch.no_grad():
114
+ joints, joints_ref = model(batch)
115
+
116
+ num_samples = len(joints)
117
+ text = batch['text']
118
+ length = batch['length']
119
+ if 'hint' in batch:
120
+ hint = batch['hint']
121
+ mask_hint = hint.view(hint.shape[0], hint.shape[1], model.njoints, 3).sum(dim=-1, keepdim=True) != 0
122
+ hint = model.datamodule.denorm_spatial(hint)
123
+ hint = hint.view(hint.shape[0], hint.shape[1], model.njoints, 3) * mask_hint
124
+ hint = remove_padding(hint, lengths=length)
125
+ else:
126
+ hint = None
127
+
128
+ for i in range(num_samples):
129
+ res = dict()
130
+ pkl_path = osp.join(vis_dir, f"batch_id_{batch_id}_sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl")
131
+ res['joints'] = joints[i].detach().cpu().numpy()
132
+ res['text'] = text[i]
133
+ res['length'] = length[i]
134
+ res['hint'] = hint[i].detach().cpu().numpy() if hint is not None else None
135
+ with open(pkl_path, 'wb') as f:
136
+ pickle.dump(res, f)
137
+ logger.info(f"Motions are generated here:\n{pkl_path}")
138
+
139
+ if not cfg.no_plot:
140
+ plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(),
141
+ text[i], fps=20, hint=hint[i].detach().cpu().numpy() if hint is not None else None)
142
+
143
+ if rep_i == 0:
144
+ res['joints'] = joints_ref[i].detach().cpu().numpy()
145
+ with open(pkl_path.replace('.pkl', '_ref.pkl'), 'wb') as f:
146
+ pickle.dump(res, f)
147
+ logger.info(f"Motions are generated here:\n{pkl_path.replace('.pkl', '_ref.pkl')}")
148
+ if not cfg.no_plot:
149
+ plot_3d_motion(pkl_path.replace('.pkl', '_ref.mp4'), joints_ref[i].detach().cpu().numpy(),
150
+ text[i], fps=20, hint=hint[i].detach().cpu().numpy() if hint is not None else None)
151
+
152
+
153
+ if __name__ == "__main__":
154
+ main()
fit.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # borrow from optimization https://github.com/wangsen1312/joints2smpl
2
+ import os
3
+ import argparse
4
+ import pickle
5
+
6
+ import h5py
7
+ import natsort
8
+ import smplx
9
+
10
+ import torch
11
+
12
+ from mld.transforms.joints2rots import config
13
+ from mld.transforms.joints2rots.smplify import SMPLify3D
14
+
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("--pkl", type=str, default=None, help="pkl motion file")
17
+ parser.add_argument("--dir", type=str, default=None, help="pkl motion folder")
18
+ parser.add_argument("--num_smplify_iters", type=int, default=150, help="num of smplify iters")
19
+ parser.add_argument("--cuda", type=bool, default=True, help="enables cuda")
20
+ parser.add_argument("--gpu_ids", type=int, default=0, help="choose gpu ids")
21
+ parser.add_argument("--num_joints", type=int, default=22, help="joint number")
22
+ parser.add_argument("--joint_category", type=str, default="AMASS", help="use correspondence")
23
+ parser.add_argument("--fix_foot", type=str, default="False", help="fix foot or not")
24
+ opt = parser.parse_args()
25
+ print(opt)
26
+
27
+ if opt.pkl:
28
+ paths = [opt.pkl]
29
+ elif opt.dir:
30
+ paths = []
31
+ file_list = natsort.natsorted(os.listdir(opt.dir))
32
+ for item in file_list:
33
+ if item.endswith('.pkl') and not item.endswith("_mesh.pkl"):
34
+ paths.append(os.path.join(opt.dir, item))
35
+ else:
36
+ raise ValueError(f'{opt.pkl} and {opt.dir} are both None!')
37
+
38
+ for path in paths:
39
+ # load joints
40
+ if os.path.exists(path.replace('.pkl', '_mesh.pkl')):
41
+ print(f"{path} is rendered! skip!")
42
+ continue
43
+
44
+ with open(path, 'rb') as f:
45
+ data = pickle.load(f)
46
+
47
+ joints = data['joints']
48
+ # load predefined something
49
+ device = torch.device("cuda:" + str(opt.gpu_ids) if opt.cuda else "cpu")
50
+ print(config.SMPL_MODEL_DIR)
51
+ smplxmodel = smplx.create(
52
+ config.SMPL_MODEL_DIR,
53
+ model_type="smpl",
54
+ gender="neutral",
55
+ ext="pkl",
56
+ batch_size=joints.shape[0],
57
+ ).to(device)
58
+
59
+ # load the mean pose as original
60
+ smpl_mean_file = config.SMPL_MEAN_FILE
61
+
62
+ file = h5py.File(smpl_mean_file, "r")
63
+ init_mean_pose = (
64
+ torch.from_numpy(file["pose"][:])
65
+ .unsqueeze(0).repeat(joints.shape[0], 1)
66
+ .float()
67
+ .to(device)
68
+ )
69
+ init_mean_shape = (
70
+ torch.from_numpy(file["shape"][:])
71
+ .unsqueeze(0).repeat(joints.shape[0], 1)
72
+ .float()
73
+ .to(device)
74
+ )
75
+ cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device)
76
+
77
+ # initialize SMPLify
78
+ smplify = SMPLify3D(
79
+ smplxmodel=smplxmodel,
80
+ batch_size=joints.shape[0],
81
+ joints_category=opt.joint_category,
82
+ num_iters=opt.num_smplify_iters,
83
+ device=device,
84
+ )
85
+ print("initialize SMPLify3D done!")
86
+
87
+ print("Start SMPLify!")
88
+ keypoints_3d = torch.Tensor(joints).to(device).float()
89
+
90
+ if opt.joint_category == "AMASS":
91
+ confidence_input = torch.ones(opt.num_joints)
92
+ # make sure the foot and ankle
93
+ if opt.fix_foot:
94
+ confidence_input[7] = 1.5
95
+ confidence_input[8] = 1.5
96
+ confidence_input[10] = 1.5
97
+ confidence_input[11] = 1.5
98
+ else:
99
+ print("Such category not settle down!")
100
+
101
+ # ----- from initial to fitting -------
102
+ (
103
+ new_opt_vertices,
104
+ new_opt_joints,
105
+ new_opt_pose,
106
+ new_opt_betas,
107
+ new_opt_cam_t,
108
+ new_opt_joint_loss,
109
+ ) = smplify(
110
+ init_mean_pose.detach(),
111
+ init_mean_shape.detach(),
112
+ cam_trans_zero.detach(),
113
+ keypoints_3d,
114
+ conf_3d=confidence_input.to(device)
115
+ )
116
+
117
+ # fix shape
118
+ betas = torch.zeros_like(new_opt_betas)
119
+ root = keypoints_3d[:, 0, :]
120
+
121
+ output = smplxmodel(
122
+ betas=betas,
123
+ global_orient=new_opt_pose[:, :3],
124
+ body_pose=new_opt_pose[:, 3:],
125
+ transl=root,
126
+ return_verts=True,
127
+ )
128
+ vertices = output.vertices.detach().cpu().numpy()
129
+ data['vertices'] = vertices
130
+
131
+ save_file = path.replace('.pkl', '_mesh.pkl')
132
+ with open(save_file, 'wb') as f:
133
+ pickle.dump(data, f)
134
+ print(f'vertices saved in {save_file}')
mld/__init__.py ADDED
File without changes
mld/config.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+ from typing import Type, TypeVar
4
+ from argparse import ArgumentParser
5
+
6
+ from omegaconf import OmegaConf, DictConfig
7
+
8
+
9
+ def get_module_config(cfg_model: DictConfig, path: str = "modules") -> DictConfig:
10
+ files = os.listdir(f'./configs/{path}/')
11
+ for file in files:
12
+ if file.endswith('.yaml'):
13
+ with open(f'./configs/{path}/' + file, 'r') as f:
14
+ cfg_model.merge_with(OmegaConf.load(f))
15
+ return cfg_model
16
+
17
+
18
+ def get_obj_from_str(string: str, reload: bool = False) -> Type:
19
+ module, cls = string.rsplit(".", 1)
20
+ if reload:
21
+ module_imp = importlib.import_module(module)
22
+ importlib.reload(module_imp)
23
+ return getattr(importlib.import_module(module, package=None), cls)
24
+
25
+
26
+ def instantiate_from_config(config: DictConfig) -> TypeVar:
27
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
28
+
29
+
30
+ def parse_args() -> DictConfig:
31
+ parser = ArgumentParser()
32
+ parser.add_argument("--cfg", type=str, required=True, help="config file")
33
+
34
+ # Demo Args
35
+ parser.add_argument('--example', type=str, required=False, help="input text and lengths with txt format")
36
+ parser.add_argument('--no-plot', action="store_true", required=False, help="whether plot the skeleton-based motion")
37
+ parser.add_argument('--replication', type=int, default=1, help="the number of replication of sampling")
38
+ args = parser.parse_args()
39
+
40
+ cfg = OmegaConf.load(args.cfg)
41
+ cfg_model = get_module_config(cfg.model, cfg.model.target)
42
+ cfg = OmegaConf.merge(cfg, cfg_model)
43
+
44
+ cfg.example = args.example
45
+ cfg.no_plot = args.no_plot
46
+ cfg.replication = args.replication
47
+ return cfg
mld/data/HumanML3D.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Callable, Optional
3
+
4
+ import numpy as np
5
+ from omegaconf import DictConfig
6
+
7
+ import torch
8
+
9
+ from .base import BASEDataModule
10
+ from .humanml.dataset import Text2MotionDatasetV2
11
+ from .humanml.scripts.motion_process import recover_from_ric
12
+
13
+
14
+ class HumanML3DDataModule(BASEDataModule):
15
+
16
+ def __init__(self,
17
+ cfg: DictConfig,
18
+ batch_size: int,
19
+ num_workers: int,
20
+ collate_fn: Optional[Callable] = None,
21
+ persistent_workers: bool = True,
22
+ phase: str = "train",
23
+ **kwargs) -> None:
24
+ super().__init__(batch_size=batch_size,
25
+ num_workers=num_workers,
26
+ collate_fn=collate_fn,
27
+ persistent_workers=persistent_workers)
28
+ self.hparams = copy.deepcopy(kwargs)
29
+ self.name = "humanml3d"
30
+ self.njoints = 22
31
+ if phase == "text_only":
32
+ raise NotImplementedError
33
+ else:
34
+ self.Dataset = Text2MotionDatasetV2
35
+ self.cfg = cfg
36
+
37
+ sample_overrides = {"tiny": True, "progress_bar": False}
38
+ self._sample_set = self.get_sample_set(overrides=sample_overrides)
39
+ self.nfeats = self._sample_set.nfeats
40
+
41
+ def denorm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
42
+ raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
43
+ raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
44
+ hint = hint * raw_std + raw_mean
45
+ return hint
46
+
47
+ def norm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
48
+ raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
49
+ raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
50
+ hint = (hint - raw_mean) / raw_std
51
+ return hint
52
+
53
+ def feats2joints(self, features: torch.Tensor) -> torch.Tensor:
54
+ mean = torch.tensor(self.hparams['mean']).to(features)
55
+ std = torch.tensor(self.hparams['std']).to(features)
56
+ features = features * std + mean
57
+ return recover_from_ric(features, self.njoints)
58
+
59
+ def renorm4t2m(self, features: torch.Tensor) -> torch.Tensor:
60
+ # renorm to t2m norms for using t2m evaluators
61
+ ori_mean = torch.tensor(self.hparams['mean']).to(features)
62
+ ori_std = torch.tensor(self.hparams['std']).to(features)
63
+ eval_mean = torch.tensor(self.hparams['mean_eval']).to(features)
64
+ eval_std = torch.tensor(self.hparams['std_eval']).to(features)
65
+ features = features * ori_std + ori_mean
66
+ features = (features - eval_mean) / eval_std
67
+ return features
68
+
69
+ def mm_mode(self, mm_on: bool = True) -> None:
70
+ if mm_on:
71
+ self.is_mm = True
72
+ self.name_list = self.test_dataset.name_list
73
+ self.mm_list = np.random.choice(self.name_list,
74
+ self.cfg.TEST.MM_NUM_SAMPLES,
75
+ replace=False)
76
+ self.test_dataset.name_list = self.mm_list
77
+ else:
78
+ self.is_mm = False
79
+ self.test_dataset.name_list = self.name_list
mld/data/Kit.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Callable, Optional
3
+
4
+ import numpy as np
5
+ from omegaconf import DictConfig
6
+
7
+ import torch
8
+
9
+ from .base import BASEDataModule
10
+ from .humanml.dataset import Text2MotionDatasetV2
11
+ from .humanml.scripts.motion_process import recover_from_ric
12
+
13
+
14
+ class KitDataModule(BASEDataModule):
15
+
16
+ def __init__(self,
17
+ cfg: DictConfig,
18
+ batch_size: int,
19
+ num_workers: int,
20
+ collate_fn: Optional[Callable] = None,
21
+ persistent_workers: bool = True,
22
+ phase: str = "train",
23
+ **kwargs) -> None:
24
+ super().__init__(batch_size=batch_size,
25
+ num_workers=num_workers,
26
+ collate_fn=collate_fn,
27
+ persistent_workers=persistent_workers)
28
+ self.hparams = copy.deepcopy(kwargs)
29
+ self.name = 'kit'
30
+ self.njoints = 21
31
+ if phase == 'text_only':
32
+ raise NotImplementedError
33
+ else:
34
+ self.Dataset = Text2MotionDatasetV2
35
+ self.cfg = cfg
36
+
37
+ sample_overrides = {"tiny": True, "progress_bar": False}
38
+ self._sample_set = self.get_sample_set(overrides=sample_overrides)
39
+ self.nfeats = self._sample_set.nfeats
40
+
41
+ def denorm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
42
+ raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
43
+ raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
44
+ hint = hint * raw_std + raw_mean
45
+ return hint
46
+
47
+ def norm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
48
+ raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
49
+ raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
50
+ hint = (hint - raw_mean) / raw_std
51
+ return hint
52
+
53
+ def feats2joints(self, features: torch.Tensor) -> torch.Tensor:
54
+ mean = torch.tensor(self.hparams['mean']).to(features)
55
+ std = torch.tensor(self.hparams['std']).to(features)
56
+ features = features * std + mean
57
+ return recover_from_ric(features, self.njoints)
58
+
59
+ def renorm4t2m(self, features: torch.Tensor) -> torch.Tensor:
60
+ # renorm to t2m norms for using t2m evaluators
61
+ ori_mean = torch.tensor(self.hparams['mean']).to(features)
62
+ ori_std = torch.tensor(self.hparams['std']).to(features)
63
+ eval_mean = torch.tensor(self.hparams['mean_eval']).to(features)
64
+ eval_std = torch.tensor(self.hparams['std_eval']).to(features)
65
+ features = features * ori_std + ori_mean
66
+ features = (features - eval_mean) / eval_std
67
+ return features
68
+
69
+ def mm_mode(self, mm_on: bool = True) -> None:
70
+ if mm_on:
71
+ self.is_mm = True
72
+ self.name_list = self.test_dataset.name_list
73
+ self.mm_list = np.random.choice(self.name_list,
74
+ self.cfg.TEST.MM_NUM_SAMPLES,
75
+ replace=False)
76
+ self.test_dataset.name_list = self.mm_list
77
+ else:
78
+ self.is_mm = False
79
+ self.test_dataset.name_list = self.name_list
mld/data/__init__.py ADDED
File without changes
mld/data/base.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from os.path import join as pjoin
3
+ from typing import Any, Callable
4
+
5
+ from torch.utils.data import DataLoader
6
+
7
+ from .humanml.dataset import Text2MotionDatasetV2
8
+
9
+
10
+ class BASEDataModule:
11
+ def __init__(self, collate_fn: Callable, batch_size: int,
12
+ num_workers: int, persistent_workers: bool) -> None:
13
+ super(BASEDataModule, self).__init__()
14
+ self.dataloader_options = {
15
+ "batch_size": batch_size,
16
+ "num_workers": num_workers,
17
+ "collate_fn": collate_fn,
18
+ "persistent_workers": persistent_workers
19
+ }
20
+ self.is_mm = False
21
+
22
+ def get_sample_set(self, overrides: dict) -> Text2MotionDatasetV2:
23
+ sample_params = copy.deepcopy(self.hparams)
24
+ sample_params.update(overrides)
25
+ split_file = pjoin(
26
+ eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"),
27
+ self.cfg.EVAL.SPLIT + ".txt",
28
+ )
29
+ return self.Dataset(split_file=split_file, **sample_params)
30
+
31
+ def __getattr__(self, item: str) -> Any:
32
+ if item.endswith("_dataset") and not item.startswith("_"):
33
+ subset = item[:-len("_dataset")]
34
+ item_c = "_" + item
35
+ if item_c not in self.__dict__:
36
+
37
+ subset = subset.upper() if subset != "val" else "EVAL"
38
+ split = eval(f"self.cfg.{subset}.SPLIT")
39
+ split_file = pjoin(
40
+ eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"),
41
+ eval(f"self.cfg.{subset}.SPLIT") + ".txt",
42
+ )
43
+ self.__dict__[item_c] = self.Dataset(split_file=split_file,
44
+ split=split,
45
+ **self.hparams)
46
+ return getattr(self, item_c)
47
+ classname = self.__class__.__name__
48
+ raise AttributeError(f"'{classname}' object has no attribute '{item}'")
49
+
50
+ def train_dataloader(self) -> DataLoader:
51
+ return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_options)
52
+
53
+ def val_dataloader(self) -> DataLoader:
54
+ dataloader_options = self.dataloader_options.copy()
55
+ dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE
56
+ dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS
57
+ dataloader_options["shuffle"] = False
58
+ return DataLoader(self.val_dataset, **dataloader_options)
59
+
60
+ def test_dataloader(self) -> DataLoader:
61
+ dataloader_options = self.dataloader_options.copy()
62
+ dataloader_options["batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
63
+ dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
64
+ dataloader_options["shuffle"] = False
65
+ return DataLoader(self.test_dataset, **dataloader_options)
mld/data/get_data.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import join as pjoin
2
+ from typing import Callable, Optional
3
+
4
+ import numpy as np
5
+
6
+ from omegaconf import DictConfig
7
+
8
+ from .humanml.utils.word_vectorizer import WordVectorizer
9
+ from .HumanML3D import HumanML3DDataModule
10
+ from .Kit import KitDataModule
11
+ from .base import BASEDataModule
12
+ from .utils import mld_collate
13
+
14
+
15
+ def get_mean_std(phase: str, cfg: DictConfig, dataset_name: str) -> tuple[np.ndarray, np.ndarray]:
16
+ name = "t2m" if dataset_name == "humanml3d" else dataset_name
17
+ assert name in ["t2m", "kit"]
18
+ if phase in ["val"]:
19
+ if name == 't2m':
20
+ data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD01", "meta")
21
+ elif name == 'kit':
22
+ data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD005", "meta")
23
+ else:
24
+ raise ValueError("Only support t2m and kit")
25
+ mean = np.load(pjoin(data_root, "mean.npy"))
26
+ std = np.load(pjoin(data_root, "std.npy"))
27
+ else:
28
+ data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
29
+ mean = np.load(pjoin(data_root, "Mean.npy"))
30
+ std = np.load(pjoin(data_root, "Std.npy"))
31
+
32
+ return mean, std
33
+
34
+
35
+ def get_WordVectorizer(cfg: DictConfig, phase: str, dataset_name: str) -> Optional[WordVectorizer]:
36
+ if phase not in ["text_only"]:
37
+ if dataset_name.lower() in ["humanml3d", "kit"]:
38
+ return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab")
39
+ else:
40
+ raise ValueError("Only support WordVectorizer for HumanML3D")
41
+ else:
42
+ return None
43
+
44
+
45
+ def get_collate_fn(name: str) -> Callable:
46
+ if name.lower() in ["humanml3d", "kit"]:
47
+ return mld_collate
48
+ else:
49
+ raise NotImplementedError
50
+
51
+
52
+ dataset_module_map = {"humanml3d": HumanML3DDataModule, "kit": KitDataModule}
53
+ motion_subdir = {"humanml3d": "new_joint_vecs", "kit": "new_joint_vecs"}
54
+
55
+
56
+ def get_datasets(cfg: DictConfig, phase: str = "train") -> list[BASEDataModule]:
57
+ dataset_names = eval(f"cfg.{phase.upper()}.DATASETS")
58
+ datasets = []
59
+ for dataset_name in dataset_names:
60
+ if dataset_name.lower() in ["humanml3d", "kit"]:
61
+ data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
62
+ mean, std = get_mean_std(phase, cfg, dataset_name)
63
+ mean_eval, std_eval = get_mean_std("val", cfg, dataset_name)
64
+ wordVectorizer = get_WordVectorizer(cfg, phase, dataset_name)
65
+ collate_fn = get_collate_fn(dataset_name)
66
+ dataset = dataset_module_map[dataset_name.lower()](
67
+ cfg=cfg,
68
+ batch_size=cfg.TRAIN.BATCH_SIZE,
69
+ num_workers=cfg.TRAIN.NUM_WORKERS,
70
+ collate_fn=collate_fn,
71
+ persistent_workers=cfg.TRAIN.PERSISTENT_WORKERS,
72
+ mean=mean,
73
+ std=std,
74
+ mean_eval=mean_eval,
75
+ std_eval=std_eval,
76
+ w_vectorizer=wordVectorizer,
77
+ text_dir=pjoin(data_root, "texts"),
78
+ motion_dir=pjoin(data_root, motion_subdir[dataset_name]),
79
+ max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN,
80
+ min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN,
81
+ max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN,
82
+ unit_length=eval(
83
+ f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"),
84
+ model_kwargs=cfg.model
85
+ )
86
+ datasets.append(dataset)
87
+
88
+ elif dataset_name.lower() in ["humanact12", 'uestc', "amass"]:
89
+ raise NotImplementedError
90
+
91
+ cfg.DATASET.NFEATS = datasets[0].nfeats
92
+ cfg.DATASET.NJOINTS = datasets[0].njoints
93
+ return datasets
mld/data/humanml/__init__.py ADDED
File without changes
mld/data/humanml/common/quaternion.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def qinv(q: torch.Tensor) -> torch.Tensor:
5
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
6
+ mask = torch.ones_like(q)
7
+ mask[..., 1:] = -mask[..., 1:]
8
+ return q * mask
9
+
10
+
11
+ def qrot(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
12
+ """
13
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
14
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
15
+ where * denotes any number of dimensions.
16
+ Returns a tensor of shape (*, 3).
17
+ """
18
+ assert q.shape[-1] == 4
19
+ assert v.shape[-1] == 3
20
+ assert q.shape[:-1] == v.shape[:-1]
21
+
22
+ original_shape = list(v.shape)
23
+ q = q.contiguous().view(-1, 4)
24
+ v = v.contiguous().view(-1, 3)
25
+
26
+ qvec = q[:, 1:]
27
+ uv = torch.cross(qvec, v, dim=1)
28
+ uuv = torch.cross(qvec, uv, dim=1)
29
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
mld/data/humanml/dataset.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import codecs as cs
2
+ import random
3
+ from os.path import join as pjoin
4
+
5
+ import numpy as np
6
+ from rich.progress import track
7
+
8
+ import torch
9
+ from torch.utils import data
10
+
11
+ from mld.data.humanml.scripts.motion_process import recover_from_ric
12
+ from .utils.word_vectorizer import WordVectorizer
13
+
14
+
15
+ class Text2MotionDatasetV2(data.Dataset):
16
+
17
+ def __init__(
18
+ self,
19
+ mean: np.ndarray,
20
+ std: np.ndarray,
21
+ split_file: str,
22
+ w_vectorizer: WordVectorizer,
23
+ max_motion_length: int,
24
+ min_motion_length: int,
25
+ max_text_len: int,
26
+ unit_length: int,
27
+ motion_dir: str,
28
+ text_dir: str,
29
+ tiny: bool = False,
30
+ progress_bar: bool = True,
31
+ **kwargs,
32
+ ) -> None:
33
+ self.w_vectorizer = w_vectorizer
34
+ self.max_motion_length = max_motion_length
35
+ self.min_motion_length = min_motion_length
36
+ self.max_text_len = max_text_len
37
+ self.unit_length = unit_length
38
+
39
+ data_dict = {}
40
+ id_list = []
41
+ with cs.open(split_file, "r") as f:
42
+ for line in f.readlines():
43
+ id_list.append(line.strip())
44
+ self.id_list = id_list
45
+
46
+ if tiny:
47
+ progress_bar = False
48
+ maxdata = 10
49
+ else:
50
+ maxdata = 1e10
51
+
52
+ if progress_bar:
53
+ enumerator = enumerate(
54
+ track(
55
+ id_list,
56
+ f"Loading HumanML3D {split_file.split('/')[-1].split('.')[0]}",
57
+ ))
58
+ else:
59
+ enumerator = enumerate(id_list)
60
+ count = 0
61
+ bad_count = 0
62
+ new_name_list = []
63
+ length_list = []
64
+ for i, name in enumerator:
65
+ if count > maxdata:
66
+ break
67
+ try:
68
+ motion = np.load(pjoin(motion_dir, name + ".npy"))
69
+ if (len(motion)) < self.min_motion_length or (len(motion) >= 200):
70
+ bad_count += 1
71
+ continue
72
+ text_data = []
73
+ flag = False
74
+ with cs.open(pjoin(text_dir, name + ".txt")) as f:
75
+ for line in f.readlines():
76
+ text_dict = {}
77
+ line_split = line.strip().split("#")
78
+ caption = line_split[0]
79
+ tokens = line_split[1].split(" ")
80
+ f_tag = float(line_split[2])
81
+ to_tag = float(line_split[3])
82
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
83
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
84
+
85
+ text_dict["caption"] = caption
86
+ text_dict["tokens"] = tokens
87
+ if f_tag == 0.0 and to_tag == 0.0:
88
+ flag = True
89
+ text_data.append(text_dict)
90
+ else:
91
+ try:
92
+ n_motion = motion[int(f_tag * 20):int(to_tag *
93
+ 20)]
94
+ if (len(n_motion)
95
+ ) < self.min_motion_length or (
96
+ (len(n_motion) >= 200)):
97
+ continue
98
+ new_name = (
99
+ random.choice("ABCDEFGHIJKLMNOPQRSTUVW") +
100
+ "_" + name)
101
+ while new_name in data_dict:
102
+ new_name = (random.choice(
103
+ "ABCDEFGHIJKLMNOPQRSTUVW") + "_" +
104
+ name)
105
+ data_dict[new_name] = {
106
+ "motion": n_motion,
107
+ "length": len(n_motion),
108
+ "text": [text_dict],
109
+ }
110
+ new_name_list.append(new_name)
111
+ length_list.append(len(n_motion))
112
+ except:
113
+ print(line_split)
114
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
115
+
116
+ if flag:
117
+ data_dict[name] = {
118
+ "motion": motion,
119
+ "length": len(motion),
120
+ "text": text_data,
121
+ }
122
+ new_name_list.append(name)
123
+ length_list.append(len(motion))
124
+ count += 1
125
+ except:
126
+ pass
127
+
128
+ name_list, length_list = zip(
129
+ *sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
130
+
131
+ self.mean = mean
132
+ self.std = std
133
+
134
+ self.mode = None
135
+ model_params = kwargs['model_kwargs']
136
+ if 'is_controlnet' in model_params and model_params.is_controlnet is True:
137
+ if 'test' in split_file or 'val' in split_file:
138
+ self.mode = 'eval'
139
+ else:
140
+ self.mode = 'train'
141
+
142
+ self.t_ctrl = model_params.is_controlnet_temporal
143
+ spatial_norm_path = './datasets/humanml_spatial_norm'
144
+ self.raw_mean = np.load(pjoin(spatial_norm_path, 'Mean_raw.npy'))
145
+ self.raw_std = np.load(pjoin(spatial_norm_path, 'Std_raw.npy'))
146
+
147
+ self.training_control_joint = np.array(model_params.training_control_joint)
148
+ self.testing_control_joint = np.array(model_params.testing_control_joint)
149
+
150
+ self.training_density = model_params.training_density
151
+ self.testing_density = model_params.testing_density
152
+
153
+ self.length_arr = np.array(length_list)
154
+ self.data_dict = data_dict
155
+ self.nfeats = motion.shape[1]
156
+ self.name_list = name_list
157
+
158
+ def __len__(self) -> int:
159
+ return len(self.name_list)
160
+
161
+ def random_mask(self, joints: np.ndarray, n_joints: int = 22) -> np.ndarray:
162
+ choose_joint = self.testing_control_joint
163
+
164
+ length = joints.shape[0]
165
+ density = self.testing_density
166
+ if density in [1, 2, 5]:
167
+ choose_seq_num = density
168
+ else:
169
+ choose_seq_num = int(length * density / 100)
170
+
171
+ if self.t_ctrl:
172
+ choose_seq = np.arange(0, choose_seq_num)
173
+ else:
174
+ choose_seq = np.random.choice(length, choose_seq_num, replace=False)
175
+ choose_seq.sort()
176
+
177
+ mask_seq = np.zeros((length, n_joints, 3)).astype(bool)
178
+
179
+ for cj in choose_joint:
180
+ mask_seq[choose_seq, cj] = True
181
+
182
+ # normalize
183
+ joints = (joints - self.raw_mean.reshape(n_joints, 3)) / self.raw_std.reshape(n_joints, 3)
184
+ joints = joints * mask_seq
185
+ return joints
186
+
187
+ def random_mask_train(self, joints: np.ndarray, n_joints: int = 22) -> np.ndarray:
188
+ if self.t_ctrl:
189
+ choose_joint = self.training_control_joint
190
+ else:
191
+ num_joints = len(self.training_control_joint)
192
+ num_joints_control = 1
193
+ choose_joint = np.random.choice(num_joints, num_joints_control, replace=False)
194
+ choose_joint = self.training_control_joint[choose_joint]
195
+
196
+ length = joints.shape[0]
197
+
198
+ if self.training_density == 'random':
199
+ choose_seq_num = np.random.choice(length - 1, 1) + 1
200
+ else:
201
+ choose_seq_num = int(length * random.uniform(self.training_density[0], self.training_density[1]) / 100)
202
+
203
+ if self.t_ctrl:
204
+ choose_seq = np.arange(0, choose_seq_num)
205
+ else:
206
+ choose_seq = np.random.choice(length, choose_seq_num, replace=False)
207
+ choose_seq.sort()
208
+
209
+ mask_seq = np.zeros((length, n_joints, 3)).astype(bool)
210
+
211
+ for cj in choose_joint:
212
+ mask_seq[choose_seq, cj] = True
213
+
214
+ # normalize
215
+ joints = (joints - self.raw_mean.reshape(n_joints, 3)) / self.raw_std.reshape(n_joints, 3)
216
+ joints = joints * mask_seq
217
+ return joints
218
+
219
+ def __getitem__(self, idx: int) -> tuple:
220
+ data = self.data_dict[self.name_list[idx]]
221
+ motion, m_length, text_list = data["motion"], data["length"], data["text"]
222
+ # Randomly select a caption
223
+ text_data = random.choice(text_list)
224
+ caption, tokens = text_data["caption"], text_data["tokens"]
225
+
226
+ if len(tokens) < self.max_text_len:
227
+ # pad with "unk"
228
+ tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
229
+ sent_len = len(tokens)
230
+ tokens = tokens + ["unk/OTHER"
231
+ ] * (self.max_text_len + 2 - sent_len)
232
+ else:
233
+ # crop
234
+ tokens = tokens[:self.max_text_len]
235
+ tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
236
+ sent_len = len(tokens)
237
+ pos_one_hots = []
238
+ word_embeddings = []
239
+ for token in tokens:
240
+ word_emb, pos_oh = self.w_vectorizer[token]
241
+ pos_one_hots.append(pos_oh[None, :])
242
+ word_embeddings.append(word_emb[None, :])
243
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
244
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
245
+
246
+ # Crop the motions in to times of 4, and introduce small variations
247
+ if self.unit_length < 10:
248
+ coin2 = np.random.choice(["single", "single", "double"])
249
+ else:
250
+ coin2 = "single"
251
+
252
+ if coin2 == "double":
253
+ m_length = (m_length // self.unit_length - 1) * self.unit_length
254
+ elif coin2 == "single":
255
+ m_length = (m_length // self.unit_length) * self.unit_length
256
+ idx = random.randint(0, len(motion) - m_length)
257
+ motion = motion[idx:idx + m_length]
258
+
259
+ hint = None
260
+ if self.mode is not None:
261
+ n_joints = 22 if motion.shape[-1] == 263 else 21
262
+ # hint is global position of the controllable joints
263
+ joints = recover_from_ric(torch.from_numpy(motion).float(), n_joints)
264
+ joints = joints.numpy()
265
+
266
+ # control any joints at any time
267
+ if self.mode == 'train':
268
+ hint = self.random_mask_train(joints, n_joints)
269
+ else:
270
+ hint = self.random_mask(joints, n_joints)
271
+
272
+ hint = hint.reshape(hint.shape[0], -1)
273
+
274
+ "Z Normalization"
275
+ motion = (motion - self.mean) / self.std
276
+
277
+ # debug check nan
278
+ if np.any(np.isnan(motion)):
279
+ raise ValueError("nan in motion")
280
+
281
+ return (
282
+ word_embeddings,
283
+ pos_one_hots,
284
+ caption,
285
+ sent_len,
286
+ motion,
287
+ m_length,
288
+ "_".join(tokens),
289
+ hint
290
+ )
mld/data/humanml/scripts/motion_process.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..common.quaternion import qinv, qrot
4
+
5
+
6
+ # Recover global angle and positions for rotation dataset
7
+ # root_rot_velocity (B, seq_len, 1)
8
+ # root_linear_velocity (B, seq_len, 2)
9
+ # root_y (B, seq_len, 1)
10
+ # ric_data (B, seq_len, (joint_num - 1)*3)
11
+ # rot_data (B, seq_len, (joint_num - 1)*6)
12
+ # local_velocity (B, seq_len, joint_num*3)
13
+ # foot contact (B, seq_len, 4)
14
+ def recover_root_rot_pos(data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
15
+ rot_vel = data[..., 0]
16
+ r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
17
+ '''Get Y-axis rotation from rotation velocity'''
18
+ r_rot_ang[..., 1:] = rot_vel[..., :-1]
19
+ r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
20
+
21
+ r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
22
+ r_rot_quat[..., 0] = torch.cos(r_rot_ang)
23
+ r_rot_quat[..., 2] = torch.sin(r_rot_ang)
24
+
25
+ r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
26
+ r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
27
+ '''Add Y-axis rotation to root position'''
28
+ r_pos = qrot(qinv(r_rot_quat), r_pos)
29
+
30
+ r_pos = torch.cumsum(r_pos, dim=-2)
31
+
32
+ r_pos[..., 1] = data[..., 3]
33
+ return r_rot_quat, r_pos
34
+
35
+
36
+ def recover_from_ric(data: torch.Tensor, joints_num: int) -> torch.Tensor:
37
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
38
+ positions = data[..., 4:(joints_num - 1) * 3 + 4]
39
+ positions = positions.view(positions.shape[:-1] + (-1, 3))
40
+
41
+ '''Add Y-axis rotation to local joints'''
42
+ positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
43
+
44
+ '''Add root XZ to joints'''
45
+ positions[..., 0] += r_pos[..., 0:1]
46
+ positions[..., 2] += r_pos[..., 2:3]
47
+
48
+ '''Concat root and joints'''
49
+ positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
50
+
51
+ return positions
mld/data/humanml/utils/__init__.py ADDED
File without changes
mld/data/humanml/utils/paramUtil.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # Define a kinematic tree for the skeletal structure
4
+ kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
5
+
6
+ kit_raw_offsets = np.array(
7
+ [
8
+ [0, 0, 0],
9
+ [0, 1, 0],
10
+ [0, 1, 0],
11
+ [0, 1, 0],
12
+ [0, 1, 0],
13
+ [1, 0, 0],
14
+ [0, -1, 0],
15
+ [0, -1, 0],
16
+ [-1, 0, 0],
17
+ [0, -1, 0],
18
+ [0, -1, 0],
19
+ [1, 0, 0],
20
+ [0, -1, 0],
21
+ [0, -1, 0],
22
+ [0, 0, 1],
23
+ [0, 0, 1],
24
+ [-1, 0, 0],
25
+ [0, -1, 0],
26
+ [0, -1, 0],
27
+ [0, 0, 1],
28
+ [0, 0, 1]
29
+ ]
30
+ )
31
+
32
+ t2m_raw_offsets = np.array([[0, 0, 0],
33
+ [1, 0, 0],
34
+ [-1, 0, 0],
35
+ [0, 1, 0],
36
+ [0, -1, 0],
37
+ [0, -1, 0],
38
+ [0, 1, 0],
39
+ [0, -1, 0],
40
+ [0, -1, 0],
41
+ [0, 1, 0],
42
+ [0, 0, 1],
43
+ [0, 0, 1],
44
+ [0, 1, 0],
45
+ [1, 0, 0],
46
+ [-1, 0, 0],
47
+ [0, 0, 1],
48
+ [0, -1, 0],
49
+ [0, -1, 0],
50
+ [0, -1, 0],
51
+ [0, -1, 0],
52
+ [0, -1, 0],
53
+ [0, -1, 0]])
54
+
55
+ t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21],
56
+ [9, 13, 16, 18, 20]]
57
+ t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
58
+ t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
59
+
60
+ kit_tgt_skel_id = '03950'
61
+
62
+ t2m_tgt_skel_id = '000021'
mld/data/humanml/utils/plot_script.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from textwrap import wrap
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+
6
+ import matplotlib.pyplot as plt
7
+ import mpl_toolkits.mplot3d.axes3d as p3
8
+ from matplotlib.animation import FuncAnimation
9
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
10
+
11
+ import mld.data.humanml.utils.paramUtil as paramUtil
12
+
13
+ skeleton = paramUtil.t2m_kinematic_chain
14
+
15
+
16
+ def plot_3d_motion(save_path: str, joints: np.ndarray, title: str,
17
+ figsize: tuple[int, int] = (3, 3),
18
+ fps: int = 120, radius: int = 3, kinematic_tree: list = skeleton,
19
+ hint: Optional[np.ndarray] = None) -> None:
20
+
21
+ title = '\n'.join(wrap(title, 20))
22
+
23
+ def init():
24
+ ax.set_xlim3d([-radius / 2, radius / 2])
25
+ ax.set_ylim3d([0, radius])
26
+ ax.set_zlim3d([-radius / 3., radius * 2 / 3.])
27
+ fig.suptitle(title, fontsize=10)
28
+ ax.grid(b=False)
29
+
30
+ def plot_xzPlane(minx, maxx, miny, minz, maxz):
31
+ # Plot a plane XZ
32
+ verts = [
33
+ [minx, miny, minz],
34
+ [minx, miny, maxz],
35
+ [maxx, miny, maxz],
36
+ [maxx, miny, minz]
37
+ ]
38
+ xz_plane = Poly3DCollection([verts])
39
+ xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
40
+ ax.add_collection3d(xz_plane)
41
+
42
+ # (seq_len, joints_num, 3)
43
+ data = joints.copy().reshape(len(joints), -1, 3)
44
+
45
+ data *= 1.3 # scale for visualization
46
+ if hint is not None:
47
+ mask = hint.sum(-1) != 0
48
+ hint = hint[mask]
49
+ hint *= 1.3
50
+
51
+ fig = plt.figure(figsize=figsize)
52
+ plt.tight_layout()
53
+ ax = p3.Axes3D(fig)
54
+ init()
55
+ MINS = data.min(axis=0).min(axis=0)
56
+ MAXS = data.max(axis=0).max(axis=0)
57
+ colors = ["#DD5A37", "#D69E00", "#B75A39", "#DD5A37", "#D69E00",
58
+ "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00",
59
+ "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", ]
60
+
61
+ frame_number = data.shape[0]
62
+
63
+ height_offset = MINS[1]
64
+ data[:, :, 1] -= height_offset
65
+ if hint is not None:
66
+ hint[..., 1] -= height_offset
67
+ trajec = data[:, 0, [0, 2]]
68
+
69
+ data[..., 0] -= data[:, 0:1, 0]
70
+ data[..., 2] -= data[:, 0:1, 2]
71
+
72
+ def update(index):
73
+ ax.lines = []
74
+ ax.collections = []
75
+ ax.view_init(elev=120, azim=-90)
76
+ ax.dist = 7.5
77
+ plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1],
78
+ MAXS[2] - trajec[index, 1])
79
+
80
+ if hint is not None:
81
+ ax.scatter(hint[..., 0] - trajec[index, 0], hint[..., 1], hint[..., 2] - trajec[index, 1], color="#80B79A")
82
+
83
+ for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
84
+ if i < 5:
85
+ linewidth = 4.0
86
+ else:
87
+ linewidth = 2.0
88
+ ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth,
89
+ color=color)
90
+
91
+ plt.axis('off')
92
+ ax.set_xticklabels([])
93
+ ax.set_yticklabels([])
94
+ ax.set_zticklabels([])
95
+
96
+ ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False)
97
+ ani.save(save_path, fps=fps)
98
+ plt.close()
mld/data/humanml/utils/word_vectorizer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from os.path import join as pjoin
3
+
4
+ import numpy as np
5
+
6
+
7
+ POS_enumerator = {
8
+ 'VERB': 0,
9
+ 'NOUN': 1,
10
+ 'DET': 2,
11
+ 'ADP': 3,
12
+ 'NUM': 4,
13
+ 'AUX': 5,
14
+ 'PRON': 6,
15
+ 'ADJ': 7,
16
+ 'ADV': 8,
17
+ 'Loc_VIP': 9,
18
+ 'Body_VIP': 10,
19
+ 'Obj_VIP': 11,
20
+ 'Act_VIP': 12,
21
+ 'Desc_VIP': 13,
22
+ 'OTHER': 14,
23
+ }
24
+
25
+ Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
26
+ 'up', 'down', 'straight', 'curve')
27
+
28
+ Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
29
+
30
+ Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
31
+
32
+ Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
33
+ 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
34
+ 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
35
+
36
+ Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
37
+ 'angrily', 'sadly')
38
+
39
+ VIP_dict = {
40
+ 'Loc_VIP': Loc_list,
41
+ 'Body_VIP': Body_list,
42
+ 'Obj_VIP': Obj_List,
43
+ 'Act_VIP': Act_list,
44
+ 'Desc_VIP': Desc_list,
45
+ }
46
+
47
+
48
+ class WordVectorizer(object):
49
+ def __init__(self, meta_root: str, prefix: str) -> None:
50
+ vectors = np.load(pjoin(meta_root, '%s_data.npy' % prefix))
51
+ words = pickle.load(open(pjoin(meta_root, '%s_words.pkl' % prefix), 'rb'))
52
+ word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl' % prefix), 'rb'))
53
+ self.word2vec = {w: vectors[word2idx[w]] for w in words}
54
+
55
+ def _get_pos_ohot(self, pos: str) -> np.ndarray:
56
+ pos_vec = np.zeros(len(POS_enumerator))
57
+ if pos in POS_enumerator:
58
+ pos_vec[POS_enumerator[pos]] = 1
59
+ else:
60
+ pos_vec[POS_enumerator['OTHER']] = 1
61
+ return pos_vec
62
+
63
+ def __len__(self) -> int:
64
+ return len(self.word2vec)
65
+
66
+ def __getitem__(self, item: str) -> tuple:
67
+ word, pos = item.split('/')
68
+ if word in self.word2vec:
69
+ word_vec = self.word2vec[word]
70
+ vip_pos = None
71
+ for key, values in VIP_dict.items():
72
+ if word in values:
73
+ vip_pos = key
74
+ break
75
+ if vip_pos is not None:
76
+ pos_vec = self._get_pos_ohot(vip_pos)
77
+ else:
78
+ pos_vec = self._get_pos_ohot(pos)
79
+ else:
80
+ word_vec = self.word2vec['unk']
81
+ pos_vec = self._get_pos_ohot('OTHER')
82
+ return word_vec, pos_vec
mld/data/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def collate_tensors(batch: list) -> torch.Tensor:
5
+ dims = batch[0].dim()
6
+ max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
7
+ size = (len(batch), ) + tuple(max_size)
8
+ canvas = batch[0].new_zeros(size=size)
9
+ for i, b in enumerate(batch):
10
+ sub_tensor = canvas[i]
11
+ for d in range(dims):
12
+ sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
13
+ sub_tensor.add_(b)
14
+ return canvas
15
+
16
+
17
+ def mld_collate(batch: list) -> dict:
18
+ notnone_batches = [b for b in batch if b is not None]
19
+ notnone_batches.sort(key=lambda x: x[3], reverse=True)
20
+ adapted_batch = {
21
+ "motion":
22
+ collate_tensors([torch.tensor(b[4]).float() for b in notnone_batches]),
23
+ "text": [b[2] for b in notnone_batches],
24
+ "length": [b[5] for b in notnone_batches],
25
+ "word_embs":
26
+ collate_tensors([torch.tensor(b[0]).float() for b in notnone_batches]),
27
+ "pos_ohot":
28
+ collate_tensors([torch.tensor(b[1]).float() for b in notnone_batches]),
29
+ "text_len":
30
+ collate_tensors([torch.tensor(b[3]) for b in notnone_batches]),
31
+ "tokens": [b[6] for b in notnone_batches],
32
+ }
33
+
34
+ # collate trajectory
35
+ if notnone_batches[0][-1] is not None:
36
+ adapted_batch['hint'] = collate_tensors([torch.tensor(b[-1]).float() for b in notnone_batches])
37
+
38
+ return adapted_batch
mld/launch/__init__.py ADDED
File without changes
mld/launch/blender.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fix blender path
2
+ import os
3
+ import sys
4
+ from argparse import ArgumentParser
5
+
6
+ sys.path.append(os.path.expanduser("~/.local/lib/python3.9/site-packages"))
7
+
8
+
9
+ # Monkey patch argparse such that
10
+ # blender / python parsing works
11
+ def parse_args(self, args=None, namespace=None):
12
+ if args is not None:
13
+ return self.parse_args_bak(args=args, namespace=namespace)
14
+ try:
15
+ idx = sys.argv.index("--")
16
+ args = sys.argv[idx + 1:] # the list after '--'
17
+ except ValueError as e: # '--' not in the list:
18
+ args = []
19
+ return self.parse_args_bak(args=args, namespace=namespace)
20
+
21
+
22
+ setattr(ArgumentParser, 'parse_args_bak', ArgumentParser.parse_args)
23
+ setattr(ArgumentParser, 'parse_args', parse_args)
mld/models/__init__.py ADDED
File without changes
mld/models/architectures/__init__.py ADDED
File without changes
mld/models/architectures/mld_clip.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import AutoModel, AutoTokenizer
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+
8
+ class MldTextEncoder(nn.Module):
9
+
10
+ def __init__(self, modelpath: str, last_hidden_state: bool = False) -> None:
11
+ super().__init__()
12
+
13
+ if 't5' in modelpath:
14
+ self.text_model = SentenceTransformer(modelpath)
15
+ self.tokenizer = self.text_model.tokenizer
16
+ else:
17
+ self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
18
+ self.text_model = AutoModel.from_pretrained(modelpath)
19
+
20
+ self.max_length = self.tokenizer.model_max_length
21
+ if "clip" in modelpath:
22
+ self.text_encoded_dim = self.text_model.config.text_config.hidden_size
23
+ if last_hidden_state:
24
+ self.name = "clip_hidden"
25
+ else:
26
+ self.name = "clip"
27
+ elif "bert" in modelpath:
28
+ self.name = "bert"
29
+ self.text_encoded_dim = self.text_model.config.hidden_size
30
+ elif 't5' in modelpath:
31
+ self.name = 't5'
32
+ else:
33
+ raise ValueError(f"Model {modelpath} not supported")
34
+
35
+ def forward(self, texts: list[str]) -> torch.Tensor:
36
+ # get prompt text embeddings
37
+ if self.name in ["clip", "clip_hidden"]:
38
+ text_inputs = self.tokenizer(
39
+ texts,
40
+ padding="max_length",
41
+ truncation=True,
42
+ max_length=self.max_length,
43
+ return_tensors="pt",
44
+ )
45
+ text_input_ids = text_inputs.input_ids
46
+ # split into max length Clip can handle
47
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
48
+ text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length]
49
+ elif self.name == "bert":
50
+ text_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
51
+
52
+ if self.name == "clip":
53
+ # (batch_Size, text_encoded_dim)
54
+ text_embeddings = self.text_model.get_text_features(
55
+ text_input_ids.to(self.text_model.device))
56
+ # (batch_Size, 1, text_encoded_dim)
57
+ text_embeddings = text_embeddings.unsqueeze(1)
58
+ elif self.name == "clip_hidden":
59
+ # (batch_Size, seq_length , text_encoded_dim)
60
+ text_embeddings = self.text_model.text_model(
61
+ text_input_ids.to(self.text_model.device)).last_hidden_state
62
+ elif self.name == "bert":
63
+ # (batch_Size, seq_length , text_encoded_dim)
64
+ text_embeddings = self.text_model(
65
+ **text_inputs.to(self.text_model.device)).last_hidden_state
66
+ elif self.name == 't5':
67
+ text_embeddings = self.text_model.encode(texts, show_progress_bar=False, convert_to_tensor=True, batch_size=len(texts))
68
+ text_embeddings = text_embeddings.unsqueeze(1)
69
+ else:
70
+ raise NotImplementedError(f"Model {self.name} not implemented")
71
+
72
+ return text_embeddings
mld/models/architectures/mld_denoiser.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from mld.models.architectures.tools.embeddings import (TimestepEmbedding,
7
+ Timesteps)
8
+ from mld.models.operator.cross_attention import (SkipTransformerEncoder,
9
+ TransformerDecoder,
10
+ TransformerDecoderLayer,
11
+ TransformerEncoder,
12
+ TransformerEncoderLayer)
13
+ from mld.models.operator.position_encoding import build_position_encoding
14
+
15
+
16
+ class MldDenoiser(nn.Module):
17
+
18
+ def __init__(self,
19
+ latent_dim: list = [1, 256],
20
+ ff_size: int = 1024,
21
+ num_layers: int = 6,
22
+ num_heads: int = 4,
23
+ dropout: float = 0.1,
24
+ normalize_before: bool = False,
25
+ activation: str = "gelu",
26
+ flip_sin_to_cos: bool = True,
27
+ return_intermediate_dec: bool = False,
28
+ position_embedding: str = "learned",
29
+ arch: str = "trans_enc",
30
+ freq_shift: float = 0,
31
+ text_encoded_dim: int = 768,
32
+ time_cond_proj_dim: int = None,
33
+ is_controlnet: bool = False) -> None:
34
+
35
+ super().__init__()
36
+
37
+ self.latent_dim = latent_dim[-1]
38
+ self.text_encoded_dim = text_encoded_dim
39
+
40
+ self.arch = arch
41
+ self.time_cond_proj_dim = time_cond_proj_dim
42
+
43
+ self.time_proj = Timesteps(text_encoded_dim, flip_sin_to_cos, freq_shift)
44
+ self.time_embedding = TimestepEmbedding(text_encoded_dim, self.latent_dim, cond_proj_dim=time_cond_proj_dim)
45
+ if text_encoded_dim != self.latent_dim:
46
+ self.emb_proj = nn.Sequential(nn.ReLU(), nn.Linear(text_encoded_dim, self.latent_dim))
47
+
48
+ self.query_pos = build_position_encoding(
49
+ self.latent_dim, position_embedding=position_embedding)
50
+
51
+ if self.arch == "trans_enc":
52
+ encoder_layer = TransformerEncoderLayer(
53
+ self.latent_dim,
54
+ num_heads,
55
+ ff_size,
56
+ dropout,
57
+ activation,
58
+ normalize_before,
59
+ )
60
+ encoder_norm = None if is_controlnet else nn.LayerNorm(self.latent_dim)
61
+ self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm,
62
+ return_intermediate=is_controlnet)
63
+
64
+ elif self.arch == "trans_dec":
65
+ assert not is_controlnet, f"controlnet not supported in architecture: 'trans_dec'"
66
+ self.mem_pos = build_position_encoding(
67
+ self.latent_dim, position_embedding=position_embedding)
68
+
69
+ decoder_layer = TransformerDecoderLayer(
70
+ self.latent_dim,
71
+ num_heads,
72
+ ff_size,
73
+ dropout,
74
+ activation,
75
+ normalize_before,
76
+ )
77
+ decoder_norm = nn.LayerNorm(self.latent_dim)
78
+ self.decoder = TransformerDecoder(
79
+ decoder_layer,
80
+ num_layers,
81
+ decoder_norm,
82
+ return_intermediate=return_intermediate_dec,
83
+ )
84
+
85
+ else:
86
+ raise ValueError(f"Not supported architecture: {self.arch}!")
87
+
88
+ self.is_controlnet = is_controlnet
89
+
90
+ def zero_module(module):
91
+ for p in module.parameters():
92
+ nn.init.zeros_(p)
93
+ return module
94
+
95
+ if self.is_controlnet:
96
+ self.controlnet_cond_embedding = nn.Sequential(
97
+ nn.Linear(self.latent_dim, self.latent_dim),
98
+ nn.Linear(self.latent_dim, self.latent_dim),
99
+ zero_module(nn.Linear(self.latent_dim, self.latent_dim))
100
+ )
101
+
102
+ self.controlnet_down_mid_blocks = nn.ModuleList([
103
+ zero_module(nn.Linear(self.latent_dim, self.latent_dim)) for _ in range(num_layers)])
104
+
105
+ def forward(self,
106
+ sample: torch.Tensor,
107
+ timestep: torch.Tensor,
108
+ encoder_hidden_states: torch.Tensor,
109
+ timestep_cond: Optional[torch.Tensor] = None,
110
+ controlnet_cond: Optional[torch.Tensor] = None,
111
+ controlnet_residuals: Optional[list[torch.Tensor]] = None
112
+ ) -> Union[torch.Tensor, list[torch.Tensor]]:
113
+
114
+ # 0. dimension matching
115
+ # sample [latent_dim[0], batch_size, latent_dim] <= [batch_size, latent_dim[0], latent_dim[1]]
116
+ sample = sample.permute(1, 0, 2)
117
+
118
+ # 1. check if controlnet
119
+ if self.is_controlnet:
120
+ controlnet_cond = controlnet_cond.permute(1, 0, 2)
121
+ sample = sample + self.controlnet_cond_embedding(controlnet_cond)
122
+
123
+ # 2. time_embedding
124
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
125
+ timesteps = timestep.expand(sample.shape[1]).clone()
126
+ time_emb = self.time_proj(timesteps)
127
+ time_emb = time_emb.to(dtype=sample.dtype)
128
+ # [1, bs, latent_dim] <= [bs, latent_dim]
129
+ time_emb = self.time_embedding(time_emb, timestep_cond).unsqueeze(0)
130
+
131
+ # 3. condition + time embedding
132
+ # text_emb [seq_len, batch_size, text_encoded_dim] <= [batch_size, seq_len, text_encoded_dim]
133
+ encoder_hidden_states = encoder_hidden_states.permute(1, 0, 2)
134
+ text_emb = encoder_hidden_states # [num_words, bs, latent_dim]
135
+ # text embedding projection
136
+ if self.text_encoded_dim != self.latent_dim:
137
+ # [1 or 2, bs, latent_dim] <= [1 or 2, bs, text_encoded_dim]
138
+ text_emb_latent = self.emb_proj(text_emb)
139
+ else:
140
+ text_emb_latent = text_emb
141
+ emb_latent = torch.cat((time_emb, text_emb_latent), 0)
142
+
143
+ # 4. transformer
144
+ if self.arch == "trans_enc":
145
+ xseq = torch.cat((sample, emb_latent), axis=0)
146
+
147
+ xseq = self.query_pos(xseq)
148
+ tokens = self.encoder(xseq, controlnet_residuals=controlnet_residuals)
149
+
150
+ if self.is_controlnet:
151
+ control_res_samples = []
152
+ for res, block in zip(tokens, self.controlnet_down_mid_blocks):
153
+ r = block(res)
154
+ control_res_samples.append(r)
155
+ return control_res_samples
156
+
157
+ sample = tokens[:sample.shape[0]]
158
+
159
+ elif self.arch == "trans_dec":
160
+ # tgt - [1 or 5 or 10, bs, latent_dim]
161
+ # memory - [token_num, bs, latent_dim]
162
+ sample = self.query_pos(sample)
163
+ emb_latent = self.mem_pos(emb_latent)
164
+ sample = self.decoder(tgt=sample, memory=emb_latent).squeeze(0)
165
+
166
+ else:
167
+ raise TypeError(f"{self.arch} is not supported")
168
+
169
+ # 5. [batch_size, latent_dim[0], latent_dim[1]] <= [latent_dim[0], batch_size, latent_dim[1]]
170
+ sample = sample.permute(1, 0, 2)
171
+
172
+ return sample
mld/models/architectures/mld_traj_encoder.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from mld.models.operator.cross_attention import SkipTransformerEncoder, TransformerEncoderLayer
7
+ from mld.models.operator.position_encoding import build_position_encoding
8
+ from mld.utils.temos_utils import lengths_to_mask
9
+
10
+
11
+ class MldTrajEncoder(nn.Module):
12
+
13
+ def __init__(self,
14
+ nfeats: int,
15
+ latent_dim: list = [1, 256],
16
+ ff_size: int = 1024,
17
+ num_layers: int = 9,
18
+ num_heads: int = 4,
19
+ dropout: float = 0.1,
20
+ normalize_before: bool = False,
21
+ activation: str = "gelu",
22
+ position_embedding: str = "learned") -> None:
23
+
24
+ super().__init__()
25
+ self.latent_size = latent_dim[0]
26
+ self.latent_dim = latent_dim[-1]
27
+
28
+ self.skel_embedding = nn.Linear(nfeats * 3, self.latent_dim)
29
+
30
+ self.query_pos_encoder = build_position_encoding(
31
+ self.latent_dim, position_embedding=position_embedding)
32
+
33
+ encoder_layer = TransformerEncoderLayer(
34
+ self.latent_dim,
35
+ num_heads,
36
+ ff_size,
37
+ dropout,
38
+ activation,
39
+ normalize_before,
40
+ )
41
+ encoder_norm = nn.LayerNorm(self.latent_dim)
42
+ self.encoder = SkipTransformerEncoder(encoder_layer, num_layers,
43
+ encoder_norm)
44
+
45
+ self.global_motion_token = nn.Parameter(
46
+ torch.randn(self.latent_size, self.latent_dim))
47
+
48
+ def forward(self, features: torch.Tensor, lengths: Optional[list[int]] = None,
49
+ mask: Optional[torch.Tensor] = None) -> torch.Tensor:
50
+
51
+ if lengths is None and mask is None:
52
+ lengths = [len(feature) for feature in features]
53
+ mask = lengths_to_mask(lengths, features.device)
54
+
55
+ bs, nframes, nfeats = features.shape
56
+
57
+ x = features
58
+ # Embed each human poses into latent vectors
59
+ x = self.skel_embedding(x)
60
+
61
+ # Switch sequence and batch_size because the input of
62
+ # Pytorch Transformer is [Sequence, Batch size, ...]
63
+ x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
64
+
65
+ # Each batch has its own set of tokens
66
+ dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1))
67
+
68
+ # create a bigger mask, to allow attend to emb
69
+ dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device)
70
+ aug_mask = torch.cat((dist_masks, mask), 1)
71
+
72
+ # adding the embedding token for all sequences
73
+ xseq = torch.cat((dist, x), 0)
74
+
75
+ xseq = self.query_pos_encoder(xseq)
76
+ global_token = self.encoder(xseq, src_key_padding_mask=~aug_mask)[:dist.shape[0]]
77
+
78
+ return global_token
mld/models/architectures/mld_vae.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.distributions.distribution import Distribution
6
+
7
+ from mld.models.operator.cross_attention import (
8
+ SkipTransformerEncoder,
9
+ SkipTransformerDecoder,
10
+ TransformerDecoder,
11
+ TransformerDecoderLayer,
12
+ TransformerEncoder,
13
+ TransformerEncoderLayer,
14
+ )
15
+ from mld.models.operator.position_encoding import build_position_encoding
16
+ from mld.utils.temos_utils import lengths_to_mask
17
+
18
+
19
+ class MldVae(nn.Module):
20
+
21
+ def __init__(self,
22
+ nfeats: int,
23
+ latent_dim: list = [1, 256],
24
+ ff_size: int = 1024,
25
+ num_layers: int = 9,
26
+ num_heads: int = 4,
27
+ dropout: float = 0.1,
28
+ arch: str = "encoder_decoder",
29
+ normalize_before: bool = False,
30
+ activation: str = "gelu",
31
+ position_embedding: str = "learned") -> None:
32
+
33
+ super().__init__()
34
+
35
+ self.latent_size = latent_dim[0]
36
+ self.latent_dim = latent_dim[-1]
37
+ input_feats = nfeats
38
+ output_feats = nfeats
39
+ self.arch = arch
40
+
41
+ self.query_pos_encoder = build_position_encoding(
42
+ self.latent_dim, position_embedding=position_embedding)
43
+
44
+ encoder_layer = TransformerEncoderLayer(
45
+ self.latent_dim,
46
+ num_heads,
47
+ ff_size,
48
+ dropout,
49
+ activation,
50
+ normalize_before,
51
+ )
52
+ encoder_norm = nn.LayerNorm(self.latent_dim)
53
+ self.encoder = SkipTransformerEncoder(encoder_layer, num_layers,
54
+ encoder_norm)
55
+
56
+ if self.arch == "all_encoder":
57
+ decoder_norm = nn.LayerNorm(self.latent_dim)
58
+ self.decoder = SkipTransformerEncoder(encoder_layer, num_layers,
59
+ decoder_norm)
60
+ elif self.arch == 'encoder_decoder':
61
+ self.query_pos_decoder = build_position_encoding(
62
+ self.latent_dim, position_embedding=position_embedding)
63
+
64
+ decoder_layer = TransformerDecoderLayer(
65
+ self.latent_dim,
66
+ num_heads,
67
+ ff_size,
68
+ dropout,
69
+ activation,
70
+ normalize_before,
71
+ )
72
+ decoder_norm = nn.LayerNorm(self.latent_dim)
73
+ self.decoder = SkipTransformerDecoder(decoder_layer, num_layers,
74
+ decoder_norm)
75
+ else:
76
+ raise ValueError(f"Not support architecture: {self.arch}!")
77
+
78
+ self.global_motion_token = nn.Parameter(
79
+ torch.randn(self.latent_size * 2, self.latent_dim))
80
+
81
+ self.skel_embedding = nn.Linear(input_feats, self.latent_dim)
82
+ self.final_layer = nn.Linear(self.latent_dim, output_feats)
83
+
84
+ def forward(self, features: torch.Tensor,
85
+ lengths: Optional[list[int]] = None) -> tuple[torch.Tensor, torch.Tensor, Distribution]:
86
+ z, dist = self.encode(features, lengths)
87
+ feats_rst = self.decode(z, lengths)
88
+ return feats_rst, z, dist
89
+
90
+ def encode(self, features: torch.Tensor,
91
+ lengths: Optional[list[int]] = None) -> tuple[torch.Tensor, Distribution]:
92
+ if lengths is None:
93
+ lengths = [len(feature) for feature in features]
94
+
95
+ device = features.device
96
+
97
+ bs, nframes, nfeats = features.shape
98
+ mask = lengths_to_mask(lengths, device)
99
+
100
+ x = features
101
+ # Embed each human poses into latent vectors
102
+ x = self.skel_embedding(x)
103
+
104
+ # Switch sequence and batch_size because the input of
105
+ # Pytorch Transformer is [Sequence, Batch size, ...]
106
+ x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
107
+
108
+ # Each batch has its own set of tokens
109
+ dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1))
110
+
111
+ # create a bigger mask, to allow attend to emb
112
+ dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device)
113
+ aug_mask = torch.cat((dist_masks, mask), 1)
114
+
115
+ # adding the embedding token for all sequences
116
+ xseq = torch.cat((dist, x), 0)
117
+
118
+ xseq = self.query_pos_encoder(xseq)
119
+ dist = self.encoder(xseq, src_key_padding_mask=~aug_mask)[:dist.shape[0]]
120
+
121
+ mu = dist[0:self.latent_size, ...]
122
+ logvar = dist[self.latent_size:, ...]
123
+
124
+ # resampling
125
+ std = logvar.exp().pow(0.5)
126
+ dist = torch.distributions.Normal(mu, std)
127
+ latent = dist.rsample()
128
+ return latent, dist
129
+
130
+ def decode(self, z: torch.Tensor, lengths: list[int]) -> torch.Tensor:
131
+ mask = lengths_to_mask(lengths, z.device)
132
+ bs, nframes = mask.shape
133
+ queries = torch.zeros(nframes, bs, self.latent_dim, device=z.device)
134
+
135
+ if self.arch == "all_encoder":
136
+ xseq = torch.cat((z, queries), axis=0)
137
+ z_mask = torch.ones((bs, self.latent_size), dtype=torch.bool, device=z.device)
138
+ aug_mask = torch.cat((z_mask, mask), axis=1)
139
+ xseq = self.query_pos_decoder(xseq)
140
+ output = self.decoder(xseq, src_key_padding_mask=~aug_mask)[z.shape[0]:]
141
+
142
+ elif self.arch == "encoder_decoder":
143
+ queries = self.query_pos_decoder(queries)
144
+ output = self.decoder(
145
+ tgt=queries,
146
+ memory=z,
147
+ tgt_key_padding_mask=~mask)
148
+
149
+ output = self.final_layer(output)
150
+ # zero for padded area
151
+ output[~mask.T] = 0
152
+ # Pytorch Transformer: [Sequence, Batch size, ...]
153
+ feats = output.permute(1, 0, 2)
154
+ return feats
mld/models/architectures/t2m_motionenc.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils.rnn import pack_padded_sequence
4
+
5
+
6
+ class MovementConvEncoder(nn.Module):
7
+ def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None:
8
+ super(MovementConvEncoder, self).__init__()
9
+ self.main = nn.Sequential(
10
+ nn.Conv1d(input_size, hidden_size, 4, 2, 1),
11
+ nn.Dropout(0.2, inplace=True),
12
+ nn.LeakyReLU(0.2, inplace=True),
13
+ nn.Conv1d(hidden_size, output_size, 4, 2, 1),
14
+ nn.Dropout(0.2, inplace=True),
15
+ nn.LeakyReLU(0.2, inplace=True),
16
+ )
17
+ self.out_net = nn.Linear(output_size, output_size)
18
+
19
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
20
+ inputs = inputs.permute(0, 2, 1)
21
+ outputs = self.main(inputs).permute(0, 2, 1)
22
+ return self.out_net(outputs)
23
+
24
+
25
+ class MotionEncoderBiGRUCo(nn.Module):
26
+ def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None:
27
+ super(MotionEncoderBiGRUCo, self).__init__()
28
+
29
+ self.input_emb = nn.Linear(input_size, hidden_size)
30
+ self.gru = nn.GRU(
31
+ hidden_size, hidden_size, batch_first=True, bidirectional=True
32
+ )
33
+ self.output_net = nn.Sequential(
34
+ nn.Linear(hidden_size * 2, hidden_size),
35
+ nn.LayerNorm(hidden_size),
36
+ nn.LeakyReLU(0.2, inplace=True),
37
+ nn.Linear(hidden_size, output_size),
38
+ )
39
+
40
+ self.hidden_size = hidden_size
41
+ self.hidden = nn.Parameter(
42
+ torch.randn((2, 1, self.hidden_size), requires_grad=True)
43
+ )
44
+
45
+ def forward(self, inputs: torch.Tensor, m_lens: torch.Tensor) -> torch.Tensor:
46
+ num_samples = inputs.shape[0]
47
+
48
+ input_embs = self.input_emb(inputs)
49
+ hidden = self.hidden.repeat(1, num_samples, 1)
50
+
51
+ cap_lens = m_lens.data.tolist()
52
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
53
+
54
+ gru_seq, gru_last = self.gru(emb, hidden)
55
+
56
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
57
+
58
+ return self.output_net(gru_last)
mld/models/architectures/t2m_textenc.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils.rnn import pack_padded_sequence
4
+
5
+
6
+ class TextEncoderBiGRUCo(nn.Module):
7
+ def __init__(self, word_size: int, pos_size: int, hidden_size: int, output_size: int) -> None:
8
+ super(TextEncoderBiGRUCo, self).__init__()
9
+
10
+ self.pos_emb = nn.Linear(pos_size, word_size)
11
+ self.input_emb = nn.Linear(word_size, hidden_size)
12
+ self.gru = nn.GRU(
13
+ hidden_size, hidden_size, batch_first=True, bidirectional=True
14
+ )
15
+ self.output_net = nn.Sequential(
16
+ nn.Linear(hidden_size * 2, hidden_size),
17
+ nn.LayerNorm(hidden_size),
18
+ nn.LeakyReLU(0.2, inplace=True),
19
+ nn.Linear(hidden_size, output_size),
20
+ )
21
+
22
+ self.hidden_size = hidden_size
23
+ self.hidden = nn.Parameter(
24
+ torch.randn((2, 1, self.hidden_size), requires_grad=True)
25
+ )
26
+
27
+ def forward(self, word_embs: torch.Tensor, pos_onehot: torch.Tensor,
28
+ cap_lens: torch.Tensor) -> torch.Tensor:
29
+ num_samples = word_embs.shape[0]
30
+
31
+ pos_embs = self.pos_emb(pos_onehot)
32
+ inputs = word_embs + pos_embs
33
+ input_embs = self.input_emb(inputs)
34
+ hidden = self.hidden.repeat(1, num_samples, 1)
35
+
36
+ cap_lens = cap_lens.data.tolist()
37
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
38
+
39
+ gru_seq, gru_last = self.gru(emb, hidden)
40
+
41
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
42
+
43
+ return self.output_net(gru_last)
mld/models/architectures/tools/embeddings.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ def get_timestep_embedding(
9
+ timesteps: torch.Tensor,
10
+ embedding_dim: int,
11
+ flip_sin_to_cos: bool = False,
12
+ downscale_freq_shift: float = 1,
13
+ scale: float = 1,
14
+ max_period: int = 10000,
15
+ ) -> torch.Tensor:
16
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
17
+
18
+ half_dim = embedding_dim // 2
19
+ exponent = -math.log(max_period) * torch.arange(
20
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
21
+ )
22
+ exponent = exponent / (half_dim - downscale_freq_shift)
23
+
24
+ emb = torch.exp(exponent)
25
+ emb = timesteps[:, None].float() * emb[None, :]
26
+
27
+ # scale embeddings
28
+ emb = scale * emb
29
+
30
+ # concat sine and cosine embeddings
31
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
32
+
33
+ # flip sine and cosine embeddings
34
+ if flip_sin_to_cos:
35
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
36
+
37
+ # zero pad
38
+ if embedding_dim % 2 == 1:
39
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
40
+ return emb
41
+
42
+
43
+ class TimestepEmbedding(nn.Module):
44
+ def __init__(self, channel: int, time_embed_dim: int,
45
+ act_fn: str = "silu", cond_proj_dim: Optional[int] = None) -> None:
46
+ super().__init__()
47
+
48
+ # distill CFG
49
+ if cond_proj_dim is not None:
50
+ self.cond_proj = nn.Linear(cond_proj_dim, channel, bias=False)
51
+ self.cond_proj.weight.data.fill_(0.0)
52
+ else:
53
+ self.cond_proj = None
54
+
55
+ self.linear_1 = nn.Linear(channel, time_embed_dim)
56
+ self.act = None
57
+ if act_fn == "silu":
58
+ self.act = nn.SiLU()
59
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
60
+
61
+ def forward(self, sample: torch.Tensor, timestep_cond: Optional[torch.Tensor] = None) -> torch.Tensor:
62
+ if timestep_cond is not None:
63
+ sample = sample + self.cond_proj(timestep_cond)
64
+
65
+ sample = self.linear_1(sample)
66
+
67
+ if self.act is not None:
68
+ sample = self.act(sample)
69
+
70
+ sample = self.linear_2(sample)
71
+ return sample
72
+
73
+
74
+ class Timesteps(nn.Module):
75
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool,
76
+ downscale_freq_shift: float) -> None:
77
+ super().__init__()
78
+ self.num_channels = num_channels
79
+ self.flip_sin_to_cos = flip_sin_to_cos
80
+ self.downscale_freq_shift = downscale_freq_shift
81
+
82
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
83
+ t_emb = get_timestep_embedding(
84
+ timesteps,
85
+ self.num_channels,
86
+ flip_sin_to_cos=self.flip_sin_to_cos,
87
+ downscale_freq_shift=self.downscale_freq_shift,
88
+ )
89
+ return t_emb
mld/models/metrics/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .tm2t import TM2TMetrics
2
+ from .mm import MMMetrics
3
+ from .cm import ControlMetrics
mld/models/metrics/cm.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchmetrics import Metric
3
+ from torchmetrics.utilities import dim_zero_cat
4
+
5
+ from mld.utils.temos_utils import remove_padding
6
+ from .utils import calculate_skating_ratio, calculate_trajectory_error, control_l2
7
+
8
+
9
+ class ControlMetrics(Metric):
10
+
11
+ def __init__(self, dist_sync_on_step: bool = True) -> None:
12
+ super().__init__(dist_sync_on_step=dist_sync_on_step)
13
+
14
+ self.name = "control_metrics"
15
+
16
+ self.add_state("count_seq", default=torch.tensor(0), dist_reduce_fx="sum")
17
+ self.add_state("skate_ratio_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
18
+ self.add_state("dist_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
19
+ self.add_state("traj_err", default=[], dist_reduce_fx="cat")
20
+ self.traj_err_key = ["traj_fail_20cm", "traj_fail_50cm", "kps_fail_20cm", "kps_fail_50cm", "kps_mean_err(m)"]
21
+
22
+ def compute(self) -> dict:
23
+ count_seq = self.count_seq.item()
24
+
25
+ metrics = dict()
26
+ metrics['Skating Ratio'] = self.skate_ratio_sum / count_seq
27
+ metrics['Control L2 dist'] = self.dist_sum / count_seq
28
+ traj_err = dim_zero_cat(self.traj_err).mean(0)
29
+
30
+ for (k, v) in zip(self.traj_err_key, traj_err):
31
+ metrics[k] = v
32
+
33
+ return {**metrics}
34
+
35
+ def update(self, joints: torch.Tensor, hint: torch.Tensor,
36
+ mask_hint: torch.Tensor, lengths: list[int]) -> None:
37
+ self.count_seq += len(lengths)
38
+
39
+ joints_no_padding = remove_padding(joints, lengths)
40
+ for j in joints_no_padding:
41
+ skate_ratio, _ = calculate_skating_ratio(j.unsqueeze(0).permute(0, 2, 3, 1))
42
+ self.skate_ratio_sum += skate_ratio[0]
43
+
44
+ joints_np = joints.cpu().numpy()
45
+ hint_np = hint.cpu().numpy()
46
+ mask_hint_np = mask_hint.cpu().numpy()
47
+
48
+ for j, h, m in zip(joints_np, hint_np, mask_hint_np):
49
+ control_error = control_l2(j[None], h[None], m[None])
50
+ mean_error = control_error.sum() / m.sum()
51
+ self.dist_sum += mean_error
52
+ control_error = control_error.reshape(-1)
53
+ m = m.reshape(-1)
54
+ err_np = calculate_trajectory_error(control_error, mean_error, m)
55
+ self.traj_err.append(torch.tensor(err_np[None], device=joints.device))