MeYourHint commited on
Commit
c0eac48
·
1 Parent(s): 08572f0

first demo version

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. LICENSE +21 -0
  3. README.md +221 -13
  4. app.py +203 -0
  5. assets/mapping.json +1 -0
  6. assets/mapping6.json +1 -0
  7. assets/text_prompt.txt +12 -0
  8. common/__init__.py +0 -0
  9. common/quaternion.py +423 -0
  10. common/skeleton.py +199 -0
  11. data/__init__.py +0 -0
  12. data/t2m_dataset.py +348 -0
  13. dataset/__init__.py +0 -0
  14. edit_t2m.py +195 -0
  15. environment.yml +204 -0
  16. eval_t2m_trans_res.py +199 -0
  17. eval_t2m_vq.py +123 -0
  18. example_data/000612.mp4 +0 -0
  19. example_data/000612.npy +3 -0
  20. gen_t2m.py +261 -0
  21. models/.DS_Store +0 -0
  22. models/__init__.py +0 -0
  23. models/mask_transformer/__init__.py +0 -0
  24. models/mask_transformer/tools.py +165 -0
  25. models/mask_transformer/transformer.py +1039 -0
  26. models/mask_transformer/transformer_trainer.py +359 -0
  27. models/t2m_eval_modules.py +182 -0
  28. models/t2m_eval_wrapper.py +191 -0
  29. models/vq/__init__.py +0 -0
  30. models/vq/encdec.py +68 -0
  31. models/vq/model.py +124 -0
  32. models/vq/quantizer.py +180 -0
  33. models/vq/residual_vq.py +194 -0
  34. models/vq/resnet.py +84 -0
  35. models/vq/vq_trainer.py +359 -0
  36. motion_loaders/__init__.py +0 -0
  37. motion_loaders/dataset_motion_loader.py +27 -0
  38. options/__init__.py +0 -0
  39. options/base_option.py +61 -0
  40. options/eval_option.py +38 -0
  41. options/train_option.py +64 -0
  42. options/vq_option.py +89 -0
  43. prepare/.DS_Store +0 -0
  44. prepare/download_evaluator.sh +24 -0
  45. prepare/download_glove.sh +9 -0
  46. prepare/download_models.sh +31 -0
  47. prepare/download_models_demo.sh +10 -0
  48. requirements.txt +140 -0
  49. train_res_transformer.py +171 -0
  50. train_t2m_transformer.py +153 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Chuan Guo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,221 @@
