wxDai
commited on
Commit
•
6b1e9f7
0
Parent(s):
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- .gitignore +10 -0
- LICENSE +25 -0
- README.md +14 -0
- app.py +258 -0
- configs/mld_control.yaml +105 -0
- configs/mld_t2m_infer.yaml +72 -0
- configs/modules/denoiser.yaml +16 -0
- configs/modules/motion_vae.yaml +13 -0
- configs/modules/scheduler.yaml +20 -0
- configs/modules/text_encoder.yaml +5 -0
- configs/modules/traj_encoder.yaml +12 -0
- configs/modules_mld/denoiser.yaml +16 -0
- configs/modules_mld/motion_vae.yaml +13 -0
- configs/modules_mld/scheduler.yaml +23 -0
- configs/modules_mld/text_encoder.yaml +5 -0
- configs/modules_mld/traj_encoder.yaml +12 -0
- configs/motionlcm_control.yaml +105 -0
- configs/motionlcm_t2m.yaml +100 -0
- demo.py +154 -0
- fit.py +134 -0
- mld/__init__.py +0 -0
- mld/config.py +47 -0
- mld/data/HumanML3D.py +79 -0
- mld/data/Kit.py +79 -0
- mld/data/__init__.py +0 -0
- mld/data/base.py +65 -0
- mld/data/get_data.py +93 -0
- mld/data/humanml/__init__.py +0 -0
- mld/data/humanml/common/quaternion.py +29 -0
- mld/data/humanml/dataset.py +290 -0
- mld/data/humanml/scripts/motion_process.py +51 -0
- mld/data/humanml/utils/__init__.py +0 -0
- mld/data/humanml/utils/paramUtil.py +62 -0
- mld/data/humanml/utils/plot_script.py +98 -0
- mld/data/humanml/utils/word_vectorizer.py +82 -0
- mld/data/utils.py +38 -0
- mld/launch/__init__.py +0 -0
- mld/launch/blender.py +23 -0
- mld/models/__init__.py +0 -0
- mld/models/architectures/__init__.py +0 -0
- mld/models/architectures/mld_clip.py +72 -0
- mld/models/architectures/mld_denoiser.py +172 -0
- mld/models/architectures/mld_traj_encoder.py +78 -0
- mld/models/architectures/mld_vae.py +154 -0
- mld/models/architectures/t2m_motionenc.py +58 -0
- mld/models/architectures/t2m_textenc.py +43 -0
- mld/models/architectures/tools/embeddings.py +89 -0
- mld/models/metrics/__init__.py +3 -0
- mld/models/metrics/cm.py +55 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
**/*.pyc
|
2 |
+
.idea/
|
3 |
+
__pycache__/
|
4 |
+
|
5 |
+
deps/
|
6 |
+
datasets/
|
7 |
+
experiments_t2m/
|
8 |
+
experiments_t2m_test/
|
9 |
+
experiments_control/
|
10 |
+
experiments_control_test/
|
LICENSE
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright Tsinghua University and Shanghai AI Laboratory. All Rights Reserved.
|
2 |
+
|
3 |
+
License for Non-commercial Scientific Research Purposes.
|
4 |
+
|
5 |
+
For more information see <https://github.com/Dai-Wenxun/MotionLCM>.
|
6 |
+
If you use this software, please cite the corresponding publications
|
7 |
+
listed on the above website.
|
8 |
+
|
9 |
+
Permission to use, copy, modify, and distribute this software and its
|
10 |
+
documentation for educational, research, and non-profit purposes only.
|
11 |
+
Any modification based on this work must be open-source and prohibited
|
12 |
+
for commercial, pornographic, military, or surveillance use.
|
13 |
+
|
14 |
+
The authors grant you a non-exclusive, worldwide, non-transferable,
|
15 |
+
non-sublicensable, revocable, royalty-free, and limited license under
|
16 |
+
our copyright interests to reproduce, distribute, and create derivative
|
17 |
+
works of the text, videos, and codes solely for your non-commercial
|
18 |
+
research purposes.
|
19 |
+
|
20 |
+
You must retain, in the source form of any derivative works that you
|
21 |
+
distribute, all copyright, patent, trademark, and attribution notices
|
22 |
+
from the source form of this work.
|
23 |
+
|
24 |
+
For commercial uses of this software, please send email to all people
|
25 |
+
in the author list.
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: MotionLCM
|
3 |
+
emoji: 🏃
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: gray
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.24.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: other
|
11 |
+
python_version: 3.10.12
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import random
|
4 |
+
import datetime
|
5 |
+
import os.path as osp
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import gradio as gr
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
|
12 |
+
from mld.config import get_module_config
|
13 |
+
from mld.data.get_data import get_datasets
|
14 |
+
from mld.models.modeltype.mld import MLD
|
15 |
+
from mld.utils.utils import set_seed
|
16 |
+
from mld.data.humanml.utils.plot_script import plot_3d_motion
|
17 |
+
|
18 |
+
WEBSITE = """
|
19 |
+
<div class="embed_hidden">
|
20 |
+
<h1 style='text-align: center'> MotionLCM: Real-time Controllable Motion Generation via Latent Consistency Model </h1>
|
21 |
+
|
22 |
+
<h2 style='text-align: center'>
|
23 |
+
<a href="https://github.com/Dai-Wenxun/" target="_blank"><nobr>Wenxun Dai</nobr><sup>1</sup></a>  
|
24 |
+
<a href="https://lhchen.top/" target="_blank"><nobr>Ling-Hao Chen</nobr></a><sup>1</sup>  
|
25 |
+
<a href="https://wangjingbo1219.github.io/" target="_blank"><nobr>Jingbo Wang</nobr></a><sup>2</sup>  
|
26 |
+
<a href="https://moonsliu.github.io/" target="_blank"><nobr>Jinpeng Liu</nobr></a><sup>1</sup>  
|
27 |
+
<a href="https://daibo.info/" target="_blank"><nobr>Bo Dai</nobr></a><sup>2</sup>  
|
28 |
+
<a href="https://andytang15.github.io/" target="_blank"><nobr>Yansong Tang</nobr></a><sup>1</sup>
|
29 |
+
</h2>
|
30 |
+
|
31 |
+
<h2 style='text-align: center'>
|
32 |
+
<nobr><sup>1</sup>Tsinghua University</nobr>  
|
33 |
+
<nobr><sup>2</sup>Shanghai AI Laboratory</nobr>
|
34 |
+
</h2>
|
35 |
+
|
36 |
+
</div>
|
37 |
+
"""
|
38 |
+
|
39 |
+
WEBSITE_bottom = """
|
40 |
+
<div class="embed_hidden">
|
41 |
+
<p>
|
42 |
+
Space adapted from <a href="https://huggingface.co/spaces/Mathux/TMR" target="_blank">TMR</a>
|
43 |
+
and <a href="https://huggingface.co/spaces/MeYourHint/MoMask" target="_blank">MoMask</a>.
|
44 |
+
</p>
|
45 |
+
</div>
|
46 |
+
"""
|
47 |
+
|
48 |
+
EXAMPLES = [
|
49 |
+
"a person does a jump",
|
50 |
+
"a person waves both arms in the air.",
|
51 |
+
"The person takes 4 steps backwards.",
|
52 |
+
"this person bends forward as if to bow.",
|
53 |
+
"The person was pushed but did not fall.",
|
54 |
+
"a man walks forward in a snake like pattern.",
|
55 |
+
"a man paces back and forth along the same line.",
|
56 |
+
"with arms out to the sides a person walks forward",
|
57 |
+
"A man bends down and picks something up with his right hand.",
|
58 |
+
"The man walked forward, spun right on one foot and walked back to his original position.",
|
59 |
+
"a person slightly bent over with right hand pressing against the air walks forward slowly"
|
60 |
+
]
|
61 |
+
|
62 |
+
CSS = """
|
63 |
+
.contour_video {
|
64 |
+
display: flex;
|
65 |
+
flex-direction: column;
|
66 |
+
justify-content: center;
|
67 |
+
align-items: center;
|
68 |
+
z-index: var(--layer-5);
|
69 |
+
border-radius: var(--block-radius);
|
70 |
+
background: var(--background-fill-primary);
|
71 |
+
padding: 0 var(--size-6);
|
72 |
+
max-height: var(--size-screen-h);
|
73 |
+
overflow: hidden;
|
74 |
+
}
|
75 |
+
"""
|
76 |
+
|
77 |
+
if not os.path.exists("./experiments_t2m/"):
|
78 |
+
os.system("bash prepare/download_pretrained_models.sh")
|
79 |
+
if not os.path.exists('./deps/glove/'):
|
80 |
+
os.system("bash prepare/download_glove.sh")
|
81 |
+
if not os.path.exists('./deps/sentence-t5-large/'):
|
82 |
+
os.system("bash prepare/prepare_t5.sh")
|
83 |
+
if not os.path.exists('./deps/t2m/'):
|
84 |
+
os.system("bash prepare/download_t2m_evaluators.sh")
|
85 |
+
if not os.path.exists('./datasets/humanml3d/'):
|
86 |
+
os.system("bash prepare/prepare_tiny_humanml3d.sh")
|
87 |
+
|
88 |
+
DEFAULT_TEXT = "A person is "
|
89 |
+
MAX_VIDEOS = 12
|
90 |
+
T2M_CFG = "./configs/motionlcm_t2m.yaml"
|
91 |
+
|
92 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
93 |
+
|
94 |
+
cfg = OmegaConf.load(T2M_CFG)
|
95 |
+
cfg_model = get_module_config(cfg.model, cfg.model.target)
|
96 |
+
cfg = OmegaConf.merge(cfg, cfg_model)
|
97 |
+
set_seed(1949)
|
98 |
+
|
99 |
+
name_time_str = osp.join(cfg.NAME, datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
|
100 |
+
output_dir = osp.join(cfg.TEST_FOLDER, name_time_str)
|
101 |
+
vis_dir = osp.join(output_dir, 'samples')
|
102 |
+
os.makedirs(output_dir, exist_ok=False)
|
103 |
+
os.makedirs(vis_dir, exist_ok=False)
|
104 |
+
|
105 |
+
state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"]
|
106 |
+
print("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS))
|
107 |
+
|
108 |
+
lcm_key = 'denoiser.time_embedding.cond_proj.weight'
|
109 |
+
is_lcm = False
|
110 |
+
if lcm_key in state_dict:
|
111 |
+
is_lcm = True
|
112 |
+
time_cond_proj_dim = state_dict[lcm_key].shape[1]
|
113 |
+
cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim
|
114 |
+
print(f'Is LCM: {is_lcm}')
|
115 |
+
|
116 |
+
cfg.model.is_controlnet = False
|
117 |
+
|
118 |
+
datasets = get_datasets(cfg, phase="test")[0]
|
119 |
+
model = MLD(cfg, datasets)
|
120 |
+
model.to(device)
|
121 |
+
model.eval()
|
122 |
+
model.load_state_dict(state_dict)
|
123 |
+
|
124 |
+
|
125 |
+
@torch.no_grad()
|
126 |
+
def generate(text, motion_len, num_videos):
|
127 |
+
batch = {"text": [text] * num_videos, "length": [motion_len] * num_videos}
|
128 |
+
|
129 |
+
s = time.time()
|
130 |
+
joints, _ = model(batch)
|
131 |
+
runtime = round(time.time() - s, 3)
|
132 |
+
runtime_info = f'Inference {len(joints)} motions, runtime: {runtime}s, device: {device}'
|
133 |
+
path = []
|
134 |
+
for i in range(num_videos):
|
135 |
+
uid = random.randrange(999999999)
|
136 |
+
video_path = osp.join(vis_dir, f"sample_{uid}.mp4")
|
137 |
+
plot_3d_motion(video_path, joints[i].detach().cpu().numpy(), '', fps=20)
|
138 |
+
path.append(video_path)
|
139 |
+
return path, runtime_info
|
140 |
+
|
141 |
+
|
142 |
+
# HTML component
|
143 |
+
def get_video_html(path, video_id, width=700, height=700):
|
144 |
+
video_html = f"""
|
145 |
+
<video class="contour_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
|
146 |
+
autoplay loop disablepictureinpicture id="{video_id}">
|
147 |
+
<source src="file/{path}" type="video/mp4">
|
148 |
+
Your browser does not support the video tag.
|
149 |
+
</video>
|
150 |
+
"""
|
151 |
+
return video_html
|
152 |
+
|
153 |
+
|
154 |
+
def generate_component(generate_function, text, motion_len, num_videos):
|
155 |
+
if text == DEFAULT_TEXT or text == "" or text is None:
|
156 |
+
return [None for _ in range(MAX_VIDEOS)] + [None]
|
157 |
+
|
158 |
+
motion_len = max(36, min(int(float(motion_len) * 20), 196))
|
159 |
+
paths, info = generate_function(text, motion_len, num_videos)
|
160 |
+
htmls = [get_video_html(path, idx) for idx, path in enumerate(paths)]
|
161 |
+
htmls = htmls + [None for _ in range(max(0, MAX_VIDEOS - num_videos))]
|
162 |
+
return htmls + [info]
|
163 |
+
|
164 |
+
|
165 |
+
theme = gr.themes.Default(primary_hue="purple", secondary_hue="gray")
|
166 |
+
generate_and_show = partial(generate_component, generate)
|
167 |
+
|
168 |
+
with gr.Blocks(css=CSS, theme=theme) as demo:
|
169 |
+
gr.HTML(WEBSITE)
|
170 |
+
videos = []
|
171 |
+
|
172 |
+
with gr.Row():
|
173 |
+
with gr.Column(scale=3):
|
174 |
+
text = gr.Textbox(
|
175 |
+
show_label=True,
|
176 |
+
label="Text prompt",
|
177 |
+
value=DEFAULT_TEXT,
|
178 |
+
)
|
179 |
+
|
180 |
+
with gr.Row():
|
181 |
+
with gr.Column(scale=1):
|
182 |
+
motion_len = gr.Textbox(
|
183 |
+
show_label=True,
|
184 |
+
label="Motion length (in seconds, <=9.8s)",
|
185 |
+
value=5,
|
186 |
+
info="Any length exceeding 9.8s will be restricted to 9.8s.",
|
187 |
+
)
|
188 |
+
with gr.Column(scale=1):
|
189 |
+
num_videos = gr.Radio(
|
190 |
+
[1, 4, 8, 12],
|
191 |
+
label="Videos",
|
192 |
+
value=8,
|
193 |
+
info="Number of videos to generate.",
|
194 |
+
)
|
195 |
+
|
196 |
+
gen_btn = gr.Button("Generate", variant="primary")
|
197 |
+
clear = gr.Button("Clear", variant="secondary")
|
198 |
+
|
199 |
+
results = gr.Textbox(show_label=True,
|
200 |
+
label='Inference info (runtime and device)',
|
201 |
+
info='Real-time inference cannot be achieved using the free CPU. Local GPU deployment is recommended.',
|
202 |
+
interactive=False)
|
203 |
+
|
204 |
+
with gr.Column(scale=2):
|
205 |
+
def generate_example(text, motion_len, num_videos):
|
206 |
+
return generate_and_show(text, motion_len, num_videos)
|
207 |
+
|
208 |
+
examples = gr.Examples(
|
209 |
+
examples=[[x, None, None] for x in EXAMPLES],
|
210 |
+
inputs=[text, motion_len, num_videos],
|
211 |
+
examples_per_page=12,
|
212 |
+
run_on_click=False,
|
213 |
+
cache_examples=False,
|
214 |
+
fn=generate_example,
|
215 |
+
outputs=[],
|
216 |
+
)
|
217 |
+
|
218 |
+
for _ in range(3):
|
219 |
+
with gr.Row():
|
220 |
+
for _ in range(4):
|
221 |
+
video = gr.HTML()
|
222 |
+
videos.append(video)
|
223 |
+
|
224 |
+
# gr.HTML(WEBSITE_bottom)
|
225 |
+
# connect the examples to the output
|
226 |
+
# a bit hacky
|
227 |
+
examples.outputs = videos
|
228 |
+
|
229 |
+
def load_example(example_id):
|
230 |
+
processed_example = examples.non_none_processed_examples[example_id]
|
231 |
+
return gr.utils.resolve_singleton(processed_example)
|
232 |
+
|
233 |
+
examples.dataset.click(
|
234 |
+
load_example,
|
235 |
+
inputs=[examples.dataset],
|
236 |
+
outputs=examples.inputs_with_examples, # type: ignore
|
237 |
+
show_progress=False,
|
238 |
+
postprocess=False,
|
239 |
+
queue=False,
|
240 |
+
).then(fn=generate_example, inputs=examples.inputs, outputs=videos + [results])
|
241 |
+
|
242 |
+
gen_btn.click(
|
243 |
+
fn=generate_and_show,
|
244 |
+
inputs=[text, motion_len, num_videos],
|
245 |
+
outputs=videos + [results],
|
246 |
+
)
|
247 |
+
text.submit(
|
248 |
+
fn=generate_and_show,
|
249 |
+
inputs=[text, motion_len, num_videos],
|
250 |
+
outputs=videos + [results],
|
251 |
+
)
|
252 |
+
|
253 |
+
def clear_videos():
|
254 |
+
return [None for _ in range(MAX_VIDEOS)] + [DEFAULT_TEXT] + [None]
|
255 |
+
|
256 |
+
clear.click(fn=clear_videos, outputs=videos + [text] + [results])
|
257 |
+
|
258 |
+
demo.launch()
|
configs/mld_control.yaml
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FOLDER: './experiments_control'
|
2 |
+
TEST_FOLDER: './experiments_control_test'
|
3 |
+
|
4 |
+
NAME: 'mld_humanml'
|
5 |
+
|
6 |
+
TRAIN:
|
7 |
+
DATASETS: ['humanml3d']
|
8 |
+
BATCH_SIZE: 128
|
9 |
+
SPLIT: 'train'
|
10 |
+
NUM_WORKERS: 8
|
11 |
+
PERSISTENT_WORKERS: true
|
12 |
+
SEED_VALUE: 1234
|
13 |
+
PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml.ckpt'
|
14 |
+
|
15 |
+
validation_steps: -1
|
16 |
+
validation_epochs: 50
|
17 |
+
checkpointing_steps: -1
|
18 |
+
checkpointing_epochs: 50
|
19 |
+
max_train_steps: -1
|
20 |
+
max_train_epochs: 1000
|
21 |
+
learning_rate: 1e-4
|
22 |
+
learning_rate_spatial: 1e-4
|
23 |
+
lr_scheduler: "cosine"
|
24 |
+
lr_warmup_steps: 1000
|
25 |
+
adam_beta1: 0.9
|
26 |
+
adam_beta2: 0.999
|
27 |
+
adam_weight_decay: 0.0
|
28 |
+
adam_epsilon: 1e-08
|
29 |
+
max_grad_norm: 1.0
|
30 |
+
|
31 |
+
EVAL:
|
32 |
+
DATASETS: ['humanml3d']
|
33 |
+
BATCH_SIZE: 32
|
34 |
+
SPLIT: 'test'
|
35 |
+
NUM_WORKERS: 12
|
36 |
+
|
37 |
+
TEST:
|
38 |
+
DATASETS: ['humanml3d']
|
39 |
+
BATCH_SIZE: 1
|
40 |
+
SPLIT: 'test'
|
41 |
+
NUM_WORKERS: 12
|
42 |
+
|
43 |
+
CHECKPOINTS: 'experiments_control/mld_humanml/mld_humanml.ckpt'
|
44 |
+
|
45 |
+
# Testing Args
|
46 |
+
REPLICATION_TIMES: 1
|
47 |
+
MM_NUM_SAMPLES: 100
|
48 |
+
MM_NUM_REPEATS: 30
|
49 |
+
MM_NUM_TIMES: 10
|
50 |
+
DIVERSITY_TIMES: 300
|
51 |
+
MAX_NUM_SAMPLES: 1024
|
52 |
+
|
53 |
+
DATASET:
|
54 |
+
SMPL_PATH: './deps/smpl'
|
55 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
56 |
+
HUMANML3D:
|
57 |
+
PICK_ONE_TEXT: true
|
58 |
+
FRAME_RATE: 20.0
|
59 |
+
UNIT_LEN: 4
|
60 |
+
ROOT: './datasets/humanml3d'
|
61 |
+
SPLIT_ROOT: './datasets/humanml3d'
|
62 |
+
SAMPLER:
|
63 |
+
MAX_LEN: 196
|
64 |
+
MIN_LEN: 40
|
65 |
+
MAX_TEXT_LEN: 20
|
66 |
+
|
67 |
+
METRIC:
|
68 |
+
DIST_SYNC_ON_STEP: true
|
69 |
+
TYPE: ['TM2TMetrics', 'ControlMetrics']
|
70 |
+
|
71 |
+
model:
|
72 |
+
target: 'modules_mld'
|
73 |
+
latent_dim: [1, 256]
|
74 |
+
guidance_scale: 7.5
|
75 |
+
guidance_uncondp: 0.1
|
76 |
+
|
77 |
+
# ControlNet Args
|
78 |
+
is_controlnet: true
|
79 |
+
is_controlnet_temporal: false
|
80 |
+
training_control_joint: [0]
|
81 |
+
testing_control_joint: [0]
|
82 |
+
training_density: 'random'
|
83 |
+
testing_density: 100
|
84 |
+
control_scale: 1.0
|
85 |
+
vaeloss: true
|
86 |
+
vaeloss_type: 'sum'
|
87 |
+
cond_ratio: 1.0
|
88 |
+
rot_ratio: 0.0
|
89 |
+
|
90 |
+
t2m_textencoder:
|
91 |
+
dim_word: 300
|
92 |
+
dim_pos_ohot: 15
|
93 |
+
dim_text_hidden: 512
|
94 |
+
dim_coemb_hidden: 512
|
95 |
+
|
96 |
+
t2m_motionencoder:
|
97 |
+
dim_move_hidden: 512
|
98 |
+
dim_move_latent: 512
|
99 |
+
dim_motion_hidden: 1024
|
100 |
+
dim_motion_latent: 512
|
101 |
+
|
102 |
+
bert_path: './deps/distilbert-base-uncased'
|
103 |
+
clip_path: './deps/clip-vit-large-patch14'
|
104 |
+
t5_path: './deps/sentence-t5-large'
|
105 |
+
t2m_path: './deps/t2m/'
|
configs/mld_t2m_infer.yaml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FOLDER: './experiments_t2m'
|
2 |
+
TEST_FOLDER: './experiments_t2m_test'
|
3 |
+
|
4 |
+
NAME: 'mld_humanml'
|
5 |
+
|
6 |
+
TRAIN:
|
7 |
+
DATASETS: ['humanml3d']
|
8 |
+
BATCH_SIZE: 1
|
9 |
+
NUM_WORKERS: 8
|
10 |
+
PERSISTENT_WORKERS: true
|
11 |
+
SEED_VALUE: 1234
|
12 |
+
|
13 |
+
EVAL:
|
14 |
+
DATASETS: ['humanml3d']
|
15 |
+
BATCH_SIZE: 32
|
16 |
+
SPLIT: test
|
17 |
+
NUM_WORKERS: 12
|
18 |
+
|
19 |
+
TEST:
|
20 |
+
DATASETS: ['humanml3d']
|
21 |
+
SPLIT: test
|
22 |
+
BATCH_SIZE: 1
|
23 |
+
NUM_WORKERS: 12
|
24 |
+
|
25 |
+
CHECKPOINTS: 'experiments_t2m/mld_humanml/mld_humanml.ckpt'
|
26 |
+
|
27 |
+
# Testing Args
|
28 |
+
REPLICATION_TIMES: 20
|
29 |
+
MM_NUM_SAMPLES: 100
|
30 |
+
MM_NUM_REPEATS: 30
|
31 |
+
MM_NUM_TIMES: 10
|
32 |
+
DIVERSITY_TIMES: 300
|
33 |
+
|
34 |
+
DATASET:
|
35 |
+
SMPL_PATH: './deps/smpl'
|
36 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
37 |
+
HUMANML3D:
|
38 |
+
PICK_ONE_TEXT: true
|
39 |
+
FRAME_RATE: 20.0
|
40 |
+
UNIT_LEN: 4
|
41 |
+
ROOT: './datasets/humanml3d'
|
42 |
+
SPLIT_ROOT: './datasets/humanml3d'
|
43 |
+
SAMPLER:
|
44 |
+
MAX_LEN: 196
|
45 |
+
MIN_LEN: 40
|
46 |
+
MAX_TEXT_LEN: 20
|
47 |
+
|
48 |
+
METRIC:
|
49 |
+
DIST_SYNC_ON_STEP: True
|
50 |
+
TYPE: ['TM2TMetrics']
|
51 |
+
|
52 |
+
model:
|
53 |
+
target: 'modules_mld'
|
54 |
+
latent_dim: [1, 256]
|
55 |
+
guidance_scale: 7.5
|
56 |
+
|
57 |
+
t2m_textencoder:
|
58 |
+
dim_word: 300
|
59 |
+
dim_pos_ohot: 15
|
60 |
+
dim_text_hidden: 512
|
61 |
+
dim_coemb_hidden: 512
|
62 |
+
|
63 |
+
t2m_motionencoder:
|
64 |
+
dim_move_hidden: 512
|
65 |
+
dim_move_latent: 512
|
66 |
+
dim_motion_hidden: 1024
|
67 |
+
dim_motion_latent: 512
|
68 |
+
|
69 |
+
bert_path: './deps/distilbert-base-uncased'
|
70 |
+
clip_path: './deps/clip-vit-large-patch14'
|
71 |
+
t5_path: './deps/sentence-t5-large'
|
72 |
+
t2m_path: './deps/t2m/'
|
configs/modules/denoiser.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
denoiser:
|
2 |
+
target: mld.models.architectures.mld_denoiser.MldDenoiser
|
3 |
+
params:
|
4 |
+
text_encoded_dim: 768
|
5 |
+
ff_size: 1024
|
6 |
+
num_layers: 9
|
7 |
+
num_heads: 4
|
8 |
+
dropout: 0.1
|
9 |
+
normalize_before: false
|
10 |
+
activation: 'gelu'
|
11 |
+
flip_sin_to_cos: true
|
12 |
+
return_intermediate_dec: false
|
13 |
+
position_embedding: 'learned'
|
14 |
+
arch: 'trans_enc'
|
15 |
+
freq_shift: 0
|
16 |
+
latent_dim: ${model.latent_dim}
|
configs/modules/motion_vae.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
motion_vae:
|
2 |
+
target: mld.models.architectures.mld_vae.MldVae
|
3 |
+
params:
|
4 |
+
arch: 'encoder_decoder'
|
5 |
+
ff_size: 1024
|
6 |
+
num_layers: 9
|
7 |
+
num_heads: 4
|
8 |
+
dropout: 0.1
|
9 |
+
normalize_before: false
|
10 |
+
activation: 'gelu'
|
11 |
+
position_embedding: 'learned'
|
12 |
+
latent_dim: ${model.latent_dim}
|
13 |
+
nfeats: ${DATASET.NFEATS}
|
configs/modules/scheduler.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scheduler:
|
2 |
+
target: diffusers.LCMScheduler
|
3 |
+
num_inference_timesteps: 1
|
4 |
+
params:
|
5 |
+
num_train_timesteps: 1000
|
6 |
+
beta_start: 0.00085
|
7 |
+
beta_end: 0.012
|
8 |
+
beta_schedule: 'scaled_linear'
|
9 |
+
clip_sample: false
|
10 |
+
set_alpha_to_one: false
|
11 |
+
|
12 |
+
noise_scheduler:
|
13 |
+
target: diffusers.DDPMScheduler
|
14 |
+
params:
|
15 |
+
num_train_timesteps: 1000
|
16 |
+
beta_start: 0.00085
|
17 |
+
beta_end: 0.012
|
18 |
+
beta_schedule: 'scaled_linear'
|
19 |
+
variance_type: 'fixed_small'
|
20 |
+
clip_sample: false
|
configs/modules/text_encoder.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
text_encoder:
|
2 |
+
target: mld.models.architectures.mld_clip.MldTextEncoder
|
3 |
+
params:
|
4 |
+
last_hidden_state: false # if true, the last hidden state is used as the text embedding
|
5 |
+
modelpath: ${model.t5_path}
|
configs/modules/traj_encoder.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
traj_encoder:
|
2 |
+
target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder
|
3 |
+
params:
|
4 |
+
ff_size: 1024
|
5 |
+
num_layers: 9
|
6 |
+
num_heads: 4
|
7 |
+
dropout: 0.1
|
8 |
+
normalize_before: false
|
9 |
+
activation: 'gelu'
|
10 |
+
position_embedding: 'learned'
|
11 |
+
latent_dim: ${model.latent_dim}
|
12 |
+
nfeats: ${DATASET.NJOINTS}
|
configs/modules_mld/denoiser.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
denoiser:
|
2 |
+
target: mld.models.architectures.mld_denoiser.MldDenoiser
|
3 |
+
params:
|
4 |
+
text_encoded_dim: 768
|
5 |
+
ff_size: 1024
|
6 |
+
num_layers: 9
|
7 |
+
num_heads: 4
|
8 |
+
dropout: 0.1
|
9 |
+
normalize_before: false
|
10 |
+
activation: 'gelu'
|
11 |
+
flip_sin_to_cos: true
|
12 |
+
return_intermediate_dec: false
|
13 |
+
position_embedding: 'learned'
|
14 |
+
arch: 'trans_enc'
|
15 |
+
freq_shift: 0
|
16 |
+
latent_dim: ${model.latent_dim}
|
configs/modules_mld/motion_vae.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
motion_vae:
|
2 |
+
target: mld.models.architectures.mld_vae.MldVae
|
3 |
+
params:
|
4 |
+
arch: 'encoder_decoder'
|
5 |
+
ff_size: 1024
|
6 |
+
num_layers: 9
|
7 |
+
num_heads: 4
|
8 |
+
dropout: 0.1
|
9 |
+
normalize_before: false
|
10 |
+
activation: 'gelu'
|
11 |
+
position_embedding: 'learned'
|
12 |
+
latent_dim: ${model.latent_dim}
|
13 |
+
nfeats: ${DATASET.NFEATS}
|
configs/modules_mld/scheduler.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scheduler:
|
2 |
+
target: diffusers.DDIMScheduler
|
3 |
+
num_inference_timesteps: 50
|
4 |
+
eta: 0.0
|
5 |
+
params:
|
6 |
+
num_train_timesteps: 1000
|
7 |
+
beta_start: 0.00085
|
8 |
+
beta_end: 0.012
|
9 |
+
beta_schedule: 'scaled_linear'
|
10 |
+
clip_sample: false
|
11 |
+
# below are for ddim
|
12 |
+
set_alpha_to_one: false
|
13 |
+
steps_offset: 1
|
14 |
+
|
15 |
+
noise_scheduler:
|
16 |
+
target: diffusers.DDPMScheduler
|
17 |
+
params:
|
18 |
+
num_train_timesteps: 1000
|
19 |
+
beta_start: 0.00085
|
20 |
+
beta_end: 0.012
|
21 |
+
beta_schedule: 'scaled_linear'
|
22 |
+
variance_type: 'fixed_small'
|
23 |
+
clip_sample: false
|
configs/modules_mld/text_encoder.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
text_encoder:
|
2 |
+
target: mld.models.architectures.mld_clip.MldTextEncoder
|
3 |
+
params:
|
4 |
+
last_hidden_state: false # if true, the last hidden state is used as the text embedding
|
5 |
+
modelpath: ${model.t5_path}
|
configs/modules_mld/traj_encoder.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
traj_encoder:
|
2 |
+
target: mld.models.architectures.mld_traj_encoder.MldTrajEncoder
|
3 |
+
params:
|
4 |
+
ff_size: 1024
|
5 |
+
num_layers: 9
|
6 |
+
num_heads: 4
|
7 |
+
dropout: 0.1
|
8 |
+
normalize_before: false
|
9 |
+
activation: 'gelu'
|
10 |
+
position_embedding: 'learned'
|
11 |
+
latent_dim: ${model.latent_dim}
|
12 |
+
nfeats: ${DATASET.NJOINTS}
|
configs/motionlcm_control.yaml
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FOLDER: './experiments_control'
|
2 |
+
TEST_FOLDER: './experiments_control_test'
|
3 |
+
|
4 |
+
NAME: 'motionlcm_humanml'
|
5 |
+
|
6 |
+
TRAIN:
|
7 |
+
DATASETS: ['humanml3d']
|
8 |
+
BATCH_SIZE: 128
|
9 |
+
SPLIT: 'train'
|
10 |
+
NUM_WORKERS: 8
|
11 |
+
PERSISTENT_WORKERS: true
|
12 |
+
SEED_VALUE: 1234
|
13 |
+
PRETRAINED: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
|
14 |
+
|
15 |
+
validation_steps: -1
|
16 |
+
validation_epochs: 50
|
17 |
+
checkpointing_steps: -1
|
18 |
+
checkpointing_epochs: 50
|
19 |
+
max_train_steps: -1
|
20 |
+
max_train_epochs: 1000
|
21 |
+
learning_rate: 1e-4
|
22 |
+
learning_rate_spatial: 1e-4
|
23 |
+
lr_scheduler: "cosine"
|
24 |
+
lr_warmup_steps: 1000
|
25 |
+
adam_beta1: 0.9
|
26 |
+
adam_beta2: 0.999
|
27 |
+
adam_weight_decay: 0.0
|
28 |
+
adam_epsilon: 1e-08
|
29 |
+
max_grad_norm: 1.0
|
30 |
+
|
31 |
+
EVAL:
|
32 |
+
DATASETS: ['humanml3d']
|
33 |
+
BATCH_SIZE: 32
|
34 |
+
SPLIT: 'test'
|
35 |
+
NUM_WORKERS: 12
|
36 |
+
|
37 |
+
TEST:
|
38 |
+
DATASETS: ['humanml3d']
|
39 |
+
BATCH_SIZE: 1
|
40 |
+
SPLIT: 'test'
|
41 |
+
NUM_WORKERS: 12
|
42 |
+
|
43 |
+
CHECKPOINTS: 'experiments_control/motionlcm_humanml/motionlcm_humanml.ckpt'
|
44 |
+
|
45 |
+
# Testing Args
|
46 |
+
REPLICATION_TIMES: 1
|
47 |
+
MM_NUM_SAMPLES: 100
|
48 |
+
MM_NUM_REPEATS: 30
|
49 |
+
MM_NUM_TIMES: 10
|
50 |
+
DIVERSITY_TIMES: 300
|
51 |
+
MAX_NUM_SAMPLES: 1024
|
52 |
+
|
53 |
+
DATASET:
|
54 |
+
SMPL_PATH: './deps/smpl'
|
55 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
56 |
+
HUMANML3D:
|
57 |
+
PICK_ONE_TEXT: true
|
58 |
+
FRAME_RATE: 20.0
|
59 |
+
UNIT_LEN: 4
|
60 |
+
ROOT: './datasets/humanml3d'
|
61 |
+
SPLIT_ROOT: './datasets/humanml3d'
|
62 |
+
SAMPLER:
|
63 |
+
MAX_LEN: 196
|
64 |
+
MIN_LEN: 40
|
65 |
+
MAX_TEXT_LEN: 20
|
66 |
+
|
67 |
+
METRIC:
|
68 |
+
DIST_SYNC_ON_STEP: true
|
69 |
+
TYPE: ['TM2TMetrics', 'ControlMetrics']
|
70 |
+
|
71 |
+
model:
|
72 |
+
target: 'modules'
|
73 |
+
latent_dim: [1, 256]
|
74 |
+
guidance_scale: 7.5
|
75 |
+
guidance_uncondp: 0.0
|
76 |
+
|
77 |
+
# ControlNet Args
|
78 |
+
is_controlnet: true
|
79 |
+
is_controlnet_temporal: false
|
80 |
+
training_control_joint: [0]
|
81 |
+
testing_control_joint: [0]
|
82 |
+
training_density: 'random'
|
83 |
+
testing_density: 100
|
84 |
+
control_scale: 1.0
|
85 |
+
vaeloss: true
|
86 |
+
vaeloss_type: 'sum'
|
87 |
+
cond_ratio: 1.0
|
88 |
+
rot_ratio: 0.0
|
89 |
+
|
90 |
+
t2m_textencoder:
|
91 |
+
dim_word: 300
|
92 |
+
dim_pos_ohot: 15
|
93 |
+
dim_text_hidden: 512
|
94 |
+
dim_coemb_hidden: 512
|
95 |
+
|
96 |
+
t2m_motionencoder:
|
97 |
+
dim_move_hidden: 512
|
98 |
+
dim_move_latent: 512
|
99 |
+
dim_motion_hidden: 1024
|
100 |
+
dim_motion_latent: 512
|
101 |
+
|
102 |
+
bert_path: './deps/distilbert-base-uncased'
|
103 |
+
clip_path: './deps/clip-vit-large-patch14'
|
104 |
+
t5_path: './deps/sentence-t5-large'
|
105 |
+
t2m_path: './deps/t2m/'
|
configs/motionlcm_t2m.yaml
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FOLDER: './experiments_t2m'
|
2 |
+
TEST_FOLDER: './experiments_t2m_test'
|
3 |
+
|
4 |
+
NAME: 'motionlcm_humanml'
|
5 |
+
|
6 |
+
TRAIN:
|
7 |
+
DATASETS: ['humanml3d']
|
8 |
+
BATCH_SIZE: 256
|
9 |
+
SPLIT: 'train'
|
10 |
+
NUM_WORKERS: 8
|
11 |
+
PERSISTENT_WORKERS: true
|
12 |
+
SEED_VALUE: 1234
|
13 |
+
PRETRAINED: 'experiments_t2m/mld_humanml/mld_humanml.ckpt'
|
14 |
+
|
15 |
+
validation_steps: -1
|
16 |
+
validation_epochs: 50
|
17 |
+
checkpointing_steps: -1
|
18 |
+
checkpointing_epochs: 50
|
19 |
+
max_train_steps: -1
|
20 |
+
max_train_epochs: 1000
|
21 |
+
learning_rate: 2e-4
|
22 |
+
lr_scheduler: "cosine"
|
23 |
+
lr_warmup_steps: 1000
|
24 |
+
adam_beta1: 0.9
|
25 |
+
adam_beta2: 0.999
|
26 |
+
adam_weight_decay: 0.0
|
27 |
+
adam_epsilon: 1e-08
|
28 |
+
max_grad_norm: 1.0
|
29 |
+
|
30 |
+
# Latent Consistency Distillation Specific Arguments
|
31 |
+
w_min: 5.0
|
32 |
+
w_max: 15.0
|
33 |
+
num_ddim_timesteps: 50
|
34 |
+
loss_type: 'huber'
|
35 |
+
huber_c: 0.001
|
36 |
+
unet_time_cond_proj_dim: 256
|
37 |
+
ema_decay: 0.95
|
38 |
+
|
39 |
+
EVAL:
|
40 |
+
DATASETS: ['humanml3d']
|
41 |
+
BATCH_SIZE: 32
|
42 |
+
SPLIT: 'test'
|
43 |
+
NUM_WORKERS: 12
|
44 |
+
|
45 |
+
TEST:
|
46 |
+
DATASETS: ['humanml3d']
|
47 |
+
BATCH_SIZE: 1
|
48 |
+
SPLIT: 'test'
|
49 |
+
NUM_WORKERS: 12
|
50 |
+
|
51 |
+
CHECKPOINTS: 'experiments_t2m/motionlcm_humanml/motionlcm_humanml.ckpt'
|
52 |
+
|
53 |
+
# Testing Args
|
54 |
+
REPLICATION_TIMES: 20
|
55 |
+
MM_NUM_SAMPLES: 100
|
56 |
+
MM_NUM_REPEATS: 30
|
57 |
+
MM_NUM_TIMES: 10
|
58 |
+
DIVERSITY_TIMES: 300
|
59 |
+
|
60 |
+
DATASET:
|
61 |
+
SMPL_PATH: './deps/smpl'
|
62 |
+
WORD_VERTILIZER_PATH: './deps/glove/'
|
63 |
+
HUMANML3D:
|
64 |
+
PICK_ONE_TEXT: true
|
65 |
+
FRAME_RATE: 20.0
|
66 |
+
UNIT_LEN: 4
|
67 |
+
ROOT: './datasets/humanml3d'
|
68 |
+
SPLIT_ROOT: './datasets/humanml3d'
|
69 |
+
SAMPLER:
|
70 |
+
MAX_LEN: 196
|
71 |
+
MIN_LEN: 40
|
72 |
+
MAX_TEXT_LEN: 20
|
73 |
+
|
74 |
+
METRIC:
|
75 |
+
DIST_SYNC_ON_STEP: true
|
76 |
+
TYPE: ['TM2TMetrics']
|
77 |
+
|
78 |
+
model:
|
79 |
+
target: 'modules'
|
80 |
+
latent_dim: [1, 256]
|
81 |
+
guidance_scale: 7.5
|
82 |
+
guidance_uncondp: 0.0
|
83 |
+
is_controlnet: false
|
84 |
+
|
85 |
+
t2m_textencoder:
|
86 |
+
dim_word: 300
|
87 |
+
dim_pos_ohot: 15
|
88 |
+
dim_text_hidden: 512
|
89 |
+
dim_coemb_hidden: 512
|
90 |
+
|
91 |
+
t2m_motionencoder:
|
92 |
+
dim_move_hidden: 512
|
93 |
+
dim_move_latent: 512
|
94 |
+
dim_motion_hidden: 1024
|
95 |
+
dim_motion_latent: 512
|
96 |
+
|
97 |
+
bert_path: './deps/distilbert-base-uncased'
|
98 |
+
clip_path: './deps/clip-vit-large-patch14'
|
99 |
+
t5_path: './deps/sentence-t5-large'
|
100 |
+
t2m_path: './deps/t2m/'
|
demo.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
import sys
|
4 |
+
import datetime
|
5 |
+
import logging
|
6 |
+
import os.path as osp
|
7 |
+
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from mld.config import parse_args
|
13 |
+
from mld.data.get_data import get_datasets
|
14 |
+
from mld.models.modeltype.mld import MLD
|
15 |
+
from mld.utils.utils import set_seed, move_batch_to_device
|
16 |
+
from mld.data.humanml.utils.plot_script import plot_3d_motion
|
17 |
+
from mld.utils.temos_utils import remove_padding
|
18 |
+
|
19 |
+
|
20 |
+
def load_example_input(text_path: str) -> tuple:
|
21 |
+
with open(text_path, "r") as f:
|
22 |
+
lines = f.readlines()
|
23 |
+
|
24 |
+
count = 0
|
25 |
+
texts, lens = [], []
|
26 |
+
# Strips the newline character
|
27 |
+
for line in lines:
|
28 |
+
count += 1
|
29 |
+
s = line.strip()
|
30 |
+
s_l = s.split(" ")[0]
|
31 |
+
s_t = s[(len(s_l) + 1):]
|
32 |
+
lens.append(int(s_l))
|
33 |
+
texts.append(s_t)
|
34 |
+
return texts, lens
|
35 |
+
|
36 |
+
|
37 |
+
def main():
|
38 |
+
cfg = parse_args()
|
39 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
40 |
+
set_seed(cfg.TRAIN.SEED_VALUE)
|
41 |
+
|
42 |
+
name_time_str = osp.join(cfg.NAME, "demo_" + datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
|
43 |
+
output_dir = osp.join(cfg.TEST_FOLDER, name_time_str)
|
44 |
+
vis_dir = osp.join(output_dir, 'samples')
|
45 |
+
os.makedirs(output_dir, exist_ok=False)
|
46 |
+
os.makedirs(vis_dir, exist_ok=False)
|
47 |
+
|
48 |
+
steam_handler = logging.StreamHandler(sys.stdout)
|
49 |
+
file_handler = logging.FileHandler(osp.join(output_dir, 'output.log'))
|
50 |
+
logging.basicConfig(level=logging.INFO,
|
51 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
52 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
53 |
+
handlers=[steam_handler, file_handler])
|
54 |
+
logger = logging.getLogger(__name__)
|
55 |
+
|
56 |
+
OmegaConf.save(cfg, osp.join(output_dir, 'config.yaml'))
|
57 |
+
|
58 |
+
state_dict = torch.load(cfg.TEST.CHECKPOINTS, map_location="cpu")["state_dict"]
|
59 |
+
logger.info("Loading checkpoints from {}".format(cfg.TEST.CHECKPOINTS))
|
60 |
+
|
61 |
+
lcm_key = 'denoiser.time_embedding.cond_proj.weight'
|
62 |
+
is_lcm = False
|
63 |
+
if lcm_key in state_dict:
|
64 |
+
is_lcm = True
|
65 |
+
time_cond_proj_dim = state_dict[lcm_key].shape[1]
|
66 |
+
cfg.model.denoiser.params.time_cond_proj_dim = time_cond_proj_dim
|
67 |
+
logger.info(f'Is LCM: {is_lcm}')
|
68 |
+
|
69 |
+
cn_key = "controlnet.controlnet_cond_embedding.0.weight"
|
70 |
+
is_controlnet = True if cn_key in state_dict else False
|
71 |
+
cfg.model.is_controlnet = is_controlnet
|
72 |
+
logger.info(f'Is Controlnet: {is_controlnet}')
|
73 |
+
|
74 |
+
datasets = get_datasets(cfg, phase="test")[0]
|
75 |
+
model = MLD(cfg, datasets)
|
76 |
+
model.to(device)
|
77 |
+
model.eval()
|
78 |
+
model.load_state_dict(state_dict)
|
79 |
+
|
80 |
+
# example only support text-to-motion
|
81 |
+
if cfg.example is not None and not is_controlnet:
|
82 |
+
text, length = load_example_input(cfg.example)
|
83 |
+
for t, l in zip(text, length):
|
84 |
+
logger.info(f"{l}: {t}")
|
85 |
+
|
86 |
+
batch = {"length": length, "text": text}
|
87 |
+
|
88 |
+
for rep_i in range(cfg.replication):
|
89 |
+
with torch.no_grad():
|
90 |
+
joints, _ = model(batch)
|
91 |
+
|
92 |
+
num_samples = len(joints)
|
93 |
+
batch_id = 0
|
94 |
+
for i in range(num_samples):
|
95 |
+
res = dict()
|
96 |
+
pkl_path = osp.join(vis_dir, f"batch_id_{batch_id}_sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl")
|
97 |
+
res['joints'] = joints[i].detach().cpu().numpy()
|
98 |
+
res['text'] = text[i]
|
99 |
+
res['length'] = length[i]
|
100 |
+
res['hint'] = None
|
101 |
+
with open(pkl_path, 'wb') as f:
|
102 |
+
pickle.dump(res, f)
|
103 |
+
logger.info(f"Motions are generated here:\n{pkl_path}")
|
104 |
+
|
105 |
+
if not cfg.no_plot:
|
106 |
+
plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(), text[i], fps=20)
|
107 |
+
|
108 |
+
else:
|
109 |
+
test_dataloader = datasets.test_dataloader()
|
110 |
+
for rep_i in range(cfg.replication):
|
111 |
+
for batch_id, batch in enumerate(test_dataloader):
|
112 |
+
batch = move_batch_to_device(batch, device)
|
113 |
+
with torch.no_grad():
|
114 |
+
joints, joints_ref = model(batch)
|
115 |
+
|
116 |
+
num_samples = len(joints)
|
117 |
+
text = batch['text']
|
118 |
+
length = batch['length']
|
119 |
+
if 'hint' in batch:
|
120 |
+
hint = batch['hint']
|
121 |
+
mask_hint = hint.view(hint.shape[0], hint.shape[1], model.njoints, 3).sum(dim=-1, keepdim=True) != 0
|
122 |
+
hint = model.datamodule.denorm_spatial(hint)
|
123 |
+
hint = hint.view(hint.shape[0], hint.shape[1], model.njoints, 3) * mask_hint
|
124 |
+
hint = remove_padding(hint, lengths=length)
|
125 |
+
else:
|
126 |
+
hint = None
|
127 |
+
|
128 |
+
for i in range(num_samples):
|
129 |
+
res = dict()
|
130 |
+
pkl_path = osp.join(vis_dir, f"batch_id_{batch_id}_sample_id_{i}_length_{length[i]}_rep_{rep_i}.pkl")
|
131 |
+
res['joints'] = joints[i].detach().cpu().numpy()
|
132 |
+
res['text'] = text[i]
|
133 |
+
res['length'] = length[i]
|
134 |
+
res['hint'] = hint[i].detach().cpu().numpy() if hint is not None else None
|
135 |
+
with open(pkl_path, 'wb') as f:
|
136 |
+
pickle.dump(res, f)
|
137 |
+
logger.info(f"Motions are generated here:\n{pkl_path}")
|
138 |
+
|
139 |
+
if not cfg.no_plot:
|
140 |
+
plot_3d_motion(pkl_path.replace('.pkl', '.mp4'), joints[i].detach().cpu().numpy(),
|
141 |
+
text[i], fps=20, hint=hint[i].detach().cpu().numpy() if hint is not None else None)
|
142 |
+
|
143 |
+
if rep_i == 0:
|
144 |
+
res['joints'] = joints_ref[i].detach().cpu().numpy()
|
145 |
+
with open(pkl_path.replace('.pkl', '_ref.pkl'), 'wb') as f:
|
146 |
+
pickle.dump(res, f)
|
147 |
+
logger.info(f"Motions are generated here:\n{pkl_path.replace('.pkl', '_ref.pkl')}")
|
148 |
+
if not cfg.no_plot:
|
149 |
+
plot_3d_motion(pkl_path.replace('.pkl', '_ref.mp4'), joints_ref[i].detach().cpu().numpy(),
|
150 |
+
text[i], fps=20, hint=hint[i].detach().cpu().numpy() if hint is not None else None)
|
151 |
+
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
main()
|
fit.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# borrow from optimization https://github.com/wangsen1312/joints2smpl
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
import h5py
|
7 |
+
import natsort
|
8 |
+
import smplx
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from mld.transforms.joints2rots import config
|
13 |
+
from mld.transforms.joints2rots.smplify import SMPLify3D
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument("--pkl", type=str, default=None, help="pkl motion file")
|
17 |
+
parser.add_argument("--dir", type=str, default=None, help="pkl motion folder")
|
18 |
+
parser.add_argument("--num_smplify_iters", type=int, default=150, help="num of smplify iters")
|
19 |
+
parser.add_argument("--cuda", type=bool, default=True, help="enables cuda")
|
20 |
+
parser.add_argument("--gpu_ids", type=int, default=0, help="choose gpu ids")
|
21 |
+
parser.add_argument("--num_joints", type=int, default=22, help="joint number")
|
22 |
+
parser.add_argument("--joint_category", type=str, default="AMASS", help="use correspondence")
|
23 |
+
parser.add_argument("--fix_foot", type=str, default="False", help="fix foot or not")
|
24 |
+
opt = parser.parse_args()
|
25 |
+
print(opt)
|
26 |
+
|
27 |
+
if opt.pkl:
|
28 |
+
paths = [opt.pkl]
|
29 |
+
elif opt.dir:
|
30 |
+
paths = []
|
31 |
+
file_list = natsort.natsorted(os.listdir(opt.dir))
|
32 |
+
for item in file_list:
|
33 |
+
if item.endswith('.pkl') and not item.endswith("_mesh.pkl"):
|
34 |
+
paths.append(os.path.join(opt.dir, item))
|
35 |
+
else:
|
36 |
+
raise ValueError(f'{opt.pkl} and {opt.dir} are both None!')
|
37 |
+
|
38 |
+
for path in paths:
|
39 |
+
# load joints
|
40 |
+
if os.path.exists(path.replace('.pkl', '_mesh.pkl')):
|
41 |
+
print(f"{path} is rendered! skip!")
|
42 |
+
continue
|
43 |
+
|
44 |
+
with open(path, 'rb') as f:
|
45 |
+
data = pickle.load(f)
|
46 |
+
|
47 |
+
joints = data['joints']
|
48 |
+
# load predefined something
|
49 |
+
device = torch.device("cuda:" + str(opt.gpu_ids) if opt.cuda else "cpu")
|
50 |
+
print(config.SMPL_MODEL_DIR)
|
51 |
+
smplxmodel = smplx.create(
|
52 |
+
config.SMPL_MODEL_DIR,
|
53 |
+
model_type="smpl",
|
54 |
+
gender="neutral",
|
55 |
+
ext="pkl",
|
56 |
+
batch_size=joints.shape[0],
|
57 |
+
).to(device)
|
58 |
+
|
59 |
+
# load the mean pose as original
|
60 |
+
smpl_mean_file = config.SMPL_MEAN_FILE
|
61 |
+
|
62 |
+
file = h5py.File(smpl_mean_file, "r")
|
63 |
+
init_mean_pose = (
|
64 |
+
torch.from_numpy(file["pose"][:])
|
65 |
+
.unsqueeze(0).repeat(joints.shape[0], 1)
|
66 |
+
.float()
|
67 |
+
.to(device)
|
68 |
+
)
|
69 |
+
init_mean_shape = (
|
70 |
+
torch.from_numpy(file["shape"][:])
|
71 |
+
.unsqueeze(0).repeat(joints.shape[0], 1)
|
72 |
+
.float()
|
73 |
+
.to(device)
|
74 |
+
)
|
75 |
+
cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device)
|
76 |
+
|
77 |
+
# initialize SMPLify
|
78 |
+
smplify = SMPLify3D(
|
79 |
+
smplxmodel=smplxmodel,
|
80 |
+
batch_size=joints.shape[0],
|
81 |
+
joints_category=opt.joint_category,
|
82 |
+
num_iters=opt.num_smplify_iters,
|
83 |
+
device=device,
|
84 |
+
)
|
85 |
+
print("initialize SMPLify3D done!")
|
86 |
+
|
87 |
+
print("Start SMPLify!")
|
88 |
+
keypoints_3d = torch.Tensor(joints).to(device).float()
|
89 |
+
|
90 |
+
if opt.joint_category == "AMASS":
|
91 |
+
confidence_input = torch.ones(opt.num_joints)
|
92 |
+
# make sure the foot and ankle
|
93 |
+
if opt.fix_foot:
|
94 |
+
confidence_input[7] = 1.5
|
95 |
+
confidence_input[8] = 1.5
|
96 |
+
confidence_input[10] = 1.5
|
97 |
+
confidence_input[11] = 1.5
|
98 |
+
else:
|
99 |
+
print("Such category not settle down!")
|
100 |
+
|
101 |
+
# ----- from initial to fitting -------
|
102 |
+
(
|
103 |
+
new_opt_vertices,
|
104 |
+
new_opt_joints,
|
105 |
+
new_opt_pose,
|
106 |
+
new_opt_betas,
|
107 |
+
new_opt_cam_t,
|
108 |
+
new_opt_joint_loss,
|
109 |
+
) = smplify(
|
110 |
+
init_mean_pose.detach(),
|
111 |
+
init_mean_shape.detach(),
|
112 |
+
cam_trans_zero.detach(),
|
113 |
+
keypoints_3d,
|
114 |
+
conf_3d=confidence_input.to(device)
|
115 |
+
)
|
116 |
+
|
117 |
+
# fix shape
|
118 |
+
betas = torch.zeros_like(new_opt_betas)
|
119 |
+
root = keypoints_3d[:, 0, :]
|
120 |
+
|
121 |
+
output = smplxmodel(
|
122 |
+
betas=betas,
|
123 |
+
global_orient=new_opt_pose[:, :3],
|
124 |
+
body_pose=new_opt_pose[:, 3:],
|
125 |
+
transl=root,
|
126 |
+
return_verts=True,
|
127 |
+
)
|
128 |
+
vertices = output.vertices.detach().cpu().numpy()
|
129 |
+
data['vertices'] = vertices
|
130 |
+
|
131 |
+
save_file = path.replace('.pkl', '_mesh.pkl')
|
132 |
+
with open(save_file, 'wb') as f:
|
133 |
+
pickle.dump(data, f)
|
134 |
+
print(f'vertices saved in {save_file}')
|
mld/__init__.py
ADDED
File without changes
|
mld/config.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import importlib
|
3 |
+
from typing import Type, TypeVar
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
|
6 |
+
from omegaconf import OmegaConf, DictConfig
|
7 |
+
|
8 |
+
|
9 |
+
def get_module_config(cfg_model: DictConfig, path: str = "modules") -> DictConfig:
|
10 |
+
files = os.listdir(f'./configs/{path}/')
|
11 |
+
for file in files:
|
12 |
+
if file.endswith('.yaml'):
|
13 |
+
with open(f'./configs/{path}/' + file, 'r') as f:
|
14 |
+
cfg_model.merge_with(OmegaConf.load(f))
|
15 |
+
return cfg_model
|
16 |
+
|
17 |
+
|
18 |
+
def get_obj_from_str(string: str, reload: bool = False) -> Type:
|
19 |
+
module, cls = string.rsplit(".", 1)
|
20 |
+
if reload:
|
21 |
+
module_imp = importlib.import_module(module)
|
22 |
+
importlib.reload(module_imp)
|
23 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
24 |
+
|
25 |
+
|
26 |
+
def instantiate_from_config(config: DictConfig) -> TypeVar:
|
27 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
28 |
+
|
29 |
+
|
30 |
+
def parse_args() -> DictConfig:
|
31 |
+
parser = ArgumentParser()
|
32 |
+
parser.add_argument("--cfg", type=str, required=True, help="config file")
|
33 |
+
|
34 |
+
# Demo Args
|
35 |
+
parser.add_argument('--example', type=str, required=False, help="input text and lengths with txt format")
|
36 |
+
parser.add_argument('--no-plot', action="store_true", required=False, help="whether plot the skeleton-based motion")
|
37 |
+
parser.add_argument('--replication', type=int, default=1, help="the number of replication of sampling")
|
38 |
+
args = parser.parse_args()
|
39 |
+
|
40 |
+
cfg = OmegaConf.load(args.cfg)
|
41 |
+
cfg_model = get_module_config(cfg.model, cfg.model.target)
|
42 |
+
cfg = OmegaConf.merge(cfg, cfg_model)
|
43 |
+
|
44 |
+
cfg.example = args.example
|
45 |
+
cfg.no_plot = args.no_plot
|
46 |
+
cfg.replication = args.replication
|
47 |
+
return cfg
|
mld/data/HumanML3D.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Callable, Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from .base import BASEDataModule
|
10 |
+
from .humanml.dataset import Text2MotionDatasetV2
|
11 |
+
from .humanml.scripts.motion_process import recover_from_ric
|
12 |
+
|
13 |
+
|
14 |
+
class HumanML3DDataModule(BASEDataModule):
|
15 |
+
|
16 |
+
def __init__(self,
|
17 |
+
cfg: DictConfig,
|
18 |
+
batch_size: int,
|
19 |
+
num_workers: int,
|
20 |
+
collate_fn: Optional[Callable] = None,
|
21 |
+
persistent_workers: bool = True,
|
22 |
+
phase: str = "train",
|
23 |
+
**kwargs) -> None:
|
24 |
+
super().__init__(batch_size=batch_size,
|
25 |
+
num_workers=num_workers,
|
26 |
+
collate_fn=collate_fn,
|
27 |
+
persistent_workers=persistent_workers)
|
28 |
+
self.hparams = copy.deepcopy(kwargs)
|
29 |
+
self.name = "humanml3d"
|
30 |
+
self.njoints = 22
|
31 |
+
if phase == "text_only":
|
32 |
+
raise NotImplementedError
|
33 |
+
else:
|
34 |
+
self.Dataset = Text2MotionDatasetV2
|
35 |
+
self.cfg = cfg
|
36 |
+
|
37 |
+
sample_overrides = {"tiny": True, "progress_bar": False}
|
38 |
+
self._sample_set = self.get_sample_set(overrides=sample_overrides)
|
39 |
+
self.nfeats = self._sample_set.nfeats
|
40 |
+
|
41 |
+
def denorm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
|
42 |
+
raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
|
43 |
+
raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
|
44 |
+
hint = hint * raw_std + raw_mean
|
45 |
+
return hint
|
46 |
+
|
47 |
+
def norm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
|
48 |
+
raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
|
49 |
+
raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
|
50 |
+
hint = (hint - raw_mean) / raw_std
|
51 |
+
return hint
|
52 |
+
|
53 |
+
def feats2joints(self, features: torch.Tensor) -> torch.Tensor:
|
54 |
+
mean = torch.tensor(self.hparams['mean']).to(features)
|
55 |
+
std = torch.tensor(self.hparams['std']).to(features)
|
56 |
+
features = features * std + mean
|
57 |
+
return recover_from_ric(features, self.njoints)
|
58 |
+
|
59 |
+
def renorm4t2m(self, features: torch.Tensor) -> torch.Tensor:
|
60 |
+
# renorm to t2m norms for using t2m evaluators
|
61 |
+
ori_mean = torch.tensor(self.hparams['mean']).to(features)
|
62 |
+
ori_std = torch.tensor(self.hparams['std']).to(features)
|
63 |
+
eval_mean = torch.tensor(self.hparams['mean_eval']).to(features)
|
64 |
+
eval_std = torch.tensor(self.hparams['std_eval']).to(features)
|
65 |
+
features = features * ori_std + ori_mean
|
66 |
+
features = (features - eval_mean) / eval_std
|
67 |
+
return features
|
68 |
+
|
69 |
+
def mm_mode(self, mm_on: bool = True) -> None:
|
70 |
+
if mm_on:
|
71 |
+
self.is_mm = True
|
72 |
+
self.name_list = self.test_dataset.name_list
|
73 |
+
self.mm_list = np.random.choice(self.name_list,
|
74 |
+
self.cfg.TEST.MM_NUM_SAMPLES,
|
75 |
+
replace=False)
|
76 |
+
self.test_dataset.name_list = self.mm_list
|
77 |
+
else:
|
78 |
+
self.is_mm = False
|
79 |
+
self.test_dataset.name_list = self.name_list
|
mld/data/Kit.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Callable, Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from .base import BASEDataModule
|
10 |
+
from .humanml.dataset import Text2MotionDatasetV2
|
11 |
+
from .humanml.scripts.motion_process import recover_from_ric
|
12 |
+
|
13 |
+
|
14 |
+
class KitDataModule(BASEDataModule):
|
15 |
+
|
16 |
+
def __init__(self,
|
17 |
+
cfg: DictConfig,
|
18 |
+
batch_size: int,
|
19 |
+
num_workers: int,
|
20 |
+
collate_fn: Optional[Callable] = None,
|
21 |
+
persistent_workers: bool = True,
|
22 |
+
phase: str = "train",
|
23 |
+
**kwargs) -> None:
|
24 |
+
super().__init__(batch_size=batch_size,
|
25 |
+
num_workers=num_workers,
|
26 |
+
collate_fn=collate_fn,
|
27 |
+
persistent_workers=persistent_workers)
|
28 |
+
self.hparams = copy.deepcopy(kwargs)
|
29 |
+
self.name = 'kit'
|
30 |
+
self.njoints = 21
|
31 |
+
if phase == 'text_only':
|
32 |
+
raise NotImplementedError
|
33 |
+
else:
|
34 |
+
self.Dataset = Text2MotionDatasetV2
|
35 |
+
self.cfg = cfg
|
36 |
+
|
37 |
+
sample_overrides = {"tiny": True, "progress_bar": False}
|
38 |
+
self._sample_set = self.get_sample_set(overrides=sample_overrides)
|
39 |
+
self.nfeats = self._sample_set.nfeats
|
40 |
+
|
41 |
+
def denorm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
|
42 |
+
raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
|
43 |
+
raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
|
44 |
+
hint = hint * raw_std + raw_mean
|
45 |
+
return hint
|
46 |
+
|
47 |
+
def norm_spatial(self, hint: torch.Tensor) -> torch.Tensor:
|
48 |
+
raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint)
|
49 |
+
raw_std = torch.tensor(self._sample_set.raw_std).to(hint)
|
50 |
+
hint = (hint - raw_mean) / raw_std
|
51 |
+
return hint
|
52 |
+
|
53 |
+
def feats2joints(self, features: torch.Tensor) -> torch.Tensor:
|
54 |
+
mean = torch.tensor(self.hparams['mean']).to(features)
|
55 |
+
std = torch.tensor(self.hparams['std']).to(features)
|
56 |
+
features = features * std + mean
|
57 |
+
return recover_from_ric(features, self.njoints)
|
58 |
+
|
59 |
+
def renorm4t2m(self, features: torch.Tensor) -> torch.Tensor:
|
60 |
+
# renorm to t2m norms for using t2m evaluators
|
61 |
+
ori_mean = torch.tensor(self.hparams['mean']).to(features)
|
62 |
+
ori_std = torch.tensor(self.hparams['std']).to(features)
|
63 |
+
eval_mean = torch.tensor(self.hparams['mean_eval']).to(features)
|
64 |
+
eval_std = torch.tensor(self.hparams['std_eval']).to(features)
|
65 |
+
features = features * ori_std + ori_mean
|
66 |
+
features = (features - eval_mean) / eval_std
|
67 |
+
return features
|
68 |
+
|
69 |
+
def mm_mode(self, mm_on: bool = True) -> None:
|
70 |
+
if mm_on:
|
71 |
+
self.is_mm = True
|
72 |
+
self.name_list = self.test_dataset.name_list
|
73 |
+
self.mm_list = np.random.choice(self.name_list,
|
74 |
+
self.cfg.TEST.MM_NUM_SAMPLES,
|
75 |
+
replace=False)
|
76 |
+
self.test_dataset.name_list = self.mm_list
|
77 |
+
else:
|
78 |
+
self.is_mm = False
|
79 |
+
self.test_dataset.name_list = self.name_list
|
mld/data/__init__.py
ADDED
File without changes
|
mld/data/base.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from os.path import join as pjoin
|
3 |
+
from typing import Any, Callable
|
4 |
+
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
from .humanml.dataset import Text2MotionDatasetV2
|
8 |
+
|
9 |
+
|
10 |
+
class BASEDataModule:
|
11 |
+
def __init__(self, collate_fn: Callable, batch_size: int,
|
12 |
+
num_workers: int, persistent_workers: bool) -> None:
|
13 |
+
super(BASEDataModule, self).__init__()
|
14 |
+
self.dataloader_options = {
|
15 |
+
"batch_size": batch_size,
|
16 |
+
"num_workers": num_workers,
|
17 |
+
"collate_fn": collate_fn,
|
18 |
+
"persistent_workers": persistent_workers
|
19 |
+
}
|
20 |
+
self.is_mm = False
|
21 |
+
|
22 |
+
def get_sample_set(self, overrides: dict) -> Text2MotionDatasetV2:
|
23 |
+
sample_params = copy.deepcopy(self.hparams)
|
24 |
+
sample_params.update(overrides)
|
25 |
+
split_file = pjoin(
|
26 |
+
eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"),
|
27 |
+
self.cfg.EVAL.SPLIT + ".txt",
|
28 |
+
)
|
29 |
+
return self.Dataset(split_file=split_file, **sample_params)
|
30 |
+
|
31 |
+
def __getattr__(self, item: str) -> Any:
|
32 |
+
if item.endswith("_dataset") and not item.startswith("_"):
|
33 |
+
subset = item[:-len("_dataset")]
|
34 |
+
item_c = "_" + item
|
35 |
+
if item_c not in self.__dict__:
|
36 |
+
|
37 |
+
subset = subset.upper() if subset != "val" else "EVAL"
|
38 |
+
split = eval(f"self.cfg.{subset}.SPLIT")
|
39 |
+
split_file = pjoin(
|
40 |
+
eval(f"self.cfg.DATASET.{self.name.upper()}.SPLIT_ROOT"),
|
41 |
+
eval(f"self.cfg.{subset}.SPLIT") + ".txt",
|
42 |
+
)
|
43 |
+
self.__dict__[item_c] = self.Dataset(split_file=split_file,
|
44 |
+
split=split,
|
45 |
+
**self.hparams)
|
46 |
+
return getattr(self, item_c)
|
47 |
+
classname = self.__class__.__name__
|
48 |
+
raise AttributeError(f"'{classname}' object has no attribute '{item}'")
|
49 |
+
|
50 |
+
def train_dataloader(self) -> DataLoader:
|
51 |
+
return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_options)
|
52 |
+
|
53 |
+
def val_dataloader(self) -> DataLoader:
|
54 |
+
dataloader_options = self.dataloader_options.copy()
|
55 |
+
dataloader_options["batch_size"] = self.cfg.EVAL.BATCH_SIZE
|
56 |
+
dataloader_options["num_workers"] = self.cfg.EVAL.NUM_WORKERS
|
57 |
+
dataloader_options["shuffle"] = False
|
58 |
+
return DataLoader(self.val_dataset, **dataloader_options)
|
59 |
+
|
60 |
+
def test_dataloader(self) -> DataLoader:
|
61 |
+
dataloader_options = self.dataloader_options.copy()
|
62 |
+
dataloader_options["batch_size"] = 1 if self.is_mm else self.cfg.TEST.BATCH_SIZE
|
63 |
+
dataloader_options["num_workers"] = self.cfg.TEST.NUM_WORKERS
|
64 |
+
dataloader_options["shuffle"] = False
|
65 |
+
return DataLoader(self.test_dataset, **dataloader_options)
|
mld/data/get_data.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import join as pjoin
|
2 |
+
from typing import Callable, Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from omegaconf import DictConfig
|
7 |
+
|
8 |
+
from .humanml.utils.word_vectorizer import WordVectorizer
|
9 |
+
from .HumanML3D import HumanML3DDataModule
|
10 |
+
from .Kit import KitDataModule
|
11 |
+
from .base import BASEDataModule
|
12 |
+
from .utils import mld_collate
|
13 |
+
|
14 |
+
|
15 |
+
def get_mean_std(phase: str, cfg: DictConfig, dataset_name: str) -> tuple[np.ndarray, np.ndarray]:
|
16 |
+
name = "t2m" if dataset_name == "humanml3d" else dataset_name
|
17 |
+
assert name in ["t2m", "kit"]
|
18 |
+
if phase in ["val"]:
|
19 |
+
if name == 't2m':
|
20 |
+
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD01", "meta")
|
21 |
+
elif name == 'kit':
|
22 |
+
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD005", "meta")
|
23 |
+
else:
|
24 |
+
raise ValueError("Only support t2m and kit")
|
25 |
+
mean = np.load(pjoin(data_root, "mean.npy"))
|
26 |
+
std = np.load(pjoin(data_root, "std.npy"))
|
27 |
+
else:
|
28 |
+
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
|
29 |
+
mean = np.load(pjoin(data_root, "Mean.npy"))
|
30 |
+
std = np.load(pjoin(data_root, "Std.npy"))
|
31 |
+
|
32 |
+
return mean, std
|
33 |
+
|
34 |
+
|
35 |
+
def get_WordVectorizer(cfg: DictConfig, phase: str, dataset_name: str) -> Optional[WordVectorizer]:
|
36 |
+
if phase not in ["text_only"]:
|
37 |
+
if dataset_name.lower() in ["humanml3d", "kit"]:
|
38 |
+
return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab")
|
39 |
+
else:
|
40 |
+
raise ValueError("Only support WordVectorizer for HumanML3D")
|
41 |
+
else:
|
42 |
+
return None
|
43 |
+
|
44 |
+
|
45 |
+
def get_collate_fn(name: str) -> Callable:
|
46 |
+
if name.lower() in ["humanml3d", "kit"]:
|
47 |
+
return mld_collate
|
48 |
+
else:
|
49 |
+
raise NotImplementedError
|
50 |
+
|
51 |
+
|
52 |
+
dataset_module_map = {"humanml3d": HumanML3DDataModule, "kit": KitDataModule}
|
53 |
+
motion_subdir = {"humanml3d": "new_joint_vecs", "kit": "new_joint_vecs"}
|
54 |
+
|
55 |
+
|
56 |
+
def get_datasets(cfg: DictConfig, phase: str = "train") -> list[BASEDataModule]:
|
57 |
+
dataset_names = eval(f"cfg.{phase.upper()}.DATASETS")
|
58 |
+
datasets = []
|
59 |
+
for dataset_name in dataset_names:
|
60 |
+
if dataset_name.lower() in ["humanml3d", "kit"]:
|
61 |
+
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
|
62 |
+
mean, std = get_mean_std(phase, cfg, dataset_name)
|
63 |
+
mean_eval, std_eval = get_mean_std("val", cfg, dataset_name)
|
64 |
+
wordVectorizer = get_WordVectorizer(cfg, phase, dataset_name)
|
65 |
+
collate_fn = get_collate_fn(dataset_name)
|
66 |
+
dataset = dataset_module_map[dataset_name.lower()](
|
67 |
+
cfg=cfg,
|
68 |
+
batch_size=cfg.TRAIN.BATCH_SIZE,
|
69 |
+
num_workers=cfg.TRAIN.NUM_WORKERS,
|
70 |
+
collate_fn=collate_fn,
|
71 |
+
persistent_workers=cfg.TRAIN.PERSISTENT_WORKERS,
|
72 |
+
mean=mean,
|
73 |
+
std=std,
|
74 |
+
mean_eval=mean_eval,
|
75 |
+
std_eval=std_eval,
|
76 |
+
w_vectorizer=wordVectorizer,
|
77 |
+
text_dir=pjoin(data_root, "texts"),
|
78 |
+
motion_dir=pjoin(data_root, motion_subdir[dataset_name]),
|
79 |
+
max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN,
|
80 |
+
min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN,
|
81 |
+
max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN,
|
82 |
+
unit_length=eval(
|
83 |
+
f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"),
|
84 |
+
model_kwargs=cfg.model
|
85 |
+
)
|
86 |
+
datasets.append(dataset)
|
87 |
+
|
88 |
+
elif dataset_name.lower() in ["humanact12", 'uestc', "amass"]:
|
89 |
+
raise NotImplementedError
|
90 |
+
|
91 |
+
cfg.DATASET.NFEATS = datasets[0].nfeats
|
92 |
+
cfg.DATASET.NJOINTS = datasets[0].njoints
|
93 |
+
return datasets
|
mld/data/humanml/__init__.py
ADDED
File without changes
|
mld/data/humanml/common/quaternion.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def qinv(q: torch.Tensor) -> torch.Tensor:
|
5 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
6 |
+
mask = torch.ones_like(q)
|
7 |
+
mask[..., 1:] = -mask[..., 1:]
|
8 |
+
return q * mask
|
9 |
+
|
10 |
+
|
11 |
+
def qrot(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
12 |
+
"""
|
13 |
+
Rotate vector(s) v about the rotation described by quaternion(s) q.
|
14 |
+
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
|
15 |
+
where * denotes any number of dimensions.
|
16 |
+
Returns a tensor of shape (*, 3).
|
17 |
+
"""
|
18 |
+
assert q.shape[-1] == 4
|
19 |
+
assert v.shape[-1] == 3
|
20 |
+
assert q.shape[:-1] == v.shape[:-1]
|
21 |
+
|
22 |
+
original_shape = list(v.shape)
|
23 |
+
q = q.contiguous().view(-1, 4)
|
24 |
+
v = v.contiguous().view(-1, 3)
|
25 |
+
|
26 |
+
qvec = q[:, 1:]
|
27 |
+
uv = torch.cross(qvec, v, dim=1)
|
28 |
+
uuv = torch.cross(qvec, uv, dim=1)
|
29 |
+
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
|
mld/data/humanml/dataset.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import codecs as cs
|
2 |
+
import random
|
3 |
+
from os.path import join as pjoin
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from rich.progress import track
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch.utils import data
|
10 |
+
|
11 |
+
from mld.data.humanml.scripts.motion_process import recover_from_ric
|
12 |
+
from .utils.word_vectorizer import WordVectorizer
|
13 |
+
|
14 |
+
|
15 |
+
class Text2MotionDatasetV2(data.Dataset):
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
mean: np.ndarray,
|
20 |
+
std: np.ndarray,
|
21 |
+
split_file: str,
|
22 |
+
w_vectorizer: WordVectorizer,
|
23 |
+
max_motion_length: int,
|
24 |
+
min_motion_length: int,
|
25 |
+
max_text_len: int,
|
26 |
+
unit_length: int,
|
27 |
+
motion_dir: str,
|
28 |
+
text_dir: str,
|
29 |
+
tiny: bool = False,
|
30 |
+
progress_bar: bool = True,
|
31 |
+
**kwargs,
|
32 |
+
) -> None:
|
33 |
+
self.w_vectorizer = w_vectorizer
|
34 |
+
self.max_motion_length = max_motion_length
|
35 |
+
self.min_motion_length = min_motion_length
|
36 |
+
self.max_text_len = max_text_len
|
37 |
+
self.unit_length = unit_length
|
38 |
+
|
39 |
+
data_dict = {}
|
40 |
+
id_list = []
|
41 |
+
with cs.open(split_file, "r") as f:
|
42 |
+
for line in f.readlines():
|
43 |
+
id_list.append(line.strip())
|
44 |
+
self.id_list = id_list
|
45 |
+
|
46 |
+
if tiny:
|
47 |
+
progress_bar = False
|
48 |
+
maxdata = 10
|
49 |
+
else:
|
50 |
+
maxdata = 1e10
|
51 |
+
|
52 |
+
if progress_bar:
|
53 |
+
enumerator = enumerate(
|
54 |
+
track(
|
55 |
+
id_list,
|
56 |
+
f"Loading HumanML3D {split_file.split('/')[-1].split('.')[0]}",
|
57 |
+
))
|
58 |
+
else:
|
59 |
+
enumerator = enumerate(id_list)
|
60 |
+
count = 0
|
61 |
+
bad_count = 0
|
62 |
+
new_name_list = []
|
63 |
+
length_list = []
|
64 |
+
for i, name in enumerator:
|
65 |
+
if count > maxdata:
|
66 |
+
break
|
67 |
+
try:
|
68 |
+
motion = np.load(pjoin(motion_dir, name + ".npy"))
|
69 |
+
if (len(motion)) < self.min_motion_length or (len(motion) >= 200):
|
70 |
+
bad_count += 1
|
71 |
+
continue
|
72 |
+
text_data = []
|
73 |
+
flag = False
|
74 |
+
with cs.open(pjoin(text_dir, name + ".txt")) as f:
|
75 |
+
for line in f.readlines():
|
76 |
+
text_dict = {}
|
77 |
+
line_split = line.strip().split("#")
|
78 |
+
caption = line_split[0]
|
79 |
+
tokens = line_split[1].split(" ")
|
80 |
+
f_tag = float(line_split[2])
|
81 |
+
to_tag = float(line_split[3])
|
82 |
+
f_tag = 0.0 if np.isnan(f_tag) else f_tag
|
83 |
+
to_tag = 0.0 if np.isnan(to_tag) else to_tag
|
84 |
+
|
85 |
+
text_dict["caption"] = caption
|
86 |
+
text_dict["tokens"] = tokens
|
87 |
+
if f_tag == 0.0 and to_tag == 0.0:
|
88 |
+
flag = True
|
89 |
+
text_data.append(text_dict)
|
90 |
+
else:
|
91 |
+
try:
|
92 |
+
n_motion = motion[int(f_tag * 20):int(to_tag *
|
93 |
+
20)]
|
94 |
+
if (len(n_motion)
|
95 |
+
) < self.min_motion_length or (
|
96 |
+
(len(n_motion) >= 200)):
|
97 |
+
continue
|
98 |
+
new_name = (
|
99 |
+
random.choice("ABCDEFGHIJKLMNOPQRSTUVW") +
|
100 |
+
"_" + name)
|
101 |
+
while new_name in data_dict:
|
102 |
+
new_name = (random.choice(
|
103 |
+
"ABCDEFGHIJKLMNOPQRSTUVW") + "_" +
|
104 |
+
name)
|
105 |
+
data_dict[new_name] = {
|
106 |
+
"motion": n_motion,
|
107 |
+
"length": len(n_motion),
|
108 |
+
"text": [text_dict],
|
109 |
+
}
|
110 |
+
new_name_list.append(new_name)
|
111 |
+
length_list.append(len(n_motion))
|
112 |
+
except:
|
113 |
+
print(line_split)
|
114 |
+
print(line_split[2], line_split[3], f_tag, to_tag, name)
|
115 |
+
|
116 |
+
if flag:
|
117 |
+
data_dict[name] = {
|
118 |
+
"motion": motion,
|
119 |
+
"length": len(motion),
|
120 |
+
"text": text_data,
|
121 |
+
}
|
122 |
+
new_name_list.append(name)
|
123 |
+
length_list.append(len(motion))
|
124 |
+
count += 1
|
125 |
+
except:
|
126 |
+
pass
|
127 |
+
|
128 |
+
name_list, length_list = zip(
|
129 |
+
*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
|
130 |
+
|
131 |
+
self.mean = mean
|
132 |
+
self.std = std
|
133 |
+
|
134 |
+
self.mode = None
|
135 |
+
model_params = kwargs['model_kwargs']
|
136 |
+
if 'is_controlnet' in model_params and model_params.is_controlnet is True:
|
137 |
+
if 'test' in split_file or 'val' in split_file:
|
138 |
+
self.mode = 'eval'
|
139 |
+
else:
|
140 |
+
self.mode = 'train'
|
141 |
+
|
142 |
+
self.t_ctrl = model_params.is_controlnet_temporal
|
143 |
+
spatial_norm_path = './datasets/humanml_spatial_norm'
|
144 |
+
self.raw_mean = np.load(pjoin(spatial_norm_path, 'Mean_raw.npy'))
|
145 |
+
self.raw_std = np.load(pjoin(spatial_norm_path, 'Std_raw.npy'))
|
146 |
+
|
147 |
+
self.training_control_joint = np.array(model_params.training_control_joint)
|
148 |
+
self.testing_control_joint = np.array(model_params.testing_control_joint)
|
149 |
+
|
150 |
+
self.training_density = model_params.training_density
|
151 |
+
self.testing_density = model_params.testing_density
|
152 |
+
|
153 |
+
self.length_arr = np.array(length_list)
|
154 |
+
self.data_dict = data_dict
|
155 |
+
self.nfeats = motion.shape[1]
|
156 |
+
self.name_list = name_list
|
157 |
+
|
158 |
+
def __len__(self) -> int:
|
159 |
+
return len(self.name_list)
|
160 |
+
|
161 |
+
def random_mask(self, joints: np.ndarray, n_joints: int = 22) -> np.ndarray:
|
162 |
+
choose_joint = self.testing_control_joint
|
163 |
+
|
164 |
+
length = joints.shape[0]
|
165 |
+
density = self.testing_density
|
166 |
+
if density in [1, 2, 5]:
|
167 |
+
choose_seq_num = density
|
168 |
+
else:
|
169 |
+
choose_seq_num = int(length * density / 100)
|
170 |
+
|
171 |
+
if self.t_ctrl:
|
172 |
+
choose_seq = np.arange(0, choose_seq_num)
|
173 |
+
else:
|
174 |
+
choose_seq = np.random.choice(length, choose_seq_num, replace=False)
|
175 |
+
choose_seq.sort()
|
176 |
+
|
177 |
+
mask_seq = np.zeros((length, n_joints, 3)).astype(bool)
|
178 |
+
|
179 |
+
for cj in choose_joint:
|
180 |
+
mask_seq[choose_seq, cj] = True
|
181 |
+
|
182 |
+
# normalize
|
183 |
+
joints = (joints - self.raw_mean.reshape(n_joints, 3)) / self.raw_std.reshape(n_joints, 3)
|
184 |
+
joints = joints * mask_seq
|
185 |
+
return joints
|
186 |
+
|
187 |
+
def random_mask_train(self, joints: np.ndarray, n_joints: int = 22) -> np.ndarray:
|
188 |
+
if self.t_ctrl:
|
189 |
+
choose_joint = self.training_control_joint
|
190 |
+
else:
|
191 |
+
num_joints = len(self.training_control_joint)
|
192 |
+
num_joints_control = 1
|
193 |
+
choose_joint = np.random.choice(num_joints, num_joints_control, replace=False)
|
194 |
+
choose_joint = self.training_control_joint[choose_joint]
|
195 |
+
|
196 |
+
length = joints.shape[0]
|
197 |
+
|
198 |
+
if self.training_density == 'random':
|
199 |
+
choose_seq_num = np.random.choice(length - 1, 1) + 1
|
200 |
+
else:
|
201 |
+
choose_seq_num = int(length * random.uniform(self.training_density[0], self.training_density[1]) / 100)
|
202 |
+
|
203 |
+
if self.t_ctrl:
|
204 |
+
choose_seq = np.arange(0, choose_seq_num)
|
205 |
+
else:
|
206 |
+
choose_seq = np.random.choice(length, choose_seq_num, replace=False)
|
207 |
+
choose_seq.sort()
|
208 |
+
|
209 |
+
mask_seq = np.zeros((length, n_joints, 3)).astype(bool)
|
210 |
+
|
211 |
+
for cj in choose_joint:
|
212 |
+
mask_seq[choose_seq, cj] = True
|
213 |
+
|
214 |
+
# normalize
|
215 |
+
joints = (joints - self.raw_mean.reshape(n_joints, 3)) / self.raw_std.reshape(n_joints, 3)
|
216 |
+
joints = joints * mask_seq
|
217 |
+
return joints
|
218 |
+
|
219 |
+
def __getitem__(self, idx: int) -> tuple:
|
220 |
+
data = self.data_dict[self.name_list[idx]]
|
221 |
+
motion, m_length, text_list = data["motion"], data["length"], data["text"]
|
222 |
+
# Randomly select a caption
|
223 |
+
text_data = random.choice(text_list)
|
224 |
+
caption, tokens = text_data["caption"], text_data["tokens"]
|
225 |
+
|
226 |
+
if len(tokens) < self.max_text_len:
|
227 |
+
# pad with "unk"
|
228 |
+
tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
|
229 |
+
sent_len = len(tokens)
|
230 |
+
tokens = tokens + ["unk/OTHER"
|
231 |
+
] * (self.max_text_len + 2 - sent_len)
|
232 |
+
else:
|
233 |
+
# crop
|
234 |
+
tokens = tokens[:self.max_text_len]
|
235 |
+
tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
|
236 |
+
sent_len = len(tokens)
|
237 |
+
pos_one_hots = []
|
238 |
+
word_embeddings = []
|
239 |
+
for token in tokens:
|
240 |
+
word_emb, pos_oh = self.w_vectorizer[token]
|
241 |
+
pos_one_hots.append(pos_oh[None, :])
|
242 |
+
word_embeddings.append(word_emb[None, :])
|
243 |
+
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
|
244 |
+
word_embeddings = np.concatenate(word_embeddings, axis=0)
|
245 |
+
|
246 |
+
# Crop the motions in to times of 4, and introduce small variations
|
247 |
+
if self.unit_length < 10:
|
248 |
+
coin2 = np.random.choice(["single", "single", "double"])
|
249 |
+
else:
|
250 |
+
coin2 = "single"
|
251 |
+
|
252 |
+
if coin2 == "double":
|
253 |
+
m_length = (m_length // self.unit_length - 1) * self.unit_length
|
254 |
+
elif coin2 == "single":
|
255 |
+
m_length = (m_length // self.unit_length) * self.unit_length
|
256 |
+
idx = random.randint(0, len(motion) - m_length)
|
257 |
+
motion = motion[idx:idx + m_length]
|
258 |
+
|
259 |
+
hint = None
|
260 |
+
if self.mode is not None:
|
261 |
+
n_joints = 22 if motion.shape[-1] == 263 else 21
|
262 |
+
# hint is global position of the controllable joints
|
263 |
+
joints = recover_from_ric(torch.from_numpy(motion).float(), n_joints)
|
264 |
+
joints = joints.numpy()
|
265 |
+
|
266 |
+
# control any joints at any time
|
267 |
+
if self.mode == 'train':
|
268 |
+
hint = self.random_mask_train(joints, n_joints)
|
269 |
+
else:
|
270 |
+
hint = self.random_mask(joints, n_joints)
|
271 |
+
|
272 |
+
hint = hint.reshape(hint.shape[0], -1)
|
273 |
+
|
274 |
+
"Z Normalization"
|
275 |
+
motion = (motion - self.mean) / self.std
|
276 |
+
|
277 |
+
# debug check nan
|
278 |
+
if np.any(np.isnan(motion)):
|
279 |
+
raise ValueError("nan in motion")
|
280 |
+
|
281 |
+
return (
|
282 |
+
word_embeddings,
|
283 |
+
pos_one_hots,
|
284 |
+
caption,
|
285 |
+
sent_len,
|
286 |
+
motion,
|
287 |
+
m_length,
|
288 |
+
"_".join(tokens),
|
289 |
+
hint
|
290 |
+
)
|
mld/data/humanml/scripts/motion_process.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ..common.quaternion import qinv, qrot
|
4 |
+
|
5 |
+
|
6 |
+
# Recover global angle and positions for rotation dataset
|
7 |
+
# root_rot_velocity (B, seq_len, 1)
|
8 |
+
# root_linear_velocity (B, seq_len, 2)
|
9 |
+
# root_y (B, seq_len, 1)
|
10 |
+
# ric_data (B, seq_len, (joint_num - 1)*3)
|
11 |
+
# rot_data (B, seq_len, (joint_num - 1)*6)
|
12 |
+
# local_velocity (B, seq_len, joint_num*3)
|
13 |
+
# foot contact (B, seq_len, 4)
|
14 |
+
def recover_root_rot_pos(data: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
15 |
+
rot_vel = data[..., 0]
|
16 |
+
r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
|
17 |
+
'''Get Y-axis rotation from rotation velocity'''
|
18 |
+
r_rot_ang[..., 1:] = rot_vel[..., :-1]
|
19 |
+
r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
|
20 |
+
|
21 |
+
r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
|
22 |
+
r_rot_quat[..., 0] = torch.cos(r_rot_ang)
|
23 |
+
r_rot_quat[..., 2] = torch.sin(r_rot_ang)
|
24 |
+
|
25 |
+
r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
|
26 |
+
r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
|
27 |
+
'''Add Y-axis rotation to root position'''
|
28 |
+
r_pos = qrot(qinv(r_rot_quat), r_pos)
|
29 |
+
|
30 |
+
r_pos = torch.cumsum(r_pos, dim=-2)
|
31 |
+
|
32 |
+
r_pos[..., 1] = data[..., 3]
|
33 |
+
return r_rot_quat, r_pos
|
34 |
+
|
35 |
+
|
36 |
+
def recover_from_ric(data: torch.Tensor, joints_num: int) -> torch.Tensor:
|
37 |
+
r_rot_quat, r_pos = recover_root_rot_pos(data)
|
38 |
+
positions = data[..., 4:(joints_num - 1) * 3 + 4]
|
39 |
+
positions = positions.view(positions.shape[:-1] + (-1, 3))
|
40 |
+
|
41 |
+
'''Add Y-axis rotation to local joints'''
|
42 |
+
positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
|
43 |
+
|
44 |
+
'''Add root XZ to joints'''
|
45 |
+
positions[..., 0] += r_pos[..., 0:1]
|
46 |
+
positions[..., 2] += r_pos[..., 2:3]
|
47 |
+
|
48 |
+
'''Concat root and joints'''
|
49 |
+
positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
|
50 |
+
|
51 |
+
return positions
|
mld/data/humanml/utils/__init__.py
ADDED
File without changes
|
mld/data/humanml/utils/paramUtil.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
# Define a kinematic tree for the skeletal structure
|
4 |
+
kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
|
5 |
+
|
6 |
+
kit_raw_offsets = np.array(
|
7 |
+
[
|
8 |
+
[0, 0, 0],
|
9 |
+
[0, 1, 0],
|
10 |
+
[0, 1, 0],
|
11 |
+
[0, 1, 0],
|
12 |
+
[0, 1, 0],
|
13 |
+
[1, 0, 0],
|
14 |
+
[0, -1, 0],
|
15 |
+
[0, -1, 0],
|
16 |
+
[-1, 0, 0],
|
17 |
+
[0, -1, 0],
|
18 |
+
[0, -1, 0],
|
19 |
+
[1, 0, 0],
|
20 |
+
[0, -1, 0],
|
21 |
+
[0, -1, 0],
|
22 |
+
[0, 0, 1],
|
23 |
+
[0, 0, 1],
|
24 |
+
[-1, 0, 0],
|
25 |
+
[0, -1, 0],
|
26 |
+
[0, -1, 0],
|
27 |
+
[0, 0, 1],
|
28 |
+
[0, 0, 1]
|
29 |
+
]
|
30 |
+
)
|
31 |
+
|
32 |
+
t2m_raw_offsets = np.array([[0, 0, 0],
|
33 |
+
[1, 0, 0],
|
34 |
+
[-1, 0, 0],
|
35 |
+
[0, 1, 0],
|
36 |
+
[0, -1, 0],
|
37 |
+
[0, -1, 0],
|
38 |
+
[0, 1, 0],
|
39 |
+
[0, -1, 0],
|
40 |
+
[0, -1, 0],
|
41 |
+
[0, 1, 0],
|
42 |
+
[0, 0, 1],
|
43 |
+
[0, 0, 1],
|
44 |
+
[0, 1, 0],
|
45 |
+
[1, 0, 0],
|
46 |
+
[-1, 0, 0],
|
47 |
+
[0, 0, 1],
|
48 |
+
[0, -1, 0],
|
49 |
+
[0, -1, 0],
|
50 |
+
[0, -1, 0],
|
51 |
+
[0, -1, 0],
|
52 |
+
[0, -1, 0],
|
53 |
+
[0, -1, 0]])
|
54 |
+
|
55 |
+
t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21],
|
56 |
+
[9, 13, 16, 18, 20]]
|
57 |
+
t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
|
58 |
+
t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
|
59 |
+
|
60 |
+
kit_tgt_skel_id = '03950'
|
61 |
+
|
62 |
+
t2m_tgt_skel_id = '000021'
|
mld/data/humanml/utils/plot_script.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from textwrap import wrap
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import mpl_toolkits.mplot3d.axes3d as p3
|
8 |
+
from matplotlib.animation import FuncAnimation
|
9 |
+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
10 |
+
|
11 |
+
import mld.data.humanml.utils.paramUtil as paramUtil
|
12 |
+
|
13 |
+
skeleton = paramUtil.t2m_kinematic_chain
|
14 |
+
|
15 |
+
|
16 |
+
def plot_3d_motion(save_path: str, joints: np.ndarray, title: str,
|
17 |
+
figsize: tuple[int, int] = (3, 3),
|
18 |
+
fps: int = 120, radius: int = 3, kinematic_tree: list = skeleton,
|
19 |
+
hint: Optional[np.ndarray] = None) -> None:
|
20 |
+
|
21 |
+
title = '\n'.join(wrap(title, 20))
|
22 |
+
|
23 |
+
def init():
|
24 |
+
ax.set_xlim3d([-radius / 2, radius / 2])
|
25 |
+
ax.set_ylim3d([0, radius])
|
26 |
+
ax.set_zlim3d([-radius / 3., radius * 2 / 3.])
|
27 |
+
fig.suptitle(title, fontsize=10)
|
28 |
+
ax.grid(b=False)
|
29 |
+
|
30 |
+
def plot_xzPlane(minx, maxx, miny, minz, maxz):
|
31 |
+
# Plot a plane XZ
|
32 |
+
verts = [
|
33 |
+
[minx, miny, minz],
|
34 |
+
[minx, miny, maxz],
|
35 |
+
[maxx, miny, maxz],
|
36 |
+
[maxx, miny, minz]
|
37 |
+
]
|
38 |
+
xz_plane = Poly3DCollection([verts])
|
39 |
+
xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
|
40 |
+
ax.add_collection3d(xz_plane)
|
41 |
+
|
42 |
+
# (seq_len, joints_num, 3)
|
43 |
+
data = joints.copy().reshape(len(joints), -1, 3)
|
44 |
+
|
45 |
+
data *= 1.3 # scale for visualization
|
46 |
+
if hint is not None:
|
47 |
+
mask = hint.sum(-1) != 0
|
48 |
+
hint = hint[mask]
|
49 |
+
hint *= 1.3
|
50 |
+
|
51 |
+
fig = plt.figure(figsize=figsize)
|
52 |
+
plt.tight_layout()
|
53 |
+
ax = p3.Axes3D(fig)
|
54 |
+
init()
|
55 |
+
MINS = data.min(axis=0).min(axis=0)
|
56 |
+
MAXS = data.max(axis=0).max(axis=0)
|
57 |
+
colors = ["#DD5A37", "#D69E00", "#B75A39", "#DD5A37", "#D69E00",
|
58 |
+
"#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00", "#FF6D00",
|
59 |
+
"#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", "#DDB50E", ]
|
60 |
+
|
61 |
+
frame_number = data.shape[0]
|
62 |
+
|
63 |
+
height_offset = MINS[1]
|
64 |
+
data[:, :, 1] -= height_offset
|
65 |
+
if hint is not None:
|
66 |
+
hint[..., 1] -= height_offset
|
67 |
+
trajec = data[:, 0, [0, 2]]
|
68 |
+
|
69 |
+
data[..., 0] -= data[:, 0:1, 0]
|
70 |
+
data[..., 2] -= data[:, 0:1, 2]
|
71 |
+
|
72 |
+
def update(index):
|
73 |
+
ax.lines = []
|
74 |
+
ax.collections = []
|
75 |
+
ax.view_init(elev=120, azim=-90)
|
76 |
+
ax.dist = 7.5
|
77 |
+
plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1],
|
78 |
+
MAXS[2] - trajec[index, 1])
|
79 |
+
|
80 |
+
if hint is not None:
|
81 |
+
ax.scatter(hint[..., 0] - trajec[index, 0], hint[..., 1], hint[..., 2] - trajec[index, 1], color="#80B79A")
|
82 |
+
|
83 |
+
for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
|
84 |
+
if i < 5:
|
85 |
+
linewidth = 4.0
|
86 |
+
else:
|
87 |
+
linewidth = 2.0
|
88 |
+
ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth,
|
89 |
+
color=color)
|
90 |
+
|
91 |
+
plt.axis('off')
|
92 |
+
ax.set_xticklabels([])
|
93 |
+
ax.set_yticklabels([])
|
94 |
+
ax.set_zticklabels([])
|
95 |
+
|
96 |
+
ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False)
|
97 |
+
ani.save(save_path, fps=fps)
|
98 |
+
plt.close()
|
mld/data/humanml/utils/word_vectorizer.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from os.path import join as pjoin
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
POS_enumerator = {
|
8 |
+
'VERB': 0,
|
9 |
+
'NOUN': 1,
|
10 |
+
'DET': 2,
|
11 |
+
'ADP': 3,
|
12 |
+
'NUM': 4,
|
13 |
+
'AUX': 5,
|
14 |
+
'PRON': 6,
|
15 |
+
'ADJ': 7,
|
16 |
+
'ADV': 8,
|
17 |
+
'Loc_VIP': 9,
|
18 |
+
'Body_VIP': 10,
|
19 |
+
'Obj_VIP': 11,
|
20 |
+
'Act_VIP': 12,
|
21 |
+
'Desc_VIP': 13,
|
22 |
+
'OTHER': 14,
|
23 |
+
}
|
24 |
+
|
25 |
+
Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
|
26 |
+
'up', 'down', 'straight', 'curve')
|
27 |
+
|
28 |
+
Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
|
29 |
+
|
30 |
+
Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
|
31 |
+
|
32 |
+
Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
|
33 |
+
'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
|
34 |
+
'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
|
35 |
+
|
36 |
+
Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
|
37 |
+
'angrily', 'sadly')
|
38 |
+
|
39 |
+
VIP_dict = {
|
40 |
+
'Loc_VIP': Loc_list,
|
41 |
+
'Body_VIP': Body_list,
|
42 |
+
'Obj_VIP': Obj_List,
|
43 |
+
'Act_VIP': Act_list,
|
44 |
+
'Desc_VIP': Desc_list,
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
class WordVectorizer(object):
|
49 |
+
def __init__(self, meta_root: str, prefix: str) -> None:
|
50 |
+
vectors = np.load(pjoin(meta_root, '%s_data.npy' % prefix))
|
51 |
+
words = pickle.load(open(pjoin(meta_root, '%s_words.pkl' % prefix), 'rb'))
|
52 |
+
word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl' % prefix), 'rb'))
|
53 |
+
self.word2vec = {w: vectors[word2idx[w]] for w in words}
|
54 |
+
|
55 |
+
def _get_pos_ohot(self, pos: str) -> np.ndarray:
|
56 |
+
pos_vec = np.zeros(len(POS_enumerator))
|
57 |
+
if pos in POS_enumerator:
|
58 |
+
pos_vec[POS_enumerator[pos]] = 1
|
59 |
+
else:
|
60 |
+
pos_vec[POS_enumerator['OTHER']] = 1
|
61 |
+
return pos_vec
|
62 |
+
|
63 |
+
def __len__(self) -> int:
|
64 |
+
return len(self.word2vec)
|
65 |
+
|
66 |
+
def __getitem__(self, item: str) -> tuple:
|
67 |
+
word, pos = item.split('/')
|
68 |
+
if word in self.word2vec:
|
69 |
+
word_vec = self.word2vec[word]
|
70 |
+
vip_pos = None
|
71 |
+
for key, values in VIP_dict.items():
|
72 |
+
if word in values:
|
73 |
+
vip_pos = key
|
74 |
+
break
|
75 |
+
if vip_pos is not None:
|
76 |
+
pos_vec = self._get_pos_ohot(vip_pos)
|
77 |
+
else:
|
78 |
+
pos_vec = self._get_pos_ohot(pos)
|
79 |
+
else:
|
80 |
+
word_vec = self.word2vec['unk']
|
81 |
+
pos_vec = self._get_pos_ohot('OTHER')
|
82 |
+
return word_vec, pos_vec
|
mld/data/utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def collate_tensors(batch: list) -> torch.Tensor:
|
5 |
+
dims = batch[0].dim()
|
6 |
+
max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
|
7 |
+
size = (len(batch), ) + tuple(max_size)
|
8 |
+
canvas = batch[0].new_zeros(size=size)
|
9 |
+
for i, b in enumerate(batch):
|
10 |
+
sub_tensor = canvas[i]
|
11 |
+
for d in range(dims):
|
12 |
+
sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
|
13 |
+
sub_tensor.add_(b)
|
14 |
+
return canvas
|
15 |
+
|
16 |
+
|
17 |
+
def mld_collate(batch: list) -> dict:
|
18 |
+
notnone_batches = [b for b in batch if b is not None]
|
19 |
+
notnone_batches.sort(key=lambda x: x[3], reverse=True)
|
20 |
+
adapted_batch = {
|
21 |
+
"motion":
|
22 |
+
collate_tensors([torch.tensor(b[4]).float() for b in notnone_batches]),
|
23 |
+
"text": [b[2] for b in notnone_batches],
|
24 |
+
"length": [b[5] for b in notnone_batches],
|
25 |
+
"word_embs":
|
26 |
+
collate_tensors([torch.tensor(b[0]).float() for b in notnone_batches]),
|
27 |
+
"pos_ohot":
|
28 |
+
collate_tensors([torch.tensor(b[1]).float() for b in notnone_batches]),
|
29 |
+
"text_len":
|
30 |
+
collate_tensors([torch.tensor(b[3]) for b in notnone_batches]),
|
31 |
+
"tokens": [b[6] for b in notnone_batches],
|
32 |
+
}
|
33 |
+
|
34 |
+
# collate trajectory
|
35 |
+
if notnone_batches[0][-1] is not None:
|
36 |
+
adapted_batch['hint'] = collate_tensors([torch.tensor(b[-1]).float() for b in notnone_batches])
|
37 |
+
|
38 |
+
return adapted_batch
|
mld/launch/__init__.py
ADDED
File without changes
|
mld/launch/blender.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Fix blender path
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
|
6 |
+
sys.path.append(os.path.expanduser("~/.local/lib/python3.9/site-packages"))
|
7 |
+
|
8 |
+
|
9 |
+
# Monkey patch argparse such that
|
10 |
+
# blender / python parsing works
|
11 |
+
def parse_args(self, args=None, namespace=None):
|
12 |
+
if args is not None:
|
13 |
+
return self.parse_args_bak(args=args, namespace=namespace)
|
14 |
+
try:
|
15 |
+
idx = sys.argv.index("--")
|
16 |
+
args = sys.argv[idx + 1:] # the list after '--'
|
17 |
+
except ValueError as e: # '--' not in the list:
|
18 |
+
args = []
|
19 |
+
return self.parse_args_bak(args=args, namespace=namespace)
|
20 |
+
|
21 |
+
|
22 |
+
setattr(ArgumentParser, 'parse_args_bak', ArgumentParser.parse_args)
|
23 |
+
setattr(ArgumentParser, 'parse_args', parse_args)
|
mld/models/__init__.py
ADDED
File without changes
|
mld/models/architectures/__init__.py
ADDED
File without changes
|
mld/models/architectures/mld_clip.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from transformers import AutoModel, AutoTokenizer
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
|
7 |
+
|
8 |
+
class MldTextEncoder(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, modelpath: str, last_hidden_state: bool = False) -> None:
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
if 't5' in modelpath:
|
14 |
+
self.text_model = SentenceTransformer(modelpath)
|
15 |
+
self.tokenizer = self.text_model.tokenizer
|
16 |
+
else:
|
17 |
+
self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
|
18 |
+
self.text_model = AutoModel.from_pretrained(modelpath)
|
19 |
+
|
20 |
+
self.max_length = self.tokenizer.model_max_length
|
21 |
+
if "clip" in modelpath:
|
22 |
+
self.text_encoded_dim = self.text_model.config.text_config.hidden_size
|
23 |
+
if last_hidden_state:
|
24 |
+
self.name = "clip_hidden"
|
25 |
+
else:
|
26 |
+
self.name = "clip"
|
27 |
+
elif "bert" in modelpath:
|
28 |
+
self.name = "bert"
|
29 |
+
self.text_encoded_dim = self.text_model.config.hidden_size
|
30 |
+
elif 't5' in modelpath:
|
31 |
+
self.name = 't5'
|
32 |
+
else:
|
33 |
+
raise ValueError(f"Model {modelpath} not supported")
|
34 |
+
|
35 |
+
def forward(self, texts: list[str]) -> torch.Tensor:
|
36 |
+
# get prompt text embeddings
|
37 |
+
if self.name in ["clip", "clip_hidden"]:
|
38 |
+
text_inputs = self.tokenizer(
|
39 |
+
texts,
|
40 |
+
padding="max_length",
|
41 |
+
truncation=True,
|
42 |
+
max_length=self.max_length,
|
43 |
+
return_tensors="pt",
|
44 |
+
)
|
45 |
+
text_input_ids = text_inputs.input_ids
|
46 |
+
# split into max length Clip can handle
|
47 |
+
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
48 |
+
text_input_ids = text_input_ids[:, :self.tokenizer.model_max_length]
|
49 |
+
elif self.name == "bert":
|
50 |
+
text_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
|
51 |
+
|
52 |
+
if self.name == "clip":
|
53 |
+
# (batch_Size, text_encoded_dim)
|
54 |
+
text_embeddings = self.text_model.get_text_features(
|
55 |
+
text_input_ids.to(self.text_model.device))
|
56 |
+
# (batch_Size, 1, text_encoded_dim)
|
57 |
+
text_embeddings = text_embeddings.unsqueeze(1)
|
58 |
+
elif self.name == "clip_hidden":
|
59 |
+
# (batch_Size, seq_length , text_encoded_dim)
|
60 |
+
text_embeddings = self.text_model.text_model(
|
61 |
+
text_input_ids.to(self.text_model.device)).last_hidden_state
|
62 |
+
elif self.name == "bert":
|
63 |
+
# (batch_Size, seq_length , text_encoded_dim)
|
64 |
+
text_embeddings = self.text_model(
|
65 |
+
**text_inputs.to(self.text_model.device)).last_hidden_state
|
66 |
+
elif self.name == 't5':
|
67 |
+
text_embeddings = self.text_model.encode(texts, show_progress_bar=False, convert_to_tensor=True, batch_size=len(texts))
|
68 |
+
text_embeddings = text_embeddings.unsqueeze(1)
|
69 |
+
else:
|
70 |
+
raise NotImplementedError(f"Model {self.name} not implemented")
|
71 |
+
|
72 |
+
return text_embeddings
|
mld/models/architectures/mld_denoiser.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from mld.models.architectures.tools.embeddings import (TimestepEmbedding,
|
7 |
+
Timesteps)
|
8 |
+
from mld.models.operator.cross_attention import (SkipTransformerEncoder,
|
9 |
+
TransformerDecoder,
|
10 |
+
TransformerDecoderLayer,
|
11 |
+
TransformerEncoder,
|
12 |
+
TransformerEncoderLayer)
|
13 |
+
from mld.models.operator.position_encoding import build_position_encoding
|
14 |
+
|
15 |
+
|
16 |
+
class MldDenoiser(nn.Module):
|
17 |
+
|
18 |
+
def __init__(self,
|
19 |
+
latent_dim: list = [1, 256],
|
20 |
+
ff_size: int = 1024,
|
21 |
+
num_layers: int = 6,
|
22 |
+
num_heads: int = 4,
|
23 |
+
dropout: float = 0.1,
|
24 |
+
normalize_before: bool = False,
|
25 |
+
activation: str = "gelu",
|
26 |
+
flip_sin_to_cos: bool = True,
|
27 |
+
return_intermediate_dec: bool = False,
|
28 |
+
position_embedding: str = "learned",
|
29 |
+
arch: str = "trans_enc",
|
30 |
+
freq_shift: float = 0,
|
31 |
+
text_encoded_dim: int = 768,
|
32 |
+
time_cond_proj_dim: int = None,
|
33 |
+
is_controlnet: bool = False) -> None:
|
34 |
+
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.latent_dim = latent_dim[-1]
|
38 |
+
self.text_encoded_dim = text_encoded_dim
|
39 |
+
|
40 |
+
self.arch = arch
|
41 |
+
self.time_cond_proj_dim = time_cond_proj_dim
|
42 |
+
|
43 |
+
self.time_proj = Timesteps(text_encoded_dim, flip_sin_to_cos, freq_shift)
|
44 |
+
self.time_embedding = TimestepEmbedding(text_encoded_dim, self.latent_dim, cond_proj_dim=time_cond_proj_dim)
|
45 |
+
if text_encoded_dim != self.latent_dim:
|
46 |
+
self.emb_proj = nn.Sequential(nn.ReLU(), nn.Linear(text_encoded_dim, self.latent_dim))
|
47 |
+
|
48 |
+
self.query_pos = build_position_encoding(
|
49 |
+
self.latent_dim, position_embedding=position_embedding)
|
50 |
+
|
51 |
+
if self.arch == "trans_enc":
|
52 |
+
encoder_layer = TransformerEncoderLayer(
|
53 |
+
self.latent_dim,
|
54 |
+
num_heads,
|
55 |
+
ff_size,
|
56 |
+
dropout,
|
57 |
+
activation,
|
58 |
+
normalize_before,
|
59 |
+
)
|
60 |
+
encoder_norm = None if is_controlnet else nn.LayerNorm(self.latent_dim)
|
61 |
+
self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm,
|
62 |
+
return_intermediate=is_controlnet)
|
63 |
+
|
64 |
+
elif self.arch == "trans_dec":
|
65 |
+
assert not is_controlnet, f"controlnet not supported in architecture: 'trans_dec'"
|
66 |
+
self.mem_pos = build_position_encoding(
|
67 |
+
self.latent_dim, position_embedding=position_embedding)
|
68 |
+
|
69 |
+
decoder_layer = TransformerDecoderLayer(
|
70 |
+
self.latent_dim,
|
71 |
+
num_heads,
|
72 |
+
ff_size,
|
73 |
+
dropout,
|
74 |
+
activation,
|
75 |
+
normalize_before,
|
76 |
+
)
|
77 |
+
decoder_norm = nn.LayerNorm(self.latent_dim)
|
78 |
+
self.decoder = TransformerDecoder(
|
79 |
+
decoder_layer,
|
80 |
+
num_layers,
|
81 |
+
decoder_norm,
|
82 |
+
return_intermediate=return_intermediate_dec,
|
83 |
+
)
|
84 |
+
|
85 |
+
else:
|
86 |
+
raise ValueError(f"Not supported architecture: {self.arch}!")
|
87 |
+
|
88 |
+
self.is_controlnet = is_controlnet
|
89 |
+
|
90 |
+
def zero_module(module):
|
91 |
+
for p in module.parameters():
|
92 |
+
nn.init.zeros_(p)
|
93 |
+
return module
|
94 |
+
|
95 |
+
if self.is_controlnet:
|
96 |
+
self.controlnet_cond_embedding = nn.Sequential(
|
97 |
+
nn.Linear(self.latent_dim, self.latent_dim),
|
98 |
+
nn.Linear(self.latent_dim, self.latent_dim),
|
99 |
+
zero_module(nn.Linear(self.latent_dim, self.latent_dim))
|
100 |
+
)
|
101 |
+
|
102 |
+
self.controlnet_down_mid_blocks = nn.ModuleList([
|
103 |
+
zero_module(nn.Linear(self.latent_dim, self.latent_dim)) for _ in range(num_layers)])
|
104 |
+
|
105 |
+
def forward(self,
|
106 |
+
sample: torch.Tensor,
|
107 |
+
timestep: torch.Tensor,
|
108 |
+
encoder_hidden_states: torch.Tensor,
|
109 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
110 |
+
controlnet_cond: Optional[torch.Tensor] = None,
|
111 |
+
controlnet_residuals: Optional[list[torch.Tensor]] = None
|
112 |
+
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
113 |
+
|
114 |
+
# 0. dimension matching
|
115 |
+
# sample [latent_dim[0], batch_size, latent_dim] <= [batch_size, latent_dim[0], latent_dim[1]]
|
116 |
+
sample = sample.permute(1, 0, 2)
|
117 |
+
|
118 |
+
# 1. check if controlnet
|
119 |
+
if self.is_controlnet:
|
120 |
+
controlnet_cond = controlnet_cond.permute(1, 0, 2)
|
121 |
+
sample = sample + self.controlnet_cond_embedding(controlnet_cond)
|
122 |
+
|
123 |
+
# 2. time_embedding
|
124 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
125 |
+
timesteps = timestep.expand(sample.shape[1]).clone()
|
126 |
+
time_emb = self.time_proj(timesteps)
|
127 |
+
time_emb = time_emb.to(dtype=sample.dtype)
|
128 |
+
# [1, bs, latent_dim] <= [bs, latent_dim]
|
129 |
+
time_emb = self.time_embedding(time_emb, timestep_cond).unsqueeze(0)
|
130 |
+
|
131 |
+
# 3. condition + time embedding
|
132 |
+
# text_emb [seq_len, batch_size, text_encoded_dim] <= [batch_size, seq_len, text_encoded_dim]
|
133 |
+
encoder_hidden_states = encoder_hidden_states.permute(1, 0, 2)
|
134 |
+
text_emb = encoder_hidden_states # [num_words, bs, latent_dim]
|
135 |
+
# text embedding projection
|
136 |
+
if self.text_encoded_dim != self.latent_dim:
|
137 |
+
# [1 or 2, bs, latent_dim] <= [1 or 2, bs, text_encoded_dim]
|
138 |
+
text_emb_latent = self.emb_proj(text_emb)
|
139 |
+
else:
|
140 |
+
text_emb_latent = text_emb
|
141 |
+
emb_latent = torch.cat((time_emb, text_emb_latent), 0)
|
142 |
+
|
143 |
+
# 4. transformer
|
144 |
+
if self.arch == "trans_enc":
|
145 |
+
xseq = torch.cat((sample, emb_latent), axis=0)
|
146 |
+
|
147 |
+
xseq = self.query_pos(xseq)
|
148 |
+
tokens = self.encoder(xseq, controlnet_residuals=controlnet_residuals)
|
149 |
+
|
150 |
+
if self.is_controlnet:
|
151 |
+
control_res_samples = []
|
152 |
+
for res, block in zip(tokens, self.controlnet_down_mid_blocks):
|
153 |
+
r = block(res)
|
154 |
+
control_res_samples.append(r)
|
155 |
+
return control_res_samples
|
156 |
+
|
157 |
+
sample = tokens[:sample.shape[0]]
|
158 |
+
|
159 |
+
elif self.arch == "trans_dec":
|
160 |
+
# tgt - [1 or 5 or 10, bs, latent_dim]
|
161 |
+
# memory - [token_num, bs, latent_dim]
|
162 |
+
sample = self.query_pos(sample)
|
163 |
+
emb_latent = self.mem_pos(emb_latent)
|
164 |
+
sample = self.decoder(tgt=sample, memory=emb_latent).squeeze(0)
|
165 |
+
|
166 |
+
else:
|
167 |
+
raise TypeError(f"{self.arch} is not supported")
|
168 |
+
|
169 |
+
# 5. [batch_size, latent_dim[0], latent_dim[1]] <= [latent_dim[0], batch_size, latent_dim[1]]
|
170 |
+
sample = sample.permute(1, 0, 2)
|
171 |
+
|
172 |
+
return sample
|
mld/models/architectures/mld_traj_encoder.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from mld.models.operator.cross_attention import SkipTransformerEncoder, TransformerEncoderLayer
|
7 |
+
from mld.models.operator.position_encoding import build_position_encoding
|
8 |
+
from mld.utils.temos_utils import lengths_to_mask
|
9 |
+
|
10 |
+
|
11 |
+
class MldTrajEncoder(nn.Module):
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
nfeats: int,
|
15 |
+
latent_dim: list = [1, 256],
|
16 |
+
ff_size: int = 1024,
|
17 |
+
num_layers: int = 9,
|
18 |
+
num_heads: int = 4,
|
19 |
+
dropout: float = 0.1,
|
20 |
+
normalize_before: bool = False,
|
21 |
+
activation: str = "gelu",
|
22 |
+
position_embedding: str = "learned") -> None:
|
23 |
+
|
24 |
+
super().__init__()
|
25 |
+
self.latent_size = latent_dim[0]
|
26 |
+
self.latent_dim = latent_dim[-1]
|
27 |
+
|
28 |
+
self.skel_embedding = nn.Linear(nfeats * 3, self.latent_dim)
|
29 |
+
|
30 |
+
self.query_pos_encoder = build_position_encoding(
|
31 |
+
self.latent_dim, position_embedding=position_embedding)
|
32 |
+
|
33 |
+
encoder_layer = TransformerEncoderLayer(
|
34 |
+
self.latent_dim,
|
35 |
+
num_heads,
|
36 |
+
ff_size,
|
37 |
+
dropout,
|
38 |
+
activation,
|
39 |
+
normalize_before,
|
40 |
+
)
|
41 |
+
encoder_norm = nn.LayerNorm(self.latent_dim)
|
42 |
+
self.encoder = SkipTransformerEncoder(encoder_layer, num_layers,
|
43 |
+
encoder_norm)
|
44 |
+
|
45 |
+
self.global_motion_token = nn.Parameter(
|
46 |
+
torch.randn(self.latent_size, self.latent_dim))
|
47 |
+
|
48 |
+
def forward(self, features: torch.Tensor, lengths: Optional[list[int]] = None,
|
49 |
+
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
50 |
+
|
51 |
+
if lengths is None and mask is None:
|
52 |
+
lengths = [len(feature) for feature in features]
|
53 |
+
mask = lengths_to_mask(lengths, features.device)
|
54 |
+
|
55 |
+
bs, nframes, nfeats = features.shape
|
56 |
+
|
57 |
+
x = features
|
58 |
+
# Embed each human poses into latent vectors
|
59 |
+
x = self.skel_embedding(x)
|
60 |
+
|
61 |
+
# Switch sequence and batch_size because the input of
|
62 |
+
# Pytorch Transformer is [Sequence, Batch size, ...]
|
63 |
+
x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
|
64 |
+
|
65 |
+
# Each batch has its own set of tokens
|
66 |
+
dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1))
|
67 |
+
|
68 |
+
# create a bigger mask, to allow attend to emb
|
69 |
+
dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device)
|
70 |
+
aug_mask = torch.cat((dist_masks, mask), 1)
|
71 |
+
|
72 |
+
# adding the embedding token for all sequences
|
73 |
+
xseq = torch.cat((dist, x), 0)
|
74 |
+
|
75 |
+
xseq = self.query_pos_encoder(xseq)
|
76 |
+
global_token = self.encoder(xseq, src_key_padding_mask=~aug_mask)[:dist.shape[0]]
|
77 |
+
|
78 |
+
return global_token
|
mld/models/architectures/mld_vae.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.distributions.distribution import Distribution
|
6 |
+
|
7 |
+
from mld.models.operator.cross_attention import (
|
8 |
+
SkipTransformerEncoder,
|
9 |
+
SkipTransformerDecoder,
|
10 |
+
TransformerDecoder,
|
11 |
+
TransformerDecoderLayer,
|
12 |
+
TransformerEncoder,
|
13 |
+
TransformerEncoderLayer,
|
14 |
+
)
|
15 |
+
from mld.models.operator.position_encoding import build_position_encoding
|
16 |
+
from mld.utils.temos_utils import lengths_to_mask
|
17 |
+
|
18 |
+
|
19 |
+
class MldVae(nn.Module):
|
20 |
+
|
21 |
+
def __init__(self,
|
22 |
+
nfeats: int,
|
23 |
+
latent_dim: list = [1, 256],
|
24 |
+
ff_size: int = 1024,
|
25 |
+
num_layers: int = 9,
|
26 |
+
num_heads: int = 4,
|
27 |
+
dropout: float = 0.1,
|
28 |
+
arch: str = "encoder_decoder",
|
29 |
+
normalize_before: bool = False,
|
30 |
+
activation: str = "gelu",
|
31 |
+
position_embedding: str = "learned") -> None:
|
32 |
+
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.latent_size = latent_dim[0]
|
36 |
+
self.latent_dim = latent_dim[-1]
|
37 |
+
input_feats = nfeats
|
38 |
+
output_feats = nfeats
|
39 |
+
self.arch = arch
|
40 |
+
|
41 |
+
self.query_pos_encoder = build_position_encoding(
|
42 |
+
self.latent_dim, position_embedding=position_embedding)
|
43 |
+
|
44 |
+
encoder_layer = TransformerEncoderLayer(
|
45 |
+
self.latent_dim,
|
46 |
+
num_heads,
|
47 |
+
ff_size,
|
48 |
+
dropout,
|
49 |
+
activation,
|
50 |
+
normalize_before,
|
51 |
+
)
|
52 |
+
encoder_norm = nn.LayerNorm(self.latent_dim)
|
53 |
+
self.encoder = SkipTransformerEncoder(encoder_layer, num_layers,
|
54 |
+
encoder_norm)
|
55 |
+
|
56 |
+
if self.arch == "all_encoder":
|
57 |
+
decoder_norm = nn.LayerNorm(self.latent_dim)
|
58 |
+
self.decoder = SkipTransformerEncoder(encoder_layer, num_layers,
|
59 |
+
decoder_norm)
|
60 |
+
elif self.arch == 'encoder_decoder':
|
61 |
+
self.query_pos_decoder = build_position_encoding(
|
62 |
+
self.latent_dim, position_embedding=position_embedding)
|
63 |
+
|
64 |
+
decoder_layer = TransformerDecoderLayer(
|
65 |
+
self.latent_dim,
|
66 |
+
num_heads,
|
67 |
+
ff_size,
|
68 |
+
dropout,
|
69 |
+
activation,
|
70 |
+
normalize_before,
|
71 |
+
)
|
72 |
+
decoder_norm = nn.LayerNorm(self.latent_dim)
|
73 |
+
self.decoder = SkipTransformerDecoder(decoder_layer, num_layers,
|
74 |
+
decoder_norm)
|
75 |
+
else:
|
76 |
+
raise ValueError(f"Not support architecture: {self.arch}!")
|
77 |
+
|
78 |
+
self.global_motion_token = nn.Parameter(
|
79 |
+
torch.randn(self.latent_size * 2, self.latent_dim))
|
80 |
+
|
81 |
+
self.skel_embedding = nn.Linear(input_feats, self.latent_dim)
|
82 |
+
self.final_layer = nn.Linear(self.latent_dim, output_feats)
|
83 |
+
|
84 |
+
def forward(self, features: torch.Tensor,
|
85 |
+
lengths: Optional[list[int]] = None) -> tuple[torch.Tensor, torch.Tensor, Distribution]:
|
86 |
+
z, dist = self.encode(features, lengths)
|
87 |
+
feats_rst = self.decode(z, lengths)
|
88 |
+
return feats_rst, z, dist
|
89 |
+
|
90 |
+
def encode(self, features: torch.Tensor,
|
91 |
+
lengths: Optional[list[int]] = None) -> tuple[torch.Tensor, Distribution]:
|
92 |
+
if lengths is None:
|
93 |
+
lengths = [len(feature) for feature in features]
|
94 |
+
|
95 |
+
device = features.device
|
96 |
+
|
97 |
+
bs, nframes, nfeats = features.shape
|
98 |
+
mask = lengths_to_mask(lengths, device)
|
99 |
+
|
100 |
+
x = features
|
101 |
+
# Embed each human poses into latent vectors
|
102 |
+
x = self.skel_embedding(x)
|
103 |
+
|
104 |
+
# Switch sequence and batch_size because the input of
|
105 |
+
# Pytorch Transformer is [Sequence, Batch size, ...]
|
106 |
+
x = x.permute(1, 0, 2) # now it is [nframes, bs, latent_dim]
|
107 |
+
|
108 |
+
# Each batch has its own set of tokens
|
109 |
+
dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1))
|
110 |
+
|
111 |
+
# create a bigger mask, to allow attend to emb
|
112 |
+
dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device)
|
113 |
+
aug_mask = torch.cat((dist_masks, mask), 1)
|
114 |
+
|
115 |
+
# adding the embedding token for all sequences
|
116 |
+
xseq = torch.cat((dist, x), 0)
|
117 |
+
|
118 |
+
xseq = self.query_pos_encoder(xseq)
|
119 |
+
dist = self.encoder(xseq, src_key_padding_mask=~aug_mask)[:dist.shape[0]]
|
120 |
+
|
121 |
+
mu = dist[0:self.latent_size, ...]
|
122 |
+
logvar = dist[self.latent_size:, ...]
|
123 |
+
|
124 |
+
# resampling
|
125 |
+
std = logvar.exp().pow(0.5)
|
126 |
+
dist = torch.distributions.Normal(mu, std)
|
127 |
+
latent = dist.rsample()
|
128 |
+
return latent, dist
|
129 |
+
|
130 |
+
def decode(self, z: torch.Tensor, lengths: list[int]) -> torch.Tensor:
|
131 |
+
mask = lengths_to_mask(lengths, z.device)
|
132 |
+
bs, nframes = mask.shape
|
133 |
+
queries = torch.zeros(nframes, bs, self.latent_dim, device=z.device)
|
134 |
+
|
135 |
+
if self.arch == "all_encoder":
|
136 |
+
xseq = torch.cat((z, queries), axis=0)
|
137 |
+
z_mask = torch.ones((bs, self.latent_size), dtype=torch.bool, device=z.device)
|
138 |
+
aug_mask = torch.cat((z_mask, mask), axis=1)
|
139 |
+
xseq = self.query_pos_decoder(xseq)
|
140 |
+
output = self.decoder(xseq, src_key_padding_mask=~aug_mask)[z.shape[0]:]
|
141 |
+
|
142 |
+
elif self.arch == "encoder_decoder":
|
143 |
+
queries = self.query_pos_decoder(queries)
|
144 |
+
output = self.decoder(
|
145 |
+
tgt=queries,
|
146 |
+
memory=z,
|
147 |
+
tgt_key_padding_mask=~mask)
|
148 |
+
|
149 |
+
output = self.final_layer(output)
|
150 |
+
# zero for padded area
|
151 |
+
output[~mask.T] = 0
|
152 |
+
# Pytorch Transformer: [Sequence, Batch size, ...]
|
153 |
+
feats = output.permute(1, 0, 2)
|
154 |
+
return feats
|
mld/models/architectures/t2m_motionenc.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.utils.rnn import pack_padded_sequence
|
4 |
+
|
5 |
+
|
6 |
+
class MovementConvEncoder(nn.Module):
|
7 |
+
def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None:
|
8 |
+
super(MovementConvEncoder, self).__init__()
|
9 |
+
self.main = nn.Sequential(
|
10 |
+
nn.Conv1d(input_size, hidden_size, 4, 2, 1),
|
11 |
+
nn.Dropout(0.2, inplace=True),
|
12 |
+
nn.LeakyReLU(0.2, inplace=True),
|
13 |
+
nn.Conv1d(hidden_size, output_size, 4, 2, 1),
|
14 |
+
nn.Dropout(0.2, inplace=True),
|
15 |
+
nn.LeakyReLU(0.2, inplace=True),
|
16 |
+
)
|
17 |
+
self.out_net = nn.Linear(output_size, output_size)
|
18 |
+
|
19 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
20 |
+
inputs = inputs.permute(0, 2, 1)
|
21 |
+
outputs = self.main(inputs).permute(0, 2, 1)
|
22 |
+
return self.out_net(outputs)
|
23 |
+
|
24 |
+
|
25 |
+
class MotionEncoderBiGRUCo(nn.Module):
|
26 |
+
def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None:
|
27 |
+
super(MotionEncoderBiGRUCo, self).__init__()
|
28 |
+
|
29 |
+
self.input_emb = nn.Linear(input_size, hidden_size)
|
30 |
+
self.gru = nn.GRU(
|
31 |
+
hidden_size, hidden_size, batch_first=True, bidirectional=True
|
32 |
+
)
|
33 |
+
self.output_net = nn.Sequential(
|
34 |
+
nn.Linear(hidden_size * 2, hidden_size),
|
35 |
+
nn.LayerNorm(hidden_size),
|
36 |
+
nn.LeakyReLU(0.2, inplace=True),
|
37 |
+
nn.Linear(hidden_size, output_size),
|
38 |
+
)
|
39 |
+
|
40 |
+
self.hidden_size = hidden_size
|
41 |
+
self.hidden = nn.Parameter(
|
42 |
+
torch.randn((2, 1, self.hidden_size), requires_grad=True)
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, inputs: torch.Tensor, m_lens: torch.Tensor) -> torch.Tensor:
|
46 |
+
num_samples = inputs.shape[0]
|
47 |
+
|
48 |
+
input_embs = self.input_emb(inputs)
|
49 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
50 |
+
|
51 |
+
cap_lens = m_lens.data.tolist()
|
52 |
+
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
|
53 |
+
|
54 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
55 |
+
|
56 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
57 |
+
|
58 |
+
return self.output_net(gru_last)
|
mld/models/architectures/t2m_textenc.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn.utils.rnn import pack_padded_sequence
|
4 |
+
|
5 |
+
|
6 |
+
class TextEncoderBiGRUCo(nn.Module):
|
7 |
+
def __init__(self, word_size: int, pos_size: int, hidden_size: int, output_size: int) -> None:
|
8 |
+
super(TextEncoderBiGRUCo, self).__init__()
|
9 |
+
|
10 |
+
self.pos_emb = nn.Linear(pos_size, word_size)
|
11 |
+
self.input_emb = nn.Linear(word_size, hidden_size)
|
12 |
+
self.gru = nn.GRU(
|
13 |
+
hidden_size, hidden_size, batch_first=True, bidirectional=True
|
14 |
+
)
|
15 |
+
self.output_net = nn.Sequential(
|
16 |
+
nn.Linear(hidden_size * 2, hidden_size),
|
17 |
+
nn.LayerNorm(hidden_size),
|
18 |
+
nn.LeakyReLU(0.2, inplace=True),
|
19 |
+
nn.Linear(hidden_size, output_size),
|
20 |
+
)
|
21 |
+
|
22 |
+
self.hidden_size = hidden_size
|
23 |
+
self.hidden = nn.Parameter(
|
24 |
+
torch.randn((2, 1, self.hidden_size), requires_grad=True)
|
25 |
+
)
|
26 |
+
|
27 |
+
def forward(self, word_embs: torch.Tensor, pos_onehot: torch.Tensor,
|
28 |
+
cap_lens: torch.Tensor) -> torch.Tensor:
|
29 |
+
num_samples = word_embs.shape[0]
|
30 |
+
|
31 |
+
pos_embs = self.pos_emb(pos_onehot)
|
32 |
+
inputs = word_embs + pos_embs
|
33 |
+
input_embs = self.input_emb(inputs)
|
34 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
35 |
+
|
36 |
+
cap_lens = cap_lens.data.tolist()
|
37 |
+
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
|
38 |
+
|
39 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
40 |
+
|
41 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
42 |
+
|
43 |
+
return self.output_net(gru_last)
|
mld/models/architectures/tools/embeddings.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
def get_timestep_embedding(
|
9 |
+
timesteps: torch.Tensor,
|
10 |
+
embedding_dim: int,
|
11 |
+
flip_sin_to_cos: bool = False,
|
12 |
+
downscale_freq_shift: float = 1,
|
13 |
+
scale: float = 1,
|
14 |
+
max_period: int = 10000,
|
15 |
+
) -> torch.Tensor:
|
16 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
17 |
+
|
18 |
+
half_dim = embedding_dim // 2
|
19 |
+
exponent = -math.log(max_period) * torch.arange(
|
20 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
21 |
+
)
|
22 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
23 |
+
|
24 |
+
emb = torch.exp(exponent)
|
25 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
26 |
+
|
27 |
+
# scale embeddings
|
28 |
+
emb = scale * emb
|
29 |
+
|
30 |
+
# concat sine and cosine embeddings
|
31 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
32 |
+
|
33 |
+
# flip sine and cosine embeddings
|
34 |
+
if flip_sin_to_cos:
|
35 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
36 |
+
|
37 |
+
# zero pad
|
38 |
+
if embedding_dim % 2 == 1:
|
39 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
40 |
+
return emb
|
41 |
+
|
42 |
+
|
43 |
+
class TimestepEmbedding(nn.Module):
|
44 |
+
def __init__(self, channel: int, time_embed_dim: int,
|
45 |
+
act_fn: str = "silu", cond_proj_dim: Optional[int] = None) -> None:
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
# distill CFG
|
49 |
+
if cond_proj_dim is not None:
|
50 |
+
self.cond_proj = nn.Linear(cond_proj_dim, channel, bias=False)
|
51 |
+
self.cond_proj.weight.data.fill_(0.0)
|
52 |
+
else:
|
53 |
+
self.cond_proj = None
|
54 |
+
|
55 |
+
self.linear_1 = nn.Linear(channel, time_embed_dim)
|
56 |
+
self.act = None
|
57 |
+
if act_fn == "silu":
|
58 |
+
self.act = nn.SiLU()
|
59 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
60 |
+
|
61 |
+
def forward(self, sample: torch.Tensor, timestep_cond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
62 |
+
if timestep_cond is not None:
|
63 |
+
sample = sample + self.cond_proj(timestep_cond)
|
64 |
+
|
65 |
+
sample = self.linear_1(sample)
|
66 |
+
|
67 |
+
if self.act is not None:
|
68 |
+
sample = self.act(sample)
|
69 |
+
|
70 |
+
sample = self.linear_2(sample)
|
71 |
+
return sample
|
72 |
+
|
73 |
+
|
74 |
+
class Timesteps(nn.Module):
|
75 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool,
|
76 |
+
downscale_freq_shift: float) -> None:
|
77 |
+
super().__init__()
|
78 |
+
self.num_channels = num_channels
|
79 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
80 |
+
self.downscale_freq_shift = downscale_freq_shift
|
81 |
+
|
82 |
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
83 |
+
t_emb = get_timestep_embedding(
|
84 |
+
timesteps,
|
85 |
+
self.num_channels,
|
86 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
87 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
88 |
+
)
|
89 |
+
return t_emb
|
mld/models/metrics/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .tm2t import TM2TMetrics
|
2 |
+
from .mm import MMMetrics
|
3 |
+
from .cm import ControlMetrics
|
mld/models/metrics/cm.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchmetrics import Metric
|
3 |
+
from torchmetrics.utilities import dim_zero_cat
|
4 |
+
|
5 |
+
from mld.utils.temos_utils import remove_padding
|
6 |
+
from .utils import calculate_skating_ratio, calculate_trajectory_error, control_l2
|
7 |
+
|
8 |
+
|
9 |
+
class ControlMetrics(Metric):
|
10 |
+
|
11 |
+
def __init__(self, dist_sync_on_step: bool = True) -> None:
|
12 |
+
super().__init__(dist_sync_on_step=dist_sync_on_step)
|
13 |
+
|
14 |
+
self.name = "control_metrics"
|
15 |
+
|
16 |
+
self.add_state("count_seq", default=torch.tensor(0), dist_reduce_fx="sum")
|
17 |
+
self.add_state("skate_ratio_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
18 |
+
self.add_state("dist_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
|
19 |
+
self.add_state("traj_err", default=[], dist_reduce_fx="cat")
|
20 |
+
self.traj_err_key = ["traj_fail_20cm", "traj_fail_50cm", "kps_fail_20cm", "kps_fail_50cm", "kps_mean_err(m)"]
|
21 |
+
|
22 |
+
def compute(self) -> dict:
|
23 |
+
count_seq = self.count_seq.item()
|
24 |
+
|
25 |
+
metrics = dict()
|
26 |
+
metrics['Skating Ratio'] = self.skate_ratio_sum / count_seq
|
27 |
+
metrics['Control L2 dist'] = self.dist_sum / count_seq
|
28 |
+
traj_err = dim_zero_cat(self.traj_err).mean(0)
|
29 |
+
|
30 |
+
for (k, v) in zip(self.traj_err_key, traj_err):
|
31 |
+
metrics[k] = v
|
32 |
+
|
33 |
+
return {**metrics}
|
34 |
+
|
35 |
+
def update(self, joints: torch.Tensor, hint: torch.Tensor,
|
36 |
+
mask_hint: torch.Tensor, lengths: list[int]) -> None:
|
37 |
+
self.count_seq += len(lengths)
|
38 |
+
|
39 |
+
joints_no_padding = remove_padding(joints, lengths)
|
40 |
+
for j in joints_no_padding:
|
41 |
+
skate_ratio, _ = calculate_skating_ratio(j.unsqueeze(0).permute(0, 2, 3, 1))
|
42 |
+
self.skate_ratio_sum += skate_ratio[0]
|
43 |
+
|
44 |
+
joints_np = joints.cpu().numpy()
|
45 |
+
hint_np = hint.cpu().numpy()
|
46 |
+
mask_hint_np = mask_hint.cpu().numpy()
|
47 |
+
|
48 |
+
for j, h, m in zip(joints_np, hint_np, mask_hint_np):
|
49 |
+
control_error = control_l2(j[None], h[None], m[None])
|
50 |
+
mean_error = control_error.sum() / m.sum()
|
51 |
+
self.dist_sum += mean_error
|
52 |
+
control_error = control_error.reshape(-1)
|
53 |
+
m = m.reshape(-1)
|
54 |
+
err_np = calculate_trajectory_error(control_error, mean_error, m)
|
55 |
+
self.traj_err.append(torch.tensor(err_np[None], device=joints.device))
|