1
- ---
2
- title: MoMask
3
- emoji: 💻
4
- colorFrom: indigo
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.12.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MoMask: Generative Masked Modeling of 3D Human Motions
2
+ ## [[Project Page]](https://ericguo5513.github.io/momask) [[Paper]](https://arxiv.org/abs/2312.00063)
3
+ ![teaser_image](https://ericguo5513.github.io/momask/static/images/teaser.png)
4
+
5
+ If you find our code or paper helpful, please consider citing:
6
+ ```
7
+ @article{guo2023momask,
8
+ title={MoMask: Generative Masked Modeling of 3D Human Motions},
9
+ author={Chuan Guo and Yuxuan Mu and Muhammad Gohar Javed and Sen Wang and Li Cheng},
10
+ year={2023},
11
+ eprint={2312.00063},
12
+ archivePrefix={arXiv},
13
+ primaryClass={cs.CV}
14
+ }
15
+ ```
16
+
17
+ ## :postbox: News
18
+ 📢 **2023-12-19** --- Release scripts for temporal inpainting.
19
+
20
+ 📢 **2023-12-15** --- Release codes and models for momask. Including training/eval/generation scripts.
21
+
22
+ 📢 **2023-11-29** --- Initialized the webpage and git project.
23
+
24
+
25
+ ## :round_pushpin: Get You Ready
26
+
27
+ <details>
28
+
29
+ ### 1. Conda Environment
30
+ ```
31
+ conda env create -f environment.yml
32
+ conda activate momask
33
+ pip install git+https://github.com/openai/CLIP.git
34
+ ```
35
+ We test our code on Python 3.7.13 and PyTorch 1.7.1
36
+
37
+
38
+ ### 2. Models and Dependencies
39
+
40
+ #### Download Pre-trained Models
41
+ ```
42
+ bash prepare/download_models.sh
43
+ ```
44
+
45
+ #### Download Evaluation Models and Gloves
46
+ For evaluation only.
47
+ ```
48
+ bash prepare/download_evaluator.sh
49
+ bash prepare/download_glove.sh
50
+ ```
51
+
52
+ #### Troubleshooting
53
+ To address the download error related to gdown: "Cannot retrieve the public link of the file. You may need to change the permission to 'Anyone with the link', or have had many accesses". A potential solution is to run `pip install --upgrade --no-cache-dir gdown`, as suggested on https://github.com/wkentaro/gdown/issues/43. This should help resolve the issue.
54
+
55
+ #### (Optional) Download Mannually
56
+ Visit [[Google Drive]](https://drive.google.com/drive/folders/1b3GnAbERH8jAoO5mdWgZhyxHB73n23sK?usp=drive_link) to download the models and evaluators mannually.
57
+
58
+ ### 3. Get Data
59
+
60
+ You have two options here:
61
+ * **Skip getting data**, if you just want to generate motions using *own* descriptions.
62
+ * **Get full data**, if you want to *re-train* and *evaluate* the model.
63
+
64
+ **(a). Full data (text + motion)**
65
+
66
+ **HumanML3D** - Follow the instruction in [HumanML3D](https://github.com/EricGuo5513/HumanML3D.git), then copy the result dataset to our repository:
67
+ ```
68
+ cp -r ../HumanML3D/HumanML3D ./dataset/HumanML3D
69
+ ```
70
+ **KIT**-Download from [HumanML3D](https://github.com/EricGuo5513/HumanML3D.git), then place result in `./dataset/KIT-ML`
71
+
72
+ ####
73
+
74
+ </details>
75
+
76
+ ## :rocket: Demo
77
+ <details>
78
+
79
+ ### (a) Generate from a single prompt
80
+ ```
81
+ python gen_t2m.py --gpu_id 1 --ext exp1 --text_prompt "A person is running on a treadmill."
82
+ ```
83
+ ### (b) Generate from a prompt file
84
+ An example of prompt file is given in `./assets/text_prompt.txt`. Please follow the format of `<text description>#<motion length>` at each line. Motion length indicates the number of poses, which must be integeter and will be rounded by 4. In our work, motion is in 20 fps.
85
+
86
+ If you write `<text description>#NA`, our model will determine a length. Note once there is **one** NA, all the others will be **NA** automatically.
87
+
88
+ ```
89
+ python gen_t2m.py --gpu_id 1 --ext exp2 --text_path ./assets/text_prompt.txt
90
+ ```
91
+
92
+
93
+ A few more parameters you may be interested:
94
+ * `--repeat_times`: number of replications for generation, default `1`.
95
+ * `--motion_length`: specify the number of poses for generation, only applicable in (a).
96
+
97
+ The output files are stored under folder `./generation/<ext>/`. They are
98
+ * `numpy files`: generated motions with shape of (nframe, 22, 3), under subfolder `./joints`.
99
+ * `video files`: stick figure animation in mp4 format, under subfolder `./animation`.
100
+ * `bvh files`: bvh files of the generated motion, under subfolder `./animation`.
101
+
102
+ We also apply naive foot ik to the generated motions, see files with suffix `_ik`. It sometimes works well, but sometimes will fail.
103
+
104
+ </details>
105
+
106
+ ## :dancers: Visualization
107
+ <details>
108
+
109
+ All the animations are manually rendered in blender. We use the characters from [mixamo](https://www.mixamo.com/#/). You need to download the characters in T-Pose with skeleton.
110
+
111
+ ### Retargeting
112
+ For retargeting, we found rokoko usually leads to large error on foot. On the other hand, [keemap.rig.transfer](https://github.com/nkeeline/Keemap-Blender-Rig-ReTargeting-Addon/releases) shows more precise retargetting. You could watch the [tutorial](https://www.youtube.com/watch?v=EG-VCMkVpxg) here.
113
+
114
+ Following these steps:
115
+ * Download keemap.rig.transfer from the github, and install it in blender.
116
+ * Import both the motion files (.bvh) and character files (.fbx) in blender.
117
+ * `Shift + Select` the both source and target skeleton. (Do not need to be Rest Position)
118
+ * Switch to `Pose Mode`, then unfold the `KeeMapRig` tool at the top-right corner of the view window.
119
+ * Load and read the bone mapping file `./assets/mapping.json`(or `mapping6.json` if it doesn't work). This file is manually made by us. It works for most characters in mixamo. You could make your own.
120
+ * Adjust the `Number of Samples`, `Source Rig`, `Destination Rig Name`.
121
+ * Clik `Transfer Animation from Source Destination`, wait a few seconds.
122
+
123
+ We didn't tried other retargetting tools. Welcome to comment if you find others are more useful.
124
+
125
+ ### Scene
126
+
127
+ We use this [scene](https://drive.google.com/file/d/1lg62nugD7RTAIz0Q_YP2iZsxpUzzOkT1/view?usp=sharing) for animation.
128
+
129
+
130
+ </details>
131
+
132
+ ## :clapper: Temporal Inpainting
133
+ <details>
134
+ We conduct mask-based editing in the m-transformer stage, followed by the regeneration of residual tokens for the entire sequence. To load your own motion, provide the path through `--source_motion`. Utilize `-msec` to specify the mask section, supporting either ratio or frame index. For instance, `-msec 0.3,0.6` with `max_motion_length=196` is equivalent to `-msec 59,118`, indicating the editing of the frame section [59, 118].
135
+
136
+ ```
137
+ python edit_t2m.py --gpu_id 1 --ext exp3 --use_res_model -msec 0.4,0.7 --text_prompt "A man picks something from the ground using his right hand."
138
+ ```
139
+
140
+ Note: Presently, the source motion must adhere to the format of a HumanML3D dim-263 feature vector. An example motion vector data from the HumanML3D test set is available in `example_data/000612.npy`. To process your own motion data, you can utilize the `process_file` function from `utils/motion_process.py`.
141
+
142
+ </details>
143
+
144
+ ## :space_invader: Train Your Own Models
145
+ <details>
146
+
147
+
148
+ **Note**: You have to train RVQ **BEFORE** training masked/residual transformers. The latter two can be trained simultaneously.
149
+
150
+ ### Train RVQ
151
+ ```
152
+ python train_vq.py --name rvq_name --gpu_id 1 --dataset_name t2m --batch_size 512 --num_quantizers 6 --max_epoch 500 --quantize_drop_prob 0.2
153
+ ```
154
+
155
+ ### Train Masked Transformer
156
+ ```
157
+ python train_t2m_transformer.py --name mtrans_name --gpu_id 2 --dataset_name t2m --batch_size 64 --vq_name rvq_name
158
+ ```
159
+
160
+ ### Train Residual Transformer
161
+ ```
162
+ python train_res_transformer.py --name rtrans_name --gpu_id 2 --dataset_name t2m --batch_size 64 --vq_name rvq_name --cond_drop_prob 0.2 --share_weight
163
+ ```
164
+
165
+ * `--dataset_name`: motion dataset, `t2m` for HumanML3D and `kit` for KIT-ML.
166
+ * `--name`: name your model. This will create to model space as `./checkpoints/<dataset_name>/<name>`
167
+ * `--gpu_id`: GPU id.
168
+ * `--batch_size`: we use `512` for rvq training. For masked/residual transformer, we use `64` on HumanML3D and `16` for KIT-ML.
169
+ * `--num_quantizers`: number of quantization layers, `6` is used in our case.
170
+ * `--quantize_drop_prob`: quantization dropout ratio, `0.2` is used.
171
+ * `--vq_name`: when training masked/residual transformer, you need to specify the name of rvq model for tokenization.
172
+ * `--cond_drop_prob`: condition drop ratio, for classifier-free guidance. `0.2` is used.
173
+ * `--share_weight`: whether to share the projection/embedding weights in residual transformer.
174
+
175
+ All the pre-trained models and intermediate results will be saved in space `./checkpoints/<dataset_name>/<name>`.
176
+ </details>
177
+
178
+ ## :book: Evaluation
179
+ <details>
180
+
181
+ ### Evaluate RVQ Reconstruction:
182
+ HumanML3D:
183
+ ```
184
+ python eval_t2m_vq.py --gpu_id 0 --name rvq_nq6_dc512_nc512_noshare_qdp0.2 --dataset_name t2m --ext rvq_nq6
185
+
186
+ ```
187
+ KIT-ML:
188
+ ```
189
+ python eval_t2m_vq.py --gpu_id 0 --name rvq_nq6_dc512_nc512_noshare_qdp0.2_k --dataset_name kit --ext rvq_nq6
190
+ ```
191
+
192
+ ### Evaluate Text2motion Generation:
193
+ HumanML3D:
194
+ ```
195
+ python eval_t2m_trans_res.py --res_name tres_nlayer8_ld384_ff1024_rvq6ns_cdp0.2_sw --dataset_name t2m --name t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns --gpu_id 1 --cond_scale 4 --time_steps 10 --ext evaluation
196
+ ```
197
+ KIT-ML:
198
+ ```
199
+ python eval_t2m_trans_res.py --res_name tres_nlayer8_ld384_ff1024_rvq6ns_cdp0.2_sw_k --dataset_name kit --name t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns_k --gpu_id 0 --cond_scale 2 --time_steps 10 --ext evaluation
200
+ ```
201
+
202
+ * `--res_name`: model name of `residual transformer`.
203
+ * `--name`: model name of `masked transformer`.
204
+ * `--cond_scale`: scale of classifer-free guidance.
205
+ * `--time_steps`: number of iterations for inference.
206
+ * `--ext`: filename for saving evaluation results.
207
+
208
+ The final evaluation results will be saved in `./checkpoints/<dataset_name>/<name>/eval/<ext>.log`
209
+
210
+ </details>
211
+
212
+ ## Acknowlegements
213
+
214
+ We sincerely thank the open-sourcing of these works where our code is based on:
215
+
216
+ [deep-motion-editing](https://github.com/DeepMotionEditing/deep-motion-editing), [Muse](https://github.com/lucidrains/muse-maskgit-pytorch), [vector-quantize-pytorch](https://github.com/lucidrains/vector-quantize-pytorch), [T2M-GPT](https://github.com/Mael-zys/T2M-GPT), [MDM](https://github.com/GuyTevet/motion-diffusion-model/tree/main) and [MLD](https://github.com/ChenFengYe/motion-latent-diffusion/tree/main)
217
+
218
+ ## License
219
+ This code is distributed under an [MIT LICENSE](https://github.com/EricGuo5513/momask-codes/tree/main?tab=MIT-1-ov-file#readme).
220
+
221
+ Note that our code depends on other libraries, including SMPL, SMPL-X, PyTorch3D, and uses datasets which each have their own respective licenses that must also be followed.
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import os
3
+
4
+ import torch
5
+ import numpy as np
6
+ import gradio as gr
7
+ import gdown
8
+
9
+
10
+ WEBSITE = """
11
+ <div class="embed_hidden">
12
+ <h1 style='text-align: center'> MoMask: Generative Masked Modeling of 3D Human Motions </h1>
13
+ <h2 style='text-align: center'>
14
+ <a href="https://ericguo5513.github.io" target="_blank"><nobr>Chuan Guo*</nobr></a> &emsp;
15
+ <a href="https://yxmu.foo/" target="_blank"><nobr>Yuxuan Mu*</nobr></a> &emsp;
16
+ <a href="https://scholar.google.com/citations?user=w4e-j9sAAAAJ&hl=en" target="_blank"><nobr>Muhammad Gohar Javed*</nobr></a> &emsp;
17
+ <a href="https://sites.google.com/site/senwang1312home/" target="_blank"><nobr>Sen Wang</nobr></a> &emsp;
18
+ <a href="https://www.ece.ualberta.ca/~lcheng5/" target="_blank"><nobr>Li Cheng</nobr></a>
19
+ </h2>
20
+ <h2 style='text-align: center'>
21
+ <nobr>arXiv 2023</nobr>
22
+ </h2>
23
+ <h3 style="text-align:center;">
24
+ <a target="_blank" href="https://arxiv.org/abs/2312.00063"> <button type="button" class="btn btn-primary btn-lg"> Paper </button></a> &ensp;
25
+ <a target="_blank" href="https://github.com/EricGuo5513/momask-codes"> <button type="button" class="btn btn-primary btn-lg"> Code </button></a> &ensp;
26
+ <a target="_blank" href="https://ericguo5513.github.io/momask/"> <button type="button" class="btn btn-primary btn-lg"> Webpage </button></a> &ensp;
27
+ <a target="_blank" href="https://ericguo5513.github.io/source_files/momask_2023_bib.txt"> <button type="button" class="btn btn-primary btn-lg"> BibTex </button></a>
28
+ </h3>
29
+ <h3> Description </h3>
30
+ <p>
31
+ This space illustrates <a href='https://ericguo5513.github.io/momask/' target='_blank'><b>MoMask</b></a>, a method for text-to-motion generation.
32
+ </p>
33
+ </div>
34
+ """
35
+
36
+ EXAMPLES = [
37
+ "A person is walking slowly",
38
+ "A person is walking in a circle",
39
+ "A person is jumping rope",
40
+ "Someone is doing a backflip",
41
+ "A person is doing a moonwalk",
42
+ "A person walks forward and then turns back",
43
+ "Picking up an object",
44
+ "A person is swimming in the sea",
45
+ "A human is squatting",
46
+ "Someone is jumping with one foot",
47
+ "A person is chopping vegetables",
48
+ "Someone walks backward",
49
+ "Somebody is ascending a staircase",
50
+ "A person is sitting down",
51
+ "A person is taking the stairs",
52
+ "Someone is doing jumping jacks",
53
+ "The person walked forward and is picking up his toolbox",
54
+ "The person angrily punching the air",
55
+ ]
56
+
57
+ # Show closest text in the training
58
+
59
+
60
+ # css to make videos look nice
61
+ # var(--block-border-color); TODO
62
+ CSS = """
63
+ .retrieved_video {
64
+ position: relative;
65
+ margin: 0;
66
+ box-shadow: var(--block-shadow);
67
+ border-width: var(--block-border-width);
68
+ border-color: #000000;
69
+ border-radius: var(--block-radius);
70
+ background: var(--block-background-fill);
71
+ width: 100%;
72
+ line-height: var(--line-sm);
73
+ }
74
+ }
75
+ """
76
+
77
+
78
+ DEFAULT_TEXT = "A person is "
79
+
80
+ def generate(
81
+ text, uid, motion_length=0, seed=351540, repeat_times=4,
82
+ ):
83
+ os.system(f'python gen_t2m.py --gpu_id 0 --seed {seed} --ext {uid} --repeat_times {repeat_times} --motion_length {motion_length} --text_prompt {text}')
84
+ datas = []
85
+ for n in repeat_times:
86
+ data_unit = {
87
+ "url": f"./generation/{uid}/animations/0/sample0_repeat{n}_len196_ik.mp4"
88
+ }
89
+ datas.append(data_unit)
90
+ return datas
91
+
92
+
93
+ # HTML component
94
+ def get_video_html(data, video_id, width=700, height=700):
95
+ url = data["url"]
96
+ # class="wrap default svelte-gjihhp hide"
97
+ # <div class="contour_video" style="position: absolute; padding: 10px;">
98
+ # width="{width}" height="{height}"
99
+ video_html = f"""
100
+ <video class="retrieved_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()"
101
+ autoplay loop disablepictureinpicture id="{video_id}">
102
+ <source src="{url}" type="video/mp4">
103
+ Your browser does not support the video tag.
104
+ </video>
105
+ """
106
+ return video_html
107
+
108
+
109
+ def generate_component(generate_function, text):
110
+ if text == DEFAULT_TEXT or text == "" or text is None:
111
+ return [None for _ in range(4)]
112
+
113
+ datas = generate_function(text, )
114
+ htmls = [get_video_html(data, idx) for idx, data in enumerate(datas)]
115
+ return htmls
116
+
117
+
118
+ if not os.path.exists("checkpoints/t2m"):
119
+ os.system("bash prepare/download_models.sh")
120
+
121
+
122
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
+
124
+ # LOADING
125
+
126
+ # DEMO
127
+ theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
128
+ generate_and_show = partial(generate_component, generate)
129
+
130
+ with gr.Blocks(css=CSS, theme=theme) as demo:
131
+ gr.Markdown(WEBSITE)
132
+ videos = []
133
+
134
+ with gr.Row():
135
+ with gr.Column(scale=3):
136
+ with gr.Column(scale=2):
137
+ text = gr.Textbox(
138
+ show_label=True,
139
+ label="Text prompt",
140
+ value=DEFAULT_TEXT,
141
+ )
142
+ with gr.Column(scale=1):
143
+ gen_btn = gr.Button("Generate", variant="primary")
144
+ clear = gr.Button("Clear", variant="secondary")
145
+
146
+ with gr.Column(scale=2):
147
+
148
+ def generate_example(text):
149
+ return generate_and_show(text)
150
+
151
+ examples = gr.Examples(
152
+ examples=[[x, None, None] for x in EXAMPLES],
153
+ inputs=[text],
154
+ examples_per_page=20,
155
+ run_on_click=False,
156
+ cache_examples=False,
157
+ fn=generate_example,
158
+ outputs=[],
159
+ )
160
+
161
+ i = -1
162
+ # should indent
163
+ for _ in range(1):
164
+ with gr.Row():
165
+ for _ in range(4):
166
+ i += 1
167
+ video = gr.HTML()
168
+ videos.append(video)
169
+
170
+ # connect the examples to the output
171
+ # a bit hacky
172
+ examples.outputs = videos
173
+
174
+ def load_example(example_id):
175
+ processed_example = examples.non_none_processed_examples[example_id]
176
+ return gr.utils.resolve_singleton(processed_example)
177
+
178
+ examples.dataset.click(
179
+ load_example,
180
+ inputs=[examples.dataset],
181
+ outputs=examples.inputs_with_examples, # type: ignore
182
+ show_progress=False,
183
+ postprocess=False,
184
+ queue=False,
185
+ ).then(fn=generate_example, inputs=examples.inputs, outputs=videos)
186
+
187
+ gen_btn.click(
188
+ fn=generate_and_show,
189
+ inputs=[text],
190
+ outputs=videos,
191
+ )
192
+ text.submit(
193
+ fn=generate_and_show,
194
+ inputs=[text],
195
+ outputs=videos,
196
+ )
197
+
198
+ def clear_videos():
199
+ return [None for x in range(4)] + [DEFAULT_TEXT]
200
+
201
+ clear.click(fn=clear_videos, outputs=videos + [text])
202
+
203
+ demo.launch()
assets/mapping.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bones": [{"name": "Hips", "label": "", "description": "", "SourceBoneName": "Hips", "DestinationBoneName": "mixamorig:Hips", "keyframe_this_bone": true, "CorrectionFactorX": 2.6179938316345215, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": true, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 0.2588190734386444, "QuatCorrectionFactorx": 0.965925931930542, "QuatCorrectionFactory": 2.7939677238464355e-09, "QuatCorrectionFactorz": -2.7939677238464355e-09, "scale_secondary_bone_name": ""}, {"name": "RightUpLeg", "label": "", "description": "", "SourceBoneName": "RightUpLeg", "DestinationBoneName": "mixamorig:RightUpLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftUpLeg", "label": "", "description": "", "SourceBoneName": "LeftUpLeg", "DestinationBoneName": "mixamorig:LeftUpLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightLeg", "label": "", "description": "", "SourceBoneName": "RightLeg", "DestinationBoneName": "mixamorig:RightLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 2.094395160675049, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftLeg", "label": "", "description": "", "SourceBoneName": "LeftLeg", "DestinationBoneName": "mixamorig:LeftLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 3.665191411972046, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightShoulder", "label": "", "description": "", "SourceBoneName": "RightShoulder", "DestinationBoneName": "mixamorig:RightShoulder", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftShoulder", "label": "", "description": "", "SourceBoneName": "LeftShoulder", "DestinationBoneName": "mixamorig:LeftShoulder", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightArm", "label": "", "description": "", "SourceBoneName": "RightArm", "DestinationBoneName": "mixamorig:RightArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": -1.0471975803375244, "CorrectionFactorZ": -0.1745329201221466, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftArm", "label": "", "description": "", "SourceBoneName": "LeftArm", "DestinationBoneName": "mixamorig:LeftArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 1.0471975803375244, "CorrectionFactorZ": 0.1745329201221466, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightForeArm", "label": "", "description": "", "SourceBoneName": "RightForeArm", "DestinationBoneName": "mixamorig:RightForeArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": -2.094395160675049, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftForeArm", "label": "", "description": "", "SourceBoneName": "LeftForeArm", "DestinationBoneName": "mixamorig:LeftForeArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 1.5707963705062866, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Spine", "label": "", "description": "", "SourceBoneName": "Spine", "DestinationBoneName": "mixamorig:Spine", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Spine1", "label": "", "description": "", "SourceBoneName": "Spine1", "DestinationBoneName": "mixamorig:Spine1", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Spine2", "label": "", "description": "", "SourceBoneName": "Spine2", "DestinationBoneName": "mixamorig:Spine2", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Neck", "label": "", "description": "", "SourceBoneName": "Neck", "DestinationBoneName": "mixamorig:Neck", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Head", "label": "", "description": "", "SourceBoneName": "Head", "DestinationBoneName": "mixamorig:Head", "keyframe_this_bone": true, "CorrectionFactorX": 0.3490658402442932, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightFoot", "label": "", "description": "", "SourceBoneName": "RightFoot", "DestinationBoneName": "mixamorig:RightFoot", "keyframe_this_bone": true, "CorrectionFactorX": -0.19192171096801758, "CorrectionFactorY": 2.979980945587158, "CorrectionFactorZ": -0.05134282633662224, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": -0.082771435379982, "QuatCorrectionFactorx": -0.0177358016371727, "QuatCorrectionFactory": -0.9920229315757751, "QuatCorrectionFactorz": -0.09340716898441315, "scale_secondary_bone_name": ""}, {"name": "LeftFoot", "label": "", "description": "", "SourceBoneName": "LeftFoot", "DestinationBoneName": "mixamorig:LeftFoot", "keyframe_this_bone": true, "CorrectionFactorX": -0.25592508912086487, "CorrectionFactorY": -2.936899423599243, "CorrectionFactorZ": 0.2450830191373825, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 0.11609010398387909, "QuatCorrectionFactorx": 0.10766097158193588, "QuatCorrectionFactory": -0.9808290004730225, "QuatCorrectionFactorz": -0.11360746622085571, "scale_secondary_bone_name": ""}], "start_frame_to_apply": 0, "number_of_frames_to_apply": 196, "keyframe_every_n_frames": 1, "source_rig_name": "bvh_batch1_sample30_repeat1_len48", "destination_rig_name": "Armature", "bone_rotation_mode": "EULER", "bone_mapping_file": "C:\\Users\\cguo2\\Documents\\CVPR2024_MoMask\\mapping.json"}
assets/mapping6.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bones": [{"name": "Hips", "label": "", "description": "", "SourceBoneName": "Hips", "DestinationBoneName": "mixamorig6:Hips", "keyframe_this_bone": true, "CorrectionFactorX": 2.6179938316345215, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": true, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 0.2588190734386444, "QuatCorrectionFactorx": 0.965925931930542, "QuatCorrectionFactory": 2.7939677238464355e-09, "QuatCorrectionFactorz": -2.7939677238464355e-09, "scale_secondary_bone_name": ""}, {"name": "RightUpLeg", "label": "", "description": "", "SourceBoneName": "RightUpLeg", "DestinationBoneName": "mixamorig6:RightUpLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftUpLeg", "label": "", "description": "", "SourceBoneName": "LeftUpLeg", "DestinationBoneName": "mixamorig6:LeftUpLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightLeg", "label": "", "description": "", "SourceBoneName": "RightLeg", "DestinationBoneName": "mixamorig6:RightLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 2.094395160675049, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftLeg", "label": "", "description": "", "SourceBoneName": "LeftLeg", "DestinationBoneName": "mixamorig6:LeftLeg", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 3.665191411972046, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightShoulder", "label": "", "description": "", "SourceBoneName": "RightShoulder", "DestinationBoneName": "mixamorig6:RightShoulder", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftShoulder", "label": "", "description": "", "SourceBoneName": "LeftShoulder", "DestinationBoneName": "mixamorig6:LeftShoulder", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightArm", "label": "", "description": "", "SourceBoneName": "RightArm", "DestinationBoneName": "mixamorig6:RightArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": -1.0471975803375244, "CorrectionFactorZ": -0.1745329201221466, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftArm", "label": "", "description": "", "SourceBoneName": "LeftArm", "DestinationBoneName": "mixamorig6:LeftArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 1.0471975803375244, "CorrectionFactorZ": 0.1745329201221466, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "RightForeArm", "label": "", "description": "", "SourceBoneName": "RightForeArm", "DestinationBoneName": "mixamorig6:RightForeArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": -2.094395160675049, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "LeftForeArm", "label": "", "description": "", "SourceBoneName": "LeftForeArm", "DestinationBoneName": "mixamorig6:LeftForeArm", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 1.5707963705062866, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Spine", "label": "", "description": "", "SourceBoneName": "Spine", "DestinationBoneName": "mixamorig6:Spine", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Spine1", "label": "", "description": "", "SourceBoneName": "Spine1", "DestinationBoneName": "mixamorig6:Spine1", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Spine2", "label": "", "description": "", "SourceBoneName": "Spine2", "DestinationBoneName": "mixamorig6:Spine2", "keyframe_this_bone": true, "CorrectionFactorX": 0.0, "CorrectionFactorY": 0.0, "CorrectionFactorZ": 0.0, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 1.0, "QuatCorrectionFactorx": 0.0, "QuatCorrectionFactory": 0.0, "QuatCorrectionFactorz": 0.0, "scale_secondary_bone_name": ""}, {"name": "Neck", "label": "", "description": "", "SourceBoneName": "Neck", "DestinationBoneName": "mixamorig6:Neck", "keyframe_this_bone": true, "CorrectionFactorX": -0.994345486164093, "CorrectionFactorY": -0.006703000050038099, "CorrectionFactorZ": 0.04061730206012726, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 0.8787809014320374, "QuatCorrectionFactorx": -0.4767816960811615, "QuatCorrectionFactory": -0.01263047568500042, "QuatCorrectionFactorz": 0.016250507906079292, "scale_secondary_bone_name": ""}, {"name": "Head", "label": "", "description": "", "SourceBoneName": "Head", "DestinationBoneName": "mixamorig6:Head", "keyframe_this_bone": true, "CorrectionFactorX": -0.07639937847852707, "CorrectionFactorY": 0.011205507442355156, "CorrectionFactorZ": 0.011367863975465298, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": 0.9992374181747437, "QuatCorrectionFactorx": -0.038221005350351334, "QuatCorrectionFactory": 0.0053814793936908245, "QuatCorrectionFactorz": 0.005893632769584656, "scale_secondary_bone_name": ""}, {"name": "RightFoot", "label": "", "description": "", "SourceBoneName": "RightFoot", "DestinationBoneName": "mixamorig6:RightFoot", "keyframe_this_bone": true, "CorrectionFactorX": -0.17194896936416626, "CorrectionFactorY": 2.7372374534606934, "CorrectionFactorZ": -0.029542576521635056, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": -0.20128199458122253, "QuatCorrectionFactorx": 0.002824343740940094, "QuatCorrectionFactory": -0.9761614799499512, "QuatCorrectionFactorz": -0.08115538209676743, "scale_secondary_bone_name": ""}, {"name": "LeftFoot", "label": "", "description": "", "SourceBoneName": "LeftFoot", "DestinationBoneName": "mixamorig6:LeftFoot", "keyframe_this_bone": true, "CorrectionFactorX": -0.09363158047199249, "CorrectionFactorY": -2.9336421489715576, "CorrectionFactorZ": -0.17343592643737793, "has_twist_bone": false, "TwistBoneName": "", "set_bone_position": false, "set_bone_rotation": true, "bone_rotation_application_axis": "XYZ", "position_correction_factorX": 0.0, "position_correction_factorY": 0.0, "position_correction_factorZ": 0.0, "position_gain": 1.0, "position_pole_distance": 0.30000001192092896, "postion_type": "SINGLE_BONE_OFFSET", "set_bone_scale": false, "scale_gain": 1.0, "scale_max": 1.0, "scale_min": 0.5, "bone_scale_application_axis": "Y", "QuatCorrectionFactorw": -0.09925344586372375, "QuatCorrectionFactorx": 0.09088610112667084, "QuatCorrectionFactory": 0.9893556833267212, "QuatCorrectionFactorz": 0.05535021424293518, "scale_secondary_bone_name": ""}], "start_frame_to_apply": 0, "number_of_frames_to_apply": 196, "keyframe_every_n_frames": 1, "source_rig_name": "MoMask__02_ik", "destination_rig_name": "Armature", "bone_rotation_mode": "EULER", "bone_mapping_file": "C:\\Users\\cguo2\\Documents\\CVPR2024_MoMask\\mapping6.json"}
assets/text_prompt.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ the person holds his left foot with his left hand, puts his right foot up and left hand up too.#132
2
+ a man bends down and picks something up with his left hand.#84
3
+ A man stands for few seconds and picks up his arms and shakes them.#176
4
+ A person walks with a limp, their left leg get injured.#192
5
+ a person jumps up and then lands.#52
6
+ a person performs a standing back kick.#52
7
+ A person pokes their right hand along the ground, like they might be planting seeds.#60
8
+ the person steps forward and uses the left leg to kick something forward.#92
9
+ the man walked forward, spun right on one foot and walked back to his original position.#92
10
+ the person was pushed but did not fall.#124
11
+ this person stumbles left and right while moving forward.#132
12
+ a person reaching down and picking something up.#148
common/__init__.py ADDED
File without changes
common/quaternion.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+ import numpy as np
10
+
11
+ _EPS4 = np.finfo(float).eps * 4.0
12
+
13
+ _FLOAT_EPS = np.finfo(np.float).eps
14
+
15
+ # PyTorch-backed implementations
16
+ def qinv(q):
17
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
18
+ mask = torch.ones_like(q)
19
+ mask[..., 1:] = -mask[..., 1:]
20
+ return q * mask
21
+
22
+
23
+ def qinv_np(q):
24
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25
+ return qinv(torch.from_numpy(q).float()).numpy()
26
+
27
+
28
+ def qnormalize(q):
29
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
30
+ return q / torch.norm(q, dim=-1, keepdim=True)
31
+
32
+
33
+ def qmul(q, r):
34
+ """
35
+ Multiply quaternion(s) q with quaternion(s) r.
36
+ Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
37
+ Returns q*r as a tensor of shape (*, 4).
38
+ """
39
+ assert q.shape[-1] == 4
40
+ assert r.shape[-1] == 4
41
+
42
+ original_shape = q.shape
43
+
44
+ # Compute outer product
45
+ terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
46
+
47
+ w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
48
+ x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
49
+ y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
50
+ z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
51
+ return torch.stack((w, x, y, z), dim=1).view(original_shape)
52
+
53
+
54
+ def qrot(q, v):
55
+ """
56
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
57
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
58
+ where * denotes any number of dimensions.
59
+ Returns a tensor of shape (*, 3).
60
+ """
61
+ assert q.shape[-1] == 4
62
+ assert v.shape[-1] == 3
63
+ assert q.shape[:-1] == v.shape[:-1]
64
+
65
+ original_shape = list(v.shape)
66
+ # print(q.shape)
67
+ q = q.contiguous().view(-1, 4)
68
+ v = v.contiguous().view(-1, 3)
69
+
70
+ qvec = q[:, 1:]
71
+ uv = torch.cross(qvec, v, dim=1)
72
+ uuv = torch.cross(qvec, uv, dim=1)
73
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
74
+
75
+
76
+ def qeuler(q, order, epsilon=0, deg=True):
77
+ """
78
+ Convert quaternion(s) q to Euler angles.
79
+ Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
80
+ Returns a tensor of shape (*, 3).
81
+ """
82
+ assert q.shape[-1] == 4
83
+
84
+ original_shape = list(q.shape)
85
+ original_shape[-1] = 3
86
+ q = q.view(-1, 4)
87
+
88
+ q0 = q[:, 0]
89
+ q1 = q[:, 1]
90
+ q2 = q[:, 2]
91
+ q3 = q[:, 3]
92
+
93
+ if order == 'xyz':
94
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
95
+ y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
96
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
97
+ elif order == 'yzx':
98
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
99
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
100
+ z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
101
+ elif order == 'zxy':
102
+ x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
103
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
104
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
105
+ elif order == 'xzy':
106
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
107
+ y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
108
+ z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
109
+ elif order == 'yxz':
110
+ x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
111
+ y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
112
+ z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
113
+ elif order == 'zyx':
114
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
115
+ y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
116
+ z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
117
+ else:
118
+ raise
119
+
120
+ if deg:
121
+ return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
122
+ else:
123
+ return torch.stack((x, y, z), dim=1).view(original_shape)
124
+
125
+
126
+ # Numpy-backed implementations
127
+
128
+ def qmul_np(q, r):
129
+ q = torch.from_numpy(q).contiguous().float()
130
+ r = torch.from_numpy(r).contiguous().float()
131
+ return qmul(q, r).numpy()
132
+
133
+
134
+ def qrot_np(q, v):
135
+ q = torch.from_numpy(q).contiguous().float()
136
+ v = torch.from_numpy(v).contiguous().float()
137
+ return qrot(q, v).numpy()
138
+
139
+
140
+ def qeuler_np(q, order, epsilon=0, use_gpu=False):
141
+ if use_gpu:
142
+ q = torch.from_numpy(q).cuda().float()
143
+ return qeuler(q, order, epsilon).cpu().numpy()
144
+ else:
145
+ q = torch.from_numpy(q).contiguous().float()
146
+ return qeuler(q, order, epsilon).numpy()
147
+
148
+
149
+ def qfix(q):
150
+ """
151
+ Enforce quaternion continuity across the time dimension by selecting
152
+ the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
153
+ between two consecutive frames.
154
+
155
+ Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
156
+ Returns a tensor of the same shape.
157
+ """
158
+ assert len(q.shape) == 3
159
+ assert q.shape[-1] == 4
160
+
161
+ result = q.copy()
162
+ dot_products = np.sum(q[1:] * q[:-1], axis=2)
163
+ mask = dot_products < 0
164
+ mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
165
+ result[1:][mask] *= -1
166
+ return result
167
+
168
+
169
+ def euler2quat(e, order, deg=True):
170
+ """
171
+ Convert Euler angles to quaternions.
172
+ """
173
+ assert e.shape[-1] == 3
174
+
175
+ original_shape = list(e.shape)
176
+ original_shape[-1] = 4
177
+
178
+ e = e.view(-1, 3)
179
+
180
+ ## if euler angles in degrees
181
+ if deg:
182
+ e = e * np.pi / 180.
183
+
184
+ x = e[:, 0]
185
+ y = e[:, 1]
186
+ z = e[:, 2]
187
+
188
+ rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
189
+ ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
190
+ rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
191
+
192
+ result = None
193
+ for coord in order:
194
+ if coord == 'x':
195
+ r = rx
196
+ elif coord == 'y':
197
+ r = ry
198
+ elif coord == 'z':
199
+ r = rz
200
+ else:
201
+ raise
202
+ if result is None:
203
+ result = r
204
+ else:
205
+ result = qmul(result, r)
206
+
207
+ # Reverse antipodal representation to have a non-negative "w"
208
+ if order in ['xyz', 'yzx', 'zxy']:
209
+ result *= -1
210
+
211
+ return result.view(original_shape)
212
+
213
+
214
+ def expmap_to_quaternion(e):
215
+ """
216
+ Convert axis-angle rotations (aka exponential maps) to quaternions.
217
+ Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
218
+ Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
219
+ Returns a tensor of shape (*, 4).
220
+ """
221
+ assert e.shape[-1] == 3
222
+
223
+ original_shape = list(e.shape)
224
+ original_shape[-1] = 4
225
+ e = e.reshape(-1, 3)
226
+
227
+ theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
228
+ w = np.cos(0.5 * theta).reshape(-1, 1)
229
+ xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
230
+ return np.concatenate((w, xyz), axis=1).reshape(original_shape)
231
+
232
+
233
+ def euler_to_quaternion(e, order):
234
+ """
235
+ Convert Euler angles to quaternions.
236
+ """
237
+ assert e.shape[-1] == 3
238
+
239
+ original_shape = list(e.shape)
240
+ original_shape[-1] = 4
241
+
242
+ e = e.reshape(-1, 3)
243
+
244
+ x = e[:, 0]
245
+ y = e[:, 1]
246
+ z = e[:, 2]
247
+
248
+ rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
249
+ ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
250
+ rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
251
+
252
+ result = None
253
+ for coord in order:
254
+ if coord == 'x':
255
+ r = rx
256
+ elif coord == 'y':
257
+ r = ry
258
+ elif coord == 'z':
259
+ r = rz
260
+ else:
261
+ raise
262
+ if result is None:
263
+ result = r
264
+ else:
265
+ result = qmul_np(result, r)
266
+
267
+ # Reverse antipodal representation to have a non-negative "w"
268
+ if order in ['xyz', 'yzx', 'zxy']:
269
+ result *= -1
270
+
271
+ return result.reshape(original_shape)
272
+
273
+
274
+ def quaternion_to_matrix(quaternions):
275
+ """
276
+ Convert rotations given as quaternions to rotation matrices.
277
+ Args:
278
+ quaternions: quaternions with real part first,
279
+ as tensor of shape (..., 4).
280
+ Returns:
281
+ Rotation matrices as tensor of shape (..., 3, 3).
282
+ """
283
+ r, i, j, k = torch.unbind(quaternions, -1)
284
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
285
+
286
+ o = torch.stack(
287
+ (
288
+ 1 - two_s * (j * j + k * k),
289
+ two_s * (i * j - k * r),
290
+ two_s * (i * k + j * r),
291
+ two_s * (i * j + k * r),
292
+ 1 - two_s * (i * i + k * k),
293
+ two_s * (j * k - i * r),
294
+ two_s * (i * k - j * r),
295
+ two_s * (j * k + i * r),
296
+ 1 - two_s * (i * i + j * j),
297
+ ),
298
+ -1,
299
+ )
300
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
301
+
302
+
303
+ def quaternion_to_matrix_np(quaternions):
304
+ q = torch.from_numpy(quaternions).contiguous().float()
305
+ return quaternion_to_matrix(q).numpy()
306
+
307
+
308
+ def quaternion_to_cont6d_np(quaternions):
309
+ rotation_mat = quaternion_to_matrix_np(quaternions)
310
+ cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
311
+ return cont_6d
312
+
313
+
314
+ def quaternion_to_cont6d(quaternions):
315
+ rotation_mat = quaternion_to_matrix(quaternions)
316
+ cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
317
+ return cont_6d
318
+
319
+
320
+ def cont6d_to_matrix(cont6d):
321
+ assert cont6d.shape[-1] == 6, "The last dimension must be 6"
322
+ x_raw = cont6d[..., 0:3]
323
+ y_raw = cont6d[..., 3:6]
324
+
325
+ x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
326
+ z = torch.cross(x, y_raw, dim=-1)
327
+ z = z / torch.norm(z, dim=-1, keepdim=True)
328
+
329
+ y = torch.cross(z, x, dim=-1)
330
+
331
+ x = x[..., None]
332
+ y = y[..., None]
333
+ z = z[..., None]
334
+
335
+ mat = torch.cat([x, y, z], dim=-1)
336
+ return mat
337
+
338
+
339
+ def cont6d_to_matrix_np(cont6d):
340
+ q = torch.from_numpy(cont6d).contiguous().float()
341
+ return cont6d_to_matrix(q).numpy()
342
+
343
+
344
+ def qpow(q0, t, dtype=torch.float):
345
+ ''' q0 : tensor of quaternions
346
+ t: tensor of powers
347
+ '''
348
+ q0 = qnormalize(q0)
349
+ theta0 = torch.acos(q0[..., 0])
350
+
351
+ ## if theta0 is close to zero, add epsilon to avoid NaNs
352
+ mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
353
+ theta0 = (1 - mask) * theta0 + mask * 10e-10
354
+ v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
355
+
356
+ if isinstance(t, torch.Tensor):
357
+ q = torch.zeros(t.shape + q0.shape)
358
+ theta = t.view(-1, 1) * theta0.view(1, -1)
359
+ else: ## if t is a number
360
+ q = torch.zeros(q0.shape)
361
+ theta = t * theta0
362
+
363
+ q[..., 0] = torch.cos(theta)
364
+ q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
365
+
366
+ return q.to(dtype)
367
+
368
+
369
+ def qslerp(q0, q1, t):
370
+ '''
371
+ q0: starting quaternion
372
+ q1: ending quaternion
373
+ t: array of points along the way
374
+
375
+ Returns:
376
+ Tensor of Slerps: t.shape + q0.shape
377
+ '''
378
+
379
+ q0 = qnormalize(q0)
380
+ q1 = qnormalize(q1)
381
+ q_ = qpow(qmul(q1, qinv(q0)), t)
382
+
383
+ return qmul(q_,
384
+ q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
385
+
386
+
387
+ def qbetween(v0, v1):
388
+ '''
389
+ find the quaternion used to rotate v0 to v1
390
+ '''
391
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
392
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
393
+
394
+ v = torch.cross(v0, v1)
395
+ w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
396
+ keepdim=True)
397
+ return qnormalize(torch.cat([w, v], dim=-1))
398
+
399
+
400
+ def qbetween_np(v0, v1):
401
+ '''
402
+ find the quaternion used to rotate v0 to v1
403
+ '''
404
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
405
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
406
+
407
+ v0 = torch.from_numpy(v0).float()
408
+ v1 = torch.from_numpy(v1).float()
409
+ return qbetween(v0, v1).numpy()
410
+
411
+
412
+ def lerp(p0, p1, t):
413
+ if not isinstance(t, torch.Tensor):
414
+ t = torch.Tensor([t])
415
+
416
+ new_shape = t.shape + p0.shape
417
+ new_view_t = t.shape + torch.Size([1] * len(p0.shape))
418
+ new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
419
+ p0 = p0.view(new_view_p).expand(new_shape)
420
+ p1 = p1.view(new_view_p).expand(new_shape)
421
+ t = t.view(new_view_t).expand(new_shape)
422
+
423
+ return p0 + t * (p1 - p0)
common/skeleton.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from common.quaternion import *
2
+ import scipy.ndimage.filters as filters
3
+
4
+ class Skeleton(object):
5
+ def __init__(self, offset, kinematic_tree, device):
6
+ self.device = device
7
+ self._raw_offset_np = offset.numpy()
8
+ self._raw_offset = offset.clone().detach().to(device).float()
9
+ self._kinematic_tree = kinematic_tree
10
+ self._offset = None
11
+ self._parents = [0] * len(self._raw_offset)
12
+ self._parents[0] = -1
13
+ for chain in self._kinematic_tree:
14
+ for j in range(1, len(chain)):
15
+ self._parents[chain[j]] = chain[j-1]
16
+
17
+ def njoints(self):
18
+ return len(self._raw_offset)
19
+
20
+ def offset(self):
21
+ return self._offset
22
+
23
+ def set_offset(self, offsets):
24
+ self._offset = offsets.clone().detach().to(self.device).float()
25
+
26
+ def kinematic_tree(self):
27
+ return self._kinematic_tree
28
+
29
+ def parents(self):
30
+ return self._parents
31
+
32
+ # joints (batch_size, joints_num, 3)
33
+ def get_offsets_joints_batch(self, joints):
34
+ assert len(joints.shape) == 3
35
+ _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
36
+ for i in range(1, self._raw_offset.shape[0]):
37
+ _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
38
+
39
+ self._offset = _offsets.detach()
40
+ return _offsets
41
+
42
+ # joints (joints_num, 3)
43
+ def get_offsets_joints(self, joints):
44
+ assert len(joints.shape) == 2
45
+ _offsets = self._raw_offset.clone()
46
+ for i in range(1, self._raw_offset.shape[0]):
47
+ # print(joints.shape)
48
+ _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
49
+
50
+ self._offset = _offsets.detach()
51
+ return _offsets
52
+
53
+ # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
54
+ # joints (batch_size, joints_num, 3)
55
+ def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):
56
+ assert len(face_joint_idx) == 4
57
+ '''Get Forward Direction'''
58
+ l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
59
+ across1 = joints[:, r_hip] - joints[:, l_hip]
60
+ across2 = joints[:, sdr_r] - joints[:, sdr_l]
61
+ across = across1 + across2
62
+ across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis]
63
+ # print(across1.shape, across2.shape)
64
+
65
+ # forward (batch_size, 3)
66
+ forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
67
+ if smooth_forward:
68
+ forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
69
+ # forward (batch_size, 3)
70
+ forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis]
71
+
72
+ '''Get Root Rotation'''
73
+ target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
74
+ root_quat = qbetween_np(forward, target)
75
+
76
+ '''Inverse Kinematics'''
77
+ # quat_params (batch_size, joints_num, 4)
78
+ # print(joints.shape[:-1])
79
+ quat_params = np.zeros(joints.shape[:-1] + (4,))
80
+ # print(quat_params.shape)
81
+ root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])
82
+ quat_params[:, 0] = root_quat
83
+ # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]])
84
+ for chain in self._kinematic_tree:
85
+ R = root_quat
86
+ for j in range(len(chain) - 1):
87
+ # (batch, 3)
88
+ u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)
89
+ # print(u.shape)
90
+ # (batch, 3)
91
+ v = joints[:, chain[j+1]] - joints[:, chain[j]]
92
+ v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis]
93
+ # print(u.shape, v.shape)
94
+ rot_u_v = qbetween_np(u, v)
95
+
96
+ R_loc = qmul_np(qinv_np(R), rot_u_v)
97
+
98
+ quat_params[:,chain[j + 1], :] = R_loc
99
+ R = qmul_np(R, R_loc)
100
+
101
+ return quat_params
102
+
103
+ # Be sure root joint is at the beginning of kinematic chains
104
+ def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
105
+ # quat_params (batch_size, joints_num, 4)
106
+ # joints (batch_size, joints_num, 3)
107
+ # root_pos (batch_size, 3)
108
+ if skel_joints is not None:
109
+ offsets = self.get_offsets_joints_batch(skel_joints)
110
+ if len(self._offset.shape) == 2:
111
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
112
+ joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
113
+ joints[:, 0] = root_pos
114
+ for chain in self._kinematic_tree:
115
+ if do_root_R:
116
+ R = quat_params[:, 0]
117
+ else:
118
+ R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
119
+ for i in range(1, len(chain)):
120
+ R = qmul(R, quat_params[:, chain[i]])
121
+ offset_vec = offsets[:, chain[i]]
122
+ joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
123
+ return joints
124
+
125
+ # Be sure root joint is at the beginning of kinematic chains
126
+ def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
127
+ # quat_params (batch_size, joints_num, 4)
128
+ # joints (batch_size, joints_num, 3)
129
+ # root_pos (batch_size, 3)
130
+ if skel_joints is not None:
131
+ skel_joints = torch.from_numpy(skel_joints)
132
+ offsets = self.get_offsets_joints_batch(skel_joints)
133
+ if len(self._offset.shape) == 2:
134
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
135
+ offsets = offsets.numpy()
136
+ joints = np.zeros(quat_params.shape[:-1] + (3,))
137
+ joints[:, 0] = root_pos
138
+ for chain in self._kinematic_tree:
139
+ if do_root_R:
140
+ R = quat_params[:, 0]
141
+ else:
142
+ R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
143
+ for i in range(1, len(chain)):
144
+ R = qmul_np(R, quat_params[:, chain[i]])
145
+ offset_vec = offsets[:, chain[i]]
146
+ joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
147
+ return joints
148
+
149
+ def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
150
+ # cont6d_params (batch_size, joints_num, 6)
151
+ # joints (batch_size, joints_num, 3)
152
+ # root_pos (batch_size, 3)
153
+ if skel_joints is not None:
154
+ skel_joints = torch.from_numpy(skel_joints)
155
+ offsets = self.get_offsets_joints_batch(skel_joints)
156
+ if len(self._offset.shape) == 2:
157
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
158
+ offsets = offsets.numpy()
159
+ joints = np.zeros(cont6d_params.shape[:-1] + (3,))
160
+ joints[:, 0] = root_pos
161
+ for chain in self._kinematic_tree:
162
+ if do_root_R:
163
+ matR = cont6d_to_matrix_np(cont6d_params[:, 0])
164
+ else:
165
+ matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
166
+ for i in range(1, len(chain)):
167
+ matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
168
+ offset_vec = offsets[:, chain[i]][..., np.newaxis]
169
+ # print(matR.shape, offset_vec.shape)
170
+ joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
171
+ return joints
172
+
173
+ def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
174
+ # cont6d_params (batch_size, joints_num, 6)
175
+ # joints (batch_size, joints_num, 3)
176
+ # root_pos (batch_size, 3)
177
+ if skel_joints is not None:
178
+ # skel_joints = torch.from_numpy(skel_joints)
179
+ offsets = self.get_offsets_joints_batch(skel_joints)
180
+ if len(self._offset.shape) == 2:
181
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
182
+ joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
183
+ joints[..., 0, :] = root_pos
184
+ for chain in self._kinematic_tree:
185
+ if do_root_R:
186
+ matR = cont6d_to_matrix(cont6d_params[:, 0])
187
+ else:
188
+ matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
189
+ for i in range(1, len(chain)):
190
+ matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
191
+ offset_vec = offsets[:, chain[i]].unsqueeze(-1)
192
+ # print(matR.shape, offset_vec.shape)
193
+ joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
194
+ return joints
195
+
196
+
197
+
198
+
199
+
data/__init__.py ADDED
File without changes
data/t2m_dataset.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import join as pjoin
2
+ import torch
3
+ from torch.utils import data
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from torch.utils.data._utils.collate import default_collate
7
+ import random
8
+ import codecs as cs
9
+
10
+
11
+ def collate_fn(batch):
12
+ batch.sort(key=lambda x: x[3], reverse=True)
13
+ return default_collate(batch)
14
+
15
+ class MotionDataset(data.Dataset):
16
+ def __init__(self, opt, mean, std, split_file):
17
+ self.opt = opt
18
+ joints_num = opt.joints_num
19
+
20
+ self.data = []
21
+ self.lengths = []
22
+ id_list = []
23
+ with open(split_file, 'r') as f:
24
+ for line in f.readlines():
25
+ id_list.append(line.strip())
26
+
27
+ for name in tqdm(id_list):
28
+ try:
29
+ motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
30
+ if motion.shape[0] < opt.window_size:
31
+ continue
32
+ self.lengths.append(motion.shape[0] - opt.window_size)
33
+ self.data.append(motion)
34
+ except Exception as e:
35
+ # Some motion may not exist in KIT dataset
36
+ print(e)
37
+ pass
38
+
39
+ self.cumsum = np.cumsum([0] + self.lengths)
40
+
41
+ if opt.is_train:
42
+ # root_rot_velocity (B, seq_len, 1)
43
+ std[0:1] = std[0:1] / opt.feat_bias
44
+ # root_linear_velocity (B, seq_len, 2)
45
+ std[1:3] = std[1:3] / opt.feat_bias
46
+ # root_y (B, seq_len, 1)
47
+ std[3:4] = std[3:4] / opt.feat_bias
48
+ # ric_data (B, seq_len, (joint_num - 1)*3)
49
+ std[4: 4 + (joints_num - 1) * 3] = std[4: 4 + (joints_num - 1) * 3] / 1.0
50
+ # rot_data (B, seq_len, (joint_num - 1)*6)
51
+ std[4 + (joints_num - 1) * 3: 4 + (joints_num - 1) * 9] = std[4 + (joints_num - 1) * 3: 4 + (
52
+ joints_num - 1) * 9] / 1.0
53
+ # local_velocity (B, seq_len, joint_num*3)
54
+ std[4 + (joints_num - 1) * 9: 4 + (joints_num - 1) * 9 + joints_num * 3] = std[
55
+ 4 + (joints_num - 1) * 9: 4 + (
56
+ joints_num - 1) * 9 + joints_num * 3] / 1.0
57
+ # foot contact (B, seq_len, 4)
58
+ std[4 + (joints_num - 1) * 9 + joints_num * 3:] = std[
59
+ 4 + (
60
+ joints_num - 1) * 9 + joints_num * 3:] / opt.feat_bias
61
+
62
+ assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1]
63
+ np.save(pjoin(opt.meta_dir, 'mean.npy'), mean)
64
+ np.save(pjoin(opt.meta_dir, 'std.npy'), std)
65
+
66
+ self.mean = mean
67
+ self.std = std
68
+ print("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1]))
69
+
70
+ def inv_transform(self, data):
71
+ return data * self.std + self.mean
72
+
73
+ def __len__(self):
74
+ return self.cumsum[-1]
75
+
76
+ def __getitem__(self, item):
77
+ if item != 0:
78
+ motion_id = np.searchsorted(self.cumsum, item) - 1
79
+ idx = item - self.cumsum[motion_id] - 1
80
+ else:
81
+ motion_id = 0
82
+ idx = 0
83
+ motion = self.data[motion_id][idx:idx + self.opt.window_size]
84
+ "Z Normalization"
85
+ motion = (motion - self.mean) / self.std
86
+
87
+ return motion
88
+
89
+
90
+ class Text2MotionDatasetEval(data.Dataset):
91
+ def __init__(self, opt, mean, std, split_file, w_vectorizer):
92
+ self.opt = opt
93
+ self.w_vectorizer = w_vectorizer
94
+ self.max_length = 20
95
+ self.pointer = 0
96
+ self.max_motion_length = opt.max_motion_length
97
+ min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
98
+
99
+ data_dict = {}
100
+ id_list = []
101
+ with cs.open(split_file, 'r') as f:
102
+ for line in f.readlines():
103
+ id_list.append(line.strip())
104
+ # id_list = id_list[:250]
105
+
106
+ new_name_list = []
107
+ length_list = []
108
+ for name in tqdm(id_list):
109
+ try:
110
+ motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
111
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
112
+ continue
113
+ text_data = []
114
+ flag = False
115
+ with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
116
+ for line in f.readlines():
117
+ text_dict = {}
118
+ line_split = line.strip().split('#')
119
+ caption = line_split[0]
120
+ tokens = line_split[1].split(' ')
121
+ f_tag = float(line_split[2])
122
+ to_tag = float(line_split[3])
123
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
124
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
125
+
126
+ text_dict['caption'] = caption
127
+ text_dict['tokens'] = tokens
128
+ if f_tag == 0.0 and to_tag == 0.0:
129
+ flag = True
130
+ text_data.append(text_dict)
131
+ else:
132
+ try:
133
+ n_motion = motion[int(f_tag*20) : int(to_tag*20)]
134
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
135
+ continue
136
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
137
+ while new_name in data_dict:
138
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
139
+ data_dict[new_name] = {'motion': n_motion,
140
+ 'length': len(n_motion),
141
+ 'text':[text_dict]}
142
+ new_name_list.append(new_name)
143
+ length_list.append(len(n_motion))
144
+ except:
145
+ print(line_split)
146
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
147
+ # break
148
+
149
+ if flag:
150
+ data_dict[name] = {'motion': motion,
151
+ 'length': len(motion),
152
+ 'text': text_data}
153
+ new_name_list.append(name)
154
+ length_list.append(len(motion))
155
+ except:
156
+ pass
157
+
158
+ name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
159
+
160
+ self.mean = mean
161
+ self.std = std
162
+ self.length_arr = np.array(length_list)
163
+ self.data_dict = data_dict
164
+ self.name_list = name_list
165
+ self.reset_max_len(self.max_length)
166
+
167
+ def reset_max_len(self, length):
168
+ assert length <= self.max_motion_length
169
+ self.pointer = np.searchsorted(self.length_arr, length)
170
+ print("Pointer Pointing at %d"%self.pointer)
171
+ self.max_length = length
172
+
173
+ def inv_transform(self, data):
174
+ return data * self.std + self.mean
175
+
176
+ def __len__(self):
177
+ return len(self.data_dict) - self.pointer
178
+
179
+ def __getitem__(self, item):
180
+ idx = self.pointer + item
181
+ data = self.data_dict[self.name_list[idx]]
182
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
183
+ # Randomly select a caption
184
+ text_data = random.choice(text_list)
185
+ caption, tokens = text_data['caption'], text_data['tokens']
186
+
187
+ if len(tokens) < self.opt.max_text_len:
188
+ # pad with "unk"
189
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
190
+ sent_len = len(tokens)
191
+ tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
192
+ else:
193
+ # crop
194
+ tokens = tokens[:self.opt.max_text_len]
195
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
196
+ sent_len = len(tokens)
197
+ pos_one_hots = []
198
+ word_embeddings = []
199
+ for token in tokens:
200
+ word_emb, pos_oh = self.w_vectorizer[token]
201
+ pos_one_hots.append(pos_oh[None, :])
202
+ word_embeddings.append(word_emb[None, :])
203
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
204
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
205
+
206
+ if self.opt.unit_length < 10:
207
+ coin2 = np.random.choice(['single', 'single', 'double'])
208
+ else:
209
+ coin2 = 'single'
210
+
211
+ if coin2 == 'double':
212
+ m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length
213
+ elif coin2 == 'single':
214
+ m_length = (m_length // self.opt.unit_length) * self.opt.unit_length
215
+ idx = random.randint(0, len(motion) - m_length)
216
+ motion = motion[idx:idx+m_length]
217
+
218
+ "Z Normalization"
219
+ motion = (motion - self.mean) / self.std
220
+
221
+ if m_length < self.max_motion_length:
222
+ motion = np.concatenate([motion,
223
+ np.zeros((self.max_motion_length - m_length, motion.shape[1]))
224
+ ], axis=0)
225
+ # print(word_embeddings.shape, motion.shape)
226
+ # print(tokens)
227
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
228
+
229
+
230
+ class Text2MotionDataset(data.Dataset):
231
+ def __init__(self, opt, mean, std, split_file):
232
+ self.opt = opt
233
+ self.max_length = 20
234
+ self.pointer = 0
235
+ self.max_motion_length = opt.max_motion_length
236
+ min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
237
+
238
+ data_dict = {}
239
+ id_list = []
240
+ with cs.open(split_file, 'r') as f:
241
+ for line in f.readlines():
242
+ id_list.append(line.strip())
243
+ # id_list = id_list[:250]
244
+
245
+ new_name_list = []
246
+ length_list = []
247
+ for name in tqdm(id_list):
248
+ try:
249
+ motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
250
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
251
+ continue
252
+ text_data = []
253
+ flag = False
254
+ with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
255
+ for line in f.readlines():
256
+ text_dict = {}
257
+ line_split = line.strip().split('#')
258
+ # print(line)
259
+ caption = line_split[0]
260
+ tokens = line_split[1].split(' ')
261
+ f_tag = float(line_split[2])
262
+ to_tag = float(line_split[3])
263
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
264
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
265
+
266
+ text_dict['caption'] = caption
267
+ text_dict['tokens'] = tokens
268
+ if f_tag == 0.0 and to_tag == 0.0:
269
+ flag = True
270
+ text_data.append(text_dict)
271
+ else:
272
+ try:
273
+ n_motion = motion[int(f_tag*20) : int(to_tag*20)]
274
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
275
+ continue
276
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
277
+ while new_name in data_dict:
278
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
279
+ data_dict[new_name] = {'motion': n_motion,
280
+ 'length': len(n_motion),
281
+ 'text':[text_dict]}
282
+ new_name_list.append(new_name)
283
+ length_list.append(len(n_motion))
284
+ except:
285
+ print(line_split)
286
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
287
+ # break
288
+
289
+ if flag:
290
+ data_dict[name] = {'motion': motion,
291
+ 'length': len(motion),
292
+ 'text': text_data}
293
+ new_name_list.append(name)
294
+ length_list.append(len(motion))
295
+ except Exception as e:
296
+ # print(e)
297
+ pass
298
+
299
+ # name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
300
+ name_list, length_list = new_name_list, length_list
301
+
302
+ self.mean = mean
303
+ self.std = std
304
+ self.length_arr = np.array(length_list)
305
+ self.data_dict = data_dict
306
+ self.name_list = name_list
307
+
308
+ def inv_transform(self, data):
309
+ return data * self.std + self.mean
310
+
311
+ def __len__(self):
312
+ return len(self.data_dict) - self.pointer
313
+
314
+ def __getitem__(self, item):
315
+ idx = self.pointer + item
316
+ data = self.data_dict[self.name_list[idx]]
317
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
318
+ # Randomly select a caption
319
+ text_data = random.choice(text_list)
320
+ caption, tokens = text_data['caption'], text_data['tokens']
321
+
322
+ if self.opt.unit_length < 10:
323
+ coin2 = np.random.choice(['single', 'single', 'double'])
324
+ else:
325
+ coin2 = 'single'
326
+
327
+ if coin2 == 'double':
328
+ m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length
329
+ elif coin2 == 'single':
330
+ m_length = (m_length // self.opt.unit_length) * self.opt.unit_length
331
+ idx = random.randint(0, len(motion) - m_length)
332
+ motion = motion[idx:idx+m_length]
333
+
334
+ "Z Normalization"
335
+ motion = (motion - self.mean) / self.std
336
+
337
+ if m_length < self.max_motion_length:
338
+ motion = np.concatenate([motion,
339
+ np.zeros((self.max_motion_length - m_length, motion.shape[1]))
340
+ ], axis=0)
341
+ # print(word_embeddings.shape, motion.shape)
342
+ # print(tokens)
343
+ return caption, motion, m_length
344
+
345
+ def reset_min_len(self, length):
346
+ assert length <= self.max_motion_length
347
+ self.pointer = np.searchsorted(self.length_arr, length)
348
+ print("Pointer Pointing at %d" % self.pointer)
dataset/__init__.py ADDED
File without changes
edit_t2m.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import join as pjoin
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
8
+ from models.vq.model import RVQVAE, LengthEstimator
9
+
10
+ from options.eval_option import EvalT2MOptions
11
+ from utils.get_opt import get_opt
12
+
13
+ from utils.fixseed import fixseed
14
+ from visualization.joints2bvh import Joint2BVHConvertor
15
+
16
+ from utils.motion_process import recover_from_ric
17
+ from utils.plot_script import plot_3d_motion
18
+
19
+ from utils.paramUtil import t2m_kinematic_chain
20
+
21
+ import numpy as np
22
+
23
+ from gen_t2m import load_vq_model, load_res_model, load_trans_model
24
+
25
+ if __name__ == '__main__':
26
+ parser = EvalT2MOptions()
27
+ opt = parser.parse()
28
+ fixseed(opt.seed)
29
+
30
+ opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id))
31
+ torch.autograd.set_detect_anomaly(True)
32
+
33
+ dim_pose = 251 if opt.dataset_name == 'kit' else 263
34
+
35
+ root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
36
+ model_dir = pjoin(root_dir, 'model')
37
+ result_dir = pjoin('./editing', opt.ext)
38
+ joints_dir = pjoin(result_dir, 'joints')
39
+ animation_dir = pjoin(result_dir, 'animations')
40
+ os.makedirs(joints_dir, exist_ok=True)
41
+ os.makedirs(animation_dir,exist_ok=True)
42
+
43
+ model_opt_path = pjoin(root_dir, 'opt.txt')
44
+ model_opt = get_opt(model_opt_path, device=opt.device)
45
+
46
+ #######################
47
+ ######Loading RVQ######
48
+ #######################
49
+ vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt')
50
+ vq_opt = get_opt(vq_opt_path, device=opt.device)
51
+ vq_opt.dim_pose = dim_pose
52
+ vq_model, vq_opt = load_vq_model(vq_opt)
53
+
54
+ model_opt.num_tokens = vq_opt.nb_code
55
+ model_opt.num_quantizers = vq_opt.num_quantizers
56
+ model_opt.code_dim = vq_opt.code_dim
57
+
58
+ #################################
59
+ ######Loading R-Transformer######
60
+ #################################
61
+ res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt')
62
+ res_opt = get_opt(res_opt_path, device=opt.device)
63
+ res_model = load_res_model(res_opt, vq_opt, opt)
64
+
65
+ assert res_opt.vq_name == model_opt.vq_name
66
+
67
+ #################################
68
+ ######Loading M-Transformer######
69
+ #################################
70
+ t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar')
71
+
72
+ t2m_transformer.eval()
73
+ vq_model.eval()
74
+ res_model.eval()
75
+
76
+ res_model.to(opt.device)
77
+ t2m_transformer.to(opt.device)
78
+ vq_model.to(opt.device)
79
+
80
+ ##### ---- Data ---- #####
81
+ max_motion_length = 196
82
+ mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy'))
83
+ std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy'))
84
+ def inv_transform(data):
85
+ return data * std + mean
86
+ ### We provided an example source motion (from 'new_joint_vecs') for editing. See './example_data/000612.mp4'###
87
+ motion = np.load(opt.source_motion)
88
+ m_length = len(motion)
89
+ motion = (motion - mean) / std
90
+ if max_motion_length > m_length:
91
+ motion = np.concatenate([motion, np.zeros((max_motion_length - m_length, motion.shape[1])) ], axis=0)
92
+ motion = torch.from_numpy(motion)[None].to(opt.device)
93
+
94
+ prompt_list = []
95
+ length_list = []
96
+ if opt.motion_length == 0:
97
+ opt.motion_length = m_length
98
+ print("Using default motion length.")
99
+
100
+ prompt_list.append(opt.text_prompt)
101
+ length_list.append(opt.motion_length)
102
+ if opt.text_prompt == "":
103
+ raise "Using an empty text prompt."
104
+
105
+ token_lens = torch.LongTensor(length_list) // 4
106
+ token_lens = token_lens.to(opt.device).long()
107
+
108
+ m_length = token_lens * 4
109
+ captions = prompt_list
110
+ print_captions = captions[0]
111
+
112
+ _edit_slice = opt.mask_edit_section
113
+ edit_slice = []
114
+ for eds in _edit_slice:
115
+ _start, _end = eds.split(',')
116
+ _start = eval(_start)
117
+ _end = eval(_end)
118
+ edit_slice.append([_start, _end])
119
+
120
+ sample = 0
121
+ kinematic_chain = t2m_kinematic_chain
122
+ converter = Joint2BVHConvertor()
123
+
124
+ with torch.no_grad():
125
+ tokens, features = vq_model.encode(motion)
126
+ ### build editing mask, TOEDIT marked as 1 ###
127
+ edit_mask = torch.zeros_like(tokens[..., 0])
128
+ seq_len = tokens.shape[1]
129
+ for _start, _end in edit_slice:
130
+ if isinstance(_start, float):
131
+ _start = int(_start*seq_len)
132
+ _end = int(_end*seq_len)
133
+ else:
134
+ _start //= 4
135
+ _end //= 4
136
+ edit_mask[:, _start: _end] = 1
137
+ print_captions = f'{print_captions} [{_start*4/20.}s - {_end*4/20.}s]'
138
+ edit_mask = edit_mask.bool()
139
+ for r in range(opt.repeat_times):
140
+ print("-->Repeat %d"%r)
141
+ with torch.no_grad():
142
+ mids = t2m_transformer.edit(
143
+ captions, tokens[..., 0].clone(), m_length//4,
144
+ timesteps=opt.time_steps,
145
+ cond_scale=opt.cond_scale,
146
+ temperature=opt.temperature,
147
+ topk_filter_thres=opt.topkr,
148
+ gsample=opt.gumbel_sample,
149
+ force_mask=opt.force_mask,
150
+ edit_mask=edit_mask.clone(),
151
+ )
152
+ if opt.use_res_model:
153
+ mids = res_model.generate(mids, captions, m_length//4, temperature=1, cond_scale=2)
154
+ else:
155
+ mids.unsqueeze_(-1)
156
+
157
+ pred_motions = vq_model.forward_decoder(mids)
158
+
159
+ pred_motions = pred_motions.detach().cpu().numpy()
160
+
161
+ source_motions = motion.detach().cpu().numpy()
162
+
163
+ data = inv_transform(pred_motions)
164
+ source_data = inv_transform(source_motions)
165
+
166
+ for k, (caption, joint_data, source_data) in enumerate(zip(captions, data, source_data)):
167
+ print("---->Sample %d: %s %d"%(k, caption, m_length[k]))
168
+ animation_path = pjoin(animation_dir, str(k))
169
+ joint_path = pjoin(joints_dir, str(k))
170
+
171
+ os.makedirs(animation_path, exist_ok=True)
172
+ os.makedirs(joint_path, exist_ok=True)
173
+
174
+ joint_data = joint_data[:m_length[k]]
175
+ joint = recover_from_ric(torch.from_numpy(joint_data).float(), 22).numpy()
176
+
177
+ source_data = source_data[:m_length[k]]
178
+ soucre_joint = recover_from_ric(torch.from_numpy(source_data).float(), 22).numpy()
179
+
180
+ bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.bvh"%(k, r, m_length[k]))
181
+ _, ik_joint = converter.convert(joint, filename=bvh_path, iterations=100)
182
+
183
+ bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d.bvh" % (k, r, m_length[k]))
184
+ _, joint = converter.convert(joint, filename=bvh_path, iterations=100, foot_ik=False)
185
+
186
+
187
+ save_path = pjoin(animation_path, "sample%d_repeat%d_len%d.mp4"%(k, r, m_length[k]))
188
+ ik_save_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.mp4"%(k, r, m_length[k]))
189
+ source_save_path = pjoin(animation_path, "sample%d_source_len%d.mp4"%(k, m_length[k]))
190
+
191
+ plot_3d_motion(ik_save_path, kinematic_chain, ik_joint, title=print_captions, fps=20)
192
+ plot_3d_motion(save_path, kinematic_chain, joint, title=print_captions, fps=20)
193
+ plot_3d_motion(source_save_path, kinematic_chain, soucre_joint, title='None', fps=20)
194
+ np.save(pjoin(joint_path, "sample%d_repeat%d_len%d.npy"%(k, r, m_length[k])), joint)
195
+ np.save(pjoin(joint_path, "sample%d_repeat%d_len%d_ik.npy"%(k, r, m_length[k])), ik_joint)
environment.yml ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: momask
2
+ channels:
3
+ - pytorch
4
+ - anaconda
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=main
9
+ - _openmp_mutex=5.1=1_gnu
10
+ - absl-py=1.4.0=pyhd8ed1ab_0
11
+ - aiohttp=3.8.3=py37h5eee18b_0
12
+ - aiosignal=1.2.0=pyhd3eb1b0_0
13
+ - argon2-cffi=21.3.0=pyhd3eb1b0_0
14
+ - argon2-cffi-bindings=21.2.0=py37h7f8727e_0
15
+ - async-timeout=4.0.2=py37h06a4308_0
16
+ - asynctest=0.13.0=py_0
17
+ - attrs=22.1.0=py37h06a4308_0
18
+ - backcall=0.2.0=pyhd3eb1b0_0
19
+ - beautifulsoup4=4.11.1=pyha770c72_0
20
+ - blas=1.0=mkl
21
+ - bleach=4.1.0=pyhd3eb1b0_0
22
+ - blinker=1.4=py37h06a4308_0
23
+ - brotlipy=0.7.0=py37h540881e_1004
24
+ - c-ares=1.19.0=h5eee18b_0
25
+ - ca-certificates=2023.05.30=h06a4308_0
26
+ - catalogue=2.0.8=py37h89c1867_0
27
+ - certifi=2022.12.7=py37h06a4308_0
28
+ - cffi=1.15.1=py37h74dc2b5_0
29
+ - charset-normalizer=2.1.1=pyhd8ed1ab_0
30
+ - click=8.0.4=py37h89c1867_0
31
+ - colorama=0.4.5=pyhd8ed1ab_0
32
+ - cryptography=35.0.0=py37hf1a17b8_2
33
+ - cudatoolkit=11.0.221=h6bb024c_0
34
+ - cycler=0.11.0=pyhd3eb1b0_0
35
+ - cymem=2.0.6=py37hd23a5d3_3
36
+ - cython-blis=0.7.7=py37hda87dfa_1
37
+ - dataclasses=0.8=pyhc8e2a94_3
38
+ - dbus=1.13.18=hb2f20db_0
39
+ - debugpy=1.5.1=py37h295c915_0
40
+ - decorator=5.1.1=pyhd3eb1b0_0
41
+ - defusedxml=0.7.1=pyhd3eb1b0_0
42
+ - entrypoints=0.4=py37h06a4308_0
43
+ - expat=2.4.9=h6a678d5_0
44
+ - fftw=3.3.9=h27cfd23_1
45
+ - filelock=3.8.0=pyhd8ed1ab_0
46
+ - fontconfig=2.13.1=h6c09931_0
47
+ - freetype=2.11.0=h70c0345_0
48
+ - frozenlist=1.3.3=py37h5eee18b_0
49
+ - giflib=5.2.1=h7b6447c_0
50
+ - glib=2.69.1=h4ff587b_1
51
+ - gst-plugins-base=1.14.0=h8213a91_2
52
+ - gstreamer=1.14.0=h28cd5cc_2
53
+ - h5py=3.7.0=py37h737f45e_0
54
+ - hdf5=1.10.6=h3ffc7dd_1
55
+ - icu=58.2=he6710b0_3
56
+ - idna=3.4=pyhd8ed1ab_0
57
+ - importlib-metadata=4.11.4=py37h89c1867_0
58
+ - intel-openmp=2021.4.0=h06a4308_3561
59
+ - ipykernel=6.15.2=py37h06a4308_0
60
+ - ipython=7.31.1=py37h06a4308_1
61
+ - ipython_genutils=0.2.0=pyhd3eb1b0_1
62
+ - jedi=0.18.1=py37h06a4308_1
63
+ - jinja2=3.1.2=pyhd8ed1ab_1
64
+ - joblib=1.1.0=pyhd3eb1b0_0
65
+ - jpeg=9b=h024ee3a_2
66
+ - jsonschema=3.0.2=py37_0
67
+ - jupyter_client=7.4.9=py37h06a4308_0
68
+ - jupyter_core=4.11.2=py37h06a4308_0
69
+ - jupyterlab_pygments=0.1.2=py_0
70
+ - kiwisolver=1.4.2=py37h295c915_0
71
+ - langcodes=3.3.0=pyhd8ed1ab_0
72
+ - lcms2=2.12=h3be6417_0
73
+ - ld_impl_linux-64=2.38=h1181459_1
74
+ - libffi=3.3=he6710b0_2
75
+ - libgcc-ng=11.2.0=h1234567_1
76
+ - libgfortran-ng=11.2.0=h00389a5_1
77
+ - libgfortran5=11.2.0=h1234567_1
78
+ - libgomp=11.2.0=h1234567_1
79
+ - libpng=1.6.37=hbc83047_0
80
+ - libprotobuf=3.15.8=h780b84a_1
81
+ - libsodium=1.0.18=h7b6447c_0
82
+ - libstdcxx-ng=11.2.0=h1234567_1
83
+ - libtiff=4.1.0=h2733197_1
84
+ - libuuid=1.0.3=h7f8727e_2
85
+ - libuv=1.40.0=h7b6447c_0
86
+ - libwebp=1.2.0=h89dd481_0
87
+ - libxcb=1.15=h7f8727e_0
88
+ - libxml2=2.9.14=h74e7548_0
89
+ - lz4-c=1.9.3=h295c915_1
90
+ - markdown=3.4.3=pyhd8ed1ab_0
91
+ - markupsafe=2.1.1=py37h540881e_1
92
+ - matplotlib=3.1.3=py37_0
93
+ - matplotlib-base=3.1.3=py37hef1b27d_0
94
+ - matplotlib-inline=0.1.6=py37h06a4308_0
95
+ - mistune=0.8.4=py37h14c3975_1001
96
+ - mkl=2021.4.0=h06a4308_640
97
+ - mkl-service=2.4.0=py37h7f8727e_0
98
+ - mkl_fft=1.3.1=py37hd3c417c_0
99
+ - mkl_random=1.2.2=py37h51133e4_0
100
+ - multidict=6.0.2=py37h5eee18b_0
101
+ - murmurhash=1.0.7=py37hd23a5d3_0
102
+ - nb_conda_kernels=2.3.1=py37h06a4308_0
103
+ - nbclient=0.5.13=py37h06a4308_0
104
+ - nbconvert=6.4.4=py37h06a4308_0
105
+ - nbformat=5.5.0=py37h06a4308_0
106
+ - ncurses=6.3=h5eee18b_3
107
+ - nest-asyncio=1.5.6=py37h06a4308_0
108
+ - ninja=1.10.2=h06a4308_5
109
+ - ninja-base=1.10.2=hd09550d_5
110
+ - notebook=6.4.12=py37h06a4308_0
111
+ - numpy=1.21.5=py37h6c91a56_3
112
+ - numpy-base=1.21.5=py37ha15fc14_3
113
+ - openssl=1.1.1v=h7f8727e_0
114
+ - packaging=21.3=pyhd8ed1ab_0
115
+ - pandocfilters=1.5.0=pyhd3eb1b0_0
116
+ - parso=0.8.3=pyhd3eb1b0_0
117
+ - pathy=0.6.2=pyhd8ed1ab_0
118
+ - pcre=8.45=h295c915_0
119
+ - pexpect=4.8.0=pyhd3eb1b0_3
120
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
121
+ - pillow=9.2.0=py37hace64e9_1
122
+ - pip=22.2.2=py37h06a4308_0
123
+ - preshed=3.0.6=py37hd23a5d3_2
124
+ - prometheus_client=0.14.1=py37h06a4308_0
125
+ - prompt-toolkit=3.0.36=py37h06a4308_0
126
+ - psutil=5.9.0=py37h5eee18b_0
127
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
128
+ - pycparser=2.21=pyhd8ed1ab_0
129
+ - pydantic=1.8.2=py37h5e8e339_2
130
+ - pygments=2.11.2=pyhd3eb1b0_0
131
+ - pyjwt=2.4.0=py37h06a4308_0
132
+ - pyopenssl=22.0.0=pyhd8ed1ab_1
133
+ - pyparsing=3.0.9=py37h06a4308_0
134
+ - pyqt=5.9.2=py37h05f1152_2
135
+ - pyrsistent=0.18.0=py37heee7806_0
136
+ - pysocks=1.7.1=py37h89c1867_5
137
+ - python=3.7.13=h12debd9_0
138
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
139
+ - python-fastjsonschema=2.16.2=py37h06a4308_0
140
+ - python_abi=3.7=2_cp37m
141
+ - pytorch=1.7.1=py3.7_cuda11.0.221_cudnn8.0.5_0
142
+ - pyzmq=23.2.0=py37h6a678d5_0
143
+ - qt=5.9.7=h5867ecd_1
144
+ - readline=8.1.2=h7f8727e_1
145
+ - requests=2.28.1=pyhd8ed1ab_1
146
+ - scikit-learn=1.0.2=py37h51133e4_1
147
+ - scipy=1.7.3=py37h6c91a56_2
148
+ - send2trash=1.8.0=pyhd3eb1b0_1
149
+ - setuptools=63.4.1=py37h06a4308_0
150
+ - shellingham=1.5.0=pyhd8ed1ab_0
151
+ - sip=4.19.8=py37hf484d3e_0
152
+ - six=1.16.0=pyhd3eb1b0_1
153
+ - smart_open=5.2.1=pyhd8ed1ab_0
154
+ - soupsieve=2.3.2.post1=pyhd8ed1ab_0
155
+ - spacy=3.3.1=py37h79cecc1_0
156
+ - spacy-legacy=3.0.10=pyhd8ed1ab_0
157
+ - spacy-loggers=1.0.3=pyhd8ed1ab_0
158
+ - sqlite=3.39.3=h5082296_0
159
+ - srsly=2.4.3=py37hd23a5d3_1
160
+ - tensorboard-plugin-wit=1.8.1=py37h06a4308_0
161
+ - terminado=0.17.1=py37h06a4308_0
162
+ - testpath=0.6.0=py37h06a4308_0
163
+ - thinc=8.0.15=py37h48bf904_0
164
+ - threadpoolctl=2.2.0=pyh0d69192_0
165
+ - tk=8.6.12=h1ccaba5_0
166
+ - torchaudio=0.7.2=py37
167
+ - torchvision=0.8.2=py37_cu110
168
+ - tornado=6.2=py37h5eee18b_0
169
+ - tqdm=4.64.1=py37h06a4308_0
170
+ - traitlets=5.7.1=py37h06a4308_0
171
+ - trimesh=3.15.3=pyh1a96a4e_0
172
+ - typer=0.4.2=pyhd8ed1ab_0
173
+ - typing-extensions=3.10.0.2=hd8ed1ab_0
174
+ - typing_extensions=3.10.0.2=pyha770c72_0
175
+ - urllib3=1.26.15=pyhd8ed1ab_0
176
+ - wasabi=0.10.1=pyhd8ed1ab_1
177
+ - webencodings=0.5.1=py37_1
178
+ - werkzeug=2.2.3=pyhd8ed1ab_0
179
+ - wheel=0.37.1=pyhd3eb1b0_0
180
+ - xz=5.2.6=h5eee18b_0
181
+ - yarl=1.8.1=py37h5eee18b_0
182
+ - zeromq=4.3.4=h2531618_0
183
+ - zipp=3.8.1=pyhd8ed1ab_0
184
+ - zlib=1.2.12=h5eee18b_3
185
+ - zstd=1.4.9=haebb681_0
186
+ - pip:
187
+ - cachetools==5.3.1
188
+ - einops==0.6.1
189
+ - ftfy==6.1.1
190
+ - gdown==4.7.1
191
+ - google-auth==2.22.0
192
+ - google-auth-oauthlib==0.4.6
193
+ - grpcio==1.57.0
194
+ - oauthlib==3.2.2
195
+ - protobuf==3.20.3
196
+ - pyasn1==0.5.0
197
+ - pyasn1-modules==0.3.0
198
+ - regex==2023.8.8
199
+ - requests-oauthlib==1.3.1
200
+ - rsa==4.9
201
+ - tensorboard==2.11.2
202
+ - tensorboard-data-server==0.6.1
203
+ - wcwidth==0.2.6
204
+ prefix: /home/chuan/anaconda3/envs/momask
eval_t2m_trans_res.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import join as pjoin
3
+
4
+ import torch
5
+
6
+ from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
7
+ from models.vq.model import RVQVAE
8
+
9
+ from options.eval_option import EvalT2MOptions
10
+ from utils.get_opt import get_opt
11
+ from motion_loaders.dataset_motion_loader import get_dataset_motion_loader
12
+ from models.t2m_eval_wrapper import EvaluatorModelWrapper
13
+
14
+ import utils.eval_t2m as eval_t2m
15
+ from utils.fixseed import fixseed
16
+
17
+ import numpy as np
18
+
19
+ def load_vq_model(vq_opt):
20
+ # opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt')
21
+ vq_model = RVQVAE(vq_opt,
22
+ dim_pose,
23
+ vq_opt.nb_code,
24
+ vq_opt.code_dim,
25
+ vq_opt.output_emb_width,
26
+ vq_opt.down_t,
27
+ vq_opt.stride_t,
28
+ vq_opt.width,
29
+ vq_opt.depth,
30
+ vq_opt.dilation_growth_rate,
31
+ vq_opt.vq_act,
32
+ vq_opt.vq_norm)
33
+ ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'),
34
+ map_location=opt.device)
35
+ model_key = 'vq_model' if 'vq_model' in ckpt else 'net'
36
+ vq_model.load_state_dict(ckpt[model_key])
37
+ print(f'Loading VQ Model {vq_opt.name} Completed!')
38
+ return vq_model, vq_opt
39
+
40
+ def load_trans_model(model_opt, which_model):
41
+ t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim,
42
+ cond_mode='text',
43
+ latent_dim=model_opt.latent_dim,
44
+ ff_size=model_opt.ff_size,
45
+ num_layers=model_opt.n_layers,
46
+ num_heads=model_opt.n_heads,
47
+ dropout=model_opt.dropout,
48
+ clip_dim=512,
49
+ cond_drop_prob=model_opt.cond_drop_prob,
50
+ clip_version=clip_version,
51
+ opt=model_opt)
52
+ ckpt = torch.load(pjoin(model_opt.checkpoints_dir, model_opt.dataset_name, model_opt.name, 'model', which_model),
53
+ map_location=opt.device)
54
+ model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans'
55
+ # print(ckpt.keys())
56
+ missing_keys, unexpected_keys = t2m_transformer.load_state_dict(ckpt[model_key], strict=False)
57
+ assert len(unexpected_keys) == 0
58
+ assert all([k.startswith('clip_model.') for k in missing_keys])
59
+ print(f'Loading Mask Transformer {opt.name} from epoch {ckpt["ep"]}!')
60
+ return t2m_transformer
61
+
62
+ def load_res_model(res_opt):
63
+ res_opt.num_quantizers = vq_opt.num_quantizers
64
+ res_opt.num_tokens = vq_opt.nb_code
65
+ res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,
66
+ cond_mode='text',
67
+ latent_dim=res_opt.latent_dim,
68
+ ff_size=res_opt.ff_size,
69
+ num_layers=res_opt.n_layers,
70
+ num_heads=res_opt.n_heads,
71
+ dropout=res_opt.dropout,
72
+ clip_dim=512,
73
+ shared_codebook=vq_opt.shared_codebook,
74
+ cond_drop_prob=res_opt.cond_drop_prob,
75
+ # codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None,
76
+ share_weight=res_opt.share_weight,
77
+ clip_version=clip_version,
78
+ opt=res_opt)
79
+
80
+ ckpt = torch.load(pjoin(res_opt.checkpoints_dir, res_opt.dataset_name, res_opt.name, 'model', 'net_best_fid.tar'),
81
+ map_location=opt.device)
82
+ missing_keys, unexpected_keys = res_transformer.load_state_dict(ckpt['res_transformer'], strict=False)
83
+ assert len(unexpected_keys) == 0
84
+ assert all([k.startswith('clip_model.') for k in missing_keys])
85
+ print(f'Loading Residual Transformer {res_opt.name} from epoch {ckpt["ep"]}!')
86
+ return res_transformer
87
+
88
+ if __name__ == '__main__':
89
+ parser = EvalT2MOptions()
90
+ opt = parser.parse()
91
+ fixseed(opt.seed)
92
+
93
+ opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id))
94
+ torch.autograd.set_detect_anomaly(True)
95
+
96
+ dim_pose = 251 if opt.dataset_name == 'kit' else 263
97
+
98
+ # out_dir = pjoin(opt.check)
99
+ root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
100
+ model_dir = pjoin(root_dir, 'model')
101
+ out_dir = pjoin(root_dir, 'eval')
102
+ os.makedirs(out_dir, exist_ok=True)
103
+
104
+ out_path = pjoin(out_dir, "%s.log"%opt.ext)
105
+
106
+ f = open(pjoin(out_path), 'w')
107
+
108
+ model_opt_path = pjoin(root_dir, 'opt.txt')
109
+ model_opt = get_opt(model_opt_path, device=opt.device)
110
+ clip_version = 'ViT-B/32'
111
+
112
+ vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt')
113
+ vq_opt = get_opt(vq_opt_path, device=opt.device)
114
+ vq_model, vq_opt = load_vq_model(vq_opt)
115
+
116
+ model_opt.num_tokens = vq_opt.nb_code
117
+ model_opt.num_quantizers = vq_opt.num_quantizers
118
+ model_opt.code_dim = vq_opt.code_dim
119
+
120
+ res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt')
121
+ res_opt = get_opt(res_opt_path, device=opt.device)
122
+ res_model = load_res_model(res_opt)
123
+
124
+ assert res_opt.vq_name == model_opt.vq_name
125
+
126
+ dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if opt.dataset_name == 'kit' \
127
+ else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'
128
+
129
+ wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
130
+ eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
131
+
132
+ ##### ---- Dataloader ---- #####
133
+ opt.nb_joints = 21 if opt.dataset_name == 'kit' else 22
134
+
135
+ eval_val_loader, _ = get_dataset_motion_loader(dataset_opt_path, 32, 'test', device=opt.device)
136
+
137
+ # model_dir = pjoin(opt.)
138
+ for file in os.listdir(model_dir):
139
+ if opt.which_epoch != "all" and opt.which_epoch not in file:
140
+ continue
141
+ print('loading checkpoint {}'.format(file))
142
+ t2m_transformer = load_trans_model(model_opt, file)
143
+ t2m_transformer.eval()
144
+ vq_model.eval()
145
+ res_model.eval()
146
+
147
+ t2m_transformer.to(opt.device)
148
+ vq_model.to(opt.device)
149
+ res_model.to(opt.device)
150
+
151
+ fid = []
152
+ div = []
153
+ top1 = []
154
+ top2 = []
155
+ top3 = []
156
+ matching = []
157
+ mm = []
158
+
159
+ repeat_time = 20
160
+ for i in range(repeat_time):
161
+ with torch.no_grad():
162
+ best_fid, best_div, Rprecision, best_matching, best_mm = \
163
+ eval_t2m.evaluation_mask_transformer_test_plus_res(eval_val_loader, vq_model, res_model, t2m_transformer,
164
+ i, eval_wrapper=eval_wrapper,
165
+ time_steps=opt.time_steps, cond_scale=opt.cond_scale,
166
+ temperature=opt.temperature, topkr=opt.topkr,
167
+ force_mask=opt.force_mask, cal_mm=True)
168
+ fid.append(best_fid)
169
+ div.append(best_div)
170
+ top1.append(Rprecision[0])
171
+ top2.append(Rprecision[1])
172
+ top3.append(Rprecision[2])
173
+ matching.append(best_matching)
174
+ mm.append(best_mm)
175
+
176
+ fid = np.array(fid)
177
+ div = np.array(div)
178
+ top1 = np.array(top1)
179
+ top2 = np.array(top2)
180
+ top3 = np.array(top3)
181
+ matching = np.array(matching)
182
+ mm = np.array(mm)
183
+
184
+ print(f'{file} final result:')
185
+ print(f'{file} final result:', file=f, flush=True)
186
+
187
+ msg_final = f"\tFID: {np.mean(fid):.3f}, conf. {np.std(fid) * 1.96 / np.sqrt(repeat_time):.3f}\n" \
188
+ f"\tDiversity: {np.mean(div):.3f}, conf. {np.std(div) * 1.96 / np.sqrt(repeat_time):.3f}\n" \
189
+ f"\tTOP1: {np.mean(top1):.3f}, conf. {np.std(top1) * 1.96 / np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2) * 1.96 / np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3) * 1.96 / np.sqrt(repeat_time):.3f}\n" \
190
+ f"\tMatching: {np.mean(matching):.3f}, conf. {np.std(matching) * 1.96 / np.sqrt(repeat_time):.3f}\n" \
191
+ f"\tMultimodality:{np.mean(mm):.3f}, conf.{np.std(mm) * 1.96 / np.sqrt(repeat_time):.3f}\n\n"
192
+ # logger.info(msg_final)
193
+ print(msg_final)
194
+ print(msg_final, file=f, flush=True)
195
+
196
+ f.close()
197
+
198
+
199
+ # python eval_t2m_trans.py --name t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_vq --dataset_name t2m --gpu_id 3 --cond_scale 4 --time_steps 18 --temperature 1 --topkr 0.9 --gumbel_sample --ext cs4_ts18_tau1_topkr0.9_gs
eval_t2m_vq.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from os.path import join as pjoin
4
+
5
+ import torch
6
+ from models.vq.model import RVQVAE
7
+ from options.vq_option import arg_parse
8
+ from motion_loaders.dataset_motion_loader import get_dataset_motion_loader
9
+ import utils.eval_t2m as eval_t2m
10
+ from utils.get_opt import get_opt
11
+ from models.t2m_eval_wrapper import EvaluatorModelWrapper
12
+ import warnings
13
+ warnings.filterwarnings('ignore')
14
+ import numpy as np
15
+ from utils.word_vectorizer import WordVectorizer
16
+
17
+ def load_vq_model(vq_opt, which_epoch):
18
+ # opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt')
19
+
20
+ vq_model = RVQVAE(vq_opt,
21
+ dim_pose,
22
+ vq_opt.nb_code,
23
+ vq_opt.code_dim,
24
+ vq_opt.code_dim,
25
+ vq_opt.down_t,
26
+ vq_opt.stride_t,
27
+ vq_opt.width,
28
+ vq_opt.depth,
29
+ vq_opt.dilation_growth_rate,
30
+ vq_opt.vq_act,
31
+ vq_opt.vq_norm)
32
+ ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', which_epoch),
33
+ map_location='cpu')
34
+ model_key = 'vq_model' if 'vq_model' in ckpt else 'net'
35
+ vq_model.load_state_dict(ckpt[model_key])
36
+ vq_epoch = ckpt['ep'] if 'ep' in ckpt else -1
37
+ print(f'Loading VQ Model {vq_opt.name} Completed!, Epoch {vq_epoch}')
38
+ return vq_model, vq_epoch
39
+
40
+ if __name__ == "__main__":
41
+ ##### ---- Exp dirs ---- #####
42
+ args = arg_parse(False)
43
+ args.device = torch.device("cpu" if args.gpu_id == -1 else "cuda:" + str(args.gpu_id))
44
+
45
+ args.out_dir = pjoin(args.checkpoints_dir, args.dataset_name, args.name, 'eval')
46
+ os.makedirs(args.out_dir, exist_ok=True)
47
+
48
+ f = open(pjoin(args.out_dir, '%s.log'%args.ext), 'w')
49
+
50
+ dataset_opt_path = 'checkpoints/kit/Comp_v6_KLD005/opt.txt' if args.dataset_name == 'kit' \
51
+ else 'checkpoints/t2m/Comp_v6_KLD005/opt.txt'
52
+
53
+ wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
54
+ eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
55
+
56
+ ##### ---- Dataloader ---- #####
57
+ args.nb_joints = 21 if args.dataset_name == 'kit' else 22
58
+ dim_pose = 251 if args.dataset_name == 'kit' else 263
59
+
60
+ eval_val_loader, _ = get_dataset_motion_loader(dataset_opt_path, 32, 'test', device=args.device)
61
+
62
+ print(len(eval_val_loader))
63
+
64
+ ##### ---- Network ---- #####
65
+ vq_opt_path = pjoin(args.checkpoints_dir, args.dataset_name, args.name, 'opt.txt')
66
+ vq_opt = get_opt(vq_opt_path, device=args.device)
67
+ # net = load_vq_model()
68
+
69
+ model_dir = pjoin(args.checkpoints_dir, args.dataset_name, args.name, 'model')
70
+ for file in os.listdir(model_dir):
71
+ # if not file.endswith('tar'):
72
+ # continue
73
+ # if not file.startswith('net_best_fid'):
74
+ # continue
75
+ if args.which_epoch != "all" and args.which_epoch not in file:
76
+ continue
77
+ print(file)
78
+ net, ep = load_vq_model(vq_opt, file)
79
+
80
+ net.eval()
81
+ net.cuda()
82
+
83
+ fid = []
84
+ div = []
85
+ top1 = []
86
+ top2 = []
87
+ top3 = []
88
+ matching = []
89
+ mae = []
90
+ repeat_time = 20
91
+ for i in range(repeat_time):
92
+ best_fid, best_div, Rprecision, best_matching, l1_dist = \
93
+ eval_t2m.evaluation_vqvae_plus_mpjpe(eval_val_loader, net, i, eval_wrapper=eval_wrapper, num_joint=args.nb_joints)
94
+ fid.append(best_fid)
95
+ div.append(best_div)
96
+ top1.append(Rprecision[0])
97
+ top2.append(Rprecision[1])
98
+ top3.append(Rprecision[2])
99
+ matching.append(best_matching)
100
+ mae.append(l1_dist)
101
+
102
+ fid = np.array(fid)
103
+ div = np.array(div)
104
+ top1 = np.array(top1)
105
+ top2 = np.array(top2)
106
+ top3 = np.array(top3)
107
+ matching = np.array(matching)
108
+ mae = np.array(mae)
109
+
110
+ print(f'{file} final result, epoch {ep}')
111
+ print(f'{file} final result, epoch {ep}', file=f, flush=True)
112
+
113
+ msg_final = f"\tFID: {np.mean(fid):.3f}, conf. {np.std(fid)*1.96/np.sqrt(repeat_time):.3f}\n" \
114
+ f"\tDiversity: {np.mean(div):.3f}, conf. {np.std(div)*1.96/np.sqrt(repeat_time):.3f}\n" \
115
+ f"\tTOP1: {np.mean(top1):.3f}, conf. {np.std(top1)*1.96/np.sqrt(repeat_time):.3f}, TOP2. {np.mean(top2):.3f}, conf. {np.std(top2)*1.96/np.sqrt(repeat_time):.3f}, TOP3. {np.mean(top3):.3f}, conf. {np.std(top3)*1.96/np.sqrt(repeat_time):.3f}\n" \
116
+ f"\tMatching: {np.mean(matching):.3f}, conf. {np.std(matching)*1.96/np.sqrt(repeat_time):.3f}\n" \
117
+ f"\tMAE:{np.mean(mae):.3f}, conf.{np.std(mae)*1.96/np.sqrt(repeat_time):.3f}\n\n"
118
+ # logger.info(msg_final)
119
+ print(msg_final)
120
+ print(msg_final, file=f, flush=True)
121
+
122
+ f.close()
123
+
example_data/000612.mp4 ADDED
Binary file (154 kB). View file
 
example_data/000612.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85e5a8081278a0e31488eaa29386940b9e4b739fb401042f7ad883afb475ab10
3
+ size 418824
gen_t2m.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import join as pjoin
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from models.mask_transformer.transformer import MaskTransformer, ResidualTransformer
8
+ from models.vq.model import RVQVAE, LengthEstimator
9
+
10
+ from options.eval_option import EvalT2MOptions
11
+ from utils.get_opt import get_opt
12
+
13
+ from utils.fixseed import fixseed
14
+ from visualization.joints2bvh import Joint2BVHConvertor
15
+ from torch.distributions.categorical import Categorical
16
+
17
+
18
+ from utils.motion_process import recover_from_ric
19
+ from utils.plot_script import plot_3d_motion
20
+
21
+ from utils.paramUtil import t2m_kinematic_chain
22
+
23
+ import numpy as np
24
+ clip_version = 'ViT-B/32'
25
+
26
+ def load_vq_model(vq_opt):
27
+ # opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt')
28
+ vq_model = RVQVAE(vq_opt,
29
+ vq_opt.dim_pose,
30
+ vq_opt.nb_code,
31
+ vq_opt.code_dim,
32
+ vq_opt.output_emb_width,
33
+ vq_opt.down_t,
34
+ vq_opt.stride_t,
35
+ vq_opt.width,
36
+ vq_opt.depth,
37
+ vq_opt.dilation_growth_rate,
38
+ vq_opt.vq_act,
39
+ vq_opt.vq_norm)
40
+ ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'),
41
+ map_location='cpu')
42
+ model_key = 'vq_model' if 'vq_model' in ckpt else 'net'
43
+ vq_model.load_state_dict(ckpt[model_key])
44
+ print(f'Loading VQ Model {vq_opt.name} Completed!')
45
+ return vq_model, vq_opt
46
+
47
+ def load_trans_model(model_opt, opt, which_model):
48
+ t2m_transformer = MaskTransformer(code_dim=model_opt.code_dim,
49
+ cond_mode='text',
50
+ latent_dim=model_opt.latent_dim,
51
+ ff_size=model_opt.ff_size,
52
+ num_layers=model_opt.n_layers,
53
+ num_heads=model_opt.n_heads,
54
+ dropout=model_opt.dropout,
55
+ clip_dim=512,
56
+ cond_drop_prob=model_opt.cond_drop_prob,
57
+ clip_version=clip_version,
58
+ opt=model_opt)
59
+ ckpt = torch.load(pjoin(model_opt.checkpoints_dir, model_opt.dataset_name, model_opt.name, 'model', which_model),
60
+ map_location='cpu')
61
+ model_key = 't2m_transformer' if 't2m_transformer' in ckpt else 'trans'
62
+ # print(ckpt.keys())
63
+ missing_keys, unexpected_keys = t2m_transformer.load_state_dict(ckpt[model_key], strict=False)
64
+ assert len(unexpected_keys) == 0
65
+ assert all([k.startswith('clip_model.') for k in missing_keys])
66
+ print(f'Loading Transformer {opt.name} from epoch {ckpt["ep"]}!')
67
+ return t2m_transformer
68
+
69
+ def load_res_model(res_opt, vq_opt, opt):
70
+ res_opt.num_quantizers = vq_opt.num_quantizers
71
+ res_opt.num_tokens = vq_opt.nb_code
72
+ res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,
73
+ cond_mode='text',
74
+ latent_dim=res_opt.latent_dim,
75
+ ff_size=res_opt.ff_size,
76
+ num_layers=res_opt.n_layers,
77
+ num_heads=res_opt.n_heads,
78
+ dropout=res_opt.dropout,
79
+ clip_dim=512,
80
+ shared_codebook=vq_opt.shared_codebook,
81
+ cond_drop_prob=res_opt.cond_drop_prob,
82
+ # codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None,
83
+ share_weight=res_opt.share_weight,
84
+ clip_version=clip_version,
85
+ opt=res_opt)
86
+
87
+ ckpt = torch.load(pjoin(res_opt.checkpoints_dir, res_opt.dataset_name, res_opt.name, 'model', 'net_best_fid.tar'),
88
+ map_location=opt.device)
89
+ missing_keys, unexpected_keys = res_transformer.load_state_dict(ckpt['res_transformer'], strict=False)
90
+ assert len(unexpected_keys) == 0
91
+ assert all([k.startswith('clip_model.') for k in missing_keys])
92
+ print(f'Loading Residual Transformer {res_opt.name} from epoch {ckpt["ep"]}!')
93
+ return res_transformer
94
+
95
+ def load_len_estimator(opt):
96
+ model = LengthEstimator(512, 50)
97
+ ckpt = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'length_estimator', 'model', 'finest.tar'),
98
+ map_location=opt.device)
99
+ model.load_state_dict(ckpt['estimator'])
100
+ print(f'Loading Length Estimator from epoch {ckpt["epoch"]}!')
101
+ return model
102
+
103
+
104
+ if __name__ == '__main__':
105
+ parser = EvalT2MOptions()
106
+ opt = parser.parse()
107
+ fixseed(opt.seed)
108
+
109
+ opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id))
110
+ torch.autograd.set_detect_anomaly(True)
111
+
112
+ dim_pose = 251 if opt.dataset_name == 'kit' else 263
113
+
114
+ # out_dir = pjoin(opt.check)
115
+ root_dir = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
116
+ model_dir = pjoin(root_dir, 'model')
117
+ result_dir = pjoin('./generation', opt.ext)
118
+ joints_dir = pjoin(result_dir, 'joints')
119
+ animation_dir = pjoin(result_dir, 'animations')
120
+ os.makedirs(joints_dir, exist_ok=True)
121
+ os.makedirs(animation_dir,exist_ok=True)
122
+
123
+ model_opt_path = pjoin(root_dir, 'opt.txt')
124
+ model_opt = get_opt(model_opt_path, device=opt.device)
125
+
126
+
127
+ #######################
128
+ ######Loading RVQ######
129
+ #######################
130
+ vq_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'opt.txt')
131
+ vq_opt = get_opt(vq_opt_path, device=opt.device)
132
+ vq_opt.dim_pose = dim_pose
133
+ vq_model, vq_opt = load_vq_model(vq_opt)
134
+
135
+ model_opt.num_tokens = vq_opt.nb_code
136
+ model_opt.num_quantizers = vq_opt.num_quantizers
137
+ model_opt.code_dim = vq_opt.code_dim
138
+
139
+ #################################
140
+ ######Loading R-Transformer######
141
+ #################################
142
+ res_opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.res_name, 'opt.txt')
143
+ res_opt = get_opt(res_opt_path, device=opt.device)
144
+ res_model = load_res_model(res_opt, vq_opt, opt)
145
+
146
+ assert res_opt.vq_name == model_opt.vq_name
147
+
148
+ #################################
149
+ ######Loading M-Transformer######
150
+ #################################
151
+ t2m_transformer = load_trans_model(model_opt, opt, 'latest.tar')
152
+
153
+ ##################################
154
+ #####Loading Length Predictor#####
155
+ ##################################
156
+ length_estimator = load_len_estimator(model_opt)
157
+
158
+ t2m_transformer.eval()
159
+ vq_model.eval()
160
+ res_model.eval()
161
+ length_estimator.eval()
162
+
163
+ res_model.to(opt.device)
164
+ t2m_transformer.to(opt.device)
165
+ vq_model.to(opt.device)
166
+ length_estimator.to(opt.device)
167
+
168
+ ##### ---- Dataloader ---- #####
169
+ opt.nb_joints = 21 if opt.dataset_name == 'kit' else 22
170
+
171
+ mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'mean.npy'))
172
+ std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, model_opt.vq_name, 'meta', 'std.npy'))
173
+ def inv_transform(data):
174
+ return data * std + mean
175
+
176
+ prompt_list = []
177
+ length_list = []
178
+
179
+ est_length = False
180
+ if opt.text_prompt != "":
181
+ prompt_list.append(opt.text_prompt)
182
+ if opt.motion_length == 0:
183
+ est_length = True
184
+ else:
185
+ length_list.append(opt.motion_length)
186
+ elif opt.text_path != "":
187
+ with open(opt.text_path, 'r') as f:
188
+ lines = f.readlines()
189
+ for line in lines:
190
+ infos = line.split('#')
191
+ prompt_list.append(infos[0])
192
+ if len(infos) == 1 or (not infos[1].isdigit()):
193
+ est_length = True
194
+ length_list = []
195
+ else:
196
+ length_list.append(int(infos[-1]))
197
+ else:
198
+ raise "A text prompt, or a file a text prompts are required!!!"
199
+ # print('loading checkpoint {}'.format(file))
200
+
201
+ if est_length:
202
+ print("Since no motion length are specified, we will use estimated motion lengthes!!")
203
+ text_embedding = t2m_transformer.encode_text(prompt_list)
204
+ pred_dis = length_estimator(text_embedding)
205
+ probs = F.softmax(pred_dis, dim=-1) # (b, ntoken)
206
+ token_lens = Categorical(probs).sample() # (b, seqlen)
207
+ # lengths = torch.multinomial()
208
+ else:
209
+ token_lens = torch.LongTensor(length_list) // 4
210
+ token_lens = token_lens.to(opt.device).long()
211
+
212
+ m_length = token_lens * 4
213
+ captions = prompt_list
214
+
215
+ sample = 0
216
+ kinematic_chain = t2m_kinematic_chain
217
+ converter = Joint2BVHConvertor()
218
+
219
+ for r in range(opt.repeat_times):
220
+ print("-->Repeat %d"%r)
221
+ with torch.no_grad():
222
+ mids = t2m_transformer.generate(captions, token_lens,
223
+ timesteps=opt.time_steps,
224
+ cond_scale=opt.cond_scale,
225
+ temperature=opt.temperature,
226
+ topk_filter_thres=opt.topkr,
227
+ gsample=opt.gumbel_sample)
228
+ # print(mids)
229
+ # print(mids.shape)
230
+ mids = res_model.generate(mids, captions, token_lens, temperature=1, cond_scale=5)
231
+ pred_motions = vq_model.forward_decoder(mids)
232
+
233
+ pred_motions = pred_motions.detach().cpu().numpy()
234
+
235
+ data = inv_transform(pred_motions)
236
+
237
+ for k, (caption, joint_data) in enumerate(zip(captions, data)):
238
+ print("---->Sample %d: %s %d"%(k, caption, m_length[k]))
239
+ animation_path = pjoin(animation_dir, str(k))
240
+ joint_path = pjoin(joints_dir, str(k))
241
+
242
+ os.makedirs(animation_path, exist_ok=True)
243
+ os.makedirs(joint_path, exist_ok=True)
244
+
245
+ joint_data = joint_data[:m_length[k]]
246
+ joint = recover_from_ric(torch.from_numpy(joint_data).float(), 22).numpy()
247
+
248
+ bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.bvh"%(k, r, m_length[k]))
249
+ _, ik_joint = converter.convert(joint, filename=bvh_path, iterations=100)
250
+
251
+ bvh_path = pjoin(animation_path, "sample%d_repeat%d_len%d.bvh" % (k, r, m_length[k]))
252
+ _, joint = converter.convert(joint, filename=bvh_path, iterations=100, foot_ik=False)
253
+
254
+
255
+ save_path = pjoin(animation_path, "sample%d_repeat%d_len%d.mp4"%(k, r, m_length[k]))
256
+ ik_save_path = pjoin(animation_path, "sample%d_repeat%d_len%d_ik.mp4"%(k, r, m_length[k]))
257
+
258
+ plot_3d_motion(ik_save_path, kinematic_chain, ik_joint, title=caption, fps=20)
259
+ plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=20)
260
+ np.save(pjoin(joint_path, "sample%d_repeat%d_len%d.npy"%(k, r, m_length[k])), joint)
261
+ np.save(pjoin(joint_path, "sample%d_repeat%d_len%d_ik.npy"%(k, r, m_length[k])), ik_joint)
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/__init__.py ADDED
File without changes
models/mask_transformer/__init__.py ADDED
File without changes
models/mask_transformer/tools.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ from einops import rearrange
5
+
6
+ # return mask where padding is FALSE
7
+ def lengths_to_mask(lengths, max_len):
8
+ # max_len = max(lengths)
9
+ mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
10
+ return mask #(b, len)
11
+
12
+ # return mask where padding is ALL FALSE
13
+ def get_pad_mask_idx(seq, pad_idx):
14
+ return (seq != pad_idx).unsqueeze(1)
15
+
16
+ # Given seq: (b, s)
17
+ # Return mat: (1, s, s)
18
+ # Example Output:
19
+ # [[[ True, False, False],
20
+ # [ True, True, False],
21
+ # [ True, True, True]]]
22
+ # For causal attention
23
+ def get_subsequent_mask(seq):
24
+ sz_b, seq_len = seq.shape
25
+ subsequent_mask = (1 - torch.triu(
26
+ torch.ones((1, seq_len, seq_len)), diagonal=1)).bool()
27
+ return subsequent_mask.to(seq.device)
28
+
29
+
30
+ def exists(val):
31
+ return val is not None
32
+
33
+ def default(val, d):
34
+ return val if exists(val) else d
35
+
36
+ def eval_decorator(fn):
37
+ def inner(model, *args, **kwargs):
38
+ was_training = model.training
39
+ model.eval()
40
+ out = fn(model, *args, **kwargs)
41
+ model.train(was_training)
42
+ return out
43
+ return inner
44
+
45
+ def l2norm(t):
46
+ return F.normalize(t, dim = -1)
47
+
48
+ # tensor helpers
49
+
50
+ # Get a random subset of TRUE mask, with prob
51
+ def get_mask_subset_prob(mask, prob):
52
+ subset_mask = torch.bernoulli(mask, p=prob) & mask
53
+ return subset_mask
54
+
55
+
56
+ # Get mask of special_tokens in ids
57
+ def get_mask_special_tokens(ids, special_ids):
58
+ mask = torch.zeros_like(ids).bool()
59
+ for special_id in special_ids:
60
+ mask |= (ids==special_id)
61
+ return mask
62
+
63
+ # network builder helpers
64
+ def _get_activation_fn(activation):
65
+ if activation == "relu":
66
+ return F.relu
67
+ elif activation == "gelu":
68
+ return F.gelu
69
+
70
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
71
+
72
+ # classifier free guidance functions
73
+
74
+ def uniform(shape, device=None):
75
+ return torch.zeros(shape, device=device).float().uniform_(0, 1)
76
+
77
+ def prob_mask_like(shape, prob, device=None):
78
+ if prob == 1:
79
+ return torch.ones(shape, device=device, dtype=torch.bool)
80
+ elif prob == 0:
81
+ return torch.zeros(shape, device=device, dtype=torch.bool)
82
+ else:
83
+ return uniform(shape, device=device) < prob
84
+
85
+ # sampling helpers
86
+
87
+ def log(t, eps = 1e-20):
88
+ return torch.log(t.clamp(min = eps))
89
+
90
+ def gumbel_noise(t):
91
+ noise = torch.zeros_like(t).uniform_(0, 1)
92
+ return -log(-log(noise))
93
+
94
+ def gumbel_sample(t, temperature = 1., dim = 1):
95
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
96
+
97
+
98
+ # Example input:
99
+ # [[ 0.3596, 0.0862, 0.9771, -1.0000, -1.0000, -1.0000],
100
+ # [ 0.4141, 0.1781, 0.6628, 0.5721, -1.0000, -1.0000],
101
+ # [ 0.9428, 0.3586, 0.1659, 0.8172, 0.9273, -1.0000]]
102
+ # Example output:
103
+ # [[ -inf, -inf, 0.9771, -inf, -inf, -inf],
104
+ # [ -inf, -inf, 0.6628, -inf, -inf, -inf],
105
+ # [0.9428, -inf, -inf, -inf, -inf, -inf]]
106
+ def top_k(logits, thres = 0.9, dim = 1):
107
+ k = math.ceil((1 - thres) * logits.shape[dim])
108
+ val, ind = logits.topk(k, dim = dim)
109
+ probs = torch.full_like(logits, float('-inf'))
110
+ probs.scatter_(dim, ind, val)
111
+ # func verified
112
+ # print(probs)
113
+ # print(logits)
114
+ # raise
115
+ return probs
116
+
117
+ # noise schedules
118
+
119
+ # More on large value, less on small
120
+ def cosine_schedule(t):
121
+ return torch.cos(t * math.pi * 0.5)
122
+
123
+ def scale_cosine_schedule(t, scale):
124
+ return torch.clip(scale*torch.cos(t * math.pi * 0.5) + 1 - scale, min=0., max=1.)
125
+
126
+ # More on small value, less on large
127
+ def q_schedule(bs, low, high, device):
128
+ noise = uniform((bs,), device=device)
129
+ schedule = 1 - cosine_schedule(noise)
130
+ return torch.round(schedule * (high - low - 1)).long() + low
131
+
132
+ def cal_performance(pred, labels, ignore_index=None, smoothing=0., tk=1):
133
+ loss = cal_loss(pred, labels, ignore_index, smoothing=smoothing)
134
+ # pred_id = torch.argmax(pred, dim=1)
135
+ # mask = labels.ne(ignore_index)
136
+ # n_correct = pred_id.eq(labels).masked_select(mask)
137
+ # acc = torch.mean(n_correct.float()).item()
138
+ pred_id_k = torch.topk(pred, k=tk, dim=1).indices
139
+ pred_id = pred_id_k[:, 0]
140
+ mask = labels.ne(ignore_index)
141
+ n_correct = (pred_id_k == labels.unsqueeze(1)).any(dim=1).masked_select(mask)
142
+ acc = torch.mean(n_correct.float()).item()
143
+
144
+ return loss, pred_id, acc
145
+
146
+
147
+ def cal_loss(pred, labels, ignore_index=None, smoothing=0.):
148
+ '''Calculate cross entropy loss, apply label smoothing if needed.'''
149
+ # print(pred.shape, labels.shape) #torch.Size([64, 1028, 55]) torch.Size([64, 55])
150
+ # print(pred.shape, labels.shape) #torch.Size([64, 1027, 55]) torch.Size([64, 55])
151
+ if smoothing:
152
+ space = 2
153
+ n_class = pred.size(1)
154
+ mask = labels.ne(ignore_index)
155
+ one_hot = rearrange(F.one_hot(labels, n_class + space), 'a ... b -> a b ...')[:, :n_class]
156
+ # one_hot = torch.zeros_like(pred).scatter(1, labels.unsqueeze(1), 1)
157
+ sm_one_hot = one_hot * (1 - smoothing) + (1 - one_hot) * smoothing / (n_class - 1)
158
+ neg_log_prb = -F.log_softmax(pred, dim=1)
159
+ loss = (sm_one_hot * neg_log_prb).sum(dim=1)
160
+ # loss = F.cross_entropy(pred, sm_one_hot, reduction='none')
161
+ loss = torch.mean(loss.masked_select(mask))
162
+ else:
163
+ loss = F.cross_entropy(pred, labels, ignore_index=ignore_index)
164
+
165
+ return loss
models/mask_transformer/transformer.py ADDED
@@ -0,0 +1,1039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ # from networks.layers import *
5
+ import torch.nn.functional as F
6
+ import clip
7
+ from einops import rearrange, repeat
8
+ import math
9
+ from random import random
10
+ from tqdm.auto import tqdm
11
+ from typing import Callable, Optional, List, Dict
12
+ from copy import deepcopy
13
+ from functools import partial
14
+ from models.mask_transformer.tools import *
15
+ from torch.distributions.categorical import Categorical
16
+
17
+ class InputProcess(nn.Module):
18
+ def __init__(self, input_feats, latent_dim):
19
+ super().__init__()
20
+ self.input_feats = input_feats
21
+ self.latent_dim = latent_dim
22
+ self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)
23
+
24
+ def forward(self, x):
25
+ # [bs, ntokens, input_feats]
26
+ x = x.permute((1, 0, 2)) # [seqen, bs, input_feats]
27
+ # print(x.shape)
28
+ x = self.poseEmbedding(x) # [seqlen, bs, d]
29
+ return x
30
+
31
+ class PositionalEncoding(nn.Module):
32
+ #Borrow from MDM, the same as above, but add dropout, exponential may improve precision
33
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
34
+ super(PositionalEncoding, self).__init__()
35
+ self.dropout = nn.Dropout(p=dropout)
36
+
37
+ pe = torch.zeros(max_len, d_model)
38
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
39
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
40
+ pe[:, 0::2] = torch.sin(position * div_term)
41
+ pe[:, 1::2] = torch.cos(position * div_term)
42
+ pe = pe.unsqueeze(0).transpose(0, 1) #[max_len, 1, d_model]
43
+
44
+ self.register_buffer('pe', pe)
45
+
46
+ def forward(self, x):
47
+ # not used in the final model
48
+ x = x + self.pe[:x.shape[0], :]
49
+ return self.dropout(x)
50
+
51
+ class OutputProcess_Bert(nn.Module):
52
+ def __init__(self, out_feats, latent_dim):
53
+ super().__init__()
54
+ self.dense = nn.Linear(latent_dim, latent_dim)
55
+ self.transform_act_fn = F.gelu
56
+ self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12)
57
+ self.poseFinal = nn.Linear(latent_dim, out_feats) #Bias!
58
+
59
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
60
+ hidden_states = self.dense(hidden_states)
61
+ hidden_states = self.transform_act_fn(hidden_states)
62
+ hidden_states = self.LayerNorm(hidden_states)
63
+ output = self.poseFinal(hidden_states) # [seqlen, bs, out_feats]
64
+ output = output.permute(1, 2, 0) # [bs, c, seqlen]
65
+ return output
66
+
67
+ class OutputProcess(nn.Module):
68
+ def __init__(self, out_feats, latent_dim):
69
+ super().__init__()
70
+ self.dense = nn.Linear(latent_dim, latent_dim)
71
+ self.transform_act_fn = F.gelu
72
+ self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12)
73
+ self.poseFinal = nn.Linear(latent_dim, out_feats) #Bias!
74
+
75
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
76
+ hidden_states = self.dense(hidden_states)
77
+ hidden_states = self.transform_act_fn(hidden_states)
78
+ hidden_states = self.LayerNorm(hidden_states)
79
+ output = self.poseFinal(hidden_states) # [seqlen, bs, out_feats]
80
+ output = output.permute(1, 2, 0) # [bs, e, seqlen]
81
+ return output
82
+
83
+
84
+ class MaskTransformer(nn.Module):
85
+ def __init__(self, code_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8,
86
+ num_heads=4, dropout=0.1, clip_dim=512, cond_drop_prob=0.1,
87
+ clip_version=None, opt=None, **kargs):
88
+ super(MaskTransformer, self).__init__()
89
+ print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}')
90
+
91
+ self.code_dim = code_dim
92
+ self.latent_dim = latent_dim
93
+ self.clip_dim = clip_dim
94
+ self.dropout = dropout
95
+ self.opt = opt
96
+
97
+ self.cond_mode = cond_mode
98
+ self.cond_drop_prob = cond_drop_prob
99
+
100
+ if self.cond_mode == 'action':
101
+ assert 'num_actions' in kargs
102
+ self.num_actions = kargs.get('num_actions', 1)
103
+
104
+ '''
105
+ Preparing Networks
106
+ '''
107
+ self.input_process = InputProcess(self.code_dim, self.latent_dim)
108
+ self.position_enc = PositionalEncoding(self.latent_dim, self.dropout)
109
+
110
+ seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
111
+ nhead=num_heads,
112
+ dim_feedforward=ff_size,
113
+ dropout=dropout,
114
+ activation='gelu')
115
+
116
+ self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
117
+ num_layers=num_layers)
118
+
119
+ self.encode_action = partial(F.one_hot, num_classes=self.num_actions)
120
+
121
+ # if self.cond_mode != 'no_cond':
122
+ if self.cond_mode == 'text':
123
+ self.cond_emb = nn.Linear(self.clip_dim, self.latent_dim)
124
+ elif self.cond_mode == 'action':
125
+ self.cond_emb = nn.Linear(self.num_actions, self.latent_dim)
126
+ elif self.cond_mode == 'uncond':
127
+ self.cond_emb = nn.Identity()
128
+ else:
129
+ raise KeyError("Unsupported condition mode!!!")
130
+
131
+
132
+ _num_tokens = opt.num_tokens + 2 # two dummy tokens, one for masking, one for padding
133
+ self.mask_id = opt.num_tokens
134
+ self.pad_id = opt.num_tokens + 1
135
+
136
+ self.output_process = OutputProcess_Bert(out_feats=opt.num_tokens, latent_dim=latent_dim)
137
+
138
+ self.token_emb = nn.Embedding(_num_tokens, self.code_dim)
139
+
140
+ self.apply(self.__init_weights)
141
+
142
+ '''
143
+ Preparing frozen weights
144
+ '''
145
+
146
+ if self.cond_mode == 'text':
147
+ print('Loading CLIP...')
148
+ self.clip_version = clip_version
149
+ self.clip_model = self.load_and_freeze_clip(clip_version)
150
+
151
+ self.noise_schedule = cosine_schedule
152
+
153
+ def load_and_freeze_token_emb(self, codebook):
154
+ '''
155
+ :param codebook: (c, d)
156
+ :return:
157
+ '''
158
+ assert self.training, 'Only necessary in training mode'
159
+ c, d = codebook.shape
160
+ self.token_emb.weight = nn.Parameter(torch.cat([codebook, torch.zeros(size=(2, d), device=codebook.device)], dim=0)) #add two dummy tokens, 0 vectors
161
+ self.token_emb.requires_grad_(False)
162
+ # self.token_emb.weight.requires_grad = False
163
+ # self.token_emb_ready = True
164
+ print("Token embedding initialized!")
165
+
166
+ def __init_weights(self, module):
167
+ if isinstance(module, (nn.Linear, nn.Embedding)):
168
+ module.weight.data.normal_(mean=0.0, std=0.02)
169
+ if isinstance(module, nn.Linear) and module.bias is not None:
170
+ module.bias.data.zero_()
171
+ elif isinstance(module, nn.LayerNorm):
172
+ module.bias.data.zero_()
173
+ module.weight.data.fill_(1.0)
174
+
175
+ def parameters_wo_clip(self):
176
+ return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]
177
+
178
+ def load_and_freeze_clip(self, clip_version):
179
+ clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
180
+ jit=False) # Must set jit=False for training
181
+ # Cannot run on cpu
182
+ clip.model.convert_weights(
183
+ clip_model) # Actually this line is unnecessary since clip by default already on float16
184
+ # Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
185
+
186
+ # Freeze CLIP weights
187
+ clip_model.eval()
188
+ for p in clip_model.parameters():
189
+ p.requires_grad = False
190
+
191
+ return clip_model
192
+
193
+ def encode_text(self, raw_text):
194
+ device = next(self.parameters()).device
195
+ text = clip.tokenize(raw_text, truncate=True).to(device)
196
+ feat_clip_text = self.clip_model.encode_text(text).float()
197
+ return feat_clip_text
198
+
199
+ def mask_cond(self, cond, force_mask=False):
200
+ bs, d = cond.shape
201
+ if force_mask:
202
+ return torch.zeros_like(cond)
203
+ elif self.training and self.cond_drop_prob > 0.:
204
+ mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1)
205
+ return cond * (1. - mask)
206
+ else:
207
+ return cond
208
+
209
+ def trans_forward(self, motion_ids, cond, padding_mask, force_mask=False):
210
+ '''
211
+ :param motion_ids: (b, seqlen)
212
+ :padding_mask: (b, seqlen), all pad positions are TRUE else FALSE
213
+ :param cond: (b, embed_dim) for text, (b, num_actions) for action
214
+ :param force_mask: boolean
215
+ :return:
216
+ -logits: (b, num_token, seqlen)
217
+ '''
218
+
219
+ cond = self.mask_cond(cond, force_mask=force_mask)
220
+
221
+ # print(motion_ids.shape)
222
+ x = self.token_emb(motion_ids)
223
+ # print(x.shape)
224
+ # (b, seqlen, d) -> (seqlen, b, latent_dim)
225
+ x = self.input_process(x)
226
+
227
+ cond = self.cond_emb(cond).unsqueeze(0) #(1, b, latent_dim)
228
+
229
+ x = self.position_enc(x)
230
+ xseq = torch.cat([cond, x], dim=0) #(seqlen+1, b, latent_dim)
231
+
232
+ padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:1]), padding_mask], dim=1) #(b, seqlen+1)
233
+ # print(xseq.shape, padding_mask.shape)
234
+
235
+ # print(padding_mask.shape, xseq.shape)
236
+
237
+ output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[1:] #(seqlen, b, e)
238
+ logits = self.output_process(output) #(seqlen, b, e) -> (b, ntoken, seqlen)
239
+ return logits
240
+
241
+ def forward(self, ids, y, m_lens):
242
+ '''
243
+ :param ids: (b, n)
244
+ :param y: raw text for cond_mode=text, (b, ) for cond_mode=action
245
+ :m_lens: (b,)
246
+ :return:
247
+ '''
248
+
249
+ bs, ntokens = ids.shape
250
+ device = ids.device
251
+
252
+ # Positions that are PADDED are ALL FALSE
253
+ non_pad_mask = lengths_to_mask(m_lens, ntokens) #(b, n)
254
+ ids = torch.where(non_pad_mask, ids, self.pad_id)
255
+
256
+ force_mask = False
257
+ if self.cond_mode == 'text':
258
+ with torch.no_grad():
259
+ cond_vector = self.encode_text(y)
260
+ elif self.cond_mode == 'action':
261
+ cond_vector = self.enc_action(y).to(device).float()
262
+ elif self.cond_mode == 'uncond':
263
+ cond_vector = torch.zeros(bs, self.latent_dim).float().to(device)
264
+ force_mask = True
265
+ else:
266
+ raise NotImplementedError("Unsupported condition mode!!!")
267
+
268
+
269
+ '''
270
+ Prepare mask
271
+ '''
272
+ rand_time = uniform((bs,), device=device)
273
+ rand_mask_probs = self.noise_schedule(rand_time)
274
+ num_token_masked = (ntokens * rand_mask_probs).round().clamp(min=1)
275
+
276
+ batch_randperm = torch.rand((bs, ntokens), device=device).argsort(dim=-1)
277
+ # Positions to be MASKED are ALL TRUE
278
+ mask = batch_randperm < num_token_masked.unsqueeze(-1)
279
+
280
+ # Positions to be MASKED must also be NON-PADDED
281
+ mask &= non_pad_mask
282
+
283
+ # Note this is our training target, not input
284
+ labels = torch.where(mask, ids, self.mask_id)
285
+
286
+ x_ids = ids.clone()
287
+
288
+ # Further Apply Bert Masking Scheme
289
+ # Step 1: 10% replace with an incorrect token
290
+ mask_rid = get_mask_subset_prob(mask, 0.1)
291
+ rand_id = torch.randint_like(x_ids, high=self.opt.num_tokens)
292
+ x_ids = torch.where(mask_rid, rand_id, x_ids)
293
+ # Step 2: 90% x 10% replace with correct token, and 90% x 88% replace with mask token
294
+ mask_mid = get_mask_subset_prob(mask & ~mask_rid, 0.88)
295
+
296
+ # mask_mid = mask
297
+
298
+ x_ids = torch.where(mask_mid, self.mask_id, x_ids)
299
+
300
+ logits = self.trans_forward(x_ids, cond_vector, ~non_pad_mask, force_mask)
301
+ ce_loss, pred_id, acc = cal_performance(logits, labels, ignore_index=self.mask_id)
302
+
303
+ return ce_loss, pred_id, acc
304
+
305
+ def forward_with_cond_scale(self,
306
+ motion_ids,
307
+ cond_vector,
308
+ padding_mask,
309
+ cond_scale=3,
310
+ force_mask=False):
311
+ # bs = motion_ids.shape[0]
312
+ # if cond_scale == 1:
313
+ if force_mask:
314
+ return self.trans_forward(motion_ids, cond_vector, padding_mask, force_mask=True)
315
+
316
+ logits = self.trans_forward(motion_ids, cond_vector, padding_mask)
317
+ if cond_scale == 1:
318
+ return logits
319
+
320
+ aux_logits = self.trans_forward(motion_ids, cond_vector, padding_mask, force_mask=True)
321
+
322
+ scaled_logits = aux_logits + (logits - aux_logits) * cond_scale
323
+ return scaled_logits
324
+
325
+ @torch.no_grad()
326
+ @eval_decorator
327
+ def generate(self,
328
+ conds,
329
+ m_lens,
330
+ timesteps: int,
331
+ cond_scale: int,
332
+ temperature=1,
333
+ topk_filter_thres=0.9,
334
+ gsample=False,
335
+ force_mask=False
336
+ ):
337
+ # print(self.opt.num_quantizers)
338
+ # assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers
339
+
340
+ device = next(self.parameters()).device
341
+ seq_len = max(m_lens)
342
+ batch_size = len(m_lens)
343
+
344
+ if self.cond_mode == 'text':
345
+ with torch.no_grad():
346
+ cond_vector = self.encode_text(conds)
347
+ elif self.cond_mode == 'action':
348
+ cond_vector = self.enc_action(conds).to(device)
349
+ elif self.cond_mode == 'uncond':
350
+ cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device)
351
+ else:
352
+ raise NotImplementedError("Unsupported condition mode!!!")
353
+
354
+ padding_mask = ~lengths_to_mask(m_lens, seq_len)
355
+ # print(padding_mask.shape, )
356
+
357
+ # Start from all tokens being masked
358
+ ids = torch.where(padding_mask, self.pad_id, self.mask_id)
359
+ scores = torch.where(padding_mask, 1e5, 0.)
360
+ starting_temperature = temperature
361
+
362
+ for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
363
+ # 0 < timestep < 1
364
+ rand_mask_prob = self.noise_schedule(timestep) # Tensor
365
+
366
+ '''
367
+ Maskout, and cope with variable length
368
+ '''
369
+ # fix: the ratio regarding lengths, instead of seq_len
370
+ num_token_masked = torch.round(rand_mask_prob * m_lens).clamp(min=1) # (b, )
371
+
372
+ # select num_token_masked tokens with lowest scores to be masked
373
+ sorted_indices = scores.argsort(
374
+ dim=1) # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1
375
+ ranks = sorted_indices.argsort(dim=1) # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1
376
+ is_mask = (ranks < num_token_masked.unsqueeze(-1))
377
+ ids = torch.where(is_mask, self.mask_id, ids)
378
+
379
+ '''
380
+ Preparing input
381
+ '''
382
+ # (b, num_token, seqlen)
383
+ logits = self.forward_with_cond_scale(ids, cond_vector=cond_vector,
384
+ padding_mask=padding_mask,
385
+ cond_scale=cond_scale,
386
+ force_mask=force_mask)
387
+
388
+ logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
389
+ # print(logits.shape, self.opt.num_tokens)
390
+ # clean low prob token
391
+ filtered_logits = top_k(logits, topk_filter_thres, dim=-1)
392
+
393
+ '''
394
+ Update ids
395
+ '''
396
+ # if force_mask:
397
+ temperature = starting_temperature
398
+ # else:
399
+ # temperature = starting_temperature * (steps_until_x0 / timesteps)
400
+ # temperature = max(temperature, 1e-4)
401
+ # print(filtered_logits.shape)
402
+ # temperature is annealed, gradually reducing temperature as well as randomness
403
+ if gsample: # use gumbel_softmax sampling
404
+ # print("1111")
405
+ pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen)
406
+ else: # use multinomial sampling
407
+ # print("2222")
408
+ probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken)
409
+ # print(temperature, starting_temperature, steps_until_x0, timesteps)
410
+ # print(probs / temperature)
411
+ pred_ids = Categorical(probs / temperature).sample() # (b, seqlen)
412
+
413
+ # print(pred_ids.max(), pred_ids.min())
414
+ # if pred_ids.
415
+ ids = torch.where(is_mask, pred_ids, ids)
416
+
417
+ '''
418
+ Updating scores
419
+ '''
420
+ probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken)
421
+ scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1)) # (b, seqlen, 1)
422
+ scores = scores.squeeze(-1) # (b, seqlen)
423
+
424
+ # We do not want to re-mask the previously kept tokens, or pad tokens
425
+ scores = scores.masked_fill(~is_mask, 1e5)
426
+
427
+ ids = torch.where(padding_mask, -1, ids)
428
+ # print("Final", ids.max(), ids.min())
429
+ return ids
430
+
431
+
432
+ @torch.no_grad()
433
+ @eval_decorator
434
+ def edit(self,
435
+ conds,
436
+ tokens,
437
+ m_lens,
438
+ timesteps: int,
439
+ cond_scale: int,
440
+ temperature=1,
441
+ topk_filter_thres=0.9,
442
+ gsample=False,
443
+ force_mask=False,
444
+ edit_mask=None,
445
+ padding_mask=None,
446
+ ):
447
+
448
+ assert edit_mask.shape == tokens.shape if edit_mask is not None else True
449
+ device = next(self.parameters()).device
450
+ seq_len = tokens.shape[1]
451
+
452
+ if self.cond_mode == 'text':
453
+ with torch.no_grad():
454
+ cond_vector = self.encode_text(conds)
455
+ elif self.cond_mode == 'action':
456
+ cond_vector = self.enc_action(conds).to(device)
457
+ elif self.cond_mode == 'uncond':
458
+ cond_vector = torch.zeros(1, self.latent_dim).float().to(device)
459
+ else:
460
+ raise NotImplementedError("Unsupported condition mode!!!")
461
+
462
+ if padding_mask == None:
463
+ padding_mask = ~lengths_to_mask(m_lens, seq_len)
464
+
465
+ # Start from all tokens being masked
466
+ if edit_mask == None:
467
+ mask_free = True
468
+ ids = torch.where(padding_mask, self.pad_id, tokens)
469
+ edit_mask = torch.ones_like(padding_mask)
470
+ edit_mask = edit_mask & ~padding_mask
471
+ edit_len = edit_mask.sum(dim=-1)
472
+ scores = torch.where(edit_mask, 0., 1e5)
473
+ else:
474
+ mask_free = False
475
+ edit_mask = edit_mask & ~padding_mask
476
+ edit_len = edit_mask.sum(dim=-1)
477
+ ids = torch.where(edit_mask, self.mask_id, tokens)
478
+ scores = torch.where(edit_mask, 0., 1e5)
479
+ starting_temperature = temperature
480
+
481
+ for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
482
+ # 0 < timestep < 1
483
+ rand_mask_prob = 0.16 if mask_free else self.noise_schedule(timestep) # Tensor
484
+
485
+ '''
486
+ Maskout, and cope with variable length
487
+ '''
488
+ # fix: the ratio regarding lengths, instead of seq_len
489
+ num_token_masked = torch.round(rand_mask_prob * edit_len).clamp(min=1) # (b, )
490
+
491
+ # select num_token_masked tokens with lowest scores to be masked
492
+ sorted_indices = scores.argsort(
493
+ dim=1) # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1
494
+ ranks = sorted_indices.argsort(dim=1) # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1
495
+ is_mask = (ranks < num_token_masked.unsqueeze(-1))
496
+ # is_mask = (torch.rand_like(scores) < 0.8) * ~padding_mask if mask_free else is_mask
497
+ ids = torch.where(is_mask, self.mask_id, ids)
498
+
499
+ '''
500
+ Preparing input
501
+ '''
502
+ # (b, num_token, seqlen)
503
+ logits = self.forward_with_cond_scale(ids, cond_vector=cond_vector,
504
+ padding_mask=padding_mask,
505
+ cond_scale=cond_scale,
506
+ force_mask=force_mask)
507
+
508
+ logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
509
+ # print(logits.shape, self.opt.num_tokens)
510
+ # clean low prob token
511
+ filtered_logits = top_k(logits, topk_filter_thres, dim=-1)
512
+
513
+ '''
514
+ Update ids
515
+ '''
516
+ # if force_mask:
517
+ temperature = starting_temperature
518
+ # else:
519
+ # temperature = starting_temperature * (steps_until_x0 / timesteps)
520
+ # temperature = max(temperature, 1e-4)
521
+ # print(filtered_logits.shape)
522
+ # temperature is annealed, gradually reducing temperature as well as randomness
523
+ if gsample: # use gumbel_softmax sampling
524
+ # print("1111")
525
+ pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen)
526
+ else: # use multinomial sampling
527
+ # print("2222")
528
+ probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken)
529
+ # print(temperature, starting_temperature, steps_until_x0, timesteps)
530
+ # print(probs / temperature)
531
+ pred_ids = Categorical(probs / temperature).sample() # (b, seqlen)
532
+
533
+ # print(pred_ids.max(), pred_ids.min())
534
+ # if pred_ids.
535
+ ids = torch.where(is_mask, pred_ids, ids)
536
+
537
+ '''
538
+ Updating scores
539
+ '''
540
+ probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken)
541
+ scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1)) # (b, seqlen, 1)
542
+ scores = scores.squeeze(-1) # (b, seqlen)
543
+
544
+ # We do not want to re-mask the previously kept tokens, or pad tokens
545
+ scores = scores.masked_fill(~edit_mask, 1e5) if mask_free else scores.masked_fill(~is_mask, 1e5)
546
+
547
+ ids = torch.where(padding_mask, -1, ids)
548
+ # print("Final", ids.max(), ids.min())
549
+ return ids
550
+
551
+ @torch.no_grad()
552
+ @eval_decorator
553
+ def edit_beta(self,
554
+ conds,
555
+ conds_og,
556
+ tokens,
557
+ m_lens,
558
+ cond_scale: int,
559
+ force_mask=False,
560
+ ):
561
+
562
+ device = next(self.parameters()).device
563
+ seq_len = tokens.shape[1]
564
+
565
+ if self.cond_mode == 'text':
566
+ with torch.no_grad():
567
+ cond_vector = self.encode_text(conds)
568
+ if conds_og is not None:
569
+ cond_vector_og = self.encode_text(conds_og)
570
+ else:
571
+ cond_vector_og = None
572
+ elif self.cond_mode == 'action':
573
+ cond_vector = self.enc_action(conds).to(device)
574
+ if conds_og is not None:
575
+ cond_vector_og = self.enc_action(conds_og).to(device)
576
+ else:
577
+ cond_vector_og = None
578
+ else:
579
+ raise NotImplementedError("Unsupported condition mode!!!")
580
+
581
+ padding_mask = ~lengths_to_mask(m_lens, seq_len)
582
+
583
+ # Start from all tokens being masked
584
+ ids = torch.where(padding_mask, self.pad_id, tokens) # Do not mask anything
585
+
586
+ '''
587
+ Preparing input
588
+ '''
589
+ # (b, num_token, seqlen)
590
+ logits = self.forward_with_cond_scale(ids,
591
+ cond_vector=cond_vector,
592
+ cond_vector_neg=cond_vector_og,
593
+ padding_mask=padding_mask,
594
+ cond_scale=cond_scale,
595
+ force_mask=force_mask)
596
+
597
+ logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
598
+
599
+ '''
600
+ Updating scores
601
+ '''
602
+ probs_without_temperature = logits.softmax(dim=-1) # (b, seqlen, ntoken)
603
+ tokens[tokens == -1] = 0 # just to get through an error when index = -1 using gather
604
+ og_tokens_scores = probs_without_temperature.gather(2, tokens.unsqueeze(dim=-1)) # (b, seqlen, 1)
605
+ og_tokens_scores = og_tokens_scores.squeeze(-1) # (b, seqlen)
606
+
607
+ return og_tokens_scores
608
+
609
+
610
+ class ResidualTransformer(nn.Module):
611
+ def __init__(self, code_dim, cond_mode, latent_dim=256, ff_size=1024, num_layers=8, cond_drop_prob=0.1,
612
+ num_heads=4, dropout=0.1, clip_dim=512, shared_codebook=False, share_weight=False,
613
+ clip_version=None, opt=None, **kargs):
614
+ super(ResidualTransformer, self).__init__()
615
+ print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}')
616
+
617
+ # assert shared_codebook == True, "Only support shared codebook right now!"
618
+
619
+ self.code_dim = code_dim
620
+ self.latent_dim = latent_dim
621
+ self.clip_dim = clip_dim
622
+ self.dropout = dropout
623
+ self.opt = opt
624
+
625
+ self.cond_mode = cond_mode
626
+ # self.cond_drop_prob = cond_drop_prob
627
+
628
+ if self.cond_mode == 'action':
629
+ assert 'num_actions' in kargs
630
+ self.num_actions = kargs.get('num_actions', 1)
631
+ self.cond_drop_prob = cond_drop_prob
632
+
633
+ '''
634
+ Preparing Networks
635
+ '''
636
+ self.input_process = InputProcess(self.code_dim, self.latent_dim)
637
+ self.position_enc = PositionalEncoding(self.latent_dim, self.dropout)
638
+
639
+ seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
640
+ nhead=num_heads,
641
+ dim_feedforward=ff_size,
642
+ dropout=dropout,
643
+ activation='gelu')
644
+
645
+ self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
646
+ num_layers=num_layers)
647
+
648
+ self.encode_quant = partial(F.one_hot, num_classes=self.opt.num_quantizers)
649
+ self.encode_action = partial(F.one_hot, num_classes=self.num_actions)
650
+
651
+ self.quant_emb = nn.Linear(self.opt.num_quantizers, self.latent_dim)
652
+ # if self.cond_mode != 'no_cond':
653
+ if self.cond_mode == 'text':
654
+ self.cond_emb = nn.Linear(self.clip_dim, self.latent_dim)
655
+ elif self.cond_mode == 'action':
656
+ self.cond_emb = nn.Linear(self.num_actions, self.latent_dim)
657
+ else:
658
+ raise KeyError("Unsupported condition mode!!!")
659
+
660
+
661
+ _num_tokens = opt.num_tokens + 1 # one dummy tokens for padding
662
+ self.pad_id = opt.num_tokens
663
+
664
+ # self.output_process = OutputProcess_Bert(out_feats=opt.num_tokens, latent_dim=latent_dim)
665
+ self.output_process = OutputProcess(out_feats=code_dim, latent_dim=latent_dim)
666
+
667
+ if shared_codebook:
668
+ token_embed = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim)))
669
+ self.token_embed_weight = token_embed.expand(opt.num_quantizers-1, _num_tokens, code_dim)
670
+ if share_weight:
671
+ self.output_proj_weight = self.token_embed_weight
672
+ self.output_proj_bias = None
673
+ else:
674
+ output_proj = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim)))
675
+ output_bias = nn.Parameter(torch.zeros(size=(_num_tokens,)))
676
+ # self.output_proj_bias = 0
677
+ self.output_proj_weight = output_proj.expand(opt.num_quantizers-1, _num_tokens, code_dim)
678
+ self.output_proj_bias = output_bias.expand(opt.num_quantizers-1, _num_tokens)
679
+
680
+ else:
681
+ if share_weight:
682
+ self.embed_proj_shared_weight = nn.Parameter(torch.normal(mean=0, std=0.02, size=(opt.num_quantizers - 2, _num_tokens, code_dim)))
683
+ self.token_embed_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim)))
684
+ self.output_proj_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim)))
685
+ self.output_proj_bias = None
686
+ self.registered = False
687
+ else:
688
+ output_proj_weight = torch.normal(mean=0, std=0.02,
689
+ size=(opt.num_quantizers - 1, _num_tokens, code_dim))
690
+
691
+ self.output_proj_weight = nn.Parameter(output_proj_weight)
692
+ self.output_proj_bias = nn.Parameter(torch.zeros(size=(opt.num_quantizers, _num_tokens)))
693
+ token_embed_weight = torch.normal(mean=0, std=0.02,
694
+ size=(opt.num_quantizers - 1, _num_tokens, code_dim))
695
+ self.token_embed_weight = nn.Parameter(token_embed_weight)
696
+
697
+ self.apply(self.__init_weights)
698
+ self.shared_codebook = shared_codebook
699
+ self.share_weight = share_weight
700
+
701
+ if self.cond_mode == 'text':
702
+ print('Loading CLIP...')
703
+ self.clip_version = clip_version
704
+ self.clip_model = self.load_and_freeze_clip(clip_version)
705
+
706
+ # def
707
+
708
+ def mask_cond(self, cond, force_mask=False):
709
+ bs, d = cond.shape
710
+ if force_mask:
711
+ return torch.zeros_like(cond)
712
+ elif self.training and self.cond_drop_prob > 0.:
713
+ mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1)
714
+ return cond * (1. - mask)
715
+ else:
716
+ return cond
717
+
718
+ def __init_weights(self, module):
719
+ if isinstance(module, (nn.Linear, nn.Embedding)):
720
+ module.weight.data.normal_(mean=0.0, std=0.02)
721
+ if isinstance(module, nn.Linear) and module.bias is not None:
722
+ module.bias.data.zero_()
723
+ elif isinstance(module, nn.LayerNorm):
724
+ module.bias.data.zero_()
725
+ module.weight.data.fill_(1.0)
726
+
727
+ def parameters_wo_clip(self):
728
+ return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]
729
+
730
+ def load_and_freeze_clip(self, clip_version):
731
+ clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
732
+ jit=False) # Must set jit=False for training
733
+ # Cannot run on cpu
734
+ clip.model.convert_weights(
735
+ clip_model) # Actually this line is unnecessary since clip by default already on float16
736
+ # Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu
737
+
738
+ # Freeze CLIP weights
739
+ clip_model.eval()
740
+ for p in clip_model.parameters():
741
+ p.requires_grad = False
742
+
743
+ return clip_model
744
+
745
+ def encode_text(self, raw_text):
746
+ device = next(self.parameters()).device
747
+ text = clip.tokenize(raw_text, truncate=True).to(device)
748
+ feat_clip_text = self.clip_model.encode_text(text).float()
749
+ return feat_clip_text
750
+
751
+
752
+ def q_schedule(self, bs, low, high):
753
+ noise = uniform((bs,), device=self.opt.device)
754
+ schedule = 1 - cosine_schedule(noise)
755
+ return torch.round(schedule * (high - low)) + low
756
+
757
+ def process_embed_proj_weight(self):
758
+ if self.share_weight and (not self.shared_codebook):
759
+ # if not self.registered:
760
+ self.output_proj_weight = torch.cat([self.embed_proj_shared_weight, self.output_proj_weight_], dim=0)
761
+ self.token_embed_weight = torch.cat([self.token_embed_weight_, self.embed_proj_shared_weight], dim=0)
762
+ # self.registered = True
763
+
764
+ def output_project(self, logits, qids):
765
+ '''
766
+ :logits: (bs, code_dim, seqlen)
767
+ :qids: (bs)
768
+
769
+ :return:
770
+ -logits (bs, ntoken, seqlen)
771
+ '''
772
+ # (num_qlayers-1, num_token, code_dim) -> (bs, ntoken, code_dim)
773
+ output_proj_weight = self.output_proj_weight[qids]
774
+ # (num_qlayers, ntoken) -> (bs, ntoken)
775
+ output_proj_bias = None if self.output_proj_bias is None else self.output_proj_bias[qids]
776
+
777
+ output = torch.einsum('bnc, bcs->bns', output_proj_weight, logits)
778
+ if output_proj_bias is not None:
779
+ output += output + output_proj_bias.unsqueeze(-1)
780
+ return output
781
+
782
+
783
+
784
+ def trans_forward(self, motion_codes, qids, cond, padding_mask, force_mask=False):
785
+ '''
786
+ :param motion_codes: (b, seqlen, d)
787
+ :padding_mask: (b, seqlen), all pad positions are TRUE else FALSE
788
+ :param qids: (b), quantizer layer ids
789
+ :param cond: (b, embed_dim) for text, (b, num_actions) for action
790
+ :return:
791
+ -logits: (b, num_token, seqlen)
792
+ '''
793
+ cond = self.mask_cond(cond, force_mask=force_mask)
794
+
795
+ # (b, seqlen, d) -> (seqlen, b, latent_dim)
796
+ x = self.input_process(motion_codes)
797
+
798
+ # (b, num_quantizer)
799
+ q_onehot = self.encode_quant(qids).float().to(x.device)
800
+
801
+ q_emb = self.quant_emb(q_onehot).unsqueeze(0) # (1, b, latent_dim)
802
+ cond = self.cond_emb(cond).unsqueeze(0) # (1, b, latent_dim)
803
+
804
+ x = self.position_enc(x)
805
+ xseq = torch.cat([cond, q_emb, x], dim=0) # (seqlen+2, b, latent_dim)
806
+
807
+ padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:2]), padding_mask], dim=1) # (b, seqlen+2)
808
+ output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[2:] # (seqlen, b, e)
809
+ logits = self.output_process(output)
810
+ return logits
811
+
812
+ def forward_with_cond_scale(self,
813
+ motion_codes,
814
+ q_id,
815
+ cond_vector,
816
+ padding_mask,
817
+ cond_scale=3,
818
+ force_mask=False):
819
+ bs = motion_codes.shape[0]
820
+ # if cond_scale == 1:
821
+ qids = torch.full((bs,), q_id, dtype=torch.long, device=motion_codes.device)
822
+ if force_mask:
823
+ logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask, force_mask=True)
824
+ logits = self.output_project(logits, qids-1)
825
+ return logits
826
+
827
+ logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask)
828
+ logits = self.output_project(logits, qids-1)
829
+ if cond_scale == 1:
830
+ return logits
831
+
832
+ aux_logits = self.trans_forward(motion_codes, qids, cond_vector, padding_mask, force_mask=True)
833
+ aux_logits = self.output_project(aux_logits, qids-1)
834
+
835
+ scaled_logits = aux_logits + (logits - aux_logits) * cond_scale
836
+ return scaled_logits
837
+
838
+ def forward(self, all_indices, y, m_lens):
839
+ '''
840
+ :param all_indices: (b, n, q)
841
+ :param y: raw text for cond_mode=text, (b, ) for cond_mode=action
842
+ :m_lens: (b,)
843
+ :return:
844
+ '''
845
+
846
+ self.process_embed_proj_weight()
847
+
848
+ bs, ntokens, num_quant_layers = all_indices.shape
849
+ device = all_indices.device
850
+
851
+ # Positions that are PADDED are ALL FALSE
852
+ non_pad_mask = lengths_to_mask(m_lens, ntokens) # (b, n)
853
+
854
+ q_non_pad_mask = repeat(non_pad_mask, 'b n -> b n q', q=num_quant_layers)
855
+ all_indices = torch.where(q_non_pad_mask, all_indices, self.pad_id) #(b, n, q)
856
+
857
+ # randomly sample quantization layers to work on, [1, num_q)
858
+ active_q_layers = q_schedule(bs, low=1, high=num_quant_layers, device=device)
859
+
860
+ # print(self.token_embed_weight.shape, all_indices.shape)
861
+ token_embed = repeat(self.token_embed_weight, 'q c d-> b c d q', b=bs)
862
+ gather_indices = repeat(all_indices[..., :-1], 'b n q -> b n d q', d=token_embed.shape[2])
863
+ # print(token_embed.shape, gather_indices.shape)
864
+ all_codes = token_embed.gather(1, gather_indices) # (b, n, d, q-1)
865
+
866
+ cumsum_codes = torch.cumsum(all_codes, dim=-1) #(b, n, d, q-1)
867
+
868
+ active_indices = all_indices[torch.arange(bs), :, active_q_layers] # (b, n)
869
+ history_sum = cumsum_codes[torch.arange(bs), :, :, active_q_layers - 1]
870
+
871
+ force_mask = False
872
+ if self.cond_mode == 'text':
873
+ with torch.no_grad():
874
+ cond_vector = self.encode_text(y)
875
+ elif self.cond_mode == 'action':
876
+ cond_vector = self.enc_action(y).to(device).float()
877
+ elif self.cond_mode == 'uncond':
878
+ cond_vector = torch.zeros(bs, self.latent_dim).float().to(device)
879
+ force_mask = True
880
+ else:
881
+ raise NotImplementedError("Unsupported condition mode!!!")
882
+
883
+ logits = self.trans_forward(history_sum, active_q_layers, cond_vector, ~non_pad_mask, force_mask)
884
+ logits = self.output_project(logits, active_q_layers-1)
885
+ ce_loss, pred_id, acc = cal_performance(logits, active_indices, ignore_index=self.pad_id)
886
+
887
+ return ce_loss, pred_id, acc
888
+
889
+ @torch.no_grad()
890
+ @eval_decorator
891
+ def generate(self,
892
+ motion_ids,
893
+ conds,
894
+ m_lens,
895
+ temperature=1,
896
+ topk_filter_thres=0.9,
897
+ cond_scale=2,
898
+ num_res_layers=-1, # If it's -1, use all.
899
+ ):
900
+
901
+ # print(self.opt.num_quantizers)
902
+ # assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers
903
+ self.process_embed_proj_weight()
904
+
905
+ device = next(self.parameters()).device
906
+ seq_len = motion_ids.shape[1]
907
+ batch_size = len(conds)
908
+
909
+ if self.cond_mode == 'text':
910
+ with torch.no_grad():
911
+ cond_vector = self.encode_text(conds)
912
+ elif self.cond_mode == 'action':
913
+ cond_vector = self.enc_action(conds).to(device)
914
+ elif self.cond_mode == 'uncond':
915
+ cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device)
916
+ else:
917
+ raise NotImplementedError("Unsupported condition mode!!!")
918
+
919
+ # token_embed = repeat(self.token_embed_weight, 'c d -> b c d', b=batch_size)
920
+ # gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
921
+ # history_sum = token_embed.gather(1, gathered_ids)
922
+
923
+ # print(pa, seq_len)
924
+ padding_mask = ~lengths_to_mask(m_lens, seq_len)
925
+ # print(padding_mask.shape, motion_ids.shape)
926
+ motion_ids = torch.where(padding_mask, self.pad_id, motion_ids)
927
+ all_indices = [motion_ids]
928
+ history_sum = 0
929
+ num_quant_layers = self.opt.num_quantizers if num_res_layers==-1 else num_res_layers+1
930
+
931
+ for i in range(1, num_quant_layers):
932
+ # print(f"--> Working on {i}-th quantizer")
933
+ # Start from all tokens being masked
934
+ # qids = torch.full((batch_size,), i, dtype=torch.long, device=motion_ids.device)
935
+ token_embed = self.token_embed_weight[i-1]
936
+ token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size)
937
+ gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
938
+ history_sum += token_embed.gather(1, gathered_ids)
939
+
940
+ logits = self.forward_with_cond_scale(history_sum, i, cond_vector, padding_mask, cond_scale=cond_scale)
941
+ # logits = self.trans_forward(history_sum, qids, cond_vector, padding_mask)
942
+
943
+ logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
944
+ # clean low prob token
945
+ filtered_logits = top_k(logits, topk_filter_thres, dim=-1)
946
+
947
+ pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen)
948
+
949
+ # probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken)
950
+ # # print(temperature, starting_temperature, steps_until_x0, timesteps)
951
+ # # print(probs / temperature)
952
+ # pred_ids = Categorical(probs / temperature).sample() # (b, seqlen)
953
+
954
+ ids = torch.where(padding_mask, self.pad_id, pred_ids)
955
+
956
+ motion_ids = ids
957
+ all_indices.append(ids)
958
+
959
+ all_indices = torch.stack(all_indices, dim=-1)
960
+ # padding_mask = repeat(padding_mask, 'b n -> b n q', q=all_indices.shape[-1])
961
+ # all_indices = torch.where(padding_mask, -1, all_indices)
962
+ all_indices = torch.where(all_indices==self.pad_id, -1, all_indices)
963
+ # all_indices = all_indices.masked_fill()
964
+ return all_indices
965
+
966
+ @torch.no_grad()
967
+ @eval_decorator
968
+ def edit(self,
969
+ motion_ids,
970
+ conds,
971
+ m_lens,
972
+ temperature=1,
973
+ topk_filter_thres=0.9,
974
+ cond_scale=2
975
+ ):
976
+
977
+ # print(self.opt.num_quantizers)
978
+ # assert len(timesteps) >= len(cond_scales) == self.opt.num_quantizers
979
+ self.process_embed_proj_weight()
980
+
981
+ device = next(self.parameters()).device
982
+ seq_len = motion_ids.shape[1]
983
+ batch_size = len(conds)
984
+
985
+ if self.cond_mode == 'text':
986
+ with torch.no_grad():
987
+ cond_vector = self.encode_text(conds)
988
+ elif self.cond_mode == 'action':
989
+ cond_vector = self.enc_action(conds).to(device)
990
+ elif self.cond_mode == 'uncond':
991
+ cond_vector = torch.zeros(batch_size, self.latent_dim).float().to(device)
992
+ else:
993
+ raise NotImplementedError("Unsupported condition mode!!!")
994
+
995
+ # token_embed = repeat(self.token_embed_weight, 'c d -> b c d', b=batch_size)
996
+ # gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
997
+ # history_sum = token_embed.gather(1, gathered_ids)
998
+
999
+ # print(pa, seq_len)
1000
+ padding_mask = ~lengths_to_mask(m_lens, seq_len)
1001
+ # print(padding_mask.shape, motion_ids.shape)
1002
+ motion_ids = torch.where(padding_mask, self.pad_id, motion_ids)
1003
+ all_indices = [motion_ids]
1004
+ history_sum = 0
1005
+
1006
+ for i in range(1, self.opt.num_quantizers):
1007
+ # print(f"--> Working on {i}-th quantizer")
1008
+ # Start from all tokens being masked
1009
+ # qids = torch.full((batch_size,), i, dtype=torch.long, device=motion_ids.device)
1010
+ token_embed = self.token_embed_weight[i-1]
1011
+ token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size)
1012
+ gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
1013
+ history_sum += token_embed.gather(1, gathered_ids)
1014
+
1015
+ logits = self.forward_with_cond_scale(history_sum, i, cond_vector, padding_mask, cond_scale=cond_scale)
1016
+ # logits = self.trans_forward(history_sum, qids, cond_vector, padding_mask)
1017
+
1018
+ logits = logits.permute(0, 2, 1) # (b, seqlen, ntoken)
1019
+ # clean low prob token
1020
+ filtered_logits = top_k(logits, topk_filter_thres, dim=-1)
1021
+
1022
+ pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) # (b, seqlen)
1023
+
1024
+ # probs = F.softmax(filtered_logits, dim=-1) # (b, seqlen, ntoken)
1025
+ # # print(temperature, starting_temperature, steps_until_x0, timesteps)
1026
+ # # print(probs / temperature)
1027
+ # pred_ids = Categorical(probs / temperature).sample() # (b, seqlen)
1028
+
1029
+ ids = torch.where(padding_mask, self.pad_id, pred_ids)
1030
+
1031
+ motion_ids = ids
1032
+ all_indices.append(ids)
1033
+
1034
+ all_indices = torch.stack(all_indices, dim=-1)
1035
+ # padding_mask = repeat(padding_mask, 'b n -> b n q', q=all_indices.shape[-1])
1036
+ # all_indices = torch.where(padding_mask, -1, all_indices)
1037
+ all_indices = torch.where(all_indices==self.pad_id, -1, all_indices)
1038
+ # all_indices = all_indices.masked_fill()
1039
+ return all_indices
models/mask_transformer/transformer_trainer.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import defaultdict
3
+ import torch.optim as optim
4
+ # import tensorflow as tf
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ from collections import OrderedDict
7
+ from utils.utils import *
8
+ from os.path import join as pjoin
9
+ from utils.eval_t2m import evaluation_mask_transformer, evaluation_res_transformer
10
+ from models.mask_transformer.tools import *
11
+
12
+ from einops import rearrange, repeat
13
+
14
+ def def_value():
15
+ return 0.0
16
+
17
+ class MaskTransformerTrainer:
18
+ def __init__(self, args, t2m_transformer, vq_model):
19
+ self.opt = args
20
+ self.t2m_transformer = t2m_transformer
21
+ self.vq_model = vq_model
22
+ self.device = args.device
23
+ self.vq_model.eval()
24
+
25
+ if args.is_train:
26
+ self.logger = SummaryWriter(args.log_dir)
27
+
28
+
29
+ def update_lr_warm_up(self, nb_iter, warm_up_iter, lr):
30
+
31
+ current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
32
+ for param_group in self.opt_t2m_transformer.param_groups:
33
+ param_group["lr"] = current_lr
34
+
35
+ return current_lr
36
+
37
+
38
+ def forward(self, batch_data):
39
+
40
+ conds, motion, m_lens = batch_data
41
+ motion = motion.detach().float().to(self.device)
42
+ m_lens = m_lens.detach().long().to(self.device)
43
+
44
+ # (b, n, q)
45
+ code_idx, _ = self.vq_model.encode(motion)
46
+ m_lens = m_lens // 4
47
+
48
+ conds = conds.to(self.device).float() if torch.is_tensor(conds) else conds
49
+
50
+ # loss_dict = {}
51
+ # self.pred_ids = []
52
+ # self.acc = []
53
+
54
+ _loss, _pred_ids, _acc = self.t2m_transformer(code_idx[..., 0], conds, m_lens)
55
+
56
+ return _loss, _acc
57
+
58
+ def update(self, batch_data):
59
+ loss, acc = self.forward(batch_data)
60
+
61
+ self.opt_t2m_transformer.zero_grad()
62
+ loss.backward()
63
+ self.opt_t2m_transformer.step()
64
+ self.scheduler.step()
65
+
66
+ return loss.item(), acc
67
+
68
+ def save(self, file_name, ep, total_it):
69
+ t2m_trans_state_dict = self.t2m_transformer.state_dict()
70
+ clip_weights = [e for e in t2m_trans_state_dict.keys() if e.startswith('clip_model.')]
71
+ for e in clip_weights:
72
+ del t2m_trans_state_dict[e]
73
+ state = {
74
+ 't2m_transformer': t2m_trans_state_dict,
75
+ 'opt_t2m_transformer': self.opt_t2m_transformer.state_dict(),
76
+ 'scheduler':self.scheduler.state_dict(),
77
+ 'ep': ep,
78
+ 'total_it': total_it,
79
+ }
80
+ torch.save(state, file_name)
81
+
82
+ def resume(self, model_dir):
83
+ checkpoint = torch.load(model_dir, map_location=self.device)
84
+ missing_keys, unexpected_keys = self.t2m_transformer.load_state_dict(checkpoint['t2m_transformer'], strict=False)
85
+ assert len(unexpected_keys) == 0
86
+ assert all([k.startswith('clip_model.') for k in missing_keys])
87
+
88
+ try:
89
+ self.opt_t2m_transformer.load_state_dict(checkpoint['opt_t2m_transformer']) # Optimizer
90
+
91
+ self.scheduler.load_state_dict(checkpoint['scheduler']) # Scheduler
92
+ except:
93
+ print('Resume wo optimizer')
94
+ return checkpoint['ep'], checkpoint['total_it']
95
+
96
+ def train(self, train_loader, val_loader, eval_val_loader, eval_wrapper, plot_eval):
97
+ self.t2m_transformer.to(self.device)
98
+ self.vq_model.to(self.device)
99
+
100
+ self.opt_t2m_transformer = optim.AdamW(self.t2m_transformer.parameters(), betas=(0.9, 0.99), lr=self.opt.lr, weight_decay=1e-5)
101
+ self.scheduler = optim.lr_scheduler.MultiStepLR(self.opt_t2m_transformer,
102
+ milestones=self.opt.milestones,
103
+ gamma=self.opt.gamma)
104
+
105
+ epoch = 0
106
+ it = 0
107
+
108
+ if self.opt.is_continue:
109
+ model_dir = pjoin(self.opt.model_dir, 'latest.tar') # TODO
110
+ epoch, it = self.resume(model_dir)
111
+ print("Load model epoch:%d iterations:%d"%(epoch, it))
112
+
113
+ start_time = time.time()
114
+ total_iters = self.opt.max_epoch * len(train_loader)
115
+ print(f'Total Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}')
116
+ print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(val_loader)))
117
+ logs = defaultdict(def_value, OrderedDict())
118
+
119
+ best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_mask_transformer(
120
+ self.opt.save_root, eval_val_loader, self.t2m_transformer, self.vq_model, self.logger, epoch,
121
+ best_fid=100, best_div=100,
122
+ best_top1=0, best_top2=0, best_top3=0,
123
+ best_matching=100, eval_wrapper=eval_wrapper,
124
+ plot_func=plot_eval, save_ckpt=False, save_anim=False
125
+ )
126
+ best_acc = 0.
127
+
128
+ while epoch < self.opt.max_epoch:
129
+ self.t2m_transformer.train()
130
+ self.vq_model.eval()
131
+
132
+ for i, batch in enumerate(train_loader):
133
+ it += 1
134
+ if it < self.opt.warm_up_iter:
135
+ self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr)
136
+
137
+ loss, acc = self.update(batch_data=batch)
138
+ logs['loss'] += loss
139
+ logs['acc'] += acc
140
+ logs['lr'] += self.opt_t2m_transformer.param_groups[0]['lr']
141
+
142
+ if it % self.opt.log_every == 0:
143
+ mean_loss = OrderedDict()
144
+ # self.logger.add_scalar('val_loss', val_loss, it)
145
+ # self.l
146
+ for tag, value in logs.items():
147
+ self.logger.add_scalar('Train/%s'%tag, value / self.opt.log_every, it)
148
+ mean_loss[tag] = value / self.opt.log_every
149
+ logs = defaultdict(def_value, OrderedDict())
150
+ print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i)
151
+
152
+ if it % self.opt.save_latest == 0:
153
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
154
+
155
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
156
+ epoch += 1
157
+
158
+ print('Validation time:')
159
+ self.vq_model.eval()
160
+ self.t2m_transformer.eval()
161
+
162
+ val_loss = []
163
+ val_acc = []
164
+ with torch.no_grad():
165
+ for i, batch_data in enumerate(val_loader):
166
+ loss, acc = self.forward(batch_data)
167
+ val_loss.append(loss.item())
168
+ val_acc.append(acc)
169
+
170
+ print(f"Validation loss:{np.mean(val_loss):.3f}, accuracy:{np.mean(val_acc):.3f}")
171
+
172
+ self.logger.add_scalar('Val/loss', np.mean(val_loss), epoch)
173
+ self.logger.add_scalar('Val/acc', np.mean(val_acc), epoch)
174
+
175
+ if np.mean(val_acc) > best_acc:
176
+ print(f"Improved accuracy from {best_acc:.02f} to {np.mean(val_acc)}!!!")
177
+ self.save(pjoin(self.opt.model_dir, 'net_best_acc.tar'), epoch, it)
178
+ best_acc = np.mean(val_acc)
179
+
180
+ best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_mask_transformer(
181
+ self.opt.save_root, eval_val_loader, self.t2m_transformer, self.vq_model, self.logger, epoch, best_fid=best_fid,
182
+ best_div=best_div, best_top1=best_top1, best_top2=best_top2, best_top3=best_top3,
183
+ best_matching=best_matching, eval_wrapper=eval_wrapper,
184
+ plot_func=plot_eval, save_ckpt=True, save_anim=(epoch%self.opt.eval_every_e==0)
185
+ )
186
+
187
+
188
+ class ResidualTransformerTrainer:
189
+ def __init__(self, args, res_transformer, vq_model):
190
+ self.opt = args
191
+ self.res_transformer = res_transformer
192
+ self.vq_model = vq_model
193
+ self.device = args.device
194
+ self.vq_model.eval()
195
+
196
+ if args.is_train:
197
+ self.logger = SummaryWriter(args.log_dir)
198
+ # self.l1_criterion = torch.nn.SmoothL1Loss()
199
+
200
+
201
+ def update_lr_warm_up(self, nb_iter, warm_up_iter, lr):
202
+
203
+ current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
204
+ for param_group in self.opt_res_transformer.param_groups:
205
+ param_group["lr"] = current_lr
206
+
207
+ return current_lr
208
+
209
+
210
+ def forward(self, batch_data):
211
+
212
+ conds, motion, m_lens = batch_data
213
+ motion = motion.detach().float().to(self.device)
214
+ m_lens = m_lens.detach().long().to(self.device)
215
+
216
+ # (b, n, q), (q, b, n ,d)
217
+ code_idx, all_codes = self.vq_model.encode(motion)
218
+ m_lens = m_lens // 4
219
+
220
+ conds = conds.to(self.device).float() if torch.is_tensor(conds) else conds
221
+
222
+ ce_loss, pred_ids, acc = self.res_transformer(code_idx, conds, m_lens)
223
+
224
+ return ce_loss, acc
225
+
226
+ def update(self, batch_data):
227
+ loss, acc = self.forward(batch_data)
228
+
229
+ self.opt_res_transformer.zero_grad()
230
+ loss.backward()
231
+ self.opt_res_transformer.step()
232
+ self.scheduler.step()
233
+
234
+ return loss.item(), acc
235
+
236
+ def save(self, file_name, ep, total_it):
237
+ res_trans_state_dict = self.res_transformer.state_dict()
238
+ clip_weights = [e for e in res_trans_state_dict.keys() if e.startswith('clip_model.')]
239
+ for e in clip_weights:
240
+ del res_trans_state_dict[e]
241
+ state = {
242
+ 'res_transformer': res_trans_state_dict,
243
+ 'opt_res_transformer': self.opt_res_transformer.state_dict(),
244
+ 'scheduler':self.scheduler.state_dict(),
245
+ 'ep': ep,
246
+ 'total_it': total_it,
247
+ }
248
+ torch.save(state, file_name)
249
+
250
+ def resume(self, model_dir):
251
+ checkpoint = torch.load(model_dir, map_location=self.device)
252
+ missing_keys, unexpected_keys = self.res_transformer.load_state_dict(checkpoint['res_transformer'], strict=False)
253
+ assert len(unexpected_keys) == 0
254
+ assert all([k.startswith('clip_model.') for k in missing_keys])
255
+
256
+ try:
257
+ self.opt_res_transformer.load_state_dict(checkpoint['opt_res_transformer']) # Optimizer
258
+
259
+ self.scheduler.load_state_dict(checkpoint['scheduler']) # Scheduler
260
+ except:
261
+ print('Resume wo optimizer')
262
+ return checkpoint['ep'], checkpoint['total_it']
263
+
264
+ def train(self, train_loader, val_loader, eval_val_loader, eval_wrapper, plot_eval):
265
+ self.res_transformer.to(self.device)
266
+ self.vq_model.to(self.device)
267
+
268
+ self.opt_res_transformer = optim.AdamW(self.res_transformer.parameters(), betas=(0.9, 0.99), lr=self.opt.lr, weight_decay=1e-5)
269
+ self.scheduler = optim.lr_scheduler.MultiStepLR(self.opt_res_transformer,
270
+ milestones=self.opt.milestones,
271
+ gamma=self.opt.gamma)
272
+
273
+ epoch = 0
274
+ it = 0
275
+
276
+ if self.opt.is_continue:
277
+ model_dir = pjoin(self.opt.model_dir, 'latest.tar') # TODO
278
+ epoch, it = self.resume(model_dir)
279
+ print("Load model epoch:%d iterations:%d"%(epoch, it))
280
+
281
+ start_time = time.time()
282
+ total_iters = self.opt.max_epoch * len(train_loader)
283
+ print(f'Total Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}')
284
+ print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(val_loader)))
285
+ logs = defaultdict(def_value, OrderedDict())
286
+
287
+ best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_res_transformer(
288
+ self.opt.save_root, eval_val_loader, self.res_transformer, self.vq_model, self.logger, epoch,
289
+ best_fid=100, best_div=100,
290
+ best_top1=0, best_top2=0, best_top3=0,
291
+ best_matching=100, eval_wrapper=eval_wrapper,
292
+ plot_func=plot_eval, save_ckpt=False, save_anim=False
293
+ )
294
+ best_loss = 100
295
+ best_acc = 0
296
+
297
+ while epoch < self.opt.max_epoch:
298
+ self.res_transformer.train()
299
+ self.vq_model.eval()
300
+
301
+ for i, batch in enumerate(train_loader):
302
+ it += 1
303
+ if it < self.opt.warm_up_iter:
304
+ self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr)
305
+
306
+ loss, acc = self.update(batch_data=batch)
307
+ logs['loss'] += loss
308
+ logs["acc"] += acc
309
+ logs['lr'] += self.opt_res_transformer.param_groups[0]['lr']
310
+
311
+ if it % self.opt.log_every == 0:
312
+ mean_loss = OrderedDict()
313
+ # self.logger.add_scalar('val_loss', val_loss, it)
314
+ # self.l
315
+ for tag, value in logs.items():
316
+ self.logger.add_scalar('Train/%s'%tag, value / self.opt.log_every, it)
317
+ mean_loss[tag] = value / self.opt.log_every
318
+ logs = defaultdict(def_value, OrderedDict())
319
+ print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i)
320
+
321
+ if it % self.opt.save_latest == 0:
322
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
323
+
324
+ epoch += 1
325
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
326
+
327
+ print('Validation time:')
328
+ self.vq_model.eval()
329
+ self.res_transformer.eval()
330
+
331
+ val_loss = []
332
+ val_acc = []
333
+ with torch.no_grad():
334
+ for i, batch_data in enumerate(val_loader):
335
+ loss, acc = self.forward(batch_data)
336
+ val_loss.append(loss.item())
337
+ val_acc.append(acc)
338
+
339
+ print(f"Validation loss:{np.mean(val_loss):.3f}, Accuracy:{np.mean(val_acc):.3f}")
340
+
341
+ self.logger.add_scalar('Val/loss', np.mean(val_loss), epoch)
342
+ self.logger.add_scalar('Val/acc', np.mean(val_acc), epoch)
343
+
344
+ if np.mean(val_loss) < best_loss:
345
+ print(f"Improved loss from {best_loss:.02f} to {np.mean(val_loss)}!!!")
346
+ self.save(pjoin(self.opt.model_dir, 'net_best_loss.tar'), epoch, it)
347
+ best_loss = np.mean(val_loss)
348
+
349
+ if np.mean(val_acc) > best_acc:
350
+ print(f"Improved acc from {best_acc:.02f} to {np.mean(val_acc)}!!!")
351
+ # self.save(pjoin(self.opt.model_dir, 'net_best_loss.tar'), epoch, it)
352
+ best_acc = np.mean(val_acc)
353
+
354
+ best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_res_transformer(
355
+ self.opt.save_root, eval_val_loader, self.res_transformer, self.vq_model, self.logger, epoch, best_fid=best_fid,
356
+ best_div=best_div, best_top1=best_top1, best_top2=best_top2, best_top3=best_top3,
357
+ best_matching=best_matching, eval_wrapper=eval_wrapper,
358
+ plot_func=plot_eval, save_ckpt=True, save_anim=(epoch%self.opt.eval_every_e==0)
359
+ )
models/t2m_eval_modules.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import time
5
+ import math
6
+ import random
7
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
8
+ # from networks.layers import *
9
+
10
+
11
+ def init_weight(m):
12
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
13
+ nn.init.xavier_normal_(m.weight)
14
+ # m.bias.data.fill_(0.01)
15
+ if m.bias is not None:
16
+ nn.init.constant_(m.bias, 0)
17
+
18
+
19
+ # batch_size, dimension and position
20
+ # output: (batch_size, dim)
21
+ def positional_encoding(batch_size, dim, pos):
22
+ assert batch_size == pos.shape[0]
23
+ positions_enc = np.array([
24
+ [pos[j] / np.power(10000, (i-i%2)/dim) for i in range(dim)]
25
+ for j in range(batch_size)
26
+ ], dtype=np.float32)
27
+ positions_enc[:, 0::2] = np.sin(positions_enc[:, 0::2])
28
+ positions_enc[:, 1::2] = np.cos(positions_enc[:, 1::2])
29
+ return torch.from_numpy(positions_enc).float()
30
+
31
+
32
+ def get_padding_mask(batch_size, seq_len, cap_lens):
33
+ cap_lens = cap_lens.data.tolist()
34
+ mask_2d = torch.ones((batch_size, seq_len, seq_len), dtype=torch.float32)
35
+ for i, cap_len in enumerate(cap_lens):
36
+ mask_2d[i, :, :cap_len] = 0
37
+ return mask_2d.bool(), 1 - mask_2d[:, :, 0].clone()
38
+
39
+
40
+ def top_k_logits(logits, k):
41
+ v, ix = torch.topk(logits, k)
42
+ out = logits.clone()
43
+ out[out < v[:, [-1]]] = -float('Inf')
44
+ return out
45
+
46
+
47
+ class PositionalEncoding(nn.Module):
48
+
49
+ def __init__(self, d_model, max_len=300):
50
+ super(PositionalEncoding, self).__init__()
51
+
52
+ pe = torch.zeros(max_len, d_model)
53
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
54
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
55
+ pe[:, 0::2] = torch.sin(position * div_term)
56
+ pe[:, 1::2] = torch.cos(position * div_term)
57
+ # pe = pe.unsqueeze(0).transpose(0, 1)
58
+ self.register_buffer('pe', pe)
59
+
60
+ def forward(self, pos):
61
+ return self.pe[pos]
62
+
63
+
64
+ class MovementConvEncoder(nn.Module):
65
+ def __init__(self, input_size, hidden_size, output_size):
66
+ super(MovementConvEncoder, self).__init__()
67
+ self.main = nn.Sequential(
68
+ nn.Conv1d(input_size, hidden_size, 4, 2, 1),
69
+ nn.Dropout(0.2, inplace=True),
70
+ nn.LeakyReLU(0.2, inplace=True),
71
+ nn.Conv1d(hidden_size, output_size, 4, 2, 1),
72
+ nn.Dropout(0.2, inplace=True),
73
+ nn.LeakyReLU(0.2, inplace=True),
74
+ )
75
+ self.out_net = nn.Linear(output_size, output_size)
76
+ self.main.apply(init_weight)
77
+ self.out_net.apply(init_weight)
78
+
79
+ def forward(self, inputs):
80
+ inputs = inputs.permute(0, 2, 1)
81
+ outputs = self.main(inputs).permute(0, 2, 1)
82
+ # print(outputs.shape)
83
+ return self.out_net(outputs)
84
+
85
+
86
+ class MovementConvDecoder(nn.Module):
87
+ def __init__(self, input_size, hidden_size, output_size):
88
+ super(MovementConvDecoder, self).__init__()
89
+ self.main = nn.Sequential(
90
+ nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1),
91
+ # nn.Dropout(0.2, inplace=True),
92
+ nn.LeakyReLU(0.2, inplace=True),
93
+ nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1),
94
+ # nn.Dropout(0.2, inplace=True),
95
+ nn.LeakyReLU(0.2, inplace=True),
96
+ )
97
+ self.out_net = nn.Linear(output_size, output_size)
98
+
99
+ self.main.apply(init_weight)
100
+ self.out_net.apply(init_weight)
101
+
102
+ def forward(self, inputs):
103
+ inputs = inputs.permute(0, 2, 1)
104
+ outputs = self.main(inputs).permute(0, 2, 1)
105
+ return self.out_net(outputs)
106
+
107
+ class TextEncoderBiGRUCo(nn.Module):
108
+ def __init__(self, word_size, pos_size, hidden_size, output_size, device):
109
+ super(TextEncoderBiGRUCo, self).__init__()
110
+ self.device = device
111
+
112
+ self.pos_emb = nn.Linear(pos_size, word_size)
113
+ self.input_emb = nn.Linear(word_size, hidden_size)
114
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
115
+ self.output_net = nn.Sequential(
116
+ nn.Linear(hidden_size * 2, hidden_size),
117
+ nn.LayerNorm(hidden_size),
118
+ nn.LeakyReLU(0.2, inplace=True),
119
+ nn.Linear(hidden_size, output_size)
120
+ )
121
+
122
+ self.input_emb.apply(init_weight)
123
+ self.pos_emb.apply(init_weight)
124
+ self.output_net.apply(init_weight)
125
+ # self.linear2.apply(init_weight)
126
+ # self.batch_size = batch_size
127
+ self.hidden_size = hidden_size
128
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
129
+
130
+ # input(batch_size, seq_len, dim)
131
+ def forward(self, word_embs, pos_onehot, cap_lens):
132
+ num_samples = word_embs.shape[0]
133
+
134
+ pos_embs = self.pos_emb(pos_onehot)
135
+ inputs = word_embs + pos_embs
136
+ input_embs = self.input_emb(inputs)
137
+ hidden = self.hidden.repeat(1, num_samples, 1)
138
+
139
+ cap_lens = cap_lens.data.tolist()
140
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
141
+
142
+ gru_seq, gru_last = self.gru(emb, hidden)
143
+
144
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
145
+
146
+ return self.output_net(gru_last)
147
+
148
+
149
+ class MotionEncoderBiGRUCo(nn.Module):
150
+ def __init__(self, input_size, hidden_size, output_size, device):
151
+ super(MotionEncoderBiGRUCo, self).__init__()
152
+ self.device = device
153
+
154
+ self.input_emb = nn.Linear(input_size, hidden_size)
155
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
156
+ self.output_net = nn.Sequential(
157
+ nn.Linear(hidden_size*2, hidden_size),
158
+ nn.LayerNorm(hidden_size),
159
+ nn.LeakyReLU(0.2, inplace=True),
160
+ nn.Linear(hidden_size, output_size)
161
+ )
162
+
163
+ self.input_emb.apply(init_weight)
164
+ self.output_net.apply(init_weight)
165
+ self.hidden_size = hidden_size
166
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
167
+
168
+ # input(batch_size, seq_len, dim)
169
+ def forward(self, inputs, m_lens):
170
+ num_samples = inputs.shape[0]
171
+
172
+ input_embs = self.input_emb(inputs)
173
+ hidden = self.hidden.repeat(1, num_samples, 1)
174
+
175
+ cap_lens = m_lens.data.tolist()
176
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
177
+
178
+ gru_seq, gru_last = self.gru(emb, hidden)
179
+
180
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
181
+
182
+ return self.output_net(gru_last)
models/t2m_eval_wrapper.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.t2m_eval_modules import *
2
+ from utils.word_vectorizer import POS_enumerator
3
+ from os.path import join as pjoin
4
+
5
+ def build_models(opt):
6
+ movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
7
+ text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
8
+ pos_size=opt.dim_pos_ohot,
9
+ hidden_size=opt.dim_text_hidden,
10
+ output_size=opt.dim_coemb_hidden,
11
+ device=opt.device)
12
+
13
+ motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
14
+ hidden_size=opt.dim_motion_hidden,
15
+ output_size=opt.dim_coemb_hidden,
16
+ device=opt.device)
17
+
18
+ checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
19
+ map_location=opt.device)
20
+ movement_enc.load_state_dict(checkpoint['movement_encoder'])
21
+ text_enc.load_state_dict(checkpoint['text_encoder'])
22
+ motion_enc.load_state_dict(checkpoint['motion_encoder'])
23
+ print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
24
+ return text_enc, motion_enc, movement_enc
25
+
26
+
27
+ class EvaluatorModelWrapper(object):
28
+
29
+ def __init__(self, opt):
30
+
31
+ if opt.dataset_name == 't2m':
32
+ opt.dim_pose = 263
33
+ elif opt.dataset_name == 'kit':
34
+ opt.dim_pose = 251
35
+ else:
36
+ raise KeyError('Dataset not Recognized!!!')
37
+
38
+ opt.dim_word = 300
39
+ opt.max_motion_length = 196
40
+ opt.dim_pos_ohot = len(POS_enumerator)
41
+ opt.dim_motion_hidden = 1024
42
+ opt.max_text_len = 20
43
+ opt.dim_text_hidden = 512
44
+ opt.dim_coemb_hidden = 512
45
+
46
+ # print(opt)
47
+
48
+ self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
49
+ self.opt = opt
50
+ self.device = opt.device
51
+
52
+ self.text_encoder.to(opt.device)
53
+ self.motion_encoder.to(opt.device)
54
+ self.movement_encoder.to(opt.device)
55
+
56
+ self.text_encoder.eval()
57
+ self.motion_encoder.eval()
58
+ self.movement_encoder.eval()
59
+
60
+ # Please note that the results does not follow the order of inputs
61
+ def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
62
+ with torch.no_grad():
63
+ word_embs = word_embs.detach().to(self.device).float()
64
+ pos_ohot = pos_ohot.detach().to(self.device).float()
65
+ motions = motions.detach().to(self.device).float()
66
+
67
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
68
+ motions = motions[align_idx]
69
+ m_lens = m_lens[align_idx]
70
+
71
+ '''Movement Encoding'''
72
+ movements = self.movement_encoder(motions[..., :-4]).detach()
73
+ m_lens = m_lens // self.opt.unit_length
74
+ motion_embedding = self.motion_encoder(movements, m_lens)
75
+
76
+ '''Text Encoding'''
77
+ text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
78
+ text_embedding = text_embedding[align_idx]
79
+ return text_embedding, motion_embedding
80
+
81
+ # Please note that the results does not follow the order of inputs
82
+ def get_motion_embeddings(self, motions, m_lens):
83
+ with torch.no_grad():
84
+ motions = motions.detach().to(self.device).float()
85
+
86
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
87
+ motions = motions[align_idx]
88
+ m_lens = m_lens[align_idx]
89
+
90
+ '''Movement Encoding'''
91
+ movements = self.movement_encoder(motions[..., :-4]).detach()
92
+ m_lens = m_lens // self.opt.unit_length
93
+ motion_embedding = self.motion_encoder(movements, m_lens)
94
+ return motion_embedding
95
+
96
+ ## Borrowed form MDM
97
+ # our version
98
+ def build_evaluators(opt):
99
+ movement_enc = MovementConvEncoder(opt['dim_pose']-4, opt['dim_movement_enc_hidden'], opt['dim_movement_latent'])
100
+ text_enc = TextEncoderBiGRUCo(word_size=opt['dim_word'],
101
+ pos_size=opt['dim_pos_ohot'],
102
+ hidden_size=opt['dim_text_hidden'],
103
+ output_size=opt['dim_coemb_hidden'],
104
+ device=opt['device'])
105
+
106
+ motion_enc = MotionEncoderBiGRUCo(input_size=opt['dim_movement_latent'],
107
+ hidden_size=opt['dim_motion_hidden'],
108
+ output_size=opt['dim_coemb_hidden'],
109
+ device=opt['device'])
110
+
111
+ ckpt_dir = opt['dataset_name']
112
+ if opt['dataset_name'] == 'humanml':
113
+ ckpt_dir = 't2m'
114
+
115
+ checkpoint = torch.load(pjoin(opt['checkpoints_dir'], ckpt_dir, 'text_mot_match', 'model', 'finest.tar'),
116
+ map_location=opt['device'])
117
+ movement_enc.load_state_dict(checkpoint['movement_encoder'])
118
+ text_enc.load_state_dict(checkpoint['text_encoder'])
119
+ motion_enc.load_state_dict(checkpoint['motion_encoder'])
120
+ print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
121
+ return text_enc, motion_enc, movement_enc
122
+
123
+ # our wrapper
124
+ class EvaluatorWrapper(object):
125
+
126
+ def __init__(self, dataset_name, device):
127
+ opt = {
128
+ 'dataset_name': dataset_name,
129
+ 'device': device,
130
+ 'dim_word': 300,
131
+ 'max_motion_length': 196,
132
+ 'dim_pos_ohot': len(POS_enumerator),
133
+ 'dim_motion_hidden': 1024,
134
+ 'max_text_len': 20,
135
+ 'dim_text_hidden': 512,
136
+ 'dim_coemb_hidden': 512,
137
+ 'dim_pose': 263 if dataset_name == 'humanml' else 251,
138
+ 'dim_movement_enc_hidden': 512,
139
+ 'dim_movement_latent': 512,
140
+ 'checkpoints_dir': './checkpoints',
141
+ 'unit_length': 4,
142
+ }
143
+
144
+ self.text_encoder, self.motion_encoder, self.movement_encoder = build_evaluators(opt)
145
+ self.opt = opt
146
+ self.device = opt['device']
147
+
148
+ self.text_encoder.to(opt['device'])
149
+ self.motion_encoder.to(opt['device'])
150
+ self.movement_encoder.to(opt['device'])
151
+
152
+ self.text_encoder.eval()
153
+ self.motion_encoder.eval()
154
+ self.movement_encoder.eval()
155
+
156
+ # Please note that the results does not following the order of inputs
157
+ def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
158
+ with torch.no_grad():
159
+ word_embs = word_embs.detach().to(self.device).float()
160
+ pos_ohot = pos_ohot.detach().to(self.device).float()
161
+ motions = motions.detach().to(self.device).float()
162
+
163
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
164
+ motions = motions[align_idx]
165
+ m_lens = m_lens[align_idx]
166
+
167
+ '''Movement Encoding'''
168
+ movements = self.movement_encoder(motions[..., :-4]).detach()
169
+ m_lens = m_lens // self.opt['unit_length']
170
+ motion_embedding = self.motion_encoder(movements, m_lens)
171
+ # print(motions.shape, movements.shape, motion_embedding.shape, m_lens)
172
+
173
+ '''Text Encoding'''
174
+ text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
175
+ text_embedding = text_embedding[align_idx]
176
+ return text_embedding, motion_embedding
177
+
178
+ # Please note that the results does not following the order of inputs
179
+ def get_motion_embeddings(self, motions, m_lens):
180
+ with torch.no_grad():
181
+ motions = motions.detach().to(self.device).float()
182
+
183
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
184
+ motions = motions[align_idx]
185
+ m_lens = m_lens[align_idx]
186
+
187
+ '''Movement Encoding'''
188
+ movements = self.movement_encoder(motions[..., :-4]).detach()
189
+ m_lens = m_lens // self.opt['unit_length']
190
+ motion_embedding = self.motion_encoder(movements, m_lens)
191
+ return motion_embedding
models/vq/__init__.py ADDED
File without changes
models/vq/encdec.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from models.vq.resnet import Resnet1D
3
+
4
+
5
+ class Encoder(nn.Module):
6
+ def __init__(self,
7
+ input_emb_width=3,
8
+ output_emb_width=512,
9
+ down_t=2,
10
+ stride_t=2,
11
+ width=512,
12
+ depth=3,
13
+ dilation_growth_rate=3,
14
+ activation='relu',
15
+ norm=None):
16
+ super().__init__()
17
+
18
+ blocks = []
19
+ filter_t, pad_t = stride_t * 2, stride_t // 2
20
+ blocks.append(nn.Conv1d(input_emb_width, width, 3, 1, 1))
21
+ blocks.append(nn.ReLU())
22
+
23
+ for i in range(down_t):
24
+ input_dim = width
25
+ block = nn.Sequential(
26
+ nn.Conv1d(input_dim, width, filter_t, stride_t, pad_t),
27
+ Resnet1D(width, depth, dilation_growth_rate, activation=activation, norm=norm),
28
+ )
29
+ blocks.append(block)
30
+ blocks.append(nn.Conv1d(width, output_emb_width, 3, 1, 1))
31
+ self.model = nn.Sequential(*blocks)
32
+
33
+ def forward(self, x):
34
+ return self.model(x)
35
+
36
+
37
+ class Decoder(nn.Module):
38
+ def __init__(self,
39
+ input_emb_width=3,
40
+ output_emb_width=512,
41
+ down_t=2,
42
+ stride_t=2,
43
+ width=512,
44
+ depth=3,
45
+ dilation_growth_rate=3,
46
+ activation='relu',
47
+ norm=None):
48
+ super().__init__()
49
+ blocks = []
50
+
51
+ blocks.append(nn.Conv1d(output_emb_width, width, 3, 1, 1))
52
+ blocks.append(nn.ReLU())
53
+ for i in range(down_t):
54
+ out_dim = width
55
+ block = nn.Sequential(
56
+ Resnet1D(width, depth, dilation_growth_rate, reverse_dilation=True, activation=activation, norm=norm),
57
+ nn.Upsample(scale_factor=2, mode='nearest'),
58
+ nn.Conv1d(width, out_dim, 3, 1, 1)
59
+ )
60
+ blocks.append(block)
61
+ blocks.append(nn.Conv1d(width, width, 3, 1, 1))
62
+ blocks.append(nn.ReLU())
63
+ blocks.append(nn.Conv1d(width, input_emb_width, 3, 1, 1))
64
+ self.model = nn.Sequential(*blocks)
65
+
66
+ def forward(self, x):
67
+ x = self.model(x)
68
+ return x.permute(0, 2, 1)
models/vq/model.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch.nn as nn
4
+ from models.vq.encdec import Encoder, Decoder
5
+ from models.vq.residual_vq import ResidualVQ
6
+
7
+ class RVQVAE(nn.Module):
8
+ def __init__(self,
9
+ args,
10
+ input_width=263,
11
+ nb_code=1024,
12
+ code_dim=512,
13
+ output_emb_width=512,
14
+ down_t=3,
15
+ stride_t=2,
16
+ width=512,
17
+ depth=3,
18
+ dilation_growth_rate=3,
19
+ activation='relu',
20
+ norm=None):
21
+
22
+ super().__init__()
23
+ assert output_emb_width == code_dim
24
+ self.code_dim = code_dim
25
+ self.num_code = nb_code
26
+ # self.quant = args.quantizer
27
+ self.encoder = Encoder(input_width, output_emb_width, down_t, stride_t, width, depth,
28
+ dilation_growth_rate, activation=activation, norm=norm)
29
+ self.decoder = Decoder(input_width, output_emb_width, down_t, stride_t, width, depth,
30
+ dilation_growth_rate, activation=activation, norm=norm)
31
+ rvqvae_config = {
32
+ 'num_quantizers': args.num_quantizers,
33
+ 'shared_codebook': args.shared_codebook,
34
+ 'quantize_dropout_prob': args.quantize_dropout_prob,
35
+ 'quantize_dropout_cutoff_index': 0,
36
+ 'nb_code': nb_code,
37
+ 'code_dim':code_dim,
38
+ 'args': args,
39
+ }
40
+ self.quantizer = ResidualVQ(**rvqvae_config)
41
+
42
+ def preprocess(self, x):
43
+ # (bs, T, Jx3) -> (bs, Jx3, T)
44
+ x = x.permute(0, 2, 1).float()
45
+ return x
46
+
47
+ def postprocess(self, x):
48
+ # (bs, Jx3, T) -> (bs, T, Jx3)
49
+ x = x.permute(0, 2, 1)
50
+ return x
51
+
52
+ def encode(self, x):
53
+ N, T, _ = x.shape
54
+ x_in = self.preprocess(x)
55
+ x_encoder = self.encoder(x_in)
56
+ # print(x_encoder.shape)
57
+ code_idx, all_codes = self.quantizer.quantize(x_encoder, return_latent=True)
58
+ # print(code_idx.shape)
59
+ # code_idx = code_idx.view(N, -1)
60
+ # (N, T, Q)
61
+ # print()
62
+ return code_idx, all_codes
63
+
64
+ def forward(self, x):
65
+ x_in = self.preprocess(x)
66
+ # Encode
67
+ x_encoder = self.encoder(x_in)
68
+
69
+ ## quantization
70
+ # x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5,
71
+ # force_dropout_index=0) #TODO hardcode
72
+ x_quantized, code_idx, commit_loss, perplexity = self.quantizer(x_encoder, sample_codebook_temp=0.5)
73
+
74
+ # print(code_idx[0, :, 1])
75
+ ## decoder
76
+ x_out = self.decoder(x_quantized)
77
+ # x_out = self.postprocess(x_decoder)
78
+ return x_out, commit_loss, perplexity
79
+
80
+ def forward_decoder(self, x):
81
+ x_d = self.quantizer.get_codes_from_indices(x)
82
+ # x_d = x_d.view(1, -1, self.code_dim).permute(0, 2, 1).contiguous()
83
+ x = x_d.sum(dim=0).permute(0, 2, 1)
84
+
85
+ # decoder
86
+ x_out = self.decoder(x)
87
+ # x_out = self.postprocess(x_decoder)
88
+ return x_out
89
+
90
+ class LengthEstimator(nn.Module):
91
+ def __init__(self, input_size, output_size):
92
+ super(LengthEstimator, self).__init__()
93
+ nd = 512
94
+ self.output = nn.Sequential(
95
+ nn.Linear(input_size, nd),
96
+ nn.LayerNorm(nd),
97
+ nn.LeakyReLU(0.2, inplace=True),
98
+
99
+ nn.Dropout(0.2),
100
+ nn.Linear(nd, nd // 2),
101
+ nn.LayerNorm(nd // 2),
102
+ nn.LeakyReLU(0.2, inplace=True),
103
+
104
+ nn.Dropout(0.2),
105
+ nn.Linear(nd // 2, nd // 4),
106
+ nn.LayerNorm(nd // 4),
107
+ nn.LeakyReLU(0.2, inplace=True),
108
+
109
+ nn.Linear(nd // 4, output_size)
110
+ )
111
+
112
+ self.output.apply(self.__init_weights)
113
+
114
+ def __init_weights(self, module):
115
+ if isinstance(module, (nn.Linear, nn.Embedding)):
116
+ module.weight.data.normal_(mean=0.0, std=0.02)
117
+ if isinstance(module, nn.Linear) and module.bias is not None:
118
+ module.bias.data.zero_()
119
+ elif isinstance(module, nn.LayerNorm):
120
+ module.bias.data.zero_()
121
+ module.weight.data.fill_(1.0)
122
+
123
+ def forward(self, text_emb):
124
+ return self.output(text_emb)
models/vq/quantizer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat, reduce, pack, unpack
6
+
7
+ # from vector_quantize_pytorch import ResidualVQ
8
+
9
+ #Borrow from vector_quantize_pytorch
10
+
11
+ def log(t, eps = 1e-20):
12
+ return torch.log(t.clamp(min = eps))
13
+
14
+ def gumbel_noise(t):
15
+ noise = torch.zeros_like(t).uniform_(0, 1)
16
+ return -log(-log(noise))
17
+
18
+ def gumbel_sample(
19
+ logits,
20
+ temperature = 1.,
21
+ stochastic = False,
22
+ dim = -1,
23
+ training = True
24
+ ):
25
+
26
+ if training and stochastic and temperature > 0:
27
+ sampling_logits = (logits / temperature) + gumbel_noise(logits)
28
+ else:
29
+ sampling_logits = logits
30
+
31
+ ind = sampling_logits.argmax(dim = dim)
32
+
33
+ return ind
34
+
35
+ class QuantizeEMAReset(nn.Module):
36
+ def __init__(self, nb_code, code_dim, args):
37
+ super(QuantizeEMAReset, self).__init__()
38
+ self.nb_code = nb_code
39
+ self.code_dim = code_dim
40
+ self.mu = args.mu ##TO_DO
41
+ self.reset_codebook()
42
+
43
+ def reset_codebook(self):
44
+ self.init = False
45
+ self.code_sum = None
46
+ self.code_count = None
47
+ self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim, requires_grad=False).cuda())
48
+
49
+ def _tile(self, x):
50
+ nb_code_x, code_dim = x.shape
51
+ if nb_code_x < self.nb_code:
52
+ n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
53
+ std = 0.01 / np.sqrt(code_dim)
54
+ out = x.repeat(n_repeats, 1)
55
+ out = out + torch.randn_like(out) * std
56
+ else:
57
+ out = x
58
+ return out
59
+
60
+ def init_codebook(self, x):
61
+ out = self._tile(x)
62
+ self.codebook = out[:self.nb_code]
63
+ self.code_sum = self.codebook.clone()
64
+ self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
65
+ self.init = True
66
+
67
+ def quantize(self, x, sample_codebook_temp=0.):
68
+ # N X C -> C X N
69
+ k_w = self.codebook.t()
70
+ # x: NT X C
71
+ # NT X N
72
+ distance = torch.sum(x ** 2, dim=-1, keepdim=True) - \
73
+ 2 * torch.matmul(x, k_w) + \
74
+ torch.sum(k_w ** 2, dim=0, keepdim=True) # (N * L, b)
75
+
76
+ # code_idx = torch.argmin(distance, dim=-1)
77
+
78
+ code_idx = gumbel_sample(-distance, dim = -1, temperature = sample_codebook_temp, stochastic=True, training = self.training)
79
+
80
+ return code_idx
81
+
82
+ def dequantize(self, code_idx):
83
+ x = F.embedding(code_idx, self.codebook)
84
+ return x
85
+
86
+ def get_codebook_entry(self, indices):
87
+ return self.dequantize(indices).permute(0, 2, 1)
88
+
89
+ @torch.no_grad()
90
+ def compute_perplexity(self, code_idx):
91
+ # Calculate new centres
92
+ code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device) # nb_code, N * L
93
+ code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)
94
+
95
+ code_count = code_onehot.sum(dim=-1) # nb_code
96
+ prob = code_count / torch.sum(code_count)
97
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
98
+ return perplexity
99
+
100
+ @torch.no_grad()
101
+ def update_codebook(self, x, code_idx):
102
+ code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
103
+ code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
104
+
105
+ code_sum = torch.matmul(code_onehot, x) # nb_code, c
106
+ code_count = code_onehot.sum(dim=-1) # nb_code
107
+
108
+ out = self._tile(x)
109
+ code_rand = out[:self.nb_code]
110
+
111
+ # Update centres
112
+ self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum
113
+ self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count
114
+
115
+ usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
116
+ code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
117
+ self.codebook = usage * code_update + (1-usage) * code_rand
118
+
119
+
120
+ prob = code_count / torch.sum(code_count)
121
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
122
+
123
+ return perplexity
124
+
125
+ def preprocess(self, x):
126
+ # NCT -> NTC -> [NT, C]
127
+ # x = x.permute(0, 2, 1).contiguous()
128
+ # x = x.view(-1, x.shape[-1])
129
+ x = rearrange(x, 'n c t -> (n t) c')
130
+ return x
131
+
132
+ def forward(self, x, return_idx=False, temperature=0.):
133
+ N, width, T = x.shape
134
+
135
+ x = self.preprocess(x)
136
+ if self.training and not self.init:
137
+ self.init_codebook(x)
138
+
139
+ code_idx = self.quantize(x, temperature)
140
+ x_d = self.dequantize(code_idx)
141
+
142
+ if self.training:
143
+ perplexity = self.update_codebook(x, code_idx)
144
+ else:
145
+ perplexity = self.compute_perplexity(code_idx)
146
+
147
+ commit_loss = F.mse_loss(x, x_d.detach()) # It's right. the t2m-gpt paper is wrong on embed loss and commitment loss.
148
+
149
+ # Passthrough
150
+ x_d = x + (x_d - x).detach()
151
+
152
+ # Postprocess
153
+ x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous()
154
+ code_idx = code_idx.view(N, T).contiguous()
155
+ # print(code_idx[0])
156
+ if return_idx:
157
+ return x_d, code_idx, commit_loss, perplexity
158
+ return x_d, commit_loss, perplexity
159
+
160
+ class QuantizeEMA(QuantizeEMAReset):
161
+ @torch.no_grad()
162
+ def update_codebook(self, x, code_idx):
163
+ code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
164
+ code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)
165
+
166
+ code_sum = torch.matmul(code_onehot, x) # nb_code, c
167
+ code_count = code_onehot.sum(dim=-1) # nb_code
168
+
169
+ # Update centres
170
+ self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum
171
+ self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count
172
+
173
+ usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
174
+ code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
175
+ self.codebook = usage * code_update + (1-usage) * self.codebook
176
+
177
+ prob = code_count / torch.sum(code_count)
178
+ perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
179
+
180
+ return perplexity
models/vq/residual_vq.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from math import ceil
3
+ from functools import partial
4
+ from itertools import zip_longest
5
+ from random import randrange
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ # from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
11
+ from models.vq.quantizer import QuantizeEMAReset, QuantizeEMA
12
+
13
+ from einops import rearrange, repeat, pack, unpack
14
+
15
+ # helper functions
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+ def default(val, d):
21
+ return val if exists(val) else d
22
+
23
+ def round_up_multiple(num, mult):
24
+ return ceil(num / mult) * mult
25
+
26
+ # main class
27
+
28
+ class ResidualVQ(nn.Module):
29
+ """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
30
+ def __init__(
31
+ self,
32
+ num_quantizers,
33
+ shared_codebook=False,
34
+ quantize_dropout_prob=0.5,
35
+ quantize_dropout_cutoff_index=0,
36
+ **kwargs
37
+ ):
38
+ super().__init__()
39
+
40
+ self.num_quantizers = num_quantizers
41
+
42
+ # self.layers = nn.ModuleList([VectorQuantize(accept_image_fmap = accept_image_fmap, **kwargs) for _ in range(num_quantizers)])
43
+ if shared_codebook:
44
+ layer = QuantizeEMAReset(**kwargs)
45
+ self.layers = nn.ModuleList([layer for _ in range(num_quantizers)])
46
+ else:
47
+ self.layers = nn.ModuleList([QuantizeEMAReset(**kwargs) for _ in range(num_quantizers)])
48
+ # self.layers = nn.ModuleList([QuantizeEMA(**kwargs) for _ in range(num_quantizers)])
49
+
50
+ # self.quantize_dropout = quantize_dropout and num_quantizers > 1
51
+
52
+ assert quantize_dropout_cutoff_index >= 0 and quantize_dropout_prob >= 0
53
+
54
+ self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
55
+ self.quantize_dropout_prob = quantize_dropout_prob
56
+
57
+
58
+ @property
59
+ def codebooks(self):
60
+ codebooks = [layer.codebook for layer in self.layers]
61
+ codebooks = torch.stack(codebooks, dim = 0)
62
+ return codebooks # 'q c d'
63
+
64
+ def get_codes_from_indices(self, indices): #indices shape 'b n q' # dequantize
65
+
66
+ batch, quantize_dim = indices.shape[0], indices.shape[-1]
67
+
68
+ # because of quantize dropout, one can pass in indices that are coarse
69
+ # and the network should be able to reconstruct
70
+
71
+ if quantize_dim < self.num_quantizers:
72
+ indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
73
+
74
+ # get ready for gathering
75
+
76
+ codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch)
77
+ gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1])
78
+
79
+ # take care of quantizer dropout
80
+
81
+ mask = gather_indices == -1.
82
+ gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later
83
+
84
+ # print(gather_indices.max(), gather_indices.min())
85
+ all_codes = codebooks.gather(2, gather_indices) # gather all codes
86
+
87
+ # mask out any codes that were dropout-ed
88
+
89
+ all_codes = all_codes.masked_fill(mask, 0.)
90
+
91
+ return all_codes # 'q b n d'
92
+
93
+ def get_codebook_entry(self, indices): #indices shape 'b n q'
94
+ all_codes = self.get_codes_from_indices(indices) #'q b n d'
95
+ latent = torch.sum(all_codes, dim=0) #'b n d'
96
+ latent = latent.permute(0, 2, 1)
97
+ return latent
98
+
99
+ def forward(self, x, return_all_codes = False, sample_codebook_temp = None, force_dropout_index=-1):
100
+ # debug check
101
+ # print(self.codebooks[:,0,0].detach().cpu().numpy())
102
+ num_quant, quant_dropout_prob, device = self.num_quantizers, self.quantize_dropout_prob, x.device
103
+
104
+ quantized_out = 0.
105
+ residual = x
106
+
107
+ all_losses = []
108
+ all_indices = []
109
+ all_perplexity = []
110
+
111
+
112
+ should_quantize_dropout = self.training and random.random() < self.quantize_dropout_prob
113
+
114
+ start_drop_quantize_index = num_quant
115
+ # To ensure the first-k layers learn things as much as possible, we randomly dropout the last q - k layers
116
+ if should_quantize_dropout:
117
+ start_drop_quantize_index = randrange(self.quantize_dropout_cutoff_index, num_quant) # keep quant layers <= quantize_dropout_cutoff_index, TODO vary in batch
118
+ null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n'
119
+ null_indices = torch.full(null_indices_shape, -1., device = device, dtype = torch.long)
120
+ # null_loss = 0.
121
+
122
+ if force_dropout_index >= 0:
123
+ should_quantize_dropout = True
124
+ start_drop_quantize_index = force_dropout_index
125
+ null_indices_shape = [x.shape[0], x.shape[-1]] # 'b*n'
126
+ null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
127
+
128
+ # print(force_dropout_index)
129
+ # go through the layers
130
+
131
+ for quantizer_index, layer in enumerate(self.layers):
132
+
133
+ if should_quantize_dropout and quantizer_index > start_drop_quantize_index:
134
+ all_indices.append(null_indices)
135
+ # all_losses.append(null_loss)
136
+ continue
137
+
138
+ # layer_indices = None
139
+ # if return_loss:
140
+ # layer_indices = indices[..., quantizer_index] #gt indices
141
+
142
+ # quantized, *rest = layer(residual, indices = layer_indices, sample_codebook_temp = sample_codebook_temp) #single quantizer TODO
143
+ quantized, *rest = layer(residual, return_idx=True, temperature=sample_codebook_temp) #single quantizer
144
+
145
+ # print(quantized.shape, residual.shape)
146
+ residual -= quantized.detach()
147
+ quantized_out += quantized
148
+
149
+ embed_indices, loss, perplexity = rest
150
+ all_indices.append(embed_indices)
151
+ all_losses.append(loss)
152
+ all_perplexity.append(perplexity)
153
+
154
+
155
+ # stack all losses and indices
156
+ all_indices = torch.stack(all_indices, dim=-1)
157
+ all_losses = sum(all_losses)/len(all_losses)
158
+ all_perplexity = sum(all_perplexity)/len(all_perplexity)
159
+
160
+ ret = (quantized_out, all_indices, all_losses, all_perplexity)
161
+
162
+ if return_all_codes:
163
+ # whether to return all codes from all codebooks across layers
164
+ all_codes = self.get_codes_from_indices(all_indices)
165
+
166
+ # will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
167
+ ret = (*ret, all_codes)
168
+
169
+ return ret
170
+
171
+ def quantize(self, x, return_latent=False):
172
+ all_indices = []
173
+ quantized_out = 0.
174
+ residual = x
175
+ all_codes = []
176
+ for quantizer_index, layer in enumerate(self.layers):
177
+
178
+ quantized, *rest = layer(residual, return_idx=True) #single quantizer
179
+
180
+ residual = residual - quantized.detach()
181
+ quantized_out = quantized_out + quantized
182
+
183
+ embed_indices, loss, perplexity = rest
184
+ all_indices.append(embed_indices)
185
+ # print(quantizer_index, embed_indices[0])
186
+ # print(quantizer_index, quantized[0])
187
+ # break
188
+ all_codes.append(quantized)
189
+
190
+ code_idx = torch.stack(all_indices, dim=-1)
191
+ all_codes = torch.stack(all_codes, dim=0)
192
+ if return_latent:
193
+ return code_idx, all_codes
194
+ return code_idx
models/vq/resnet.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class nonlinearity(nn.Module):
5
+ def __init(self):
6
+ super().__init__()
7
+
8
+ def forward(self, x):
9
+ return x * torch.sigmoid(x)
10
+
11
+
12
+ class ResConv1DBlock(nn.Module):
13
+ def __init__(self, n_in, n_state, dilation=1, activation='silu', norm=None, dropout=0.2):
14
+ super(ResConv1DBlock, self).__init__()
15
+
16
+ padding = dilation
17
+ self.norm = norm
18
+
19
+ if norm == "LN":
20
+ self.norm1 = nn.LayerNorm(n_in)
21
+ self.norm2 = nn.LayerNorm(n_in)
22
+ elif norm == "GN":
23
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
24
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=n_in, eps=1e-6, affine=True)
25
+ elif norm == "BN":
26
+ self.norm1 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
27
+ self.norm2 = nn.BatchNorm1d(num_features=n_in, eps=1e-6, affine=True)
28
+ else:
29
+ self.norm1 = nn.Identity()
30
+ self.norm2 = nn.Identity()
31
+
32
+ if activation == "relu":
33
+ self.activation1 = nn.ReLU()
34
+ self.activation2 = nn.ReLU()
35
+
36
+ elif activation == "silu":
37
+ self.activation1 = nonlinearity()
38
+ self.activation2 = nonlinearity()
39
+
40
+ elif activation == "gelu":
41
+ self.activation1 = nn.GELU()
42
+ self.activation2 = nn.GELU()
43
+
44
+ self.conv1 = nn.Conv1d(n_in, n_state, 3, 1, padding, dilation)
45
+ self.conv2 = nn.Conv1d(n_state, n_in, 1, 1, 0, )
46
+ self.dropout = nn.Dropout(dropout)
47
+
48
+ def forward(self, x):
49
+ x_orig = x
50
+ if self.norm == "LN":
51
+ x = self.norm1(x.transpose(-2, -1))
52
+ x = self.activation1(x.transpose(-2, -1))
53
+ else:
54
+ x = self.norm1(x)
55
+ x = self.activation1(x)
56
+
57
+ x = self.conv1(x)
58
+
59
+ if self.norm == "LN":
60
+ x = self.norm2(x.transpose(-2, -1))
61
+ x = self.activation2(x.transpose(-2, -1))
62
+ else:
63
+ x = self.norm2(x)
64
+ x = self.activation2(x)
65
+
66
+ x = self.conv2(x)
67
+ x = self.dropout(x)
68
+ x = x + x_orig
69
+ return x
70
+
71
+
72
+ class Resnet1D(nn.Module):
73
+ def __init__(self, n_in, n_depth, dilation_growth_rate=1, reverse_dilation=True, activation='relu', norm=None):
74
+ super().__init__()
75
+
76
+ blocks = [ResConv1DBlock(n_in, n_in, dilation=dilation_growth_rate ** depth, activation=activation, norm=norm)
77
+ for depth in range(n_depth)]
78
+ if reverse_dilation:
79
+ blocks = blocks[::-1]
80
+
81
+ self.model = nn.Sequential(*blocks)
82
+
83
+ def forward(self, x):
84
+ return self.model(x)
models/vq/vq_trainer.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torch.nn.utils import clip_grad_norm_
4
+ from torch.utils.tensorboard import SummaryWriter
5
+ from os.path import join as pjoin
6
+ import torch.nn.functional as F
7
+
8
+ import torch.optim as optim
9
+
10
+ import time
11
+ import numpy as np
12
+ from collections import OrderedDict, defaultdict
13
+ from utils.eval_t2m import evaluation_vqvae, evaluation_res_conv
14
+ from utils.utils import print_current_loss
15
+
16
+ import os
17
+ import sys
18
+
19
+ def def_value():
20
+ return 0.0
21
+
22
+
23
+ class RVQTokenizerTrainer:
24
+ def __init__(self, args, vq_model):
25
+ self.opt = args
26
+ self.vq_model = vq_model
27
+ self.device = args.device
28
+
29
+ if args.is_train:
30
+ self.logger = SummaryWriter(args.log_dir)
31
+ if args.recons_loss == 'l1':
32
+ self.l1_criterion = torch.nn.L1Loss()
33
+ elif args.recons_loss == 'l1_smooth':
34
+ self.l1_criterion = torch.nn.SmoothL1Loss()
35
+
36
+ # self.critic = CriticWrapper(self.opt.dataset_name, self.opt.device)
37
+
38
+ def forward(self, batch_data):
39
+ motions = batch_data.detach().to(self.device).float()
40
+ pred_motion, loss_commit, perplexity = self.vq_model(motions)
41
+
42
+ self.motions = motions
43
+ self.pred_motion = pred_motion
44
+
45
+ loss_rec = self.l1_criterion(pred_motion, motions)
46
+ pred_local_pos = pred_motion[..., 4 : (self.opt.joints_num - 1) * 3 + 4]
47
+ local_pos = motions[..., 4 : (self.opt.joints_num - 1) * 3 + 4]
48
+ loss_explicit = self.l1_criterion(pred_local_pos, local_pos)
49
+
50
+ loss = loss_rec + self.opt.loss_vel * loss_explicit + self.opt.commit * loss_commit
51
+
52
+ # return loss, loss_rec, loss_vel, loss_commit, perplexity
53
+ # return loss, loss_rec, loss_percept, loss_commit, perplexity
54
+ return loss, loss_rec, loss_explicit, loss_commit, perplexity
55
+
56
+
57
+ # @staticmethod
58
+ def update_lr_warm_up(self, nb_iter, warm_up_iter, lr):
59
+
60
+ current_lr = lr * (nb_iter + 1) / (warm_up_iter + 1)
61
+ for param_group in self.opt_vq_model.param_groups:
62
+ param_group["lr"] = current_lr
63
+
64
+ return current_lr
65
+
66
+ def save(self, file_name, ep, total_it):
67
+ state = {
68
+ "vq_model": self.vq_model.state_dict(),
69
+ "opt_vq_model": self.opt_vq_model.state_dict(),
70
+ "scheduler": self.scheduler.state_dict(),
71
+ 'ep': ep,
72
+ 'total_it': total_it,
73
+ }
74
+ torch.save(state, file_name)
75
+
76
+ def resume(self, model_dir):
77
+ checkpoint = torch.load(model_dir, map_location=self.device)
78
+ self.vq_model.load_state_dict(checkpoint['vq_model'])
79
+ self.opt_vq_model.load_state_dict(checkpoint['opt_vq_model'])
80
+ self.scheduler.load_state_dict(checkpoint['scheduler'])
81
+ return checkpoint['ep'], checkpoint['total_it']
82
+
83
+ def train(self, train_loader, val_loader, eval_val_loader, eval_wrapper, plot_eval=None):
84
+ self.vq_model.to(self.device)
85
+
86
+ self.opt_vq_model = optim.AdamW(self.vq_model.parameters(), lr=self.opt.lr, betas=(0.9, 0.99), weight_decay=self.opt.weight_decay)
87
+ self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.opt_vq_model, milestones=self.opt.milestones, gamma=self.opt.gamma)
88
+
89
+ epoch = 0
90
+ it = 0
91
+ if self.opt.is_continue:
92
+ model_dir = pjoin(self.opt.model_dir, 'latest.tar')
93
+ epoch, it = self.resume(model_dir)
94
+ print("Load model epoch:%d iterations:%d"%(epoch, it))
95
+
96
+ start_time = time.time()
97
+ total_iters = self.opt.max_epoch * len(train_loader)
98
+ print(f'Total Epochs: {self.opt.max_epoch}, Total Iters: {total_iters}')
99
+ print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_loader), len(eval_val_loader)))
100
+ # val_loss = 0
101
+ # min_val_loss = np.inf
102
+ # min_val_epoch = epoch
103
+ current_lr = self.opt.lr
104
+ logs = defaultdict(def_value, OrderedDict())
105
+
106
+ # sys.exit()
107
+ best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_vqvae(
108
+ self.opt.model_dir, eval_val_loader, self.vq_model, self.logger, epoch, best_fid=1000,
109
+ best_div=100, best_top1=0,
110
+ best_top2=0, best_top3=0, best_matching=100,
111
+ eval_wrapper=eval_wrapper, save=False)
112
+
113
+ while epoch < self.opt.max_epoch:
114
+ self.vq_model.train()
115
+ for i, batch_data in enumerate(train_loader):
116
+ it += 1
117
+ if it < self.opt.warm_up_iter:
118
+ current_lr = self.update_lr_warm_up(it, self.opt.warm_up_iter, self.opt.lr)
119
+ loss, loss_rec, loss_vel, loss_commit, perplexity = self.forward(batch_data)
120
+ self.opt_vq_model.zero_grad()
121
+ loss.backward()
122
+ self.opt_vq_model.step()
123
+
124
+ if it >= self.opt.warm_up_iter:
125
+ self.scheduler.step()
126
+
127
+ logs['loss'] += loss.item()
128
+ logs['loss_rec'] += loss_rec.item()
129
+ # Note it not necessarily velocity, too lazy to change the name now
130
+ logs['loss_vel'] += loss_vel.item()
131
+ logs['loss_commit'] += loss_commit.item()
132
+ logs['perplexity'] += perplexity.item()
133
+ logs['lr'] += self.opt_vq_model.param_groups[0]['lr']
134
+
135
+ if it % self.opt.log_every == 0:
136
+ mean_loss = OrderedDict()
137
+ # self.logger.add_scalar('val_loss', val_loss, it)
138
+ # self.l
139
+ for tag, value in logs.items():
140
+ self.logger.add_scalar('Train/%s'%tag, value / self.opt.log_every, it)
141
+ mean_loss[tag] = value / self.opt.log_every
142
+ logs = defaultdict(def_value, OrderedDict())
143
+ print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i)
144
+
145
+ if it % self.opt.save_latest == 0:
146
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
147
+
148
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
149
+
150
+ epoch += 1
151
+ # if epoch % self.opt.save_every_e == 0:
152
+ # self.save(pjoin(self.opt.model_dir, 'E%04d.tar' % (epoch)), epoch, total_it=it)
153
+
154
+ print('Validation time:')
155
+ self.vq_model.eval()
156
+ val_loss_rec = []
157
+ val_loss_vel = []
158
+ val_loss_commit = []
159
+ val_loss = []
160
+ val_perpexity = []
161
+ with torch.no_grad():
162
+ for i, batch_data in enumerate(val_loader):
163
+ loss, loss_rec, loss_vel, loss_commit, perplexity = self.forward(batch_data)
164
+ # val_loss_rec += self.l1_criterion(self.recon_motions, self.motions).item()
165
+ # val_loss_emb += self.embedding_loss.item()
166
+ val_loss.append(loss.item())
167
+ val_loss_rec.append(loss_rec.item())
168
+ val_loss_vel.append(loss_vel.item())
169
+ val_loss_commit.append(loss_commit.item())
170
+ val_perpexity.append(perplexity.item())
171
+
172
+ # val_loss = val_loss_rec / (len(val_dataloader) + 1)
173
+ # val_loss = val_loss / (len(val_dataloader) + 1)
174
+ # val_loss_rec = val_loss_rec / (len(val_dataloader) + 1)
175
+ # val_loss_emb = val_loss_emb / (len(val_dataloader) + 1)
176
+ self.logger.add_scalar('Val/loss', sum(val_loss) / len(val_loss), epoch)
177
+ self.logger.add_scalar('Val/loss_rec', sum(val_loss_rec) / len(val_loss_rec), epoch)
178
+ self.logger.add_scalar('Val/loss_vel', sum(val_loss_vel) / len(val_loss_vel), epoch)
179
+ self.logger.add_scalar('Val/loss_commit', sum(val_loss_commit) / len(val_loss), epoch)
180
+ self.logger.add_scalar('Val/loss_perplexity', sum(val_perpexity) / len(val_loss_rec), epoch)
181
+
182
+ print('Validation Loss: %.5f Reconstruction: %.5f, Velocity: %.5f, Commit: %.5f' %
183
+ (sum(val_loss)/len(val_loss), sum(val_loss_rec)/len(val_loss),
184
+ sum(val_loss_vel)/len(val_loss), sum(val_loss_commit)/len(val_loss)))
185
+
186
+ # if sum(val_loss) / len(val_loss) < min_val_loss:
187
+ # min_val_loss = sum(val_loss) / len(val_loss)
188
+ # # if sum(val_loss_vel) / len(val_loss_vel) < min_val_loss:
189
+ # # min_val_loss = sum(val_loss_vel) / len(val_loss_vel)
190
+ # min_val_epoch = epoch
191
+ # self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it)
192
+ # print('Best Validation Model So Far!~')
193
+
194
+ best_fid, best_div, best_top1, best_top2, best_top3, best_matching, writer = evaluation_vqvae(
195
+ self.opt.model_dir, eval_val_loader, self.vq_model, self.logger, epoch, best_fid=best_fid,
196
+ best_div=best_div, best_top1=best_top1,
197
+ best_top2=best_top2, best_top3=best_top3, best_matching=best_matching, eval_wrapper=eval_wrapper)
198
+
199
+
200
+ if epoch % self.opt.eval_every_e == 0:
201
+ data = torch.cat([self.motions[:4], self.pred_motion[:4]], dim=0).detach().cpu().numpy()
202
+ # np.save(pjoin(self.opt.eval_dir, 'E%04d.npy' % (epoch)), data)
203
+ save_dir = pjoin(self.opt.eval_dir, 'E%04d' % (epoch))
204
+ os.makedirs(save_dir, exist_ok=True)
205
+ plot_eval(data, save_dir)
206
+ # if plot_eval is not None:
207
+ # save_dir = pjoin(self.opt.eval_dir, 'E%04d' % (epoch))
208
+ # os.makedirs(save_dir, exist_ok=True)
209
+ # plot_eval(data, save_dir)
210
+
211
+ # if epoch - min_val_epoch >= self.opt.early_stop_e:
212
+ # print('Early Stopping!~')
213
+
214
+
215
+ class LengthEstTrainer(object):
216
+
217
+ def __init__(self, args, estimator, text_encoder, encode_fnc):
218
+ self.opt = args
219
+ self.estimator = estimator
220
+ self.text_encoder = text_encoder
221
+ self.encode_fnc = encode_fnc
222
+ self.device = args.device
223
+
224
+ if args.is_train:
225
+ # self.motion_dis
226
+ self.logger = SummaryWriter(args.log_dir)
227
+ self.mul_cls_criterion = torch.nn.CrossEntropyLoss()
228
+
229
+ def resume(self, model_dir):
230
+ checkpoints = torch.load(model_dir, map_location=self.device)
231
+ self.estimator.load_state_dict(checkpoints['estimator'])
232
+ # self.opt_estimator.load_state_dict(checkpoints['opt_estimator'])
233
+ return checkpoints['epoch'], checkpoints['iter']
234
+
235
+ def save(self, model_dir, epoch, niter):
236
+ state = {
237
+ 'estimator': self.estimator.state_dict(),
238
+ # 'opt_estimator': self.opt_estimator.state_dict(),
239
+ 'epoch': epoch,
240
+ 'niter': niter,
241
+ }
242
+ torch.save(state, model_dir)
243
+
244
+ @staticmethod
245
+ def zero_grad(opt_list):
246
+ for opt in opt_list:
247
+ opt.zero_grad()
248
+
249
+ @staticmethod
250
+ def clip_norm(network_list):
251
+ for network in network_list:
252
+ clip_grad_norm_(network.parameters(), 0.5)
253
+
254
+ @staticmethod
255
+ def step(opt_list):
256
+ for opt in opt_list:
257
+ opt.step()
258
+
259
+ def train(self, train_dataloader, val_dataloader):
260
+ self.estimator.to(self.device)
261
+ self.text_encoder.to(self.device)
262
+
263
+ self.opt_estimator = optim.Adam(self.estimator.parameters(), lr=self.opt.lr)
264
+
265
+ epoch = 0
266
+ it = 0
267
+
268
+ if self.opt.is_continue:
269
+ model_dir = pjoin(self.opt.model_dir, 'latest.tar')
270
+ epoch, it = self.resume(model_dir)
271
+
272
+ start_time = time.time()
273
+ total_iters = self.opt.max_epoch * len(train_dataloader)
274
+ print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_dataloader), len(val_dataloader)))
275
+ val_loss = 0
276
+ min_val_loss = np.inf
277
+ logs = defaultdict(float)
278
+ while epoch < self.opt.max_epoch:
279
+ # time0 = time.time()
280
+ for i, batch_data in enumerate(train_dataloader):
281
+ self.estimator.train()
282
+
283
+ conds, _, m_lens = batch_data
284
+ # word_emb = word_emb.detach().to(self.device).float()
285
+ # pos_ohot = pos_ohot.detach().to(self.device).float()
286
+ # m_lens = m_lens.to(self.device).long()
287
+ text_embs = self.encode_fnc(self.text_encoder, conds, self.opt.device).detach()
288
+ # print(text_embs.shape, text_embs.device)
289
+
290
+ pred_dis = self.estimator(text_embs)
291
+
292
+ self.zero_grad([self.opt_estimator])
293
+
294
+ gt_labels = m_lens // self.opt.unit_length
295
+ gt_labels = gt_labels.long().to(self.device)
296
+ # print(gt_labels.shape, pred_dis.shape)
297
+ # print(gt_labels.max(), gt_labels.min())
298
+ # print(pred_dis)
299
+ acc = (gt_labels == pred_dis.argmax(dim=-1)).sum() / len(gt_labels)
300
+ loss = self.mul_cls_criterion(pred_dis, gt_labels)
301
+
302
+ loss.backward()
303
+
304
+ self.clip_norm([self.estimator])
305
+ self.step([self.opt_estimator])
306
+
307
+ logs['loss'] += loss.item()
308
+ logs['acc'] += acc.item()
309
+
310
+ it += 1
311
+ if it % self.opt.log_every == 0:
312
+ mean_loss = OrderedDict({'val_loss': val_loss})
313
+ # self.logger.add_scalar('Val/loss', val_loss, it)
314
+
315
+ for tag, value in logs.items():
316
+ self.logger.add_scalar("Train/%s"%tag, value / self.opt.log_every, it)
317
+ mean_loss[tag] = value / self.opt.log_every
318
+ logs = defaultdict(float)
319
+ print_current_loss(start_time, it, total_iters, mean_loss, epoch=epoch, inner_iter=i)
320
+
321
+ if it % self.opt.save_latest == 0:
322
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
323
+
324
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
325
+
326
+ epoch += 1
327
+
328
+ print('Validation time:')
329
+
330
+ val_loss = 0
331
+ val_acc = 0
332
+ # self.estimator.eval()
333
+ with torch.no_grad():
334
+ for i, batch_data in enumerate(val_dataloader):
335
+ self.estimator.eval()
336
+
337
+ conds, _, m_lens = batch_data
338
+ # word_emb = word_emb.detach().to(self.device).float()
339
+ # pos_ohot = pos_ohot.detach().to(self.device).float()
340
+ # m_lens = m_lens.to(self.device).long()
341
+ text_embs = self.encode_fnc(self.text_encoder, conds, self.opt.device)
342
+ pred_dis = self.estimator(text_embs)
343
+
344
+ gt_labels = m_lens // self.opt.unit_length
345
+ gt_labels = gt_labels.long().to(self.device)
346
+ loss = self.mul_cls_criterion(pred_dis, gt_labels)
347
+ acc = (gt_labels == pred_dis.argmax(dim=-1)).sum() / len(gt_labels)
348
+
349
+ val_loss += loss.item()
350
+ val_acc += acc.item()
351
+
352
+
353
+ val_loss = val_loss / len(val_dataloader)
354
+ val_acc = val_acc / len(val_dataloader)
355
+ print('Validation Loss: %.5f Validation Acc: %.5f' % (val_loss, val_acc))
356
+
357
+ if val_loss < min_val_loss:
358
+ self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it)
359
+ min_val_loss = val_loss
motion_loaders/__init__.py ADDED
File without changes
motion_loaders/dataset_motion_loader.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data.t2m_dataset import Text2MotionDatasetEval, collate_fn # TODO
2
+ from utils.word_vectorizer import WordVectorizer
3
+ import numpy as np
4
+ from os.path import join as pjoin
5
+ from torch.utils.data import DataLoader
6
+ from utils.get_opt import get_opt
7
+
8
+ def get_dataset_motion_loader(opt_path, batch_size, fname, device):
9
+ opt = get_opt(opt_path, device)
10
+
11
+ # Configurations of T2M dataset and KIT dataset is almost the same
12
+ if opt.dataset_name == 't2m' or opt.dataset_name == 'kit':
13
+ print('Loading dataset %s ...' % opt.dataset_name)
14
+
15
+ mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
16
+ std = np.load(pjoin(opt.meta_dir, 'std.npy'))
17
+
18
+ w_vectorizer = WordVectorizer('./glove', 'our_vab')
19
+ split_file = pjoin(opt.data_root, '%s.txt'%fname)
20
+ dataset = Text2MotionDatasetEval(opt, mean, std, split_file, w_vectorizer)
21
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, drop_last=True,
22
+ collate_fn=collate_fn, shuffle=True)
23
+ else:
24
+ raise KeyError('Dataset not Recognized !!')
25
+
26
+ print('Ground Truth Dataset Loading Completed!!!')
27
+ return dataloader, dataset
options/__init__.py ADDED
File without changes
options/base_option.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+
5
+ class BaseOptions():
6
+ def __init__(self):
7
+ self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
8
+ self.initialized = False
9
+
10
+ def initialize(self):
11
+ self.parser.add_argument('--name', type=str, default="t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns", help='Name of this trial')
12
+
13
+ self.parser.add_argument('--vq_name', type=str, default="rvq_nq1_dc512_nc512", help='Name of the rvq model.')
14
+
15
+ self.parser.add_argument("--gpu_id", type=int, default=-1, help='GPU id')
16
+ self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name, {t2m} for humanml3d, {kit} for kit-ml')
17
+ self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here.')
18
+
19
+ self.parser.add_argument('--latent_dim', type=int, default=384, help='Dimension of transformer latent.')
20
+ self.parser.add_argument('--n_heads', type=int, default=6, help='Number of heads.')
21
+ self.parser.add_argument('--n_layers', type=int, default=8, help='Number of attention layers.')
22
+ self.parser.add_argument('--ff_size', type=int, default=1024, help='FF_Size')
23
+ self.parser.add_argument('--dropout', type=float, default=0.2, help='Dropout ratio in transformer')
24
+
25
+ self.parser.add_argument("--max_motion_length", type=int, default=196, help="Max length of motion")
26
+ self.parser.add_argument("--unit_length", type=int, default=4, help="Downscale ratio of VQ")
27
+
28
+ self.parser.add_argument('--force_mask', action="store_true", help='True: mask out conditions')
29
+
30
+ self.initialized = True
31
+
32
+ def parse(self):
33
+ if not self.initialized:
34
+ self.initialize()
35
+
36
+ self.opt = self.parser.parse_args()
37
+
38
+ self.opt.is_train = self.is_train
39
+
40
+ if self.opt.gpu_id != -1:
41
+ # self.opt.gpu_id = int(self.opt.gpu_id)
42
+ torch.cuda.set_device(self.opt.gpu_id)
43
+
44
+ args = vars(self.opt)
45
+
46
+ print('------------ Options -------------')
47
+ for k, v in sorted(args.items()):
48
+ print('%s: %s' % (str(k), str(v)))
49
+ print('-------------- End ----------------')
50
+ if self.is_train:
51
+ # save to the disk
52
+ expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.dataset_name, self.opt.name)
53
+ if not os.path.exists(expr_dir):
54
+ os.makedirs(expr_dir)
55
+ file_name = os.path.join(expr_dir, 'opt.txt')
56
+ with open(file_name, 'wt') as opt_file:
57
+ opt_file.write('------------ Options -------------\n')
58
+ for k, v in sorted(args.items()):
59
+ opt_file.write('%s: %s\n' % (str(k), str(v)))
60
+ opt_file.write('-------------- End ----------------\n')
61
+ return self.opt
options/eval_option.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from options.base_option import BaseOptions
2
+
3
+ class EvalT2MOptions(BaseOptions):
4
+ def initialize(self):
5
+ BaseOptions.initialize(self)
6
+ self.parser.add_argument('--which_epoch', type=str, default="latest", help='Checkpoint you want to use, {latest, net_best_fid, etc}')
7
+ self.parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
8
+
9
+ self.parser.add_argument('--ext', type=str, default='text2motion', help='Extension of the result file or folder')
10
+ self.parser.add_argument("--num_batch", default=2, type=int,
11
+ help="Number of batch for generation")
12
+ self.parser.add_argument("--repeat_times", default=1, type=int,
13
+ help="Number of repetitions, per sample text prompt")
14
+ self.parser.add_argument("--cond_scale", default=4, type=float,
15
+ help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")
16
+ self.parser.add_argument("--temperature", default=1., type=float,
17
+ help="Sampling Temperature.")
18
+ self.parser.add_argument("--topkr", default=0.9, type=float,
19
+ help="Filter out percentil low prop entries.")
20
+ self.parser.add_argument("--time_steps", default=18, type=int,
21
+ help="Mask Generate steps.")
22
+ self.parser.add_argument("--seed", default=10107, type=int)
23
+
24
+ self.parser.add_argument('--gumbel_sample', action="store_true", help='True: gumbel sampling, False: categorical sampling.')
25
+ self.parser.add_argument('--use_res_model', action="store_true", help='Whether to use residual transformer.')
26
+ # self.parser.add_argument('--est_length', action="store_true", help='Training iterations')
27
+
28
+ self.parser.add_argument('--res_name', type=str, default='tres_nlayer8_ld384_ff1024_rvq6ns_cdp0.2_sw', help='Model name of residual transformer')
29
+ self.parser.add_argument('--text_path', type=str, default="", help='Text prompt file')
30
+
31
+
32
+ self.parser.add_argument('-msec', '--mask_edit_section', nargs='*', type=str, help='Indicate sections for editing, use comma to separate the start and end of a section'
33
+ 'type int will specify the token frame, type float will specify the ratio of seq_len')
34
+ self.parser.add_argument('--text_prompt', default='', type=str, help="A text prompt to be generated. If empty, will take text prompts from dataset.")
35
+ self.parser.add_argument('--source_motion', default='example_data/000612.npy', type=str, help="Source motion path for editing. (new_joint_vecs format .npy file)")
36
+ self.parser.add_argument("--motion_length", default=0, type=int,
37
+ help="Motion length for generation, only applicable with single text prompt.")
38
+ self.is_train = False
options/train_option.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from options.base_option import BaseOptions
2
+ import argparse
3
+
4
+ class TrainT2MOptions(BaseOptions):
5
+ def initialize(self):
6
+ BaseOptions.initialize(self)
7
+ self.parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
8
+ self.parser.add_argument('--max_epoch', type=int, default=500, help='Maximum number of epoch for training')
9
+ # self.parser.add_argument('--max_iters', type=int, default=150_000, help='Training iterations')
10
+
11
+ '''LR scheduler'''
12
+ self.parser.add_argument('--lr', type=float, default=2e-4, help='Learning rate')
13
+ self.parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate schedule factor')
14
+ self.parser.add_argument('--milestones', default=[50_000], nargs="+", type=int,
15
+ help="learning rate schedule (iterations)")
16
+ self.parser.add_argument('--warm_up_iter', default=2000, type=int, help='number of total iterations for warmup')
17
+
18
+ '''Condition'''
19
+ self.parser.add_argument('--cond_drop_prob', type=float, default=0.1, help='Drop ratio of condition, for classifier-free guidance')
20
+ self.parser.add_argument("--seed", default=3407, type=int, help="Seed")
21
+
22
+ self.parser.add_argument('--is_continue', action="store_true", help='Is this trial continuing previous state?')
23
+ self.parser.add_argument('--gumbel_sample', action="store_true", help='Strategy for token sampling, True: Gumbel sampling, False: Categorical sampling')
24
+ self.parser.add_argument('--share_weight', action="store_true", help='Whether to share weight for projection/embedding, for residual transformer.')
25
+
26
+ self.parser.add_argument('--log_every', type=int, default=50, help='Frequency of printing training progress, (iteration)')
27
+ # self.parser.add_argument('--save_every_e', type=int, default=100, help='Frequency of printing training progress')
28
+ self.parser.add_argument('--eval_every_e', type=int, default=10, help='Frequency of animating eval results, (epoch)')
29
+ self.parser.add_argument('--save_latest', type=int, default=500, help='Frequency of saving checkpoint, (iteration)')
30
+
31
+
32
+ self.is_train = True
33
+
34
+
35
+ class TrainLenEstOptions():
36
+ def __init__(self):
37
+ self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
38
+ self.parser.add_argument('--name', type=str, default="test", help='Name of this trial')
39
+ self.parser.add_argument("--gpu_id", type=int, default=-1, help='GPU id')
40
+
41
+ self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name')
42
+ self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
43
+
44
+ self.parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
45
+
46
+ self.parser.add_argument("--unit_length", type=int, default=4, help="Length of motion")
47
+ self.parser.add_argument("--max_text_len", type=int, default=20, help="Length of motion")
48
+
49
+ self.parser.add_argument('--max_epoch', type=int, default=300, help='Training iterations')
50
+
51
+ self.parser.add_argument('--lr', type=float, default=1e-4, help='Layers of GRU')
52
+
53
+ self.parser.add_argument('--is_continue', action="store_true", help='Training iterations')
54
+
55
+ self.parser.add_argument('--log_every', type=int, default=50, help='Frequency of printing training progress')
56
+ self.parser.add_argument('--save_every_e', type=int, default=5, help='Frequency of printing training progress')
57
+ self.parser.add_argument('--eval_every_e', type=int, default=3, help='Frequency of printing training progress')
58
+ self.parser.add_argument('--save_latest', type=int, default=500, help='Frequency of printing training progress')
59
+
60
+ def parse(self):
61
+ self.opt = self.parser.parse_args()
62
+ self.opt.is_train = True
63
+ # args = vars(self.opt)
64
+ return self.opt
options/vq_option.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+
5
+ def arg_parse(is_train=False):
6
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
7
+
8
+ ## dataloader
9
+ parser.add_argument('--dataset_name', type=str, default='humanml3d', help='dataset directory')
10
+ parser.add_argument('--batch_size', default=256, type=int, help='batch size')
11
+ parser.add_argument('--window_size', type=int, default=64, help='training motion length')
12
+ parser.add_argument("--gpu_id", type=int, default=0, help='GPU id')
13
+
14
+ ## optimization
15
+ parser.add_argument('--max_epoch', default=50, type=int, help='number of total epochs to run')
16
+ # parser.add_argument('--total_iter', default=None, type=int, help='number of total iterations to run')
17
+ parser.add_argument('--warm_up_iter', default=2000, type=int, help='number of total iterations for warmup')
18
+ parser.add_argument('--lr', default=2e-4, type=float, help='max learning rate')
19
+ parser.add_argument('--milestones', default=[150000, 250000], nargs="+", type=int, help="learning rate schedule (iterations)")
20
+ parser.add_argument('--gamma', default=0.1, type=float, help="learning rate decay")
21
+
22
+ parser.add_argument('--weight_decay', default=0.0, type=float, help='weight decay')
23
+ parser.add_argument("--commit", type=float, default=0.02, help="hyper-parameter for the commitment loss")
24
+ parser.add_argument('--loss_vel', type=float, default=0.5, help='hyper-parameter for the velocity loss')
25
+ parser.add_argument('--recons_loss', type=str, default='l1_smooth', help='reconstruction loss')
26
+
27
+ ## vqvae arch
28
+ parser.add_argument("--code_dim", type=int, default=512, help="embedding dimension")
29
+ parser.add_argument("--nb_code", type=int, default=512, help="nb of embedding")
30
+ parser.add_argument("--mu", type=float, default=0.99, help="exponential moving average to update the codebook")
31
+ parser.add_argument("--down_t", type=int, default=2, help="downsampling rate")
32
+ parser.add_argument("--stride_t", type=int, default=2, help="stride size")
33
+ parser.add_argument("--width", type=int, default=512, help="width of the network")
34
+ parser.add_argument("--depth", type=int, default=3, help="num of resblocks for each res")
35
+ parser.add_argument("--dilation_growth_rate", type=int, default=3, help="dilation growth rate")
36
+ parser.add_argument("--output_emb_width", type=int, default=512, help="output embedding width")
37
+ parser.add_argument('--vq_act', type=str, default='relu', choices=['relu', 'silu', 'gelu'],
38
+ help='dataset directory')
39
+ parser.add_argument('--vq_norm', type=str, default=None, help='dataset directory')
40
+
41
+ parser.add_argument('--num_quantizers', type=int, default=3, help='num_quantizers')
42
+ parser.add_argument('--shared_codebook', action="store_true")
43
+ parser.add_argument('--quantize_dropout_prob', type=float, default=0.2, help='quantize_dropout_prob')
44
+ # parser.add_argument('--use_vq_prob', type=float, default=0.8, help='quantize_dropout_prob')
45
+
46
+ parser.add_argument('--ext', type=str, default='default', help='reconstruction loss')
47
+
48
+
49
+ ## other
50
+ parser.add_argument('--name', type=str, default="test", help='Name of this trial')
51
+ parser.add_argument('--is_continue', action="store_true", help='Name of this trial')
52
+ parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
53
+ parser.add_argument('--log_every', default=10, type=int, help='iter log frequency')
54
+ parser.add_argument('--save_latest', default=500, type=int, help='iter save latest model frequency')
55
+ parser.add_argument('--save_every_e', default=2, type=int, help='save model every n epoch')
56
+ parser.add_argument('--eval_every_e', default=1, type=int, help='save eval results every n epoch')
57
+ # parser.add_argument('--early_stop_e', default=5, type=int, help='early stopping epoch')
58
+ parser.add_argument('--feat_bias', type=float, default=5, help='Layers of GRU')
59
+
60
+ parser.add_argument('--which_epoch', type=str, default="all", help='Name of this trial')
61
+
62
+ ## For Res Predictor only
63
+ parser.add_argument('--vq_name', type=str, default="rvq_nq6_dc512_nc512_noshare_qdp0.2", help='Name of this trial')
64
+ parser.add_argument('--n_res', type=int, default=2, help='Name of this trial')
65
+ parser.add_argument('--do_vq_res', action="store_true")
66
+ parser.add_argument("--seed", default=3407, type=int)
67
+
68
+ opt = parser.parse_args()
69
+ torch.cuda.set_device(opt.gpu_id)
70
+
71
+ args = vars(opt)
72
+
73
+ print('------------ Options -------------')
74
+ for k, v in sorted(args.items()):
75
+ print('%s: %s' % (str(k), str(v)))
76
+ print('-------------- End ----------------')
77
+ opt.is_train = is_train
78
+ if is_train:
79
+ # save to the disk
80
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.dataset_name, opt.name)
81
+ if not os.path.exists(expr_dir):
82
+ os.makedirs(expr_dir)
83
+ file_name = os.path.join(expr_dir, 'opt.txt')
84
+ with open(file_name, 'wt') as opt_file:
85
+ opt_file.write('------------ Options -------------\n')
86
+ for k, v in sorted(args.items()):
87
+ opt_file.write('%s: %s\n' % (str(k), str(v)))
88
+ opt_file.write('-------------- End ----------------\n')
89
+ return opt
prepare/.DS_Store ADDED
Binary file (6.15 kB). View file
 
prepare/download_evaluator.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd checkpoints
2
+
3
+ cd t2m
4
+ echo -e "Downloading evaluation models for HumanML3D dataset"
5
+ gdown --fuzzy https://drive.google.com/file/d/1oLhSH7zTlYkQdUWPv3-v4opigB7pXkFk/view?usp=sharing
6
+ echo -e "Unzipping humanml3d_evaluator.zip"
7
+ unzip humanml3d_evaluator.zip
8
+
9
+ echo -e "Clearning humanml3d_evaluator.zip"
10
+ rm humanml3d_evaluator.zip
11
+
12
+ cd ../kit/
13
+ echo -e "Downloading pretrained models for KIT-ML dataset"
14
+ gdown --fuzzy https://drive.google.com/file/d/115n1ijntyKDDIZZEuA_aBgffyplNE5az/view?usp=sharing
15
+
16
+ echo -e "Unzipping kit_evaluator.zip"
17
+ unzip kit_evaluator.zip
18
+
19
+ echo -e "Clearning kit_evaluator.zip"
20
+ rm kit_evaluator.zip
21
+
22
+ cd ../../
23
+
24
+ echo -e "Downloading done!"
prepare/download_glove.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ echo -e "Downloading glove (in use by the evaluators, not by MoMask itself)"
2
+ gdown --fuzzy https://drive.google.com/file/d/1cmXKUT31pqd7_XpJAiWEo1K81TMYHA5n/view?usp=sharing
3
+ rm -rf glove
4
+
5
+ unzip glove.zip
6
+ echo -e "Cleaning\n"
7
+ rm glove.zip
8
+
9
+ echo -e "Downloading done!"
prepare/download_models.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ rm -rf checkpoints
2
+ mkdir checkpoints
3
+ cd checkpoints
4
+ mkdir t2m
5
+
6
+ cd t2m
7
+ echo -e "Downloading pretrained models for HumanML3D dataset"
8
+ gdown --fuzzy https://drive.google.com/file/d/1dtKP2xBk-UjG9o16MVfBJDmGNSI56Dch/view?usp=sharing
9
+
10
+ echo -e "Unzipping humanml3d_models.zip"
11
+ unzip humanml3d_models.zip
12
+
13
+ echo -e "Cleaning humanml3d_models.zip"
14
+ rm humanml3d_models.zip
15
+
16
+ cd ../
17
+ mkdir kit
18
+ cd kit
19
+
20
+ echo -e "Downloading pretrained models for KIT-ML dataset"
21
+ gdown --fuzzy https://drive.google.com/file/d/1MNMdUdn5QoO8UW1iwTcZ0QNaLSH4A6G9/view?usp=sharing
22
+
23
+ echo -e "Unzipping kit_models.zip"
24
+ unzip kit_models.zip
25
+
26
+ echo -e "Cleaning kit_models.zip"
27
+ rm kit_models.zip
28
+
29
+ cd ../../
30
+
31
+ echo -e "Downloading done!"
prepare/download_models_demo.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ rm -rf checkpoints
2
+ mkdir checkpoints
3
+ cd checkpoints
4
+ mkdir t2m
5
+ cd t2m
6
+ echo -e "Downloading pretrained models for HumanML3D dataset"
7
+ gdown --fuzzy https://drive.google.com/file/d/1dtKP2xBk-UjG9o16MVfBJDmGNSI56Dch/view?usp=sharing
8
+ unzip humanml3d_models.zip
9
+ rm humanml3d_models.zip
10
+ cd ../../
requirements.txt ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py @ file:///home/conda/feedstock_root/build_artifacts/absl-py_1673535674859/work
2
+ aiofiles==23.2.1
3
+ aiohttp @ file:///croot/aiohttp_1670009560265/work
4
+ aiosignal @ file:///tmp/build/80754af9/aiosignal_1637843061372/work
5
+ altair==5.0.1
6
+ anyio==3.7.1
7
+ async-timeout @ file:///opt/conda/conda-bld/async-timeout_1664876359750/work
8
+ asynctest==0.13.0
9
+ attrs @ file:///croot/attrs_1668696182826/work
10
+ beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1649463573192/work
11
+ blinker==1.4
12
+ blis==0.7.8
13
+ blobfile==2.0.2
14
+ brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1648854164153/work
15
+ cachetools==5.3.1
16
+ catalogue @ file:///home/conda/feedstock_root/build_artifacts/catalogue_1661366519934/work
17
+ certifi @ file:///croot/certifi_1671487769961/work/certifi
18
+ cffi @ file:///tmp/abs_98z5h56wf8/croots/recipe/cffi_1659598650955/work
19
+ charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1661170624537/work
20
+ chumpy==0.70
21
+ click==8.1.3
22
+ clip @ git+https://github.com/openai/CLIP.git@a9b1bf5920416aaeaec965c25dd9e8f98c864f16
23
+ colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1655412516417/work
24
+ confection==0.0.2
25
+ cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography_1636040646098/work
26
+ cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
27
+ cymem @ file:///home/conda/feedstock_root/build_artifacts/cymem_1649412169067/work
28
+ dataclasses @ file:///home/conda/feedstock_root/build_artifacts/dataclasses_1628958434797/work
29
+ einops==0.6.1
30
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.3.0/en_core_web_sm-3.3.0-py3-none-any.whl
31
+ exceptiongroup==1.2.0
32
+ fastapi==0.103.2
33
+ ffmpy==0.3.1
34
+ filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1660129891014/work
35
+ frozenlist @ file:///croot/frozenlist_1670004507010/work
36
+ fsspec==2023.1.0
37
+ ftfy==6.1.1
38
+ gdown==4.7.1
39
+ google-auth==2.19.1
40
+ google-auth-oauthlib==0.4.6
41
+ gradio==3.34.0
42
+ gradio_client==0.2.6
43
+ grpcio==1.54.2
44
+ h11==0.14.0
45
+ h5py @ file:///tmp/abs_4aewd3wzey/croots/recipe/h5py_1659091371897/work
46
+ httpcore==0.17.3
47
+ httpx==0.24.1
48
+ huggingface-hub==0.16.4
49
+ idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
50
+ importlib-metadata==5.0.0
51
+ importlib-resources==5.12.0
52
+ Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work
53
+ joblib @ file:///tmp/build/80754af9/joblib_1635411271373/work
54
+ jsonschema==4.17.3
55
+ kiwisolver @ file:///opt/conda/conda-bld/kiwisolver_1653292039266/work
56
+ langcodes @ file:///home/conda/feedstock_root/build_artifacts/langcodes_1636741340529/work
57
+ linkify-it-py==2.0.2
58
+ loralib==0.1.1
59
+ lxml==4.9.1
60
+ Markdown @ file:///home/conda/feedstock_root/build_artifacts/markdown_1679584000376/work
61
+ markdown-it-py==2.2.0
62
+ MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1648737551960/work
63
+ matplotlib==3.1.3
64
+ mdit-py-plugins==0.3.3
65
+ mdurl==0.1.2
66
+ mkl-fft==1.3.1
67
+ mkl-random @ file:///tmp/build/80754af9/mkl_random_1626179032232/work
68
+ mkl-service==2.4.0
69
+ multidict @ file:///croot/multidict_1665674239670/work
70
+ murmurhash==1.0.8
71
+ numpy @ file:///opt/conda/conda-bld/numpy_and_numpy_base_1653915516269/work
72
+ oauthlib==3.2.2
73
+ orjson==3.9.7
74
+ packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1637239678211/work
75
+ pandas==1.3.5
76
+ pathy @ file:///home/conda/feedstock_root/build_artifacts/pathy_1656568808184/work
77
+ Pillow==9.2.0
78
+ pkgutil_resolve_name==1.3.10
79
+ preshed==3.0.7
80
+ protobuf==3.20.3
81
+ pyasn1==0.5.0
82
+ pyasn1-modules==0.3.0
83
+ pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
84
+ pycryptodomex==3.15.0
85
+ pydantic @ file:///home/conda/feedstock_root/build_artifacts/pydantic_1636021129189/work
86
+ pydub==0.25.1
87
+ Pygments==2.17.2
88
+ PyJWT @ file:///opt/conda/conda-bld/pyjwt_1657544592787/work
89
+ pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1663846997386/work
90
+ pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work
91
+ pyrsistent==0.19.3
92
+ PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1648857264451/work
93
+ python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
94
+ python-multipart==0.0.6
95
+ pytz==2023.3
96
+ PyYAML==6.0
97
+ regex==2022.9.13
98
+ requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1661872987712/work
99
+ requests-oauthlib==1.3.1
100
+ rsa==4.9
101
+ scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1642601761909/work
102
+ scipy @ file:///opt/conda/conda-bld/scipy_1661390393401/work
103
+ semantic-version==2.10.0
104
+ shellingham @ file:///home/conda/feedstock_root/build_artifacts/shellingham_1659638615822/work
105
+ six @ file:///tmp/build/80754af9/six_1644875935023/work
106
+ smart-open @ file:///home/conda/feedstock_root/build_artifacts/smart_open_1630238320325/work
107
+ smplx==0.1.28
108
+ sniffio==1.3.0
109
+ soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1658207591808/work
110
+ spacy @ file:///opt/conda/conda-bld/spacy_1656601313568/work
111
+ spacy-legacy @ file:///home/conda/feedstock_root/build_artifacts/spacy-legacy_1660748275723/work
112
+ spacy-loggers @ file:///home/conda/feedstock_root/build_artifacts/spacy-loggers_1661365735520/work
113
+ srsly==2.4.4
114
+ starlette==0.27.0
115
+ tensorboard==2.11.2
116
+ tensorboard-data-server==0.6.1
117
+ tensorboard-plugin-wit @ file:///home/builder/tkoch/workspace/tensorflow/tensorboard-plugin-wit_1658918494740/work/tensorboard_plugin_wit-1.8.1-py3-none-any.whl
118
+ tensorboardX==2.6
119
+ thinc==8.0.17
120
+ threadpoolctl @ file:///Users/ktietz/demo/mc3/conda-bld/threadpoolctl_1629802263681/work
121
+ toolz==0.12.0
122
+ torch==1.7.1
123
+ torch-tb-profiler==0.4.1
124
+ torchaudio==0.7.0a0+a853dff
125
+ torchvision==0.8.2
126
+ tornado @ file:///opt/conda/conda-bld/tornado_1662061693373/work
127
+ tqdm @ file:///opt/conda/conda-bld/tqdm_1664392687731/work
128
+ trimesh @ file:///home/conda/feedstock_root/build_artifacts/trimesh_1664841281434/work
129
+ typer @ file:///home/conda/feedstock_root/build_artifacts/typer_1657029164904/work
130
+ typing_extensions==4.7.1
131
+ uc-micro-py==1.0.2
132
+ urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1678635778344/work
133
+ uvicorn==0.22.0
134
+ vector-quantize-pytorch==1.6.30
135
+ wasabi @ file:///home/conda/feedstock_root/build_artifacts/wasabi_1668249950899/work
136
+ wcwidth==0.2.5
137
+ websockets==11.0.3
138
+ Werkzeug @ file:///home/conda/feedstock_root/build_artifacts/werkzeug_1676411946679/work
139
+ yarl @ file:///opt/conda/conda-bld/yarl_1661437085904/work
140
+ zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1659400682470/work
train_res_transformer.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+
5
+ from torch.utils.data import DataLoader
6
+ from os.path import join as pjoin
7
+
8
+ from models.mask_transformer.transformer import ResidualTransformer
9
+ from models.mask_transformer.transformer_trainer import ResidualTransformerTrainer
10
+ from models.vq.model import RVQVAE
11
+
12
+ from options.train_option import TrainT2MOptions
13
+
14
+ from utils.plot_script import plot_3d_motion
15
+ from utils.motion_process import recover_from_ric
16
+ from utils.get_opt import get_opt
17
+ from utils.fixseed import fixseed
18
+ from utils.paramUtil import t2m_kinematic_chain, kit_kinematic_chain
19
+
20
+ from data.t2m_dataset import Text2MotionDataset
21
+ from motion_loaders.dataset_motion_loader import get_dataset_motion_loader
22
+ from models.t2m_eval_wrapper import EvaluatorModelWrapper
23
+
24
+
25
+ def plot_t2m(data, save_dir, captions, m_lengths):
26
+ data = train_dataset.inv_transform(data)
27
+
28
+ # print(ep_curves.shape)
29
+ for i, (caption, joint_data) in enumerate(zip(captions, data)):
30
+ joint_data = joint_data[:m_lengths[i]]
31
+ joint = recover_from_ric(torch.from_numpy(joint_data).float(), opt.joints_num).numpy()
32
+ save_path = pjoin(save_dir, '%02d.mp4'%i)
33
+ # print(joint.shape)
34
+ plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=20)
35
+
36
+ def load_vq_model():
37
+ opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt')
38
+ vq_opt = get_opt(opt_path, opt.device)
39
+ vq_model = RVQVAE(vq_opt,
40
+ dim_pose,
41
+ vq_opt.nb_code,
42
+ vq_opt.code_dim,
43
+ vq_opt.output_emb_width,
44
+ vq_opt.down_t,
45
+ vq_opt.stride_t,
46
+ vq_opt.width,
47
+ vq_opt.depth,
48
+ vq_opt.dilation_growth_rate,
49
+ vq_opt.vq_act,
50
+ vq_opt.vq_norm)
51
+ ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'),
52
+ map_location=opt.device)
53
+ model_key = 'vq_model' if 'vq_model' in ckpt else 'net'
54
+ vq_model.load_state_dict(ckpt[model_key])
55
+ print(f'Loading VQ Model {opt.vq_name}')
56
+ vq_model.to(opt.device)
57
+ return vq_model, vq_opt
58
+
59
+ if __name__ == '__main__':
60
+ parser = TrainT2MOptions()
61
+ opt = parser.parse()
62
+ fixseed(opt.seed)
63
+
64
+ opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id))
65
+ torch.autograd.set_detect_anomaly(True)
66
+
67
+ opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
68
+ opt.model_dir = pjoin(opt.save_root, 'model')
69
+ # opt.meta_dir = pjoin(opt.save_root, 'meta')
70
+ opt.eval_dir = pjoin(opt.save_root, 'animation')
71
+ opt.log_dir = pjoin('./log/res/', opt.dataset_name, opt.name)
72
+
73
+ os.makedirs(opt.model_dir, exist_ok=True)
74
+ # os.makedirs(opt.meta_dir, exist_ok=True)
75
+ os.makedirs(opt.eval_dir, exist_ok=True)
76
+ os.makedirs(opt.log_dir, exist_ok=True)
77
+
78
+ if opt.dataset_name == 't2m':
79
+ opt.data_root = './dataset/HumanML3D'
80
+ opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
81
+ opt.joints_num = 22
82
+ opt.max_motion_len = 55
83
+ dim_pose = 263
84
+ radius = 4
85
+ fps = 20
86
+ kinematic_chain = t2m_kinematic_chain
87
+ dataset_opt_path = './checkpoints/t2m/Comp_v6_KLD005/opt.txt'
88
+
89
+ elif opt.dataset_name == 'kit': #TODO
90
+ opt.data_root = './dataset/KIT-ML'
91
+ opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
92
+ opt.joints_num = 21
93
+ radius = 240 * 8
94
+ fps = 12.5
95
+ dim_pose = 251
96
+ opt.max_motion_len = 55
97
+ kinematic_chain = kit_kinematic_chain
98
+ dataset_opt_path = './checkpoints/kit/Comp_v6_KLD005/opt.txt'
99
+
100
+ else:
101
+ raise KeyError('Dataset Does Not Exist')
102
+
103
+ opt.text_dir = pjoin(opt.data_root, 'texts')
104
+
105
+ vq_model, vq_opt = load_vq_model()
106
+
107
+ clip_version = 'ViT-B/32'
108
+
109
+ opt.num_tokens = vq_opt.nb_code
110
+ opt.num_quantizers = vq_opt.num_quantizers
111
+
112
+ # if opt.is_v2:
113
+ res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,
114
+ cond_mode='text',
115
+ latent_dim=opt.latent_dim,
116
+ ff_size=opt.ff_size,
117
+ num_layers=opt.n_layers,
118
+ num_heads=opt.n_heads,
119
+ dropout=opt.dropout,
120
+ clip_dim=512,
121
+ shared_codebook=vq_opt.shared_codebook,
122
+ cond_drop_prob=opt.cond_drop_prob,
123
+ # codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None,
124
+ share_weight=opt.share_weight,
125
+ clip_version=clip_version,
126
+ opt=opt)
127
+ # else:
128
+ # res_transformer = ResidualTransformer(code_dim=vq_opt.code_dim,
129
+ # cond_mode='text',
130
+ # latent_dim=opt.latent_dim,
131
+ # ff_size=opt.ff_size,
132
+ # num_layers=opt.n_layers,
133
+ # num_heads=opt.n_heads,
134
+ # dropout=opt.dropout,
135
+ # clip_dim=512,
136
+ # shared_codebook=vq_opt.shared_codebook,
137
+ # cond_drop_prob=opt.cond_drop_prob,
138
+ # # codebook=vq_model.quantizer.codebooks[0] if opt.fix_token_emb else None,
139
+ # clip_version=clip_version,
140
+ # opt=opt)
141
+
142
+
143
+ all_params = 0
144
+ pc_transformer = sum(param.numel() for param in res_transformer.parameters_wo_clip())
145
+
146
+ print(res_transformer)
147
+ # print("Total parameters of t2m_transformer net: {:.2f}M".format(pc_transformer / 1000_000))
148
+ all_params += pc_transformer
149
+
150
+ print('Total parameters of all models: {:.2f}M'.format(all_params / 1000_000))
151
+
152
+ mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'meta', 'mean.npy'))
153
+ std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'meta', 'std.npy'))
154
+
155
+ train_split_file = pjoin(opt.data_root, 'train.txt')
156
+ val_split_file = pjoin(opt.data_root, 'val.txt')
157
+
158
+ train_dataset = Text2MotionDataset(opt, mean, std, train_split_file)
159
+ val_dataset = Text2MotionDataset(opt, mean, std, val_split_file)
160
+
161
+ train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, num_workers=4, shuffle=True, drop_last=True)
162
+ val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, num_workers=4, shuffle=True, drop_last=True)
163
+
164
+ eval_val_loader, _ = get_dataset_motion_loader(dataset_opt_path, 32, 'val', device=opt.device)
165
+
166
+ wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
167
+ eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
168
+
169
+ trainer = ResidualTransformerTrainer(opt, res_transformer, vq_model)
170
+
171
+ trainer.train(train_loader, val_loader, eval_val_loader, eval_wrapper=eval_wrapper, plot_eval=plot_t2m)
train_t2m_transformer.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+
5
+ from torch.utils.data import DataLoader
6
+ from os.path import join as pjoin
7
+
8
+ from models.mask_transformer.transformer import MaskTransformer
9
+ from models.mask_transformer.transformer_trainer import MaskTransformerTrainer
10
+ from models.vq.model import RVQVAE
11
+
12
+ from options.train_option import TrainT2MOptions
13
+
14
+ from utils.plot_script import plot_3d_motion
15
+ from utils.motion_process import recover_from_ric
16
+ from utils.get_opt import get_opt
17
+ from utils.fixseed import fixseed
18
+ from utils.paramUtil import t2m_kinematic_chain, kit_kinematic_chain
19
+
20
+ from data.t2m_dataset import Text2MotionDataset
21
+ from motion_loaders.dataset_motion_loader import get_dataset_motion_loader
22
+ from models.t2m_eval_wrapper import EvaluatorModelWrapper
23
+
24
+
25
+ def plot_t2m(data, save_dir, captions, m_lengths):
26
+ data = train_dataset.inv_transform(data)
27
+
28
+ # print(ep_curves.shape)
29
+ for i, (caption, joint_data) in enumerate(zip(captions, data)):
30
+ joint_data = joint_data[:m_lengths[i]]
31
+ joint = recover_from_ric(torch.from_numpy(joint_data).float(), opt.joints_num).numpy()
32
+ save_path = pjoin(save_dir, '%02d.mp4'%i)
33
+ # print(joint.shape)
34
+ plot_3d_motion(save_path, kinematic_chain, joint, title=caption, fps=20)
35
+
36
+ def load_vq_model():
37
+ opt_path = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'opt.txt')
38
+ vq_opt = get_opt(opt_path, opt.device)
39
+ vq_model = RVQVAE(vq_opt,
40
+ dim_pose,
41
+ vq_opt.nb_code,
42
+ vq_opt.code_dim,
43
+ vq_opt.output_emb_width,
44
+ vq_opt.down_t,
45
+ vq_opt.stride_t,
46
+ vq_opt.width,
47
+ vq_opt.depth,
48
+ vq_opt.dilation_growth_rate,
49
+ vq_opt.vq_act,
50
+ vq_opt.vq_norm)
51
+ ckpt = torch.load(pjoin(vq_opt.checkpoints_dir, vq_opt.dataset_name, vq_opt.name, 'model', 'net_best_fid.tar'),
52
+ map_location='cpu')
53
+ model_key = 'vq_model' if 'vq_model' in ckpt else 'net'
54
+ vq_model.load_state_dict(ckpt[model_key])
55
+ print(f'Loading VQ Model {opt.vq_name}')
56
+ return vq_model, vq_opt
57
+
58
+ if __name__ == '__main__':
59
+ parser = TrainT2MOptions()
60
+ opt = parser.parse()
61
+ fixseed(opt.seed)
62
+
63
+ opt.device = torch.device("cpu" if opt.gpu_id == -1 else "cuda:" + str(opt.gpu_id))
64
+ torch.autograd.set_detect_anomaly(True)
65
+
66
+ opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
67
+ opt.model_dir = pjoin(opt.save_root, 'model')
68
+ # opt.meta_dir = pjoin(opt.save_root, 'meta')
69
+ opt.eval_dir = pjoin(opt.save_root, 'animation')
70
+ opt.log_dir = pjoin('./log/t2m/', opt.dataset_name, opt.name)
71
+
72
+ os.makedirs(opt.model_dir, exist_ok=True)
73
+ # os.makedirs(opt.meta_dir, exist_ok=True)
74
+ os.makedirs(opt.eval_dir, exist_ok=True)
75
+ os.makedirs(opt.log_dir, exist_ok=True)
76
+
77
+ if opt.dataset_name == 't2m':
78
+ opt.data_root = './dataset/HumanML3D'
79
+ opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
80
+ opt.joints_num = 22
81
+ opt.max_motion_len = 55
82
+ dim_pose = 263
83
+ radius = 4
84
+ fps = 20
85
+ kinematic_chain = t2m_kinematic_chain
86
+ dataset_opt_path = './checkpoints/t2m/Comp_v6_KLD005/opt.txt'
87
+
88
+ elif opt.dataset_name == 'kit': #TODO
89
+ opt.data_root = './dataset/KIT-ML'
90
+ opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
91
+ opt.joints_num = 21
92
+ radius = 240 * 8
93
+ fps = 12.5
94
+ dim_pose = 251
95
+ opt.max_motion_len = 55
96
+ kinematic_chain = kit_kinematic_chain
97
+ dataset_opt_path = './checkpoints/kit/Comp_v6_KLD005/opt.txt'
98
+
99
+ else:
100
+ raise KeyError('Dataset Does Not Exist')
101
+
102
+ opt.text_dir = pjoin(opt.data_root, 'texts')
103
+
104
+ vq_model, vq_opt = load_vq_model()
105
+
106
+ clip_version = 'ViT-B/32'
107
+
108
+ opt.num_tokens = vq_opt.nb_code
109
+
110
+ t2m_transformer = MaskTransformer(code_dim=vq_opt.code_dim,
111
+ cond_mode='text',
112
+ latent_dim=opt.latent_dim,
113
+ ff_size=opt.ff_size,
114
+ num_layers=opt.n_layers,
115
+ num_heads=opt.n_heads,
116
+ dropout=opt.dropout,
117
+ clip_dim=512,
118
+ cond_drop_prob=opt.cond_drop_prob,
119
+ clip_version=clip_version,
120
+ opt=opt)
121
+
122
+ # if opt.fix_token_emb:
123
+ # t2m_transformer.load_and_freeze_token_emb(vq_model.quantizer.codebooks[0])
124
+
125
+ all_params = 0
126
+ pc_transformer = sum(param.numel() for param in t2m_transformer.parameters_wo_clip())
127
+
128
+ # print(t2m_transformer)
129
+ # print("Total parameters of t2m_transformer net: {:.2f}M".format(pc_transformer / 1000_000))
130
+ all_params += pc_transformer
131
+
132
+ print('Total parameters of all models: {:.2f}M'.format(all_params / 1000_000))
133
+
134
+ mean = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'meta', 'mean.npy'))
135
+ std = np.load(pjoin(opt.checkpoints_dir, opt.dataset_name, opt.vq_name, 'meta', 'std.npy'))
136
+
137
+ train_split_file = pjoin(opt.data_root, 'train.txt')
138
+ val_split_file = pjoin(opt.data_root, 'val.txt')
139
+
140
+ train_dataset = Text2MotionDataset(opt, mean, std, train_split_file)
141
+ val_dataset = Text2MotionDataset(opt, mean, std, val_split_file)
142
+
143
+ train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, num_workers=4, shuffle=True, drop_last=True)
144
+ val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, num_workers=4, shuffle=True, drop_last=True)
145
+
146
+ eval_val_loader, _ = get_dataset_motion_loader(dataset_opt_path, 32, 'val', device=opt.device)
147
+
148
+ wrapper_opt = get_opt(dataset_opt_path, torch.device('cuda'))
149
+ eval_wrapper = EvaluatorModelWrapper(wrapper_opt)
150
+
151
+ trainer = MaskTransformerTrainer(opt, t2m_transformer, vq_model)
152
+
153
+ trainer.train(train_loader, val_loader, eval_val_loader, eval_wrapper=eval_wrapper, plot_eval=plot_t2m)