meow commited on
Commit
d6d3a5b
1 Parent(s): f9fd2fa
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +462 -12
  3. app.py +52 -0
  4. cog.yaml +38 -0
  5. common/.gitignore +1 -0
  6. common/___init___.py +0 -0
  7. common/abstract_pl.py +180 -0
  8. common/args_utils.py +15 -0
  9. common/body_models.py +146 -0
  10. common/camera.py +474 -0
  11. common/comet_utils.py +158 -0
  12. common/data_utils.py +371 -0
  13. common/ld_utils.py +116 -0
  14. common/list_utils.py +52 -0
  15. common/mesh.py +94 -0
  16. common/metrics.py +51 -0
  17. common/np_utils.py +7 -0
  18. common/object_tensors.py +293 -0
  19. common/pl_utils.py +63 -0
  20. common/rend_utils.py +139 -0
  21. common/rot.py +782 -0
  22. common/sys_utils.py +44 -0
  23. common/thing.py +66 -0
  24. common/torch_utils.py +212 -0
  25. common/transforms.py +356 -0
  26. common/viewer.py +287 -0
  27. common/vis_utils.py +129 -0
  28. common/xdict.py +288 -0
  29. data_loaders/.DS_Store +0 -0
  30. data_loaders/__pycache__/get_data.cpython-38.pyc +0 -0
  31. data_loaders/__pycache__/tensors.cpython-38.pyc +0 -0
  32. data_loaders/get_data.py +178 -0
  33. data_loaders/humanml/.DS_Store +0 -0
  34. data_loaders/humanml/README.md +1 -0
  35. data_loaders/humanml/common/__pycache__/quaternion.cpython-38.pyc +0 -0
  36. data_loaders/humanml/common/__pycache__/skeleton.cpython-38.pyc +0 -0
  37. data_loaders/humanml/common/quaternion.py +423 -0
  38. data_loaders/humanml/common/skeleton.py +199 -0
  39. data_loaders/humanml/data/__init__.py +0 -0
  40. data_loaders/humanml/data/__pycache__/__init__.cpython-38.pyc +0 -0
  41. data_loaders/humanml/data/__pycache__/dataset.cpython-38.pyc +0 -0
  42. data_loaders/humanml/data/__pycache__/dataset_ours.cpython-38.pyc +0 -0
  43. data_loaders/humanml/data/__pycache__/dataset_ours_single_seq.cpython-38.pyc +0 -0
  44. data_loaders/humanml/data/__pycache__/utils.cpython-38.pyc +0 -0
  45. data_loaders/humanml/data/dataset.py +795 -0
  46. data_loaders/humanml/data/dataset_ours.py +0 -0
  47. data_loaders/humanml/data/dataset_ours_single_seq.py +0 -0
  48. data_loaders/humanml/data/utils.py +507 -0
  49. data_loaders/humanml/motion_loaders/__init__.py +0 -0
  50. data_loaders/humanml/motion_loaders/__pycache__/__init__.cpython-38.pyc +0 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Guy Tevet
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,462 @@
1
- ---
2
- title: Gene Hoi Denoising
3
- emoji: 📉
4
- colorFrom: yellow
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 4.17.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MDM: Human Motion Diffusion Model
2
+
3
+
4
+ data in what format and data in this foramt
5
+
6
+
7
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/human-motion-diffusion-model/motion-synthesis-on-humanact12)](https://paperswithcode.com/sota/motion-synthesis-on-humanact12?p=human-motion-diffusion-model)
8
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/human-motion-diffusion-model/motion-synthesis-on-humanml3d)](https://paperswithcode.com/sota/motion-synthesis-on-humanml3d?p=human-motion-diffusion-model)
9
+ [![arXiv](https://img.shields.io/badge/arXiv-<2209.14916>-<COLOR>.svg)](https://arxiv.org/abs/2209.14916)
10
+
11
+ <a href="https://replicate.com/arielreplicate/motion_diffusion_model"><img src="https://replicate.com/arielreplicate/motion_diffusion_model/badge"></a>
12
+
13
+ The official PyTorch implementation of the paper [**"Human Motion Diffusion Model"**](https://arxiv.org/abs/2209.14916).
14
+
15
+ Please visit our [**webpage**](https://guytevet.github.io/mdm-page/) for more details.
16
+
17
+ ![teaser](https://github.com/GuyTevet/mdm-page/raw/main/static/figures/github.gif)
18
+
19
+ #### Bibtex
20
+ If you find this code useful in your research, please cite:
21
+
22
+ ```
23
+ @article{tevet2022human,
24
+ title={Human Motion Diffusion Model},
25
+ author={Tevet, Guy and Raab, Sigal and Gordon, Brian and Shafir, Yonatan and Bermano, Amit H and Cohen-Or, Daniel},
26
+ journal={arXiv preprint arXiv:2209.14916},
27
+ year={2022}
28
+ }
29
+ ```
30
+
31
+ ## News
32
+
33
+ 📢 **23/Nov/22** - Fixed evaluation issue (#42) - Please pull and run `bash prepare/download_t2m_evaluators.sh` from the top of the repo to adapt.
34
+
35
+ 📢 **4/Nov/22** - Added sampling, training and evaluation of unconstrained tasks.
36
+ Note slight env changes adapting to the new code. If you already have an installed environment, run `bash prepare/download_unconstrained_assets.sh; conda install -y -c anaconda scikit-learn
37
+ ` to adapt.
38
+
39
+ 📢 **3/Nov/22** - Added in-between and upper-body editing.
40
+
41
+ 📢 **31/Oct/22** - Added sampling, training and evaluation of action-to-motion tasks.
42
+
43
+ 📢 **9/Oct/22** - Added training and evaluation scripts.
44
+ Note slight env changes adapting to the new code. If you already have an installed environment, run `bash prepare/download_glove.sh; pip install clearml` to adapt.
45
+
46
+ 📢 **6/Oct/22** - First release - sampling and rendering using pre-trained models.
47
+
48
+
49
+ ## Getting started
50
+
51
+ This code was tested on `Ubuntu 18.04.5 LTS` and requires:
52
+
53
+ * Python 3.7
54
+ * conda3 or miniconda3
55
+ * CUDA capable GPU (one is enough)
56
+
57
+ ### 1. Setup environment
58
+
59
+ Install ffmpeg (if not already installed):
60
+
61
+ ```shell
62
+ sudo apt update
63
+ sudo apt install ffmpeg
64
+ ```
65
+ For windows use [this](https://www.geeksforgeeks.org/how-to-install-ffmpeg-on-windows/) instead.
66
+
67
+ Setup conda env:
68
+ ```shell
69
+ conda env create -f environment.yml
70
+ conda activate mdm
71
+ python -m spacy download en_core_web_sm
72
+ pip install git+https://github.com/openai/CLIP.git
73
+ ```
74
+
75
+ Download dependencies:
76
+
77
+ <details>
78
+ <summary><b>Text to Motion</b></summary>
79
+
80
+ ```bash
81
+ bash prepare/download_smpl_files.sh
82
+ bash prepare/download_glove.sh
83
+ bash prepare/download_t2m_evaluators.sh
84
+ ```
85
+ </details>
86
+
87
+ <details>
88
+ <summary><b>Action to Motion</b></summary>
89
+
90
+ ```bash
91
+ bash prepare/download_smpl_files.sh
92
+ bash prepare/download_recognition_models.sh
93
+ ```
94
+ </details>
95
+
96
+ <details>
97
+ <summary><b>Unconstrained</b></summary>
98
+
99
+ ```bash
100
+ bash prepare/download_smpl_files.sh
101
+ bash prepare/download_recognition_models.sh
102
+ bash prepare/download_recognition_unconstrained_models.sh
103
+ ```
104
+ </details>
105
+
106
+ ### 2. Get data
107
+
108
+ <details>
109
+ <summary><b>Text to Motion</b></summary>
110
+
111
+ There are two paths to get the data:
112
+
113
+ (a) **Go the easy way if** you just want to generate text-to-motion (excluding editing which does require motion capture data)
114
+
115
+ (b) **Get full data** to train and evaluate the model.
116
+
117
+
118
+ #### a. The easy way (text only)
119
+
120
+ **HumanML3D** - Clone HumanML3D, then copy the data dir to our repository:
121
+
122
+ ```shell
123
+ cd ..
124
+ git clone https://github.com/EricGuo5513/HumanML3D.git
125
+ unzip ./HumanML3D/HumanML3D/texts.zip -d ./HumanML3D/HumanML3D/
126
+ cp -r HumanML3D/HumanML3D motion-diffusion-model/dataset/HumanML3D
127
+ cd motion-diffusion-model
128
+ ```
129
+
130
+
131
+ #### b. Full data (text + motion capture)
132
+
133
+ **HumanML3D** - Follow the instructions in [HumanML3D](https://github.com/EricGuo5513/HumanML3D.git),
134
+ then copy the result dataset to our repository:
135
+
136
+ ```shell
137
+ cp -r ../HumanML3D/HumanML3D ./dataset/HumanML3D
138
+ ```
139
+
140
+ **KIT** - Download from [HumanML3D](https://github.com/EricGuo5513/HumanML3D.git) (no processing needed this time) and the place result in `./dataset/KIT-ML`
141
+ </details>
142
+
143
+ <details>
144
+ <summary><b>Action to Motion</b></summary>
145
+
146
+ **UESTC, HumanAct12**
147
+ ```bash
148
+ bash prepare/download_a2m_datasets.sh
149
+ ```
150
+ </details>
151
+
152
+ <details>
153
+ <summary><b>Unconstrained</b></summary>
154
+
155
+ **HumanAct12**
156
+ ```bash
157
+ bash prepare/download_unconstrained_datasets.sh
158
+ ```
159
+ </details>
160
+
161
+ ### 3. Download the pretrained models
162
+
163
+ Download the model(s) you wish to use, then unzip and place them in `./save/`.
164
+
165
+ <details>
166
+ <summary><b>Text to Motion</b></summary>
167
+
168
+ **You need only the first one.**
169
+
170
+ **HumanML3D**
171
+
172
+ [humanml-encoder-512](https://drive.google.com/file/d/1PE0PK8e5a5j-7-Xhs5YET5U5pGh0c821/view?usp=sharing) (best model)
173
+
174
+ [humanml-decoder-512](https://drive.google.com/file/d/1q3soLadvVh7kJuJPd2cegMNY2xVuVudj/view?usp=sharing)
175
+
176
+ [humanml-decoder-with-emb-512](https://drive.google.com/file/d/1GnsW0K3UjuOkNkAWmjrGIUmeDDZrmPE5/view?usp=sharing)
177
+
178
+ **KIT**
179
+
180
+ [kit-encoder-512](https://drive.google.com/file/d/1SHCRcE0es31vkJMLGf9dyLe7YsWj7pNL/view?usp=sharing)
181
+
182
+ </details>
183
+
184
+ <details>
185
+ <summary><b>Action to Motion</b></summary>
186
+
187
+ **UESTC**
188
+
189
+ [uestc](https://drive.google.com/file/d/1goB2DJK4B-fLu2QmqGWKAqWGMTAO6wQ6/view?usp=sharing)
190
+
191
+ [uestc_no_fc](https://drive.google.com/file/d/1fpv3mR-qP9CYCsi9CrQhFqlLavcSQky6/view?usp=sharing)
192
+
193
+ **HumanAct12**
194
+
195
+ [humanact12](https://drive.google.com/file/d/154X8_Lgpec6Xj0glEGql7FVKqPYCdBFO/view?usp=sharing)
196
+
197
+ [humanact12_no_fc](https://drive.google.com/file/d/1frKVMBYNiN5Mlq7zsnhDBzs9vGJvFeiQ/view?usp=sharing)
198
+
199
+ </details>
200
+
201
+ <details>
202
+ <summary><b>Unconstrained</b></summary>
203
+
204
+ **HumanAct12**
205
+
206
+ [humanact12_unconstrained](https://drive.google.com/file/d/1uG68m200pZK3pD-zTmPXu5XkgNpx_mEx/view?usp=share_link)
207
+
208
+ </details>
209
+
210
+
211
+ ## Example Usage
212
+
213
+
214
+ example usage and results on TACO dataset
215
+
216
+
217
+ | Input | Result | Overlayed |
218
+ | :----------------------: | :---------------------: | :-----------------------: |
219
+ | ![](assets/taco-20231104_017-src-a.gif) | ![](assets/taco-20231104_017-res-a.gif) | ![](assets/taco-20231104_017-overlayed-a.gif) |
220
+
221
+
222
+ Follow steps below to reproduce the above result.
223
+
224
+ 1. **Denoising**
225
+ ```bash
226
+ bash scripts/val_examples/predict_taco_rndseed_spatial_20231104_017.sh
227
+ ```
228
+ Ten random seeds will be utilizd for prediction. The predicted results will be saved in the folder `./data/taco/result`.
229
+ 2. **Mesh reconstruction**
230
+ ```bash
231
+ bash scripts/val_examples/reconstruct_taco_20231104_017.sh
232
+ ```
233
+ Results will be saved under the same folder with the above step.
234
+ 3. **Extracting results and visualization**
235
+
236
+
237
+
238
+ <details>
239
+ <summary><b>Text to Motion</b></summary>
240
+
241
+ ### Generate from test set prompts
242
+
243
+ ```shell
244
+ python -m sample.generate --model_path ./save/humanml_trans_enc_512/model000200000.pt --num_samples 10 --num_repetitions 3
245
+ ```
246
+
247
+ ### Generate from your text file
248
+
249
+ ```shell
250
+ python -m sample.generate --model_path ./save/humanml_trans_enc_512/model000200000.pt --input_text ./assets/example_text_prompts.txt
251
+ ```
252
+
253
+ ### Generate a single prompt
254
+
255
+ ```shell
256
+ python -m sample.generate --model_path ./save/humanml_trans_enc_512/model000200000.pt --text_prompt "the person walked forward and is picking up his toolbox."
257
+ ```
258
+ </details>
259
+
260
+ <details>
261
+ <summary><b>Action to Motion</b></summary>
262
+
263
+ ### Generate from test set actions
264
+
265
+ ```shell
266
+ python -m sample.generate --model_path ./save/humanact12/model000350000.pt --num_samples 10 --num_repetitions 3
267
+ ```
268
+
269
+ ### Generate from your actions file
270
+
271
+ ```shell
272
+ python -m sample.generate --model_path ./save/humanact12/model000350000.pt --action_file ./assets/example_action_names_humanact12.txt
273
+ ```
274
+
275
+ ### Generate a single action
276
+
277
+ ```shell
278
+ python -m sample.generate --model_path ./save/humanact12/model000350000.pt --text_prompt "drink"
279
+ ```
280
+ </details>
281
+
282
+ <details>
283
+ <summary><b>Unconstrained</b></summary>
284
+
285
+ ```shell
286
+ python -m sample.generate --model_path ./save/unconstrained/model000450000.pt --num_samples 10 --num_repetitions 3
287
+ ```
288
+
289
+ By abuse of notation, (num_samples * num_repetitions) samples are created, and are visually organized in a display of num_samples rows and num_repetitions columns.
290
+
291
+ </details>
292
+
293
+ **You may also define:**
294
+ * `--device` id.
295
+ * `--seed` to sample different prompts.
296
+ * `--motion_length` (text-to-motion only) in seconds (maximum is 9.8[sec]).
297
+
298
+ **Running those will get you:**
299
+
300
+ * `results.npy` file with text prompts and xyz positions of the generated animation
301
+ * `sample##_rep##.mp4` - a stick figure animation for each generated motion.
302
+
303
+ It will look something like this:
304
+
305
+ ![example](assets/example_stick_fig.gif)
306
+
307
+ You can stop here, or render the SMPL mesh using the following script.
308
+
309
+ ### Render SMPL mesh
310
+
311
+ To create SMPL mesh per frame run:
312
+
313
+ ```shell
314
+ python -m visualize.render_mesh --input_path /path/to/mp4/stick/figure/file
315
+ ```
316
+
317
+ **This script outputs:**
318
+ * `sample##_rep##_smpl_params.npy` - SMPL parameters (thetas, root translations, vertices and faces)
319
+ * `sample##_rep##_obj` - Mesh per frame in `.obj` format.
320
+
321
+ **Notes:**
322
+ * The `.obj` can be integrated into Blender/Maya/3DS-MAX and rendered using them.
323
+ * This script is running [SMPLify](https://smplify.is.tue.mpg.de/) and needs GPU as well (can be specified with the `--device` flag).
324
+ * **Important** - Do not change the original `.mp4` path before running the script.
325
+
326
+ **Notes for 3d makers:**
327
+ * You have two ways to animate the sequence:
328
+ 1. Use the [SMPL add-on](https://smpl.is.tue.mpg.de/index.html) and the theta parameters saved to `sample##_rep##_smpl_params.npy` (we always use beta=0 and the gender-neutral model).
329
+ 1. A more straightforward way is using the mesh data itself. All meshes have the same topology (SMPL), so you just need to keyframe vertex locations.
330
+ Since the OBJs are not preserving vertices order, we also save this data to the `sample##_rep##_smpl_params.npy` file for your convenience.
331
+
332
+ ## Motion Editing
333
+
334
+ * This feature is available for text-to-motion datasets (HumanML3D and KIT).
335
+ * In order to use it, you need to acquire the full data (not just the texts).
336
+ * We support the two modes presented in the paper: `in_between` and `upper_body`.
337
+
338
+ ### Unconditioned editing
339
+
340
+ ```shell
341
+ python -m sample.edit --model_path ./save/humanml_trans_enc_512/model000200000.pt --edit_mode in_between
342
+ ```
343
+
344
+ **You may also define:**
345
+ * `--num_samples` (default is 10) / `--num_repetitions` (default is 3).
346
+ * `--device` id.
347
+ * `--seed` to sample different prompts.
348
+ * `--edit_mode upper_body` For upper body editing (lower body is fixed).
349
+
350
+
351
+ The output will look like this (blue frames are from the input motion; orange were generated by the model):
352
+
353
+ ![example](assets/in_between_edit.gif)
354
+
355
+ * As in *Motion Synthesis*, you may follow the **Render SMPL mesh** section to obtain meshes for your edited motions.
356
+
357
+ ### Text conditioned editing
358
+
359
+ Just add the text conditioning using `--text_condition`. For example:
360
+
361
+ ```shell
362
+ python -m sample.edit --model_path ./save/humanml_trans_enc_512/model000200000.pt --edit_mode upper_body --text_condition "A person throws a ball"
363
+ ```
364
+
365
+ The output will look like this (blue joints are from the input motion; orange were generated by the model):
366
+
367
+ ![example](assets/upper_body_edit.gif)
368
+
369
+ ## Train your own MDM
370
+
371
+ <details>
372
+ <summary><b>Text to Motion</b></summary>
373
+
374
+ **HumanML3D**
375
+ ```shell
376
+ python -m train.train_mdm --save_dir save/my_humanml_trans_enc_512 --dataset humanml
377
+ ```
378
+
379
+ **KIT**
380
+ ```shell
381
+ python -m train.train_mdm --save_dir save/my_kit_trans_enc_512 --dataset kit
382
+ ```
383
+ </details>
384
+ <details>
385
+ <summary><b>Action to Motion</b></summary>
386
+
387
+ ```shell
388
+ python -m train.train_mdm --save_dir save/my_name --dataset {humanact12,uestc} --cond_mask_prob 0 --lambda_rcxyz 1 --lambda_vel 1 --lambda_fc 1
389
+ ```
390
+ </details>
391
+
392
+ <details>
393
+ <summary><b>Unconstrained</b></summary>
394
+
395
+ ```shell
396
+ python -m train.train_mdm --save_dir save/my_name --dataset humanact12 --cond_mask_prob 0 --lambda_rcxyz 1 --lambda_vel 1 --lambda_fc 1 --unconstrained
397
+ ```
398
+ </details>
399
+
400
+ * Use `--device` to define GPU id.
401
+ * Use `--arch` to choose one of the architectures reported in the paper `{trans_enc, trans_dec, gru}` (`trans_enc` is default).
402
+ * Add `--train_platform_type {ClearmlPlatform, TensorboardPlatform}` to track results with either [ClearML](https://clear.ml/) or [Tensorboard](https://www.tensorflow.org/tensorboard).
403
+ * Add `--eval_during_training` to run a short (90 minutes) evaluation for each saved checkpoint.
404
+ This will slow down training but will give you better monitoring.
405
+
406
+ ## Evaluate
407
+
408
+ <details>
409
+ <summary><b>Text to Motion</b></summary>
410
+
411
+ * Takes about 20 hours (on a single GPU)
412
+ * The output of this script for the pre-trained models (as was reported in the paper) is provided in the checkpoints zip file.
413
+
414
+ **HumanML3D**
415
+ ```shell
416
+ python -m eval.eval_humanml --model_path ./save/humanml_trans_enc_512/model000475000.pt
417
+ ```
418
+
419
+ **KIT**
420
+ ```shell
421
+ python -m eval.eval_humanml --model_path ./save/kit_trans_enc_512/model000400000.pt
422
+ ```
423
+ </details>
424
+
425
+ <details>
426
+ <summary><b>Action to Motion</b></summary>
427
+
428
+ * Takes about 7 hours for UESTC and 2 hours for HumanAct12 (on a single GPU)
429
+ * The output of this script for the pre-trained models (as was reported in the paper) is provided in the checkpoints zip file.
430
+
431
+ ```shell
432
+ python -m eval.eval_humanact12_uestc --model <path-to-model-ckpt> --eval_mode full
433
+ ```
434
+ where `path-to-model-ckpt` can be a path to any of the pretrained action-to-motion models listed above, or to a checkpoint trained by the user.
435
+
436
+ </details>
437
+
438
+
439
+ <details>
440
+ <summary><b>Unconstrained</b></summary>
441
+
442
+ * Takes about 3 hours (on a single GPU)
443
+
444
+ ```shell
445
+ python -m eval.eval_humanact12_uestc --model ./save/unconstrained/model000450000.pt --eval_mode full
446
+ ```
447
+
448
+ Precision and recall are not computed to save computing time. If you wish to compute them, edit the file eval/a2m/gru_eval.py and change the string `fast=True` to `fast=False`.
449
+
450
+ </details>
451
+
452
+ ## Acknowledgments
453
+
454
+ This code is standing on the shoulders of giants. We want to thank the following contributors
455
+ that our code is based on:
456
+
457
+ [guided-diffusion](https://github.com/openai/guided-diffusion), [MotionCLIP](https://github.com/GuyTevet/MotionCLIP), [text-to-motion](https://github.com/EricGuo5513/text-to-motion), [actor](https://github.com/Mathux/ACTOR), [joints2smpl](https://github.com/wangsen1312/joints2smpl), [MoDi](https://github.com/sigal-raab/MoDi).
458
+
459
+ ## License
460
+ This code is distributed under an [MIT LICENSE](LICENSE).
461
+
462
+ Note that our code depends on other libraries, including CLIP, SMPL, SMPL-X, PyTorch3D, and uses datasets that each have their own respective licenses that must also be followed.
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import gradio as gr
4
+
5
+
6
+ import os
7
+
8
+ import tempfile
9
+
10
+ import shutil
11
+
12
+ # from gradio_inter.predict_from_file import predict_from_file
13
+ from gradio_inter.create_bash_file import create_bash_file
14
+
15
+ def create_temp_file(path: str) -> str:
16
+ temp_dir = tempfile.gettempdir()
17
+ temp_folder = os.path.join(temp_dir, "denoising")
18
+ os.makedirs(temp_folder, exist_ok=True)
19
+ # Clean up directory
20
+ # for i in os.listdir(temp_folder):
21
+ # print("Removing", i)
22
+ # os.remove(os.path.join(temp_folder, i))
23
+
24
+ temp_path = os.path.join(temp_folder, path.split("/")[-1])
25
+ shutil.copy2(path, temp_path)
26
+ return temp_path
27
+
28
+ # from gradio_inter.predict import predict_from_data
29
+ # from gradio_inter.predi
30
+
31
+ def transpose(matrix):
32
+ return matrix.T
33
+
34
+
35
+ def predict(file_path: str):
36
+ temp_file_path = create_temp_file(file_path)
37
+ # predict_from_file
38
+ temp_bash_file = create_bash_file(temp_file_path)
39
+
40
+ os.system(f"bash {temp_bash_file}")
41
+
42
+
43
+ demo = gr.Interface(
44
+ predict,
45
+ # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3),
46
+ gr.File(type="filepath"),
47
+ "dict",
48
+ cache_examples=False
49
+ )
50
+
51
+ if __name__ == "__main__":
52
+ demo.launch()
cog.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build:
2
+ gpu: true
3
+ cuda: "11.3"
4
+ python_version: 3.8
5
+ system_packages:
6
+ - libgl1-mesa-glx
7
+ - libglib2.0-0
8
+
9
+ python_packages:
10
+ - imageio==2.22.2
11
+ - matplotlib==3.1.3
12
+ - spacy==3.3.1
13
+ - smplx==0.1.28
14
+ - chumpy==0.70
15
+ - blis==0.7.8
16
+ - click==8.1.3
17
+ - confection==0.0.2
18
+ - ftfy==6.1.1
19
+ - importlib-metadata==5.0.0
20
+ - lxml==4.9.1
21
+ - murmurhash==1.0.8
22
+ - preshed==3.0.7
23
+ - pycryptodomex==3.15.0
24
+ - regex==2022.9.13
25
+ - srsly==2.4.4
26
+ - thinc==8.0.17
27
+ - typing-extensions==4.1.1
28
+ - urllib3==1.26.12
29
+ - wasabi==0.10.1
30
+ - wcwidth==0.2.5
31
+
32
+ run:
33
+ - apt update -y && apt-get install ffmpeg -y
34
+ # - python -m spacy download en_core_web_sm
35
+ - git clone https://github.com/openai/CLIP.git sub_modules/CLIP
36
+ - pip install -e sub_modules/CLIP
37
+
38
+ predict: "sample/predict.py:Predictor"
common/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
common/___init___.py ADDED
File without changes
common/abstract_pl.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ import torch.optim as optim
7
+
8
+ import common.pl_utils as pl_utils
9
+ from common.comet_utils import log_dict
10
+ from common.pl_utils import avg_losses_cpu, push_checkpoint_metric
11
+ from common.xdict import xdict
12
+
13
+
14
+ class AbstractPL(pl.LightningModule):
15
+ def __init__(
16
+ self,
17
+ args,
18
+ push_images_fn,
19
+ tracked_metric,
20
+ metric_init_val,
21
+ high_loss_val,
22
+ ):
23
+ super().__init__()
24
+ self.experiment = args.experiment
25
+ self.args = args
26
+ self.tracked_metric = tracked_metric
27
+ self.metric_init_val = metric_init_val
28
+
29
+ self.started_training = False
30
+ self.loss_dict_vec = []
31
+ self.push_images = push_images_fn
32
+ self.vis_train_batches = []
33
+ self.vis_val_batches = []
34
+ self.high_loss_val = high_loss_val
35
+ self.max_vis_examples = 20
36
+ self.val_step_outputs = []
37
+ self.test_step_outputs = []
38
+
39
+ def set_training_flags(self):
40
+ self.started_training = True
41
+
42
+ def load_from_ckpt(self, ckpt_path):
43
+ sd = torch.load(ckpt_path)["state_dict"]
44
+ print(self.load_state_dict(sd))
45
+
46
+ def training_step(self, batch, batch_idx):
47
+ self.set_training_flags()
48
+ if len(self.vis_train_batches) < self.num_vis_train:
49
+ self.vis_train_batches.append(batch)
50
+ inputs, targets, meta_info = batch
51
+
52
+ out = self.forward(inputs, targets, meta_info, "train")
53
+ loss = out["loss"]
54
+
55
+ loss = {k: loss[k].mean().view(-1) for k in loss}
56
+ total_loss = sum(loss[k] for k in loss)
57
+
58
+ loss_dict = {"total_loss": total_loss, "loss": total_loss}
59
+ loss_dict.update(loss)
60
+
61
+ for k, v in loss_dict.items():
62
+ if k != "loss":
63
+ loss_dict[k] = v.detach()
64
+
65
+ log_every = self.args.log_every
66
+ self.loss_dict_vec.append(loss_dict)
67
+ self.loss_dict_vec = self.loss_dict_vec[len(self.loss_dict_vec) - log_every :]
68
+ if batch_idx % log_every == 0 and batch_idx != 0:
69
+ running_loss_dict = avg_losses_cpu(self.loss_dict_vec)
70
+ running_loss_dict = xdict(running_loss_dict).postfix("__train")
71
+ log_dict(self.experiment, running_loss_dict, step=self.global_step)
72
+ return loss_dict
73
+
74
+ def on_train_epoch_end(self):
75
+ self.experiment.log_epoch_end(self.current_epoch)
76
+
77
+ def validation_step(self, batch, batch_idx):
78
+ if len(self.vis_val_batches) < self.num_vis_val:
79
+ self.vis_val_batches.append(batch)
80
+ out = self.inference_step(batch, batch_idx)
81
+ self.val_step_outputs.append(out)
82
+ return out
83
+
84
+ def on_validation_epoch_end(self):
85
+ outputs = self.val_step_outputs
86
+ outputs = self.inference_epoch_end(outputs, postfix="__val")
87
+ self.log("loss__val", outputs["loss__val"])
88
+ self.val_step_outputs.clear() # free memory
89
+ return outputs
90
+
91
+ def inference_step(self, batch, batch_idx):
92
+ if self.training:
93
+ self.eval()
94
+ with torch.no_grad():
95
+ inputs, targets, meta_info = batch
96
+ out, loss = self.forward(inputs, targets, meta_info, "test")
97
+ return {"out_dict": out, "loss": loss}
98
+
99
+ def inference_epoch_end(self, out_list, postfix):
100
+ if not self.started_training:
101
+ self.started_training = True
102
+ result = push_checkpoint_metric(self.tracked_metric, self.metric_init_val)
103
+ return result
104
+
105
+ # unpack
106
+ outputs, loss_dict = pl_utils.reform_outputs(out_list)
107
+
108
+ if "test" in postfix:
109
+ per_img_metric_dict = {}
110
+ for k, v in outputs.items():
111
+ if "metric." in k:
112
+ per_img_metric_dict[k] = np.array(v)
113
+
114
+ metric_dict = {}
115
+ for k, v in outputs.items():
116
+ if "metric." in k:
117
+ metric_dict[k] = np.nanmean(np.array(v))
118
+
119
+ loss_metric_dict = {}
120
+ loss_metric_dict.update(metric_dict)
121
+ loss_metric_dict.update(loss_dict)
122
+ loss_metric_dict = xdict(loss_metric_dict).postfix(postfix)
123
+
124
+ log_dict(
125
+ self.experiment,
126
+ loss_metric_dict,
127
+ step=self.global_step,
128
+ )
129
+
130
+ if self.args.interface_p is None and "test" not in postfix:
131
+ result = push_checkpoint_metric(
132
+ self.tracked_metric, loss_metric_dict[self.tracked_metric]
133
+ )
134
+ self.log(self.tracked_metric, result[self.tracked_metric])
135
+
136
+ if not self.args.no_vis:
137
+ print("Rendering train images")
138
+ self.visualize_batches(self.vis_train_batches, "_train", False)
139
+ print("Rendering val images")
140
+ self.visualize_batches(self.vis_val_batches, "_val", False)
141
+
142
+ if "test" in postfix:
143
+ return (
144
+ outputs,
145
+ {"per_img_metric_dict": per_img_metric_dict},
146
+ metric_dict,
147
+ )
148
+ return loss_metric_dict
149
+
150
+ def configure_optimizers(self):
151
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr)
152
+ scheduler = optim.lr_scheduler.MultiStepLR(
153
+ optimizer, self.args.lr_dec_epoch, gamma=self.args.lr_decay, verbose=True
154
+ )
155
+ return [optimizer], [scheduler]
156
+
157
+ def visualize_batches(self, batches, postfix, no_tqdm=True):
158
+ im_list = []
159
+ if self.training:
160
+ self.eval()
161
+
162
+ tic = time.time()
163
+ for batch_idx, batch in enumerate(batches):
164
+ with torch.no_grad():
165
+ inputs, targets, meta_info = batch
166
+ vis_dict = self.forward(inputs, targets, meta_info, "vis")
167
+ for vis_fn in self.vis_fns:
168
+ curr_im_list = vis_fn(
169
+ vis_dict,
170
+ self.max_vis_examples,
171
+ self.renderer,
172
+ postfix=postfix,
173
+ no_tqdm=no_tqdm,
174
+ )
175
+ im_list += curr_im_list
176
+ print("Rendering: %d/%d" % (batch_idx + 1, len(batches)))
177
+
178
+ self.push_images(self.experiment, im_list, self.global_step)
179
+ print("Done rendering (%.1fs)" % (time.time() - tic))
180
+ return im_list
common/args_utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+
3
+
4
+ def set_default_params(args, default_args):
5
+ # if a val is not set on argparse, use default val
6
+ # else, use the one in the argparse
7
+ custom_dict = {}
8
+ for key, val in args.items():
9
+ if val is None:
10
+ args[key] = default_args[key]
11
+ else:
12
+ custom_dict[key] = val
13
+
14
+ logger.info(f"Using custom values: {custom_dict}")
15
+ return args
common/body_models.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import numpy as np
4
+ import torch
5
+ from smplx import MANO
6
+
7
+ from common.mesh import Mesh
8
+
9
+
10
+ class MANODecimator:
11
+ def __init__(self):
12
+ data = np.load(
13
+ "./data/arctic_data/data/meta/mano_decimator_195.npy", allow_pickle=True
14
+ ).item()
15
+ mydata = {}
16
+ for key, val in data.items():
17
+ # only consider decimation matrix so far
18
+ if "D" in key:
19
+ mydata[key] = torch.FloatTensor(val)
20
+ self.data = mydata
21
+
22
+ def downsample(self, verts, is_right):
23
+ dev = verts.device
24
+ flag = "right" if is_right else "left"
25
+ if self.data[f"D_{flag}"].device != dev:
26
+ self.data[f"D_{flag}"] = self.data[f"D_{flag}"].to(dev)
27
+ D = self.data[f"D_{flag}"]
28
+ batch_size = verts.shape[0]
29
+ D_batch = D[None, :, :].repeat(batch_size, 1, 1)
30
+ verts_sub = torch.bmm(D_batch, verts)
31
+ return verts_sub
32
+
33
+
34
+ MODEL_DIR = "./data/body_models/mano"
35
+
36
+ SEAL_FACES_R = [
37
+ [120, 108, 778],
38
+ [108, 79, 778],
39
+ [79, 78, 778],
40
+ [78, 121, 778],
41
+ [121, 214, 778],
42
+ [214, 215, 778],
43
+ [215, 279, 778],
44
+ [279, 239, 778],
45
+ [239, 234, 778],
46
+ [234, 92, 778],
47
+ [92, 38, 778],
48
+ [38, 122, 778],
49
+ [122, 118, 778],
50
+ [118, 117, 778],
51
+ [117, 119, 778],
52
+ [119, 120, 778],
53
+ ]
54
+
55
+ # vertex ids around the ring of the wrist
56
+ CIRCLE_V_ID = np.array(
57
+ [108, 79, 78, 121, 214, 215, 279, 239, 234, 92, 38, 122, 118, 117, 119, 120],
58
+ dtype=np.int64,
59
+ )
60
+
61
+
62
+ def seal_mano_mesh(v3d, faces, is_rhand):
63
+ # v3d: B, 778, 3
64
+ # faces: 1538, 3
65
+ # output: v3d(B, 779, 3); faces (1554, 3)
66
+
67
+ seal_faces = torch.LongTensor(np.array(SEAL_FACES_R)).to(faces.device)
68
+ if not is_rhand:
69
+ # left hand
70
+ seal_faces = seal_faces[:, np.array([1, 0, 2])] # invert face normal
71
+ centers = v3d[:, CIRCLE_V_ID].mean(dim=1)[:, None, :]
72
+ sealed_vertices = torch.cat((v3d, centers), dim=1)
73
+ faces = torch.cat((faces, seal_faces), dim=0)
74
+ return sealed_vertices, faces
75
+
76
+
77
+ def build_layers(device=None):
78
+ from common.object_tensors import ObjectTensors
79
+
80
+ layers = {
81
+ "right": build_mano_aa(True),
82
+ "left": build_mano_aa(False),
83
+ "object_tensors": ObjectTensors(),
84
+ }
85
+
86
+ if device is not None:
87
+ layers["right"] = layers["right"].to(device)
88
+ layers["left"] = layers["left"].to(device)
89
+ layers["object_tensors"].to(device)
90
+ return layers
91
+
92
+
93
+ MANO_MODEL_DIR = "./data/body_models/mano"
94
+ SMPLX_MODEL_P = {
95
+ "male": "./data/body_models/smplx/SMPLX_MALE.npz",
96
+ "female": "./data/body_models/smplx/SMPLX_FEMALE.npz",
97
+ "neutral": "./data/body_models/smplx/SMPLX_NEUTRAL.npz",
98
+ }
99
+
100
+
101
+ def build_smplx(batch_size, gender, vtemplate):
102
+ import smplx
103
+
104
+ subj_m = smplx.create(
105
+ model_path=SMPLX_MODEL_P[gender],
106
+ model_type="smplx",
107
+ gender=gender,
108
+ num_pca_comps=45,
109
+ v_template=vtemplate,
110
+ flat_hand_mean=True,
111
+ use_pca=False,
112
+ batch_size=batch_size,
113
+ # batch_size=320,
114
+ )
115
+ return subj_m
116
+
117
+
118
+ def build_subject_smplx(batch_size, subject_id):
119
+ with open("./data/arctic_data/data/meta/misc.json", "r") as f:
120
+ misc = json.load(f)
121
+ vtemplate_p = f"./data/arctic_data/data/meta/subject_vtemplates/{subject_id}.obj"
122
+ mesh = Mesh(filename=vtemplate_p)
123
+ vtemplate = mesh.v
124
+ gender = misc[subject_id]["gender"]
125
+ return build_smplx(batch_size, gender, vtemplate)
126
+
127
+
128
+ def build_mano_aa(is_rhand, create_transl=False, flat_hand=False):
129
+ return MANO(
130
+ MODEL_DIR,
131
+ create_transl=create_transl,
132
+ use_pca=False,
133
+ flat_hand_mean=flat_hand,
134
+ is_rhand=is_rhand,
135
+ )
136
+
137
+ ##
138
+ def construct_layers(dev):
139
+ mano_layers = {
140
+ "right": build_mano_aa(True, create_transl=True, flat_hand=False),
141
+ "left": build_mano_aa(False, create_transl=True, flat_hand=False),
142
+ "smplx": build_smplx(1, "neutral", None),
143
+ }
144
+ for layer in mano_layers.values():
145
+ layer.to(dev)
146
+ return mano_layers
common/camera.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ """
5
+ Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula
6
+ Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
7
+ """
8
+
9
+
10
+ def perspective_to_weak_perspective_torch(
11
+ perspective_camera,
12
+ focal_length,
13
+ img_res,
14
+ ):
15
+ # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz]
16
+ # in 3D given the bounding box size
17
+ # This camera translation can be used in a full perspective projection
18
+ # if isinstance(focal_length, torch.Tensor):
19
+ # focal_length = focal_length[:, 0]
20
+
21
+ tx = perspective_camera[:, 0]
22
+ ty = perspective_camera[:, 1]
23
+ tz = perspective_camera[:, 2]
24
+
25
+ weak_perspective_camera = torch.stack(
26
+ [2 * focal_length / (img_res * tz + 1e-9), tx, ty],
27
+ dim=-1,
28
+ )
29
+ return weak_perspective_camera
30
+
31
+
32
+ def convert_perspective_to_weak_perspective(
33
+ perspective_camera,
34
+ focal_length,
35
+ img_res,
36
+ ):
37
+ # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz]
38
+ # in 3D given the bounding box size
39
+ # This camera translation can be used in a full perspective projection
40
+ # if isinstance(focal_length, torch.Tensor):
41
+ # focal_length = focal_length[:, 0]
42
+
43
+ weak_perspective_camera = torch.stack(
44
+ [
45
+ 2 * focal_length / (img_res * perspective_camera[:, 2] + 1e-9),
46
+ perspective_camera[:, 0],
47
+ perspective_camera[:, 1],
48
+ ],
49
+ dim=-1,
50
+ )
51
+ return weak_perspective_camera
52
+
53
+
54
+ def convert_weak_perspective_to_perspective(
55
+ weak_perspective_camera, focal_length, img_res
56
+ ):
57
+ # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz]
58
+ # in 3D given the bounding box size
59
+ # This camera translation can be used in a full perspective projection
60
+ # if isinstance(focal_length, torch.Tensor):
61
+ # focal_length = focal_length[:, 0]
62
+
63
+ perspective_camera = torch.stack(
64
+ [
65
+ weak_perspective_camera[:, 1],
66
+ weak_perspective_camera[:, 2],
67
+ 2 * focal_length / (img_res * weak_perspective_camera[:, 0] + 1e-9),
68
+ ],
69
+ dim=-1,
70
+ )
71
+ return perspective_camera
72
+
73
+
74
+ def get_default_cam_t(f, img_res):
75
+ cam = torch.tensor([[5.0, 0.0, 0.0]])
76
+ return convert_weak_perspective_to_perspective(cam, f, img_res)
77
+
78
+
79
+ def estimate_translation_np(S, joints_2d, joints_conf, focal_length, img_size):
80
+ """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
81
+ Input:
82
+ S: (25, 3) 3D joint locations
83
+ joints: (25, 3) 2D joint locations and confidence
84
+ Returns:
85
+ (3,) camera translation vector
86
+ """
87
+ num_joints = S.shape[0]
88
+ # focal length
89
+
90
+ f = np.array([focal_length[0], focal_length[1]])
91
+ # optical center
92
+ center = np.array([img_size[1] / 2.0, img_size[0] / 2.0])
93
+
94
+ # transformations
95
+ Z = np.reshape(np.tile(S[:, 2], (2, 1)).T, -1)
96
+ XY = np.reshape(S[:, 0:2], -1)
97
+ O = np.tile(center, num_joints)
98
+ F = np.tile(f, num_joints)
99
+ weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
100
+
101
+ # least squares
102
+ Q = np.array(
103
+ [
104
+ F * np.tile(np.array([1, 0]), num_joints),
105
+ F * np.tile(np.array([0, 1]), num_joints),
106
+ O - np.reshape(joints_2d, -1),
107
+ ]
108
+ ).T
109
+ c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
110
+
111
+ # weighted least squares
112
+ W = np.diagflat(weight2)
113
+ Q = np.dot(W, Q)
114
+ c = np.dot(W, c)
115
+
116
+ # square matrix
117
+ A = np.dot(Q.T, Q)
118
+ b = np.dot(Q.T, c)
119
+
120
+ # solution
121
+ trans = np.linalg.solve(A, b)
122
+
123
+ return trans
124
+
125
+
126
+ def estimate_translation(
127
+ S,
128
+ joints_2d,
129
+ focal_length,
130
+ img_size,
131
+ use_all_joints=False,
132
+ rotation=None,
133
+ pad_2d=False,
134
+ ):
135
+ """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
136
+ Input:
137
+ S: (B, 49, 3) 3D joint locations
138
+ joints: (B, 49, 3) 2D joint locations and confidence
139
+ Returns:
140
+ (B, 3) camera translation vectors
141
+ """
142
+ if pad_2d:
143
+ batch, num_pts = joints_2d.shape[:2]
144
+ joints_2d_pad = torch.ones((batch, num_pts, 3))
145
+ joints_2d_pad[:, :, :2] = joints_2d
146
+ joints_2d_pad = joints_2d_pad.to(joints_2d.device)
147
+ joints_2d = joints_2d_pad
148
+
149
+ device = S.device
150
+
151
+ if rotation is not None:
152
+ S = torch.einsum("bij,bkj->bki", rotation, S)
153
+
154
+ # Use only joints 25:49 (GT joints)
155
+ if use_all_joints:
156
+ S = S.cpu().numpy()
157
+ joints_2d = joints_2d.cpu().numpy()
158
+ else:
159
+ S = S[:, 25:, :].cpu().numpy()
160
+ joints_2d = joints_2d[:, 25:, :].cpu().numpy()
161
+
162
+ joints_conf = joints_2d[:, :, -1]
163
+ joints_2d = joints_2d[:, :, :-1]
164
+ trans = np.zeros((S.shape[0], 3), dtype=np.float32)
165
+ # Find the translation for each example in the batch
166
+ for i in range(S.shape[0]):
167
+ S_i = S[i]
168
+ joints_i = joints_2d[i]
169
+ conf_i = joints_conf[i]
170
+ trans[i] = estimate_translation_np(
171
+ S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size
172
+ )
173
+ return torch.from_numpy(trans).to(device)
174
+
175
+
176
+ def estimate_translation_cam(
177
+ S, joints_2d, focal_length, img_size, use_all_joints=False, rotation=None
178
+ ):
179
+ """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
180
+ Input:
181
+ S: (B, 49, 3) 3D joint locations
182
+ joints: (B, 49, 3) 2D joint locations and confidence
183
+ Returns:
184
+ (B, 3) camera translation vectors
185
+ """
186
+
187
+ def estimate_translation_np(S, joints_2d, joints_conf, focal_length, img_size):
188
+ """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
189
+ Input:
190
+ S: (25, 3) 3D joint locations
191
+ joints: (25, 3) 2D joint locations and confidence
192
+ Returns:
193
+ (3,) camera translation vector
194
+ """
195
+
196
+ num_joints = S.shape[0]
197
+ # focal length
198
+ f = np.array([focal_length[0], focal_length[1]])
199
+ # optical center
200
+ center = np.array([img_size[0] / 2.0, img_size[1] / 2.0])
201
+
202
+ # transformations
203
+ Z = np.reshape(np.tile(S[:, 2], (2, 1)).T, -1)
204
+ XY = np.reshape(S[:, 0:2], -1)
205
+ O = np.tile(center, num_joints)
206
+ F = np.tile(f, num_joints)
207
+ weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
208
+
209
+ # least squares
210
+ Q = np.array(
211
+ [
212
+ F * np.tile(np.array([1, 0]), num_joints),
213
+ F * np.tile(np.array([0, 1]), num_joints),
214
+ O - np.reshape(joints_2d, -1),
215
+ ]
216
+ ).T
217
+ c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
218
+
219
+ # weighted least squares
220
+ W = np.diagflat(weight2)
221
+ Q = np.dot(W, Q)
222
+ c = np.dot(W, c)
223
+
224
+ # square matrix
225
+ A = np.dot(Q.T, Q)
226
+ b = np.dot(Q.T, c)
227
+
228
+ # solution
229
+ trans = np.linalg.solve(A, b)
230
+
231
+ return trans
232
+
233
+ device = S.device
234
+
235
+ if rotation is not None:
236
+ S = torch.einsum("bij,bkj->bki", rotation, S)
237
+
238
+ # Use only joints 25:49 (GT joints)
239
+ if use_all_joints:
240
+ S = S.cpu().numpy()
241
+ joints_2d = joints_2d.cpu().numpy()
242
+ else:
243
+ S = S[:, 25:, :].cpu().numpy()
244
+ joints_2d = joints_2d[:, 25:, :].cpu().numpy()
245
+
246
+ joints_conf = joints_2d[:, :, -1]
247
+ joints_2d = joints_2d[:, :, :-1]
248
+ trans = np.zeros((S.shape[0], 3), dtype=np.float32)
249
+ # Find the translation for each example in the batch
250
+ for i in range(S.shape[0]):
251
+ S_i = S[i]
252
+ joints_i = joints_2d[i]
253
+ conf_i = joints_conf[i]
254
+ trans[i] = estimate_translation_np(
255
+ S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size
256
+ )
257
+ return torch.from_numpy(trans).to(device)
258
+
259
+
260
+ def get_coord_maps(size=56):
261
+ xx_ones = torch.ones([1, size], dtype=torch.int32)
262
+ xx_ones = xx_ones.unsqueeze(-1)
263
+
264
+ xx_range = torch.arange(size, dtype=torch.int32).unsqueeze(0)
265
+ xx_range = xx_range.unsqueeze(1)
266
+
267
+ xx_channel = torch.matmul(xx_ones, xx_range)
268
+ xx_channel = xx_channel.unsqueeze(-1)
269
+
270
+ yy_ones = torch.ones([1, size], dtype=torch.int32)
271
+ yy_ones = yy_ones.unsqueeze(1)
272
+
273
+ yy_range = torch.arange(size, dtype=torch.int32).unsqueeze(0)
274
+ yy_range = yy_range.unsqueeze(-1)
275
+
276
+ yy_channel = torch.matmul(yy_range, yy_ones)
277
+ yy_channel = yy_channel.unsqueeze(-1)
278
+
279
+ xx_channel = xx_channel.permute(0, 3, 1, 2)
280
+ yy_channel = yy_channel.permute(0, 3, 1, 2)
281
+
282
+ xx_channel = xx_channel.float() / (size - 1)
283
+ yy_channel = yy_channel.float() / (size - 1)
284
+
285
+ xx_channel = xx_channel * 2 - 1
286
+ yy_channel = yy_channel * 2 - 1
287
+
288
+ out = torch.cat([xx_channel, yy_channel], dim=1)
289
+ return out
290
+
291
+
292
+ def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5):
293
+ at = at.astype(float).reshape(1, 3)
294
+ up = up.astype(float).reshape(1, 3)
295
+
296
+ eye = eye.reshape(-1, 3)
297
+ up = up.repeat(eye.shape[0] // up.shape[0], axis=0)
298
+ eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0)
299
+
300
+ z_axis = eye - at
301
+ z_axis /= np.max(np.stack([np.linalg.norm(z_axis, axis=1, keepdims=True), eps]))
302
+
303
+ x_axis = np.cross(up, z_axis)
304
+ x_axis /= np.max(np.stack([np.linalg.norm(x_axis, axis=1, keepdims=True), eps]))
305
+
306
+ y_axis = np.cross(z_axis, x_axis)
307
+ y_axis /= np.max(np.stack([np.linalg.norm(y_axis, axis=1, keepdims=True), eps]))
308
+
309
+ r_mat = np.concatenate(
310
+ (x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape(-1, 3, 1)),
311
+ axis=2,
312
+ )
313
+
314
+ return r_mat
315
+
316
+
317
+ def to_sphere(u, v):
318
+ theta = 2 * np.pi * u
319
+ phi = np.arccos(1 - 2 * v)
320
+ cx = np.sin(phi) * np.cos(theta)
321
+ cy = np.sin(phi) * np.sin(theta)
322
+ cz = np.cos(phi)
323
+ s = np.stack([cx, cy, cz])
324
+ return s
325
+
326
+
327
+ def sample_on_sphere(range_u=(0, 1), range_v=(0, 1)):
328
+ u = np.random.uniform(*range_u)
329
+ v = np.random.uniform(*range_v)
330
+ return to_sphere(u, v)
331
+
332
+
333
+ def sample_pose_on_sphere(range_v=(0, 1), range_u=(0, 1), radius=1, up=[0, 1, 0]):
334
+ # sample location on unit sphere
335
+ loc = sample_on_sphere(range_u, range_v)
336
+
337
+ # sample radius if necessary
338
+ if isinstance(radius, tuple):
339
+ radius = np.random.uniform(*radius)
340
+
341
+ loc = loc * radius
342
+ R = look_at(loc, up=np.array(up))[0]
343
+
344
+ RT = np.concatenate([R, loc.reshape(3, 1)], axis=1)
345
+ RT = torch.Tensor(RT.astype(np.float32))
346
+ return RT
347
+
348
+
349
+ def rectify_pose(camera_r, body_aa, rotate_x=False):
350
+ body_r = batch_rodrigues(body_aa).reshape(-1, 3, 3)
351
+
352
+ if rotate_x:
353
+ rotate_x = torch.tensor([[[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0]]])
354
+ body_r = body_r @ rotate_x
355
+
356
+ final_r = camera_r @ body_r
357
+ body_aa = batch_rot2aa(final_r)
358
+ return body_aa
359
+
360
+
361
+ def estimate_translation_k_np(S, joints_2d, joints_conf, K):
362
+ """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
363
+ Input:
364
+ S: (25, 3) 3D joint locations
365
+ joints: (25, 3) 2D joint locations and confidence
366
+ Returns:
367
+ (3,) camera translation vector
368
+ """
369
+ num_joints = S.shape[0]
370
+ # focal length
371
+
372
+ focal = np.array([K[0, 0], K[1, 1]])
373
+ # optical center
374
+ center = np.array([K[0, 2], K[1, 2]])
375
+
376
+ # transformations
377
+ Z = np.reshape(np.tile(S[:, 2], (2, 1)).T, -1)
378
+ XY = np.reshape(S[:, 0:2], -1)
379
+ O = np.tile(center, num_joints)
380
+ F = np.tile(focal, num_joints)
381
+ weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
382
+
383
+ # least squares
384
+ Q = np.array(
385
+ [
386
+ F * np.tile(np.array([1, 0]), num_joints),
387
+ F * np.tile(np.array([0, 1]), num_joints),
388
+ O - np.reshape(joints_2d, -1),
389
+ ]
390
+ ).T
391
+ c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
392
+
393
+ # weighted least squares
394
+ W = np.diagflat(weight2)
395
+ Q = np.dot(W, Q)
396
+ c = np.dot(W, c)
397
+
398
+ # square matrix
399
+ A = np.dot(Q.T, Q)
400
+ b = np.dot(Q.T, c)
401
+
402
+ # solution
403
+ trans = np.linalg.solve(A, b)
404
+
405
+ return trans
406
+
407
+
408
+ def estimate_translation_k(
409
+ S,
410
+ joints_2d,
411
+ K,
412
+ use_all_joints=False,
413
+ rotation=None,
414
+ pad_2d=False,
415
+ ):
416
+ """Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
417
+ Input:
418
+ S: (B, 49, 3) 3D joint locations
419
+ joints: (B, 49, 3) 2D joint locations and confidence
420
+ Returns:
421
+ (B, 3) camera translation vectors
422
+ """
423
+ if pad_2d:
424
+ batch, num_pts = joints_2d.shape[:2]
425
+ joints_2d_pad = torch.ones((batch, num_pts, 3))
426
+ joints_2d_pad[:, :, :2] = joints_2d
427
+ joints_2d_pad = joints_2d_pad.to(joints_2d.device)
428
+ joints_2d = joints_2d_pad
429
+
430
+ device = S.device
431
+
432
+ if rotation is not None:
433
+ S = torch.einsum("bij,bkj->bki", rotation, S)
434
+
435
+ # Use only joints 25:49 (GT joints)
436
+ if use_all_joints:
437
+ S = S.cpu().numpy()
438
+ joints_2d = joints_2d.cpu().numpy()
439
+ else:
440
+ S = S[:, 25:, :].cpu().numpy()
441
+ joints_2d = joints_2d[:, 25:, :].cpu().numpy()
442
+
443
+ joints_conf = joints_2d[:, :, -1]
444
+ joints_2d = joints_2d[:, :, :-1]
445
+ trans = np.zeros((S.shape[0], 3), dtype=np.float32)
446
+ # Find the translation for each example in the batch
447
+ for i in range(S.shape[0]):
448
+ S_i = S[i]
449
+ joints_i = joints_2d[i]
450
+ conf_i = joints_conf[i]
451
+ K_i = K[i]
452
+ trans[i] = estimate_translation_k_np(S_i, joints_i, conf_i, K_i)
453
+ return torch.from_numpy(trans).to(device)
454
+
455
+
456
+ def weak_perspective_to_perspective_torch(
457
+ weak_perspective_camera, focal_length, img_res, min_s
458
+ ):
459
+ # Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz]
460
+ # in 3D given the bounding box size
461
+ # This camera translation can be used in a full perspective projection
462
+ s = weak_perspective_camera[:, 0]
463
+ s = torch.clamp(s, min_s)
464
+ tx = weak_perspective_camera[:, 1]
465
+ ty = weak_perspective_camera[:, 2]
466
+ perspective_camera = torch.stack(
467
+ [
468
+ tx,
469
+ ty,
470
+ 2 * focal_length / (img_res * s + 1e-9),
471
+ ],
472
+ dim=-1,
473
+ )
474
+ return perspective_camera
common/comet_utils.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path as op
4
+ import time
5
+
6
+ import comet_ml
7
+ import numpy as np
8
+ import torch
9
+ from loguru import logger
10
+ from tqdm import tqdm
11
+
12
+ from src.datasets.dataset_utils import copy_repo_arctic
13
+
14
+ # folder used for debugging
15
+ DUMMY_EXP = "xxxxxxxxx"
16
+
17
+
18
+ def add_paths(args):
19
+ exp_key = args.exp_key
20
+ args_p = f"./logs/{exp_key}/args.json"
21
+ ckpt_p = f"./logs/{exp_key}/checkpoints/last.ckpt"
22
+ if not op.exists(ckpt_p) or DUMMY_EXP in ckpt_p:
23
+ ckpt_p = ""
24
+ if args.resume_ckpt != "":
25
+ ckpt_p = args.resume_ckpt
26
+ args.ckpt_p = ckpt_p
27
+ args.log_dir = f"./logs/{exp_key}"
28
+
29
+ if args.infer_ckpt != "":
30
+ basedir = "/".join(args.infer_ckpt.split("/")[:2])
31
+ basename = op.basename(args.infer_ckpt).replace(".ckpt", ".params.pt")
32
+ args.interface_p = op.join(basedir, basename)
33
+ args.args_p = args_p
34
+ if args.cluster:
35
+ args.run_p = op.join(args.log_dir, "condor", "run.sh")
36
+ args.submit_p = op.join(args.log_dir, "condor", "submit.sub")
37
+ args.repo_p = op.join(args.log_dir, "repo")
38
+
39
+ return args
40
+
41
+
42
+ def save_args(args, save_keys):
43
+ args_save = {}
44
+ for key, val in args.items():
45
+ if key in save_keys:
46
+ args_save[key] = val
47
+ with open(args.args_p, "w") as f:
48
+ json.dump(args_save, f, indent=4)
49
+ logger.info(f"Saved args at {args.args_p}")
50
+
51
+
52
+ def create_files(args):
53
+ os.makedirs(args.log_dir, exist_ok=True)
54
+ if args.cluster:
55
+ os.makedirs(op.dirname(args.run_p), exist_ok=True)
56
+ copy_repo_arctic(args.exp_key)
57
+
58
+
59
+ def log_exp_meta(args):
60
+ tags = [args.method]
61
+ logger.info(f"Experiment tags: {tags}")
62
+ args.experiment.set_name(args.exp_key)
63
+ args.experiment.add_tags(tags)
64
+ args.experiment.log_parameters(args)
65
+
66
+
67
+ def init_experiment(args):
68
+ if args.resume_ckpt != "":
69
+ args.exp_key = args.resume_ckpt.split("/")[1]
70
+ if args.fast_dev_run:
71
+ args.exp_key = DUMMY_EXP
72
+ if args.exp_key == "":
73
+ args.exp_key = generate_exp_key()
74
+ args = add_paths(args)
75
+ if op.exists(args.args_p) and args.exp_key not in [DUMMY_EXP]:
76
+ with open(args.args_p, "r") as f:
77
+ args_disk = json.load(f)
78
+ if "comet_key" in args_disk.keys():
79
+ args.comet_key = args_disk["comet_key"]
80
+
81
+ create_files(args)
82
+
83
+ project_name = args.project
84
+ disabled = args.mute
85
+ comet_url = args["comet_key"] if "comet_key" in args.keys() else None
86
+
87
+ api_key = os.environ["COMET_API_KEY"]
88
+ workspace = os.environ["COMET_WORKSPACE"]
89
+ if not args.cluster:
90
+ if comet_url is None:
91
+ experiment = comet_ml.Experiment(
92
+ api_key=api_key,
93
+ workspace=workspace,
94
+ project_name=project_name,
95
+ disabled=disabled,
96
+ display_summary_level=0,
97
+ )
98
+ args.comet_key = experiment.get_key()
99
+ else:
100
+ experiment = comet_ml.ExistingExperiment(
101
+ previous_experiment=comet_url,
102
+ api_key=api_key,
103
+ project_name=project_name,
104
+ workspace=workspace,
105
+ disabled=disabled,
106
+ display_summary_level=0,
107
+ )
108
+
109
+ device = "cuda" if torch.cuda.is_available() else "cpu"
110
+ logger.add(
111
+ os.path.join(args.log_dir, "train.log"),
112
+ level="INFO",
113
+ colorize=True,
114
+ )
115
+ logger.info(torch.cuda.get_device_properties(device))
116
+ args.gpu = torch.cuda.get_device_properties(device).name
117
+ else:
118
+ experiment = None
119
+ args.experiment = experiment
120
+ return experiment, args
121
+
122
+
123
+ def log_dict(experiment, metric_dict, step, postfix=None):
124
+ if experiment is None:
125
+ return
126
+ for key, value in metric_dict.items():
127
+ if postfix is not None:
128
+ key = key + postfix
129
+ if isinstance(value, torch.Tensor) and len(value.view(-1)) == 1:
130
+ value = value.item()
131
+
132
+ if isinstance(value, (int, float, np.float32)):
133
+ experiment.log_metric(key, value, step=step)
134
+
135
+
136
+ def generate_exp_key():
137
+ import random
138
+
139
+ hash = random.getrandbits(128)
140
+ key = "%032x" % (hash)
141
+ key = key[:9]
142
+ return key
143
+
144
+
145
+ def push_images(experiment, all_im_list, global_step=None, no_tqdm=False, verbose=True):
146
+ if verbose:
147
+ print("Pushing PIL images")
148
+ tic = time.time()
149
+ iterator = all_im_list if no_tqdm else tqdm(all_im_list)
150
+ for im in iterator:
151
+ im_np = np.array(im["im"])
152
+ if "fig_name" in im.keys():
153
+ experiment.log_image(im_np, im["fig_name"], step=global_step)
154
+ else:
155
+ experiment.log_image(im_np, "unnamed", step=global_step)
156
+ if verbose:
157
+ toc = time.time()
158
+ print("Done pushing PIL images (%.1fs)" % (toc - tic))
common/data_utils.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains functions that are used to perform data augmentation.
3
+ """
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from loguru import logger
8
+
9
+
10
+ def get_transform(center, scale, res, rot=0):
11
+ """Generate transformation matrix."""
12
+ h = 200 * scale
13
+ t = np.zeros((3, 3))
14
+ t[0, 0] = float(res[1]) / h
15
+ t[1, 1] = float(res[0]) / h
16
+ t[0, 2] = res[1] * (-float(center[0]) / h + 0.5)
17
+ t[1, 2] = res[0] * (-float(center[1]) / h + 0.5)
18
+ t[2, 2] = 1
19
+ if not rot == 0:
20
+ rot = -rot # To match direction of rotation from cropping
21
+ rot_mat = np.zeros((3, 3))
22
+ rot_rad = rot * np.pi / 180
23
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
24
+ rot_mat[0, :2] = [cs, -sn]
25
+ rot_mat[1, :2] = [sn, cs]
26
+ rot_mat[2, 2] = 1
27
+ # Need to rotate around center
28
+ t_mat = np.eye(3)
29
+ t_mat[0, 2] = -res[1] / 2
30
+ t_mat[1, 2] = -res[0] / 2
31
+ t_inv = t_mat.copy()
32
+ t_inv[:2, 2] *= -1
33
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
34
+ return t
35
+
36
+
37
+ def transform(pt, center, scale, res, invert=0, rot=0):
38
+ """Transform pixel location to different reference."""
39
+ t = get_transform(center, scale, res, rot=rot)
40
+ if invert:
41
+ t = np.linalg.inv(t)
42
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T
43
+ new_pt = np.dot(t, new_pt)
44
+ return new_pt[:2].astype(int) + 1
45
+
46
+
47
+ def rotate_2d(pt_2d, rot_rad):
48
+ x = pt_2d[0]
49
+ y = pt_2d[1]
50
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
51
+ xx = x * cs - y * sn
52
+ yy = x * sn + y * cs
53
+ return np.array([xx, yy], dtype=np.float32)
54
+
55
+
56
+ def gen_trans_from_patch_cv(
57
+ c_x, c_y, src_width, src_height, dst_width, dst_height, scale, rot, inv=False
58
+ ):
59
+ # augment size with scale
60
+ src_w = src_width * scale
61
+ src_h = src_height * scale
62
+ src_center = np.array([c_x, c_y], dtype=np.float32)
63
+
64
+ # augment rotation
65
+ rot_rad = np.pi * rot / 180
66
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
67
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
68
+
69
+ dst_w = dst_width
70
+ dst_h = dst_height
71
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
72
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
73
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
74
+
75
+ src = np.zeros((3, 2), dtype=np.float32)
76
+ src[0, :] = src_center
77
+ src[1, :] = src_center + src_downdir
78
+ src[2, :] = src_center + src_rightdir
79
+
80
+ dst = np.zeros((3, 2), dtype=np.float32)
81
+ dst[0, :] = dst_center
82
+ dst[1, :] = dst_center + dst_downdir
83
+ dst[2, :] = dst_center + dst_rightdir
84
+
85
+ if inv:
86
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
87
+ else:
88
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
89
+
90
+ trans = trans.astype(np.float32)
91
+ return trans
92
+
93
+
94
+ def generate_patch_image(
95
+ cvimg,
96
+ bbox,
97
+ scale,
98
+ rot,
99
+ out_shape,
100
+ interpl_strategy,
101
+ gauss_kernel=5,
102
+ gauss_sigma=8.0,
103
+ ):
104
+ img = cvimg.copy()
105
+
106
+ bb_c_x = float(bbox[0])
107
+ bb_c_y = float(bbox[1])
108
+ bb_width = float(bbox[2])
109
+ bb_height = float(bbox[3])
110
+
111
+ trans = gen_trans_from_patch_cv(
112
+ bb_c_x, bb_c_y, bb_width, bb_height, out_shape[1], out_shape[0], scale, rot
113
+ )
114
+
115
+ # anti-aliasing
116
+ blur = cv2.GaussianBlur(img, (gauss_kernel, gauss_kernel), gauss_sigma)
117
+ img_patch = cv2.warpAffine(
118
+ blur, trans, (int(out_shape[1]), int(out_shape[0])), flags=interpl_strategy
119
+ )
120
+ img_patch = img_patch.astype(np.float32)
121
+ inv_trans = gen_trans_from_patch_cv(
122
+ bb_c_x,
123
+ bb_c_y,
124
+ bb_width,
125
+ bb_height,
126
+ out_shape[1],
127
+ out_shape[0],
128
+ scale,
129
+ rot,
130
+ inv=True,
131
+ )
132
+
133
+ return img_patch, trans, inv_trans
134
+
135
+
136
+ def augm_params(is_train, flip_prob, noise_factor, rot_factor, scale_factor):
137
+ """Get augmentation parameters."""
138
+ flip = 0 # flipping
139
+ pn = np.ones(3) # per channel pixel-noise
140
+ rot = 0 # rotation
141
+ sc = 1 # scaling
142
+ if is_train:
143
+ # We flip with probability 1/2
144
+ if np.random.uniform() <= flip_prob:
145
+ flip = 1
146
+ assert False, "Flipping not supported"
147
+
148
+ # Each channel is multiplied with a number
149
+ # in the area [1-opt.noiseFactor,1+opt.noiseFactor]
150
+ pn = np.random.uniform(1 - noise_factor, 1 + noise_factor, 3)
151
+
152
+ # The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
153
+ rot = min(
154
+ 2 * rot_factor,
155
+ max(
156
+ -2 * rot_factor,
157
+ np.random.randn() * rot_factor,
158
+ ),
159
+ )
160
+
161
+ # The scale is multiplied with a number
162
+ # in the area [1-scaleFactor,1+scaleFactor]
163
+ sc = min(
164
+ 1 + scale_factor,
165
+ max(
166
+ 1 - scale_factor,
167
+ np.random.randn() * scale_factor + 1,
168
+ ),
169
+ )
170
+ # but it is zero with probability 3/5
171
+ if np.random.uniform() <= 0.6:
172
+ rot = 0
173
+
174
+ augm_dict = {}
175
+ augm_dict["flip"] = flip
176
+ augm_dict["pn"] = pn
177
+ augm_dict["rot"] = rot
178
+ augm_dict["sc"] = sc
179
+ return augm_dict
180
+
181
+
182
+ def rgb_processing(is_train, rgb_img, center, bbox_dim, augm_dict, img_res):
183
+ rot = augm_dict["rot"]
184
+ sc = augm_dict["sc"]
185
+ pn = augm_dict["pn"]
186
+ scale = sc * bbox_dim
187
+
188
+ crop_dim = int(scale * 200)
189
+ # faster cropping!!
190
+ rgb_img = generate_patch_image(
191
+ rgb_img,
192
+ [center[0], center[1], crop_dim, crop_dim],
193
+ 1.0,
194
+ rot,
195
+ [img_res, img_res],
196
+ cv2.INTER_CUBIC,
197
+ )[0]
198
+
199
+ # in the rgb image we add pixel noise in a channel-wise manner
200
+ rgb_img[:, :, 0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 0] * pn[0]))
201
+ rgb_img[:, :, 1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 1] * pn[1]))
202
+ rgb_img[:, :, 2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 2] * pn[2]))
203
+ rgb_img = np.transpose(rgb_img.astype("float32"), (2, 0, 1)) / 255.0
204
+ return rgb_img
205
+
206
+
207
+ def transform_kp2d(kp2d, bbox):
208
+ # bbox: (cx, cy, scale) in the original image space
209
+ # scale is normalized
210
+ assert isinstance(kp2d, np.ndarray)
211
+ assert len(kp2d.shape) == 2
212
+ cx, cy, scale = bbox
213
+ s = 200 * scale # to px
214
+ cap_dim = 1000 # px
215
+ factor = cap_dim / (1.5 * s)
216
+ kp2d_cropped = np.copy(kp2d)
217
+ kp2d_cropped[:, 0] -= cx - 1.5 / 2 * s
218
+ kp2d_cropped[:, 1] -= cy - 1.5 / 2 * s
219
+ kp2d_cropped[:, 0] *= factor
220
+ kp2d_cropped[:, 1] *= factor
221
+ return kp2d_cropped
222
+
223
+
224
+ def j2d_processing(kp, center, bbox_dim, augm_dict, img_res):
225
+ """Process gt 2D keypoints and apply all augmentation transforms."""
226
+ scale = augm_dict["sc"] * bbox_dim
227
+ rot = augm_dict["rot"]
228
+
229
+ nparts = kp.shape[0]
230
+ for i in range(nparts):
231
+ kp[i, 0:2] = transform(
232
+ kp[i, 0:2] + 1,
233
+ center,
234
+ scale,
235
+ [img_res, img_res],
236
+ rot=rot,
237
+ )
238
+ # convert to normalized coordinates
239
+ kp = normalize_kp2d_np(kp, img_res)
240
+ kp = kp.astype("float32")
241
+ return kp
242
+
243
+
244
+ def pose_processing(pose, augm_dict):
245
+ """Process SMPL theta parameters and apply all augmentation transforms."""
246
+ rot = augm_dict["rot"]
247
+ # rotation or the pose parameters
248
+ pose[:3] = rot_aa(pose[:3], rot)
249
+ # flip the pose parameters
250
+ # (72),float
251
+ pose = pose.astype("float32")
252
+ return pose
253
+
254
+
255
+ def rot_aa(aa, rot):
256
+ """Rotate axis angle parameters."""
257
+ # pose parameters
258
+ R = np.array(
259
+ [
260
+ [np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
261
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
262
+ [0, 0, 1],
263
+ ]
264
+ )
265
+ # find the rotation of the body in camera frame
266
+ per_rdg, _ = cv2.Rodrigues(aa)
267
+ # apply the global rotation to the global orientation
268
+ resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
269
+ aa = (resrot.T)[0]
270
+ return aa
271
+
272
+
273
+ def denormalize_images(images):
274
+ images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(
275
+ 1, 3, 1, 1
276
+ )
277
+ images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(
278
+ 1, 3, 1, 1
279
+ )
280
+ return images
281
+
282
+
283
+ def read_img(img_fn, dummy_shape):
284
+ try:
285
+ cv_img = _read_img(img_fn)
286
+ except:
287
+ logger.warning(f"Unable to load {img_fn}")
288
+ cv_img = np.zeros(dummy_shape, dtype=np.float32)
289
+ return cv_img, False
290
+ return cv_img, True
291
+
292
+
293
+ def _read_img(img_fn):
294
+ img = cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB)
295
+ return img.astype(np.float32)
296
+
297
+
298
+ def normalize_kp2d_np(kp2d: np.ndarray, img_res):
299
+ assert kp2d.shape[1] == 3
300
+ kp2d_normalized = kp2d.copy()
301
+ kp2d_normalized[:, :2] = 2.0 * kp2d[:, :2] / img_res - 1.0
302
+ return kp2d_normalized
303
+
304
+
305
+ def unnormalize_2d_kp(kp_2d_np: np.ndarray, res):
306
+ assert kp_2d_np.shape[1] == 3
307
+ kp_2d = np.copy(kp_2d_np)
308
+ kp_2d[:, :2] = 0.5 * res * (kp_2d[:, :2] + 1)
309
+ return kp_2d
310
+
311
+
312
+ def normalize_kp2d(kp2d: torch.Tensor, img_res):
313
+ assert len(kp2d.shape) == 3
314
+ kp2d_normalized = kp2d.clone()
315
+ kp2d_normalized[:, :, :2] = 2.0 * kp2d[:, :, :2] / img_res - 1.0
316
+ return kp2d_normalized
317
+
318
+
319
+ def unormalize_kp2d(kp2d_normalized: torch.Tensor, img_res):
320
+ assert len(kp2d_normalized.shape) == 3
321
+ assert kp2d_normalized.shape[2] == 2
322
+ kp2d = kp2d_normalized.clone()
323
+ kp2d = 0.5 * img_res * (kp2d + 1)
324
+ return kp2d
325
+
326
+
327
+ def get_wp_intrix(fixed_focal: float, img_res):
328
+ # consruct weak perspective on patch
329
+ camera_center = np.array([img_res // 2, img_res // 2])
330
+ intrx = torch.zeros([3, 3])
331
+ intrx[0, 0] = fixed_focal
332
+ intrx[1, 1] = fixed_focal
333
+ intrx[2, 2] = 1.0
334
+ intrx[0, -1] = camera_center[0]
335
+ intrx[1, -1] = camera_center[1]
336
+ return intrx
337
+
338
+
339
+ def get_aug_intrix(
340
+ intrx, fixed_focal: float, img_res, use_gt_k, bbox_cx, bbox_cy, scale
341
+ ):
342
+ """
343
+ This function returns camera intrinsics under scaling.
344
+ If use_gt_k, the GT K is used, but scaled based on the amount of scaling in the patch.
345
+ Else, we construct an intrinsic camera with a fixed focal length and fixed camera center.
346
+ """
347
+
348
+ if not use_gt_k:
349
+ # consruct weak perspective on patch
350
+ intrx = get_wp_intrix(fixed_focal, img_res)
351
+ else:
352
+ # update the GT intrinsics (full image space)
353
+ # such that it matches the scale of the patch
354
+
355
+ dim = scale * 200.0 # bbox size
356
+ k_scale = float(img_res) / dim # resized_dim / bbox_size in full image space
357
+ """
358
+ # x1 and y1: top-left corner of bbox
359
+ intrinsics after data augmentation
360
+ fx' = k*fx
361
+ fy' = k*fy
362
+ cx' = k*(cx - x1)
363
+ cy' = k*(cy - y1)
364
+ """
365
+ intrx[0, 0] *= k_scale # k*fx
366
+ intrx[1, 1] *= k_scale # k*fy
367
+ intrx[0, 2] -= bbox_cx - dim / 2.0
368
+ intrx[1, 2] -= bbox_cy - dim / 2.0
369
+ intrx[0, 2] *= k_scale
370
+ intrx[1, 2] *= k_scale
371
+ return intrx
common/ld_utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def sort_dict(disordered):
8
+ sorted_dict = {k: disordered[k] for k in sorted(disordered)}
9
+ return sorted_dict
10
+
11
+
12
+ def prefix_dict(mydict, prefix):
13
+ out = {prefix + k: v for k, v in mydict.items()}
14
+ return out
15
+
16
+
17
+ def postfix_dict(mydict, postfix):
18
+ out = {k + postfix: v for k, v in mydict.items()}
19
+ return out
20
+
21
+
22
+ def unsort(L, sort_idx):
23
+ assert isinstance(sort_idx, list)
24
+ assert isinstance(L, list)
25
+ LL = zip(sort_idx, L)
26
+ LL = sorted(LL, key=lambda x: x[0])
27
+ _, L = zip(*LL)
28
+ return list(L)
29
+
30
+
31
+ def cat_dl(out_list, dim, verbose=True, squeeze=True):
32
+ out = {}
33
+ for key, val in out_list.items():
34
+ if isinstance(val[0], torch.Tensor):
35
+ out[key] = torch.cat(val, dim=dim)
36
+ if squeeze:
37
+ out[key] = out[key].squeeze()
38
+ elif isinstance(val[0], np.ndarray):
39
+ out[key] = np.concatenate(val, axis=dim)
40
+ if squeeze:
41
+ out[key] = np.squeeze(out[key])
42
+ elif isinstance(val[0], list):
43
+ out[key] = sum(val, [])
44
+ else:
45
+ if verbose:
46
+ print(f"Ignoring {key} undefined type {type(val[0])}")
47
+ return out
48
+
49
+
50
+ def stack_dl(out_list, dim, verbose=True, squeeze=True):
51
+ out = {}
52
+ for key, val in out_list.items():
53
+ if isinstance(val[0], torch.Tensor):
54
+ out[key] = torch.stack(val, dim=dim)
55
+ if squeeze:
56
+ out[key] = out[key].squeeze()
57
+ elif isinstance(val[0], np.ndarray):
58
+ out[key] = np.stack(val, axis=dim)
59
+ if squeeze:
60
+ out[key] = np.squeeze(out[key])
61
+ elif isinstance(val[0], list):
62
+ out[key] = sum(val, [])
63
+ else:
64
+ out[key] = val
65
+ if verbose:
66
+ print(f"Processing {key} undefined type {type(val[0])}")
67
+ return out
68
+
69
+
70
+ def add_prefix_postfix(mydict, prefix="", postfix=""):
71
+ assert isinstance(mydict, dict)
72
+ return dict((prefix + key + postfix, value) for (key, value) in mydict.items())
73
+
74
+
75
+ def ld2dl(LD):
76
+ assert isinstance(LD, list)
77
+ assert isinstance(LD[0], dict)
78
+ """
79
+ A list of dict (same keys) to a dict of lists
80
+ """
81
+ dict_list = {k: [dic[k] for dic in LD] for k in LD[0]}
82
+ return dict_list
83
+
84
+
85
+ class NameSpace(object):
86
+ def __init__(self, adict):
87
+ self.__dict__.update(adict)
88
+
89
+
90
+ def dict2ns(mydict):
91
+ """
92
+ Convert dict objec to namespace
93
+ """
94
+ return NameSpace(mydict)
95
+
96
+
97
+ def ld2dev(ld, dev):
98
+ """
99
+ Convert tensors in a list or dict to a device recursively
100
+ """
101
+ if isinstance(ld, torch.Tensor):
102
+ return ld.to(dev)
103
+ if isinstance(ld, dict):
104
+ for k, v in ld.items():
105
+ ld[k] = ld2dev(v, dev)
106
+ return ld
107
+ if isinstance(ld, list):
108
+ return [ld2dev(x, dev) for x in ld]
109
+ return ld
110
+
111
+
112
+ def all_comb_dict(hyper_dict):
113
+ assert isinstance(hyper_dict, dict)
114
+ keys, values = zip(*hyper_dict.items())
115
+ permute_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
116
+ return permute_dicts
common/list_utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+
4
+ def chunks_by_len(L, n):
5
+ """
6
+ Split a list into n chunks
7
+ """
8
+ num_chunks = int(math.ceil(float(len(L)) / n))
9
+ splits = [L[x : x + num_chunks] for x in range(0, len(L), num_chunks)]
10
+ return splits
11
+
12
+
13
+ def chunks_by_size(L, n):
14
+ """Yield successive n-sized chunks from lst."""
15
+ seqs = []
16
+ for i in range(0, len(L), n):
17
+ seqs.append(L[i : i + n])
18
+ return seqs
19
+
20
+
21
+ def unsort(L, sort_idx):
22
+ assert isinstance(sort_idx, list)
23
+ assert isinstance(L, list)
24
+ LL = zip(sort_idx, L)
25
+ LL = sorted(LL, key=lambda x: x[0])
26
+ _, L = zip(*LL)
27
+ return list(L)
28
+
29
+
30
+ def add_prefix_postfix(mydict, prefix="", postfix=""):
31
+ assert isinstance(mydict, dict)
32
+ return dict((prefix + key + postfix, value) for (key, value) in mydict.items())
33
+
34
+
35
+ def ld2dl(LD):
36
+ assert isinstance(LD, list)
37
+ assert isinstance(LD[0], dict)
38
+ """
39
+ A list of dict (same keys) to a dict of lists
40
+ """
41
+ dict_list = {k: [dic[k] for dic in LD] for k in LD[0]}
42
+ return dict_list
43
+
44
+
45
+ def chunks(lst, n):
46
+ """Yield successive n-sized chunks from lst."""
47
+ seqs = []
48
+ for i in range(0, len(lst), n):
49
+ seqs.append(lst[i : i + n])
50
+ seqs_chunked = sum(seqs, [])
51
+ assert set(seqs_chunked) == set(lst)
52
+ return seqs
common/mesh.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import trimesh
3
+
4
+ colors = {
5
+ "pink": [1.00, 0.75, 0.80],
6
+ "purple": [0.63, 0.13, 0.94],
7
+ "red": [1.0, 0.0, 0.0],
8
+ "green": [0.0, 1.0, 0.0],
9
+ "yellow": [1.0, 1.0, 0],
10
+ "brown": [1.00, 0.25, 0.25],
11
+ "blue": [0.0, 0.0, 1.0],
12
+ "white": [1.0, 1.0, 1.0],
13
+ "orange": [1.00, 0.65, 0.00],
14
+ "grey": [0.75, 0.75, 0.75],
15
+ "black": [0.0, 0.0, 0.0],
16
+ }
17
+
18
+
19
+ class Mesh(trimesh.Trimesh):
20
+ def __init__(
21
+ self,
22
+ filename=None,
23
+ v=None,
24
+ f=None,
25
+ vc=None,
26
+ fc=None,
27
+ process=False,
28
+ visual=None,
29
+ **kwargs
30
+ ):
31
+ if filename is not None:
32
+ mesh = trimesh.load(filename, process=process)
33
+ v = mesh.vertices
34
+ f = mesh.faces
35
+ visual = mesh.visual
36
+
37
+ super(Mesh, self).__init__(
38
+ vertices=v, faces=f, visual=visual, process=process, **kwargs
39
+ )
40
+
41
+ self.v = self.vertices
42
+ self.f = self.faces
43
+ assert self.v is self.vertices
44
+ assert self.f is self.faces
45
+
46
+ if vc is not None:
47
+ self.set_vc(vc)
48
+ self.vc = self.visual.vertex_colors
49
+ assert self.vc is self.visual.vertex_colors
50
+ if fc is not None:
51
+ self.set_fc(fc)
52
+ self.fc = self.visual.face_colors
53
+ assert self.fc is self.visual.face_colors
54
+
55
+ def rot_verts(self, vertices, rxyz):
56
+ return np.array(vertices * rxyz.T)
57
+
58
+ def colors_like(self, color, array, ids):
59
+ color = np.array(color)
60
+
61
+ if color.max() <= 1.0:
62
+ color = color * 255
63
+ color = color.astype(np.int8)
64
+
65
+ n_color = color.shape[0]
66
+ n_ids = ids.shape[0]
67
+
68
+ new_color = np.array(array)
69
+ if n_color <= 4:
70
+ new_color[ids, :n_color] = np.repeat(color[np.newaxis], n_ids, axis=0)
71
+ else:
72
+ new_color[ids, :] = color
73
+
74
+ return new_color
75
+
76
+ def set_vc(self, vc, vertex_ids=None):
77
+ all_ids = np.arange(self.vertices.shape[0])
78
+ if vertex_ids is None:
79
+ vertex_ids = all_ids
80
+
81
+ vertex_ids = all_ids[vertex_ids]
82
+ new_vc = self.colors_like(vc, self.visual.vertex_colors, vertex_ids)
83
+ self.visual.vertex_colors[:] = new_vc
84
+
85
+ def set_fc(self, fc, face_ids=None):
86
+ if face_ids is None:
87
+ face_ids = np.arange(self.faces.shape[0])
88
+
89
+ new_fc = self.colors_like(fc, self.visual.face_colors, face_ids)
90
+ self.visual.face_colors[:] = new_fc
91
+
92
+ @staticmethod
93
+ def cat(meshes):
94
+ return trimesh.util.concatenate(meshes)
common/metrics.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ def compute_v2v_dist_no_reduce(v3d_cam_gt, v3d_cam_pred, is_valid):
8
+ assert isinstance(v3d_cam_gt, list)
9
+ assert isinstance(v3d_cam_pred, list)
10
+ assert len(v3d_cam_gt) == len(v3d_cam_pred)
11
+ assert len(v3d_cam_gt) == len(is_valid)
12
+ v2v = []
13
+ for v_gt, v_pred, valid in zip(v3d_cam_gt, v3d_cam_pred, is_valid):
14
+ if valid:
15
+ dist = ((v_gt - v_pred) ** 2).sum(dim=1).sqrt().cpu().numpy() # meter
16
+ else:
17
+ dist = None
18
+ v2v.append(dist)
19
+ return v2v
20
+
21
+
22
+ def compute_joint3d_error(joints3d_cam_gt, joints3d_cam_pred, valid_jts):
23
+ valid_jts = valid_jts.view(-1)
24
+ assert joints3d_cam_gt.shape == joints3d_cam_pred.shape
25
+ assert joints3d_cam_gt.shape[0] == valid_jts.shape[0]
26
+ dist = ((joints3d_cam_gt - joints3d_cam_pred) ** 2).sum(dim=2).sqrt()
27
+ invalid_idx = torch.nonzero((1 - valid_jts).long()).view(-1)
28
+ dist[invalid_idx, :] = float("nan")
29
+ dist = dist.cpu().numpy()
30
+ return dist
31
+
32
+
33
+ def compute_mrrpe(root_r_gt, root_l_gt, root_r_pred, root_l_pred, is_valid):
34
+ rel_vec_gt = root_l_gt - root_r_gt
35
+ rel_vec_pred = root_l_pred - root_r_pred
36
+
37
+ invalid_idx = torch.nonzero((1 - is_valid).long()).view(-1)
38
+ mrrpe = ((rel_vec_pred - rel_vec_gt) ** 2).sum(dim=1).sqrt()
39
+ mrrpe[invalid_idx] = float("nan")
40
+ mrrpe = mrrpe.cpu().numpy()
41
+ return mrrpe
42
+
43
+
44
+ def compute_arti_deg_error(pred_radian, gt_radian):
45
+ assert pred_radian.shape == gt_radian.shape
46
+
47
+ # articulation error in degree
48
+ pred_degree = pred_radian / math.pi * 180 # degree
49
+ gt_degree = gt_radian / math.pi * 180 # degree
50
+ err_deg = torch.abs(pred_degree - gt_degree).tolist()
51
+ return np.array(err_deg, dtype=np.float32)
common/np_utils.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def permute_np(x, idx):
5
+ original_perm = tuple(range(len(x.shape)))
6
+ x = np.moveaxis(x, original_perm, idx)
7
+ return x
common/object_tensors.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os.path as op
3
+ import sys
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import trimesh
9
+ from easydict import EasyDict
10
+ from scipy.spatial.distance import cdist
11
+
12
+ sys.path = [".."] + sys.path
13
+ import common.thing as thing
14
+ from common.rot import axis_angle_to_quaternion, quaternion_apply
15
+ from common.torch_utils import pad_tensor_list
16
+ from common.xdict import xdict
17
+
18
+ # objects to consider for training so far
19
+ OBJECTS = [
20
+ "capsulemachine",
21
+ "box",
22
+ "ketchup",
23
+ "laptop",
24
+ "microwave",
25
+ "mixer",
26
+ "notebook",
27
+ "espressomachine",
28
+ "waffleiron",
29
+ "scissors",
30
+ "phone",
31
+ ]
32
+
33
+
34
+ class ObjectTensors(nn.Module):
35
+ def __init__(self):
36
+ super(ObjectTensors, self).__init__()
37
+ self.obj_tensors = thing.thing2dev(construct_obj_tensors(OBJECTS), "cpu")
38
+ self.dev = None
39
+
40
+ def forward_7d_batch(
41
+ self,
42
+ angles: (None, torch.Tensor),
43
+ global_orient: (None, torch.Tensor),
44
+ transl: (None, torch.Tensor),
45
+ query_names: list,
46
+ fwd_template: bool,
47
+ ):
48
+ self._sanity_check(angles, global_orient, transl, query_names, fwd_template)
49
+
50
+ # store output
51
+ out = xdict()
52
+
53
+ # meta info
54
+ obj_idx = np.array(
55
+ [self.obj_tensors["names"].index(name) for name in query_names]
56
+ )
57
+ out["diameter"] = self.obj_tensors["diameter"][obj_idx]
58
+ out["f"] = self.obj_tensors["f"][obj_idx]
59
+ out["f_len"] = self.obj_tensors["f_len"][obj_idx]
60
+ out["v_len"] = self.obj_tensors["v_len"][obj_idx]
61
+
62
+ max_len = out["v_len"].max()
63
+ out["v"] = self.obj_tensors["v"][obj_idx][:, :max_len]
64
+ out["mask"] = self.obj_tensors["mask"][obj_idx][:, :max_len]
65
+ out["v_sub"] = self.obj_tensors["v_sub"][obj_idx]
66
+ out["parts_ids"] = self.obj_tensors["parts_ids"][obj_idx][:, :max_len]
67
+ out["parts_sub_ids"] = self.obj_tensors["parts_sub_ids"][obj_idx]
68
+
69
+ if fwd_template:
70
+ return out
71
+
72
+ # articulation + global rotation
73
+ quat_arti = axis_angle_to_quaternion(self.obj_tensors["z_axis"] * angles)
74
+ quat_global = axis_angle_to_quaternion(global_orient.view(-1, 3))
75
+
76
+ # mm
77
+ # collect entities to be transformed
78
+ tf_dict = xdict()
79
+ tf_dict["v_top"] = out["v"].clone()
80
+ tf_dict["v_sub_top"] = out["v_sub"].clone()
81
+ tf_dict["v_bottom"] = out["v"].clone()
82
+ tf_dict["v_sub_bottom"] = out["v_sub"].clone()
83
+ tf_dict["bbox_top"] = self.obj_tensors["bbox_top"][obj_idx]
84
+ tf_dict["bbox_bottom"] = self.obj_tensors["bbox_bottom"][obj_idx]
85
+ tf_dict["kp_top"] = self.obj_tensors["kp_top"][obj_idx]
86
+ tf_dict["kp_bottom"] = self.obj_tensors["kp_bottom"][obj_idx]
87
+
88
+ # articulate top parts
89
+ for key, val in tf_dict.items():
90
+ if "top" in key:
91
+ val_rot = quaternion_apply(quat_arti[:, None, :], val)
92
+ tf_dict.overwrite(key, val_rot)
93
+
94
+ # global rotation for all
95
+ for key, val in tf_dict.items():
96
+ val_rot = quaternion_apply(quat_global[:, None, :], val)
97
+ if transl is not None:
98
+ val_rot = val_rot + transl[:, None, :]
99
+ tf_dict.overwrite(key, val_rot)
100
+
101
+ # prep output
102
+ top_idx = out["parts_ids"] == 1
103
+ v_tensor = tf_dict["v_bottom"].clone()
104
+ v_tensor[top_idx, :] = tf_dict["v_top"][top_idx, :]
105
+
106
+ top_idx = out["parts_sub_ids"] == 1
107
+ v_sub_tensor = tf_dict["v_sub_bottom"].clone()
108
+ v_sub_tensor[top_idx, :] = tf_dict["v_sub_top"][top_idx, :]
109
+
110
+ bbox = torch.cat((tf_dict["bbox_top"], tf_dict["bbox_bottom"]), dim=1)
111
+ kp3d = torch.cat((tf_dict["kp_top"], tf_dict["kp_bottom"]), dim=1)
112
+
113
+ out.overwrite("v", v_tensor)
114
+ out.overwrite("v_sub", v_sub_tensor)
115
+ out.overwrite("bbox3d", bbox)
116
+ out.overwrite("kp3d", kp3d)
117
+ return out
118
+
119
+ def forward(self, angles, global_orient, transl, query_names):
120
+ out = self.forward_7d_batch(
121
+ angles, global_orient, transl, query_names, fwd_template=False
122
+ )
123
+ return out
124
+
125
+ def forward_template(self, query_names):
126
+ out = self.forward_7d_batch(
127
+ angles=None,
128
+ global_orient=None,
129
+ transl=None,
130
+ query_names=query_names,
131
+ fwd_template=True,
132
+ )
133
+ return out
134
+
135
+ def to(self, dev):
136
+ self.obj_tensors = thing.thing2dev(self.obj_tensors, dev)
137
+ self.dev = dev
138
+
139
+ def _sanity_check(self, angles, global_orient, transl, query_names, fwd_template):
140
+ # sanity check
141
+ if not fwd_template:
142
+ # assume transl is in meter
143
+ if transl is not None:
144
+ transl = transl * 1000 # mm
145
+
146
+ batch_size = angles.shape[0]
147
+ assert angles.shape == (batch_size, 1)
148
+ assert global_orient.shape == (batch_size, 3)
149
+ if transl is not None:
150
+ assert isinstance(transl, torch.Tensor)
151
+ assert transl.shape == (batch_size, 3)
152
+ assert len(query_names) == batch_size
153
+
154
+
155
+ def construct_obj(object_model_p):
156
+ # load vtemplate
157
+ mesh_p = op.join(object_model_p, "mesh.obj")
158
+ parts_p = op.join(object_model_p, f"parts.json")
159
+ json_p = op.join(object_model_p, "object_params.json")
160
+ obj_name = op.basename(object_model_p)
161
+
162
+ top_sub_p = f"./data/arctic_data/data/meta/object_vtemplates/{obj_name}/top_keypoints_300.json"
163
+ bottom_sub_p = top_sub_p.replace("top_", "bottom_")
164
+ with open(top_sub_p, "r") as f:
165
+ sub_top = np.array(json.load(f)["keypoints"])
166
+
167
+ with open(bottom_sub_p, "r") as f:
168
+ sub_bottom = np.array(json.load(f)["keypoints"])
169
+ sub_v = np.concatenate((sub_top, sub_bottom), axis=0)
170
+
171
+ with open(parts_p, "r") as f:
172
+ parts = np.array(json.load(f), dtype=np.bool)
173
+
174
+ assert op.exists(mesh_p), f"Not found: {mesh_p}"
175
+
176
+ mesh = trimesh.exchange.load.load_mesh(mesh_p, process=False)
177
+ mesh_v = mesh.vertices
178
+
179
+ mesh_f = torch.LongTensor(mesh.faces)
180
+ vidx = np.argmin(cdist(sub_v, mesh_v, metric="euclidean"), axis=1)
181
+ parts_sub = parts[vidx]
182
+
183
+ vsk = object_model_p.split("/")[-1]
184
+
185
+ with open(json_p, "r") as f:
186
+ params = json.load(f)
187
+ rest = EasyDict()
188
+ rest.top = np.array(params["mocap_top"])
189
+ rest.bottom = np.array(params["mocap_bottom"])
190
+ bbox_top = np.array(params["bbox_top"])
191
+ bbox_bottom = np.array(params["bbox_bottom"])
192
+ kp_top = np.array(params["keypoints_top"])
193
+ kp_bottom = np.array(params["keypoints_bottom"])
194
+
195
+ np.random.seed(1)
196
+
197
+ obj = EasyDict()
198
+ obj.name = vsk
199
+ obj.obj_name = "".join([i for i in vsk if not i.isdigit()])
200
+ obj.v = torch.FloatTensor(mesh_v)
201
+ obj.v_sub = torch.FloatTensor(sub_v)
202
+ obj.f = torch.LongTensor(mesh_f)
203
+ obj.parts = torch.LongTensor(parts)
204
+ obj.parts_sub = torch.LongTensor(parts_sub)
205
+
206
+ with open("./data/arctic_data/data/meta/object_meta.json", "r") as f:
207
+ object_meta = json.load(f)
208
+ obj.diameter = torch.FloatTensor(np.array(object_meta[obj.obj_name]["diameter"]))
209
+ obj.bbox_top = torch.FloatTensor(bbox_top)
210
+ obj.bbox_bottom = torch.FloatTensor(bbox_bottom)
211
+ obj.kp_top = torch.FloatTensor(kp_top)
212
+ obj.kp_bottom = torch.FloatTensor(kp_bottom)
213
+ obj.mocap_top = torch.FloatTensor(np.array(params["mocap_top"]))
214
+ obj.mocap_bottom = torch.FloatTensor(np.array(params["mocap_bottom"]))
215
+ return obj
216
+
217
+
218
+ def construct_obj_tensors(object_names):
219
+ obj_list = []
220
+ for k in object_names:
221
+ object_model_p = f"./data/arctic_data/data/meta/object_vtemplates/%s" % (k)
222
+ obj = construct_obj(object_model_p)
223
+ obj_list.append(obj)
224
+
225
+ bbox_top_list = []
226
+ bbox_bottom_list = []
227
+ mocap_top_list = []
228
+ mocap_bottom_list = []
229
+ kp_top_list = []
230
+ kp_bottom_list = []
231
+ v_list = []
232
+ v_sub_list = []
233
+ f_list = []
234
+ parts_list = []
235
+ parts_sub_list = []
236
+ diameter_list = []
237
+ for obj in obj_list:
238
+ v_list.append(obj.v)
239
+ v_sub_list.append(obj.v_sub)
240
+ f_list.append(obj.f)
241
+
242
+ # root_list.append(obj.root)
243
+ bbox_top_list.append(obj.bbox_top)
244
+ bbox_bottom_list.append(obj.bbox_bottom)
245
+ kp_top_list.append(obj.kp_top)
246
+ kp_bottom_list.append(obj.kp_bottom)
247
+ mocap_top_list.append(obj.mocap_top / 1000)
248
+ mocap_bottom_list.append(obj.mocap_bottom / 1000)
249
+ parts_list.append(obj.parts + 1)
250
+ parts_sub_list.append(obj.parts_sub + 1)
251
+ diameter_list.append(obj.diameter)
252
+
253
+ v_list, v_len_list = pad_tensor_list(v_list)
254
+ p_list, p_len_list = pad_tensor_list(parts_list)
255
+ ps_list = torch.stack(parts_sub_list, dim=0)
256
+ assert (p_len_list - v_len_list).sum() == 0
257
+
258
+ max_len = v_len_list.max()
259
+ mask = torch.zeros(len(obj_list), max_len)
260
+ for idx, vlen in enumerate(v_len_list):
261
+ mask[idx, :vlen] = 1.0
262
+
263
+ v_sub_list = torch.stack(v_sub_list, dim=0)
264
+ diameter_list = torch.stack(diameter_list, dim=0)
265
+
266
+ f_list, f_len_list = pad_tensor_list(f_list)
267
+
268
+ bbox_top_list = torch.stack(bbox_top_list, dim=0)
269
+ bbox_bottom_list = torch.stack(bbox_bottom_list, dim=0)
270
+ kp_top_list = torch.stack(kp_top_list, dim=0)
271
+ kp_bottom_list = torch.stack(kp_bottom_list, dim=0)
272
+
273
+ obj_tensors = {}
274
+ obj_tensors["names"] = object_names
275
+ obj_tensors["parts_ids"] = p_list
276
+ obj_tensors["parts_sub_ids"] = ps_list
277
+
278
+ obj_tensors["v"] = v_list.float() / 1000
279
+ obj_tensors["v_sub"] = v_sub_list.float() / 1000
280
+ obj_tensors["v_len"] = v_len_list
281
+ obj_tensors["f"] = f_list
282
+ obj_tensors["f_len"] = f_len_list
283
+ obj_tensors["diameter"] = diameter_list.float()
284
+
285
+ obj_tensors["mask"] = mask
286
+ obj_tensors["bbox_top"] = bbox_top_list.float() / 1000
287
+ obj_tensors["bbox_bottom"] = bbox_bottom_list.float() / 1000
288
+ obj_tensors["kp_top"] = kp_top_list.float() / 1000
289
+ obj_tensors["kp_bottom"] = kp_bottom_list.float() / 1000
290
+ obj_tensors["mocap_top"] = mocap_top_list
291
+ obj_tensors["mocap_bottom"] = mocap_bottom_list
292
+ obj_tensors["z_axis"] = torch.FloatTensor(np.array([0, 0, -1])).view(1, 3)
293
+ return obj_tensors
common/pl_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+
4
+ import torch
5
+
6
+ import common.thing as thing
7
+ from common.ld_utils import ld2dl
8
+
9
+
10
+ def reweight_loss_by_keys(loss_dict, keys, alpha):
11
+ for key in keys:
12
+ val, weight = loss_dict[key]
13
+ weight_new = weight * alpha
14
+ loss_dict[key] = (val, weight_new)
15
+ return loss_dict
16
+
17
+
18
+ def select_loss_group(groups, agent_id, alphas):
19
+ random.seed(1)
20
+ random.shuffle(groups)
21
+
22
+ keys = groups[agent_id % len(groups)]
23
+
24
+ random.seed(time.time())
25
+ alpha = random.choice(alphas)
26
+ random.seed(1)
27
+ return keys, alpha
28
+
29
+
30
+ def push_checkpoint_metric(key, val):
31
+ val = float(val)
32
+ checkpt_metric = torch.FloatTensor([val])
33
+ result = {key: checkpt_metric}
34
+ return result
35
+
36
+
37
+ def avg_losses_cpu(outputs):
38
+ outputs = ld2dl(outputs)
39
+ for key, val in outputs.items():
40
+ val = [v.cpu() for v in val]
41
+ val = torch.cat(val, dim=0).view(-1)
42
+ outputs[key] = val.mean()
43
+ return outputs
44
+
45
+
46
+ def reform_outputs(out_list):
47
+ out_list_dict = ld2dl(out_list)
48
+ outputs = ld2dl(out_list_dict["out_dict"])
49
+ losses = ld2dl(out_list_dict["loss"])
50
+
51
+ for k, tensor in outputs.items():
52
+ if isinstance(tensor[0], list):
53
+ outputs[k] = sum(tensor, [])
54
+ else:
55
+ outputs[k] = torch.cat(tensor)
56
+
57
+ for k, tensor in losses.items():
58
+ tensor = [ten.view(-1) for ten in tensor]
59
+ losses[k] = torch.cat(tensor)
60
+
61
+ outputs = {k: thing.thing2np(v) for k, v in outputs.items()}
62
+ loss_dict = {k: v.mean().item() for k, v in losses.items()}
63
+ return outputs, loss_dict
common/rend_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+
4
+ import numpy as np
5
+ import pyrender
6
+ import trimesh
7
+
8
+ # offline rendering
9
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
10
+
11
+
12
+ def flip_meshes(meshes):
13
+ rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
14
+ for mesh in meshes:
15
+ mesh.apply_transform(rot)
16
+ return meshes
17
+
18
+
19
+ def color2material(mesh_color: list):
20
+ material = pyrender.MetallicRoughnessMaterial(
21
+ metallicFactor=0.1,
22
+ alphaMode="OPAQUE",
23
+ baseColorFactor=(
24
+ mesh_color[0] / 255.0,
25
+ mesh_color[1] / 255.0,
26
+ mesh_color[2] / 255.0,
27
+ 0.5,
28
+ ),
29
+ )
30
+ return material
31
+
32
+
33
+ class Renderer:
34
+ def __init__(self, img_res: int) -> None:
35
+ self.renderer = pyrender.OffscreenRenderer(
36
+ viewport_width=img_res, viewport_height=img_res, point_size=1.0
37
+ )
38
+
39
+ self.img_res = img_res
40
+
41
+ def render_meshes_pose(
42
+ self,
43
+ meshes,
44
+ image=None,
45
+ cam_transl=None,
46
+ cam_center=None,
47
+ K=None,
48
+ materials=None,
49
+ sideview_angle=None,
50
+ ):
51
+ # unpack
52
+ if cam_transl is not None:
53
+ cam_trans = np.copy(cam_transl)
54
+ cam_trans[0] *= -1.0
55
+ else:
56
+ cam_trans = None
57
+ meshes = copy.deepcopy(meshes)
58
+ meshes = flip_meshes(meshes)
59
+
60
+ if sideview_angle is not None:
61
+ # center around the final mesh
62
+ anchor_mesh = meshes[-1]
63
+ center = anchor_mesh.vertices.mean(axis=0)
64
+
65
+ rot = trimesh.transformations.rotation_matrix(
66
+ np.radians(sideview_angle), [0, 1, 0]
67
+ )
68
+ out_meshes = []
69
+ for mesh in copy.deepcopy(meshes):
70
+ mesh.vertices -= center
71
+ mesh.apply_transform(rot)
72
+ mesh.vertices += center
73
+ # further away to see more
74
+ mesh.vertices += np.array([0, 0, -0.10])
75
+ out_meshes.append(mesh)
76
+ meshes = out_meshes
77
+
78
+ # setting up
79
+ self.create_scene()
80
+ self.setup_light()
81
+ self.position_camera(cam_trans, K)
82
+ if materials is not None:
83
+ meshes = [
84
+ pyrender.Mesh.from_trimesh(mesh, material=material)
85
+ for mesh, material in zip(meshes, materials)
86
+ ]
87
+ else:
88
+ meshes = [pyrender.Mesh.from_trimesh(mesh) for mesh in meshes]
89
+
90
+ for mesh in meshes:
91
+ self.scene.add(mesh)
92
+
93
+ color, valid_mask = self.render_rgb()
94
+ if image is None:
95
+ output_img = color[:, :, :3]
96
+ else:
97
+ output_img = self.overlay_image(color, valid_mask, image)
98
+ rend_img = (output_img * 255).astype(np.uint8)
99
+ return rend_img
100
+
101
+ def render_rgb(self):
102
+ color, rend_depth = self.renderer.render(
103
+ self.scene, flags=pyrender.RenderFlags.RGBA
104
+ )
105
+ color = color.astype(np.float32) / 255.0
106
+ valid_mask = (rend_depth > 0)[:, :, None]
107
+ return color, valid_mask
108
+
109
+ def overlay_image(self, color, valid_mask, image):
110
+ output_img = color[:, :, :3] * valid_mask + (1 - valid_mask) * image
111
+ return output_img
112
+
113
+ def position_camera(self, cam_transl, K):
114
+ camera_pose = np.eye(4)
115
+ if cam_transl is not None:
116
+ camera_pose[:3, 3] = cam_transl
117
+
118
+ fx = K[0, 0]
119
+ fy = K[1, 1]
120
+ cx = K[0, 2]
121
+ cy = K[1, 2]
122
+ camera = pyrender.IntrinsicsCamera(fx=fx, fy=fy, cx=cx, cy=cy)
123
+ self.scene.add(camera, pose=camera_pose)
124
+
125
+ def setup_light(self):
126
+ light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=1)
127
+ light_pose = np.eye(4)
128
+
129
+ light_pose[:3, 3] = np.array([0, -1, 1])
130
+ self.scene.add(light, pose=light_pose)
131
+
132
+ light_pose[:3, 3] = np.array([0, 1, 1])
133
+ self.scene.add(light, pose=light_pose)
134
+
135
+ light_pose[:3, 3] = np.array([1, 1, 2])
136
+ self.scene.add(light, pose=light_pose)
137
+
138
+ def create_scene(self):
139
+ self.scene = pyrender.Scene(ambient_light=(0.5, 0.5, 0.5))
common/rot.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+ """
7
+ Taken from https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html
8
+ Just to avoid installing pytorch3d at times
9
+ """
10
+
11
+
12
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
13
+ """
14
+ Convert a unit quaternion to a standard form: one in which the real
15
+ part is non negative.
16
+
17
+ Args:
18
+ quaternions: Quaternions with real part first,
19
+ as tensor of shape (..., 4).
20
+
21
+ Returns:
22
+ Standardized quaternions as tensor of shape (..., 4).
23
+ """
24
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
25
+
26
+
27
+ def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
28
+ """
29
+ Multiply two quaternions representing rotations, returning the quaternion
30
+ representing their composition, i.e. the versor with nonnegative real part.
31
+ Usual torch rules for broadcasting apply.
32
+
33
+ Args:
34
+ a: Quaternions as tensor of shape (..., 4), real part first.
35
+ b: Quaternions as tensor of shape (..., 4), real part first.
36
+
37
+ Returns:
38
+ The product of a and b, a tensor of quaternions of shape (..., 4).
39
+ """
40
+ ab = quaternion_raw_multiply(a, b)
41
+ return standardize_quaternion(ab)
42
+
43
+
44
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Returns torch.sqrt(torch.max(0, x))
47
+ but with a zero subgradient where x is 0.
48
+ """
49
+ ret = torch.zeros_like(x)
50
+ positive_mask = x > 0
51
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
52
+ return ret
53
+
54
+
55
+ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
56
+ """
57
+ Convert rotations given as quaternions to axis/angle.
58
+
59
+ Args:
60
+ quaternions: quaternions with real part first,
61
+ as tensor of shape (..., 4).
62
+
63
+ Returns:
64
+ Rotations given as a vector in axis angle form, as a tensor
65
+ of shape (..., 3), where the magnitude is the angle
66
+ turned anticlockwise in radians around the vector's
67
+ direction.
68
+ """
69
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
70
+ half_angles = torch.atan2(norms, quaternions[..., :1])
71
+ angles = 2 * half_angles
72
+ eps = 1e-6
73
+ small_angles = angles.abs() < eps
74
+ sin_half_angles_over_angles = torch.empty_like(angles)
75
+ sin_half_angles_over_angles[~small_angles] = (
76
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
77
+ )
78
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
79
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
80
+ sin_half_angles_over_angles[small_angles] = (
81
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
82
+ )
83
+ return quaternions[..., 1:] / sin_half_angles_over_angles
84
+
85
+
86
+ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
87
+ """
88
+ Convert rotations given as quaternions to rotation matrices.
89
+
90
+ Args:
91
+ quaternions: quaternions with real part first,
92
+ as tensor of shape (..., 4).
93
+
94
+ Returns:
95
+ Rotation matrices as tensor of shape (..., 3, 3).
96
+ """
97
+ r, i, j, k = torch.unbind(quaternions, -1)
98
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
99
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
100
+
101
+ o = torch.stack(
102
+ (
103
+ 1 - two_s * (j * j + k * k),
104
+ two_s * (i * j - k * r),
105
+ two_s * (i * k + j * r),
106
+ two_s * (i * j + k * r),
107
+ 1 - two_s * (i * i + k * k),
108
+ two_s * (j * k - i * r),
109
+ two_s * (i * k - j * r),
110
+ two_s * (j * k + i * r),
111
+ 1 - two_s * (i * i + j * j),
112
+ ),
113
+ -1,
114
+ )
115
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
116
+
117
+
118
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
119
+ """
120
+ Convert rotations given as rotation matrices to quaternions.
121
+
122
+ Args:
123
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
124
+
125
+ Returns:
126
+ quaternions with real part first, as tensor of shape (..., 4).
127
+ """
128
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
129
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
130
+
131
+ batch_dim = matrix.shape[:-2]
132
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
133
+ matrix.reshape(batch_dim + (9,)), dim=-1
134
+ )
135
+
136
+ q_abs = _sqrt_positive_part(
137
+ torch.stack(
138
+ [
139
+ 1.0 + m00 + m11 + m22,
140
+ 1.0 + m00 - m11 - m22,
141
+ 1.0 - m00 + m11 - m22,
142
+ 1.0 - m00 - m11 + m22,
143
+ ],
144
+ dim=-1,
145
+ )
146
+ )
147
+
148
+ # we produce the desired quaternion multiplied by each of r, i, j, k
149
+ quat_by_rijk = torch.stack(
150
+ [
151
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
152
+ # `int`.
153
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
154
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
155
+ # `int`.
156
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
157
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
158
+ # `int`.
159
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
160
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
161
+ # `int`.
162
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
163
+ ],
164
+ dim=-2,
165
+ )
166
+
167
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
168
+ # the candidate won't be picked.
169
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
170
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
171
+
172
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
173
+ # forall i; we pick the best-conditioned one (with the largest denominator)
174
+
175
+ return quat_candidates[
176
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
177
+ ].reshape(batch_dim + (4,))
178
+
179
+
180
+ def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
181
+ """
182
+ Convert rotations given as rotation matrices to axis/angle.
183
+
184
+ Args:
185
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
186
+
187
+ Returns:
188
+ Rotations given as a vector in axis angle form, as a tensor
189
+ of shape (..., 3), where the magnitude is the angle
190
+ turned anticlockwise in radians around the vector's
191
+ direction.
192
+ """
193
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
194
+
195
+
196
+ def rot_aa(aa, rot):
197
+ """Rotate axis angle parameters."""
198
+ # pose parameters
199
+ R = np.array(
200
+ [
201
+ [np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
202
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
203
+ [0, 0, 1],
204
+ ]
205
+ )
206
+ # find the rotation of the body in camera frame
207
+ per_rdg, _ = cv2.Rodrigues(aa)
208
+ # apply the global rotation to the global orientation
209
+ resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
210
+ aa = (resrot.T)[0]
211
+ return aa
212
+
213
+
214
+ def quat2mat(quat):
215
+ """
216
+ This function is borrowed from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py#L50
217
+ Convert quaternion coefficients to rotation matrix.
218
+ Args:
219
+ quat: size = [batch_size, 4] 4 <===>(w, x, y, z)
220
+ Returns:
221
+ Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3]
222
+ """
223
+ norm_quat = quat
224
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
225
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
226
+
227
+ batch_size = quat.size(0)
228
+
229
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
230
+ wx, wy, wz = w * x, w * y, w * z
231
+ xy, xz, yz = x * y, x * z, y * z
232
+
233
+ rotMat = torch.stack(
234
+ [
235
+ w2 + x2 - y2 - z2,
236
+ 2 * xy - 2 * wz,
237
+ 2 * wy + 2 * xz,
238
+ 2 * wz + 2 * xy,
239
+ w2 - x2 + y2 - z2,
240
+ 2 * yz - 2 * wx,
241
+ 2 * xz - 2 * wy,
242
+ 2 * wx + 2 * yz,
243
+ w2 - x2 - y2 + z2,
244
+ ],
245
+ dim=1,
246
+ ).view(batch_size, 3, 3)
247
+ return rotMat
248
+
249
+
250
+ def batch_aa2rot(axisang):
251
+ # This function is borrowed from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py#L37
252
+ assert len(axisang.shape) == 2
253
+ assert axisang.shape[1] == 3
254
+ # axisang N x 3
255
+ axisang_norm = torch.norm(axisang + 1e-8, p=2, dim=1)
256
+ angle = torch.unsqueeze(axisang_norm, -1)
257
+ axisang_normalized = torch.div(axisang, angle)
258
+ angle = angle * 0.5
259
+ v_cos = torch.cos(angle)
260
+ v_sin = torch.sin(angle)
261
+ quat = torch.cat([v_cos, v_sin * axisang_normalized], dim=1)
262
+ rot_mat = quat2mat(quat)
263
+ rot_mat = rot_mat.view(rot_mat.shape[0], 9)
264
+ return rot_mat
265
+
266
+
267
+ def batch_rot2aa(Rs):
268
+ assert len(Rs.shape) == 3
269
+ assert Rs.shape[1] == Rs.shape[2]
270
+ assert Rs.shape[1] == 3
271
+
272
+ """
273
+ Rs is B x 3 x 3
274
+ void cMathUtil::RotMatToAxisAngle(const tMatrix& mat, tVector& out_axis,
275
+ double& out_theta)
276
+ {
277
+ double c = 0.5 * (mat(0, 0) + mat(1, 1) + mat(2, 2) - 1);
278
+ c = cMathUtil::Clamp(c, -1.0, 1.0);
279
+
280
+ out_theta = std::acos(c);
281
+
282
+ if (std::abs(out_theta) < 0.00001)
283
+ {
284
+ out_axis = tVector(0, 0, 1, 0);
285
+ }
286
+ else
287
+ {
288
+ double m21 = mat(2, 1) - mat(1, 2);
289
+ double m02 = mat(0, 2) - mat(2, 0);
290
+ double m10 = mat(1, 0) - mat(0, 1);
291
+ double denom = std::sqrt(m21 * m21 + m02 * m02 + m10 * m10);
292
+ out_axis[0] = m21 / denom;
293
+ out_axis[1] = m02 / denom;
294
+ out_axis[2] = m10 / denom;
295
+ out_axis[3] = 0;
296
+ }
297
+ }
298
+ """
299
+ cos = 0.5 * (torch.stack([torch.trace(x) for x in Rs]) - 1)
300
+ cos = torch.clamp(cos, -1, 1)
301
+
302
+ theta = torch.acos(cos)
303
+
304
+ m21 = Rs[:, 2, 1] - Rs[:, 1, 2]
305
+ m02 = Rs[:, 0, 2] - Rs[:, 2, 0]
306
+ m10 = Rs[:, 1, 0] - Rs[:, 0, 1]
307
+ denom = torch.sqrt(m21 * m21 + m02 * m02 + m10 * m10)
308
+
309
+ axis0 = torch.where(torch.abs(theta) < 0.00001, m21, m21 / denom)
310
+ axis1 = torch.where(torch.abs(theta) < 0.00001, m02, m02 / denom)
311
+ axis2 = torch.where(torch.abs(theta) < 0.00001, m10, m10 / denom)
312
+
313
+ return theta.unsqueeze(1) * torch.stack([axis0, axis1, axis2], 1)
314
+
315
+
316
+ def batch_rodrigues(theta):
317
+ """Convert axis-angle representation to rotation matrix.
318
+ Args:
319
+ theta: size = [B, 3]
320
+ Returns:
321
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
322
+ """
323
+ l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
324
+ angle = torch.unsqueeze(l1norm, -1)
325
+ normalized = torch.div(theta, angle)
326
+ angle = angle * 0.5
327
+ v_cos = torch.cos(angle)
328
+ v_sin = torch.sin(angle)
329
+ quat = torch.cat([v_cos, v_sin * normalized], dim=1)
330
+ return quat_to_rotmat(quat)
331
+
332
+
333
+ def quat_to_rotmat(quat):
334
+ """Convert quaternion coefficients to rotation matrix.
335
+ Args:
336
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
337
+ Returns:
338
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
339
+ """
340
+ norm_quat = quat
341
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
342
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
343
+
344
+ B = quat.size(0)
345
+
346
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
347
+ wx, wy, wz = w * x, w * y, w * z
348
+ xy, xz, yz = x * y, x * z, y * z
349
+
350
+ rotMat = torch.stack(
351
+ [
352
+ w2 + x2 - y2 - z2,
353
+ 2 * xy - 2 * wz,
354
+ 2 * wy + 2 * xz,
355
+ 2 * wz + 2 * xy,
356
+ w2 - x2 + y2 - z2,
357
+ 2 * yz - 2 * wx,
358
+ 2 * xz - 2 * wy,
359
+ 2 * wx + 2 * yz,
360
+ w2 - x2 - y2 + z2,
361
+ ],
362
+ dim=1,
363
+ ).view(B, 3, 3)
364
+ return rotMat
365
+
366
+
367
+ def rot6d_to_rotmat(x):
368
+ """Convert 6D rotation representation to 3x3 rotation matrix.
369
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
370
+ Input:
371
+ (B,6) Batch of 6-D rotation representations
372
+ Output:
373
+ (B,3,3) Batch of corresponding rotation matrices
374
+ """
375
+ x = x.reshape(-1, 3, 2)
376
+ a1 = x[:, :, 0]
377
+ a2 = x[:, :, 1]
378
+ b1 = F.normalize(a1)
379
+ b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
380
+ b3 = torch.cross(b1, b2)
381
+ return torch.stack((b1, b2, b3), dim=-1)
382
+
383
+
384
+ def rotmat_to_rot6d(x):
385
+ rotmat = x.reshape(-1, 3, 3)
386
+ rot6d = rotmat[:, :, :2].reshape(x.shape[0], -1)
387
+ return rot6d
388
+
389
+
390
+ def rotation_matrix_to_angle_axis(rotation_matrix):
391
+ """
392
+ This function is borrowed from https://github.com/kornia/kornia
393
+
394
+ Convert 3x4 rotation matrix to Rodrigues vector
395
+
396
+ Args:
397
+ rotation_matrix (Tensor): rotation matrix.
398
+
399
+ Returns:
400
+ Tensor: Rodrigues vector transformation.
401
+
402
+ Shape:
403
+ - Input: :math:`(N, 3, 4)`
404
+ - Output: :math:`(N, 3)`
405
+
406
+ Example:
407
+ >>> input = torch.rand(2, 3, 4) # Nx4x4
408
+ >>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
409
+ """
410
+ if rotation_matrix.shape[1:] == (3, 3):
411
+ rot_mat = rotation_matrix.reshape(-1, 3, 3)
412
+ hom = (
413
+ torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device)
414
+ .reshape(1, 3, 1)
415
+ .expand(rot_mat.shape[0], -1, -1)
416
+ )
417
+ rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
418
+
419
+ quaternion = rotation_matrix_to_quaternion(rotation_matrix)
420
+ aa = quaternion_to_angle_axis(quaternion)
421
+ aa[torch.isnan(aa)] = 0.0
422
+ return aa
423
+
424
+
425
+ def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
426
+ """
427
+ This function is borrowed from https://github.com/kornia/kornia
428
+
429
+ Convert quaternion vector to angle axis of rotation.
430
+
431
+ Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
432
+
433
+ Args:
434
+ quaternion (torch.Tensor): tensor with quaternions.
435
+
436
+ Return:
437
+ torch.Tensor: tensor with angle axis of rotation.
438
+
439
+ Shape:
440
+ - Input: :math:`(*, 4)` where `*` means, any number of dimensions
441
+ - Output: :math:`(*, 3)`
442
+
443
+ Example:
444
+ >>> quaternion = torch.rand(2, 4) # Nx4
445
+ >>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
446
+ """
447
+ if not torch.is_tensor(quaternion):
448
+ raise TypeError(
449
+ "Input type is not a torch.Tensor. Got {}".format(type(quaternion))
450
+ )
451
+
452
+ if not quaternion.shape[-1] == 4:
453
+ raise ValueError(
454
+ "Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape)
455
+ )
456
+ # unpack input and compute conversion
457
+ q1: torch.Tensor = quaternion[..., 1]
458
+ q2: torch.Tensor = quaternion[..., 2]
459
+ q3: torch.Tensor = quaternion[..., 3]
460
+ sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
461
+
462
+ sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
463
+ cos_theta: torch.Tensor = quaternion[..., 0]
464
+ two_theta: torch.Tensor = 2.0 * torch.where(
465
+ cos_theta < 0.0,
466
+ torch.atan2(-sin_theta, -cos_theta),
467
+ torch.atan2(sin_theta, cos_theta),
468
+ )
469
+
470
+ k_pos: torch.Tensor = two_theta / sin_theta
471
+ k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
472
+ k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
473
+
474
+ angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
475
+ angle_axis[..., 0] += q1 * k
476
+ angle_axis[..., 1] += q2 * k
477
+ angle_axis[..., 2] += q3 * k
478
+ return angle_axis
479
+
480
+
481
+ def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
482
+ """
483
+ This function is borrowed from https://github.com/kornia/kornia
484
+
485
+ Convert 3x4 rotation matrix to 4d quaternion vector
486
+
487
+ This algorithm is based on algorithm described in
488
+ https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
489
+
490
+ Args:
491
+ rotation_matrix (Tensor): the rotation matrix to convert.
492
+
493
+ Return:
494
+ Tensor: the rotation in quaternion
495
+
496
+ Shape:
497
+ - Input: :math:`(N, 3, 4)`
498
+ - Output: :math:`(N, 4)`
499
+
500
+ Example:
501
+ >>> input = torch.rand(4, 3, 4) # Nx3x4
502
+ >>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
503
+ """
504
+ if not torch.is_tensor(rotation_matrix):
505
+ raise TypeError(
506
+ "Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix))
507
+ )
508
+
509
+ if len(rotation_matrix.shape) > 3:
510
+ raise ValueError(
511
+ "Input size must be a three dimensional tensor. Got {}".format(
512
+ rotation_matrix.shape
513
+ )
514
+ )
515
+ if not rotation_matrix.shape[-2:] == (3, 4):
516
+ raise ValueError(
517
+ "Input size must be a N x 3 x 4 tensor. Got {}".format(
518
+ rotation_matrix.shape
519
+ )
520
+ )
521
+
522
+ rmat_t = torch.transpose(rotation_matrix, 1, 2)
523
+
524
+ mask_d2 = rmat_t[:, 2, 2] < eps
525
+
526
+ mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
527
+ mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
528
+
529
+ t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
530
+ q0 = torch.stack(
531
+ [
532
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
533
+ t0,
534
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
535
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
536
+ ],
537
+ -1,
538
+ )
539
+ t0_rep = t0.repeat(4, 1).t()
540
+
541
+ t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
542
+ q1 = torch.stack(
543
+ [
544
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
545
+ rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
546
+ t1,
547
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
548
+ ],
549
+ -1,
550
+ )
551
+ t1_rep = t1.repeat(4, 1).t()
552
+
553
+ t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
554
+ q2 = torch.stack(
555
+ [
556
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
557
+ rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
558
+ rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
559
+ t2,
560
+ ],
561
+ -1,
562
+ )
563
+ t2_rep = t2.repeat(4, 1).t()
564
+
565
+ t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
566
+ q3 = torch.stack(
567
+ [
568
+ t3,
569
+ rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
570
+ rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
571
+ rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
572
+ ],
573
+ -1,
574
+ )
575
+ t3_rep = t3.repeat(4, 1).t()
576
+
577
+ mask_c0 = mask_d2 * mask_d0_d1
578
+ mask_c1 = mask_d2 * ~mask_d0_d1
579
+ mask_c2 = ~mask_d2 * mask_d0_nd1
580
+ mask_c3 = ~mask_d2 * ~mask_d0_nd1
581
+ mask_c0 = mask_c0.view(-1, 1).type_as(q0)
582
+ mask_c1 = mask_c1.view(-1, 1).type_as(q1)
583
+ mask_c2 = mask_c2.view(-1, 1).type_as(q2)
584
+ mask_c3 = mask_c3.view(-1, 1).type_as(q3)
585
+
586
+ q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
587
+ q /= torch.sqrt(
588
+ t0_rep * mask_c0
589
+ + t1_rep * mask_c1
590
+ + t2_rep * mask_c2 # noqa
591
+ + t3_rep * mask_c3
592
+ ) # noqa
593
+ q *= 0.5
594
+ return q
595
+
596
+
597
+ def batch_euler2matrix(r):
598
+ return quaternion_to_rotation_matrix(euler_to_quaternion(r))
599
+
600
+
601
+ def euler_to_quaternion(r):
602
+ x = r[..., 0]
603
+ y = r[..., 1]
604
+ z = r[..., 2]
605
+
606
+ z = z / 2.0
607
+ y = y / 2.0
608
+ x = x / 2.0
609
+ cz = torch.cos(z)
610
+ sz = torch.sin(z)
611
+ cy = torch.cos(y)
612
+ sy = torch.sin(y)
613
+ cx = torch.cos(x)
614
+ sx = torch.sin(x)
615
+ quaternion = torch.zeros_like(r.repeat(1, 2))[..., :4].to(r.device)
616
+ quaternion[..., 0] += cx * cy * cz - sx * sy * sz
617
+ quaternion[..., 1] += cx * sy * sz + cy * cz * sx
618
+ quaternion[..., 2] += cx * cz * sy - sx * cy * sz
619
+ quaternion[..., 3] += cx * cy * sz + sx * cz * sy
620
+ return quaternion
621
+
622
+
623
+ def quaternion_to_rotation_matrix(quat):
624
+ """Convert quaternion coefficients to rotation matrix.
625
+ Args:
626
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
627
+ Returns:
628
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
629
+ """
630
+ norm_quat = quat
631
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
632
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
633
+
634
+ B = quat.size(0)
635
+
636
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
637
+ wx, wy, wz = w * x, w * y, w * z
638
+ xy, xz, yz = x * y, x * z, y * z
639
+
640
+ rotMat = torch.stack(
641
+ [
642
+ w2 + x2 - y2 - z2,
643
+ 2 * xy - 2 * wz,
644
+ 2 * wy + 2 * xz,
645
+ 2 * wz + 2 * xy,
646
+ w2 - x2 + y2 - z2,
647
+ 2 * yz - 2 * wx,
648
+ 2 * xz - 2 * wy,
649
+ 2 * wx + 2 * yz,
650
+ w2 - x2 - y2 + z2,
651
+ ],
652
+ dim=1,
653
+ ).view(B, 3, 3)
654
+ return rotMat
655
+
656
+
657
+ def euler_angles_from_rotmat(R):
658
+ """
659
+ computer euler angles for rotation around x, y, z axis
660
+ from rotation amtrix
661
+ R: 4x4 rotation matrix
662
+ https://www.gregslabaugh.net/publications/euler.pdf
663
+ """
664
+ r21 = np.round(R[:, 2, 0].item(), 4)
665
+ if abs(r21) != 1:
666
+ y_angle1 = -1 * torch.asin(R[:, 2, 0])
667
+ y_angle2 = math.pi + torch.asin(R[:, 2, 0])
668
+ cy1, cy2 = torch.cos(y_angle1), torch.cos(y_angle2)
669
+
670
+ x_angle1 = torch.atan2(R[:, 2, 1] / cy1, R[:, 2, 2] / cy1)
671
+ x_angle2 = torch.atan2(R[:, 2, 1] / cy2, R[:, 2, 2] / cy2)
672
+ z_angle1 = torch.atan2(R[:, 1, 0] / cy1, R[:, 0, 0] / cy1)
673
+ z_angle2 = torch.atan2(R[:, 1, 0] / cy2, R[:, 0, 0] / cy2)
674
+
675
+ s1 = (x_angle1, y_angle1, z_angle1)
676
+ s2 = (x_angle2, y_angle2, z_angle2)
677
+ s = (s1, s2)
678
+
679
+ else:
680
+ z_angle = torch.tensor([0], device=R.device).float()
681
+ if r21 == -1:
682
+ y_angle = torch.tensor([math.pi / 2], device=R.device).float()
683
+ x_angle = z_angle + torch.atan2(R[:, 0, 1], R[:, 0, 2])
684
+ else:
685
+ y_angle = -torch.tensor([math.pi / 2], device=R.device).float()
686
+ x_angle = -z_angle + torch.atan2(-R[:, 0, 1], R[:, 0, 2])
687
+ s = ((x_angle, y_angle, z_angle),)
688
+ return s
689
+
690
+
691
+ def quaternion_raw_multiply(a, b):
692
+ """
693
+ Source: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
694
+ Multiply two quaternions.
695
+ Usual torch rules for broadcasting apply.
696
+
697
+ Args:
698
+ a: Quaternions as tensor of shape (..., 4), real part first.
699
+ b: Quaternions as tensor of shape (..., 4), real part first.
700
+
701
+ Returns:
702
+ The product of a and b, a tensor of quaternions shape (..., 4).
703
+ """
704
+ aw, ax, ay, az = torch.unbind(a, -1)
705
+ bw, bx, by, bz = torch.unbind(b, -1)
706
+ ow = aw * bw - ax * bx - ay * by - az * bz
707
+ ox = aw * bx + ax * bw + ay * bz - az * by
708
+ oy = aw * by - ax * bz + ay * bw + az * bx
709
+ oz = aw * bz + ax * by - ay * bx + az * bw
710
+ return torch.stack((ow, ox, oy, oz), -1)
711
+
712
+
713
+ def quaternion_invert(quaternion):
714
+ """
715
+ Source: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
716
+ Given a quaternion representing rotation, get the quaternion representing
717
+ its inverse.
718
+
719
+ Args:
720
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
721
+ first, which must be versors (unit quaternions).
722
+
723
+ Returns:
724
+ The inverse, a tensor of quaternions of shape (..., 4).
725
+ """
726
+
727
+ return quaternion * quaternion.new_tensor([1, -1, -1, -1])
728
+
729
+
730
+ def quaternion_apply(quaternion, point):
731
+ """
732
+ Source: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
733
+ Apply the rotation given by a quaternion to a 3D point.
734
+ Usual torch rules for broadcasting apply.
735
+
736
+ Args:
737
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
738
+ point: Tensor of 3D points of shape (..., 3).
739
+
740
+ Returns:
741
+ Tensor of rotated points of shape (..., 3).
742
+ """
743
+ if point.size(-1) != 3:
744
+ raise ValueError(f"Points are not in 3D, f{point.shape}.")
745
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
746
+ point_as_quaternion = torch.cat((real_parts, point), -1)
747
+ out = quaternion_raw_multiply(
748
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
749
+ quaternion_invert(quaternion),
750
+ )
751
+ return out[..., 1:]
752
+
753
+
754
+ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
755
+ """
756
+ Source: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
757
+ Convert rotations given as axis/angle to quaternions.
758
+ Args:
759
+ axis_angle: Rotations given as a vector in axis angle form,
760
+ as a tensor of shape (..., 3), where the magnitude is
761
+ the angle turned anticlockwise in radians around the
762
+ vector's direction.
763
+ Returns:
764
+ quaternions with real part first, as tensor of shape (..., 4).
765
+ """
766
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
767
+ half_angles = angles * 0.5
768
+ eps = 1e-6
769
+ small_angles = angles.abs() < eps
770
+ sin_half_angles_over_angles = torch.empty_like(angles)
771
+ sin_half_angles_over_angles[~small_angles] = (
772
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
773
+ )
774
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
775
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
776
+ sin_half_angles_over_angles[small_angles] = (
777
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
778
+ )
779
+ quaternions = torch.cat(
780
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
781
+ )
782
+ return quaternions
common/sys_utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as op
3
+ import shutil
4
+ from glob import glob
5
+
6
+ from loguru import logger
7
+
8
+
9
+ def copy(src, dst):
10
+ if os.path.islink(src):
11
+ linkto = os.readlink(src)
12
+ os.symlink(linkto, dst)
13
+ else:
14
+ if os.path.isdir(src):
15
+ shutil.copytree(src, dst)
16
+ else:
17
+ shutil.copy(src, dst)
18
+
19
+
20
+ def copy_repo(src_files, dst_folder, filter_keywords):
21
+ src_files = [
22
+ f for f in src_files if not any(keyword in f for keyword in filter_keywords)
23
+ ]
24
+ dst_files = [op.join(dst_folder, op.basename(f)) for f in src_files]
25
+ for src_f, dst_f in zip(src_files, dst_files):
26
+ logger.info(f"FROM: {src_f}\nTO:{dst_f}")
27
+ copy(src_f, dst_f)
28
+
29
+
30
+ def mkdir(directory):
31
+ if not os.path.exists(directory):
32
+ os.makedirs(directory)
33
+
34
+
35
+ def mkdir_p(exp_path):
36
+ os.makedirs(exp_path, exist_ok=True)
37
+
38
+
39
+ def count_files(path):
40
+ """
41
+ Non-recursively count number of files in a folder.
42
+ """
43
+ files = glob(path)
44
+ return len(files)
common/thing.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ """
5
+ This file stores functions for conversion between numpy and torch, torch, list, etc.
6
+ Also deal with general operations such as to(dev), detach, etc.
7
+ """
8
+
9
+
10
+ def thing2list(thing):
11
+ if isinstance(thing, torch.Tensor):
12
+ return thing.tolist()
13
+ if isinstance(thing, np.ndarray):
14
+ return thing.tolist()
15
+ if isinstance(thing, dict):
16
+ return {k: thing2list(v) for k, v in md.items()}
17
+ if isinstance(thing, list):
18
+ return [thing2list(ten) for ten in thing]
19
+ return thing
20
+
21
+
22
+ def thing2dev(thing, dev):
23
+ if hasattr(thing, "to"):
24
+ thing = thing.to(dev)
25
+ return thing
26
+ if isinstance(thing, list):
27
+ return [thing2dev(ten, dev) for ten in thing]
28
+ if isinstance(thing, tuple):
29
+ return tuple(thing2dev(list(thing), dev))
30
+ if isinstance(thing, dict):
31
+ return {k: thing2dev(v, dev) for k, v in thing.items()}
32
+ if isinstance(thing, torch.Tensor):
33
+ return thing.to(dev)
34
+ return thing
35
+
36
+
37
+ def thing2np(thing):
38
+ if isinstance(thing, list):
39
+ return np.array(thing)
40
+ if isinstance(thing, torch.Tensor):
41
+ return thing.cpu().detach().numpy()
42
+ if isinstance(thing, dict):
43
+ return {k: thing2np(v) for k, v in thing.items()}
44
+ return thing
45
+
46
+
47
+ def thing2torch(thing):
48
+ if isinstance(thing, list):
49
+ return torch.tensor(np.array(thing))
50
+ if isinstance(thing, np.ndarray):
51
+ return torch.from_numpy(thing)
52
+ if isinstance(thing, dict):
53
+ return {k: thing2torch(v) for k, v in thing.items()}
54
+ return thing
55
+
56
+
57
+ def detach_thing(thing):
58
+ if isinstance(thing, torch.Tensor):
59
+ return thing.cpu().detach()
60
+ if isinstance(thing, list):
61
+ return [detach_thing(ten) for ten in thing]
62
+ if isinstance(thing, tuple):
63
+ return tuple(detach_thing(list(thing)))
64
+ if isinstance(thing, dict):
65
+ return {k: detach_thing(v) for k, v in thing.items()}
66
+ return thing
common/torch_utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+
8
+ from common.ld_utils import unsort as unsort_list
9
+
10
+
11
+ # pytorch implementation for np.nanmean
12
+ # https://github.com/pytorch/pytorch/issues/21987#issuecomment-539402619
13
+ def nanmean(v, *args, inplace=False, **kwargs):
14
+ if not inplace:
15
+ v = v.clone()
16
+ is_nan = torch.isnan(v)
17
+ v[is_nan] = 0
18
+ return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)
19
+
20
+
21
+ def grad_norm(model):
22
+ # compute norm of gradient for a model
23
+ total_norm = None
24
+ for p in model.parameters():
25
+ if p.grad is not None:
26
+ if total_norm is None:
27
+ total_norm = 0
28
+ param_norm = p.grad.detach().data.norm(2)
29
+ total_norm += param_norm.item() ** 2
30
+
31
+ if total_norm is not None:
32
+ total_norm = total_norm ** (1.0 / 2)
33
+ else:
34
+ total_norm = 0.0
35
+ return total_norm
36
+
37
+
38
+ def pad_tensor_list(v_list: list):
39
+ dev = v_list[0].device
40
+ num_meshes = len(v_list)
41
+ num_dim = 1 if len(v_list[0].shape) == 1 else v_list[0].shape[1]
42
+ v_len_list = []
43
+ for verts in v_list:
44
+ v_len_list.append(verts.shape[0])
45
+
46
+ pad_len = max(v_len_list)
47
+ dtype = v_list[0].dtype
48
+ if num_dim == 1:
49
+ padded_tensor = torch.zeros(num_meshes, pad_len, dtype=dtype)
50
+ else:
51
+ padded_tensor = torch.zeros(num_meshes, pad_len, num_dim, dtype=dtype)
52
+ for idx, (verts, v_len) in enumerate(zip(v_list, v_len_list)):
53
+ padded_tensor[idx, :v_len] = verts
54
+ padded_tensor = padded_tensor.to(dev)
55
+ v_len_list = torch.LongTensor(v_len_list).to(dev)
56
+ return padded_tensor, v_len_list
57
+
58
+
59
+ def unpad_vtensor(
60
+ vtensor: (torch.Tensor), lens: (torch.LongTensor, torch.cuda.LongTensor)
61
+ ):
62
+ tensors_list = []
63
+ for verts, vlen in zip(vtensor, lens):
64
+ tensors_list.append(verts[:vlen])
65
+ return tensors_list
66
+
67
+
68
+ def one_hot_embedding(labels, num_classes):
69
+ """Embedding labels to one-hot form.
70
+ Args:
71
+ labels: (LongTensor) class labels, sized [N, D1, D2, ..].
72
+ num_classes: (int) number of classes.
73
+ Returns:
74
+ (tensor) encoded labels, sized [N, D1, D2, .., Dk, #classes].
75
+ """
76
+ y = torch.eye(num_classes).float()
77
+ return y[labels]
78
+
79
+
80
+ def unsort(ten, sort_idx):
81
+ """
82
+ Unsort a tensor of shape (N, *) using the sort_idx list(N).
83
+ Return a tensor of the pre-sorting order in shape (N, *)
84
+ """
85
+ assert isinstance(ten, torch.Tensor)
86
+ assert isinstance(sort_idx, list)
87
+ assert ten.shape[0] == len(sort_idx)
88
+
89
+ out_list = list(torch.chunk(ten, ten.size(0), dim=0))
90
+ out_list = unsort_list(out_list, sort_idx)
91
+ out_list = torch.cat(out_list, dim=0)
92
+ return out_list
93
+
94
+
95
+ def all_comb(X, Y):
96
+ """
97
+ Returns all possible combinations of elements in X and Y.
98
+ X: (n_x, d_x)
99
+ Y: (n_y, d_y)
100
+ Output: Z: (n_x*x_y, d_x+d_y)
101
+ Example:
102
+ X = tensor([[8, 8, 8],
103
+ [7, 5, 9]])
104
+ Y = tensor([[3, 8, 7, 7],
105
+ [3, 7, 9, 9],
106
+ [6, 4, 3, 7]])
107
+ Z = tensor([[8, 8, 8, 3, 8, 7, 7],
108
+ [8, 8, 8, 3, 7, 9, 9],
109
+ [8, 8, 8, 6, 4, 3, 7],
110
+ [7, 5, 9, 3, 8, 7, 7],
111
+ [7, 5, 9, 3, 7, 9, 9],
112
+ [7, 5, 9, 6, 4, 3, 7]])
113
+ """
114
+ assert len(X.size()) == 2
115
+ assert len(Y.size()) == 2
116
+ X1 = X.unsqueeze(1)
117
+ Y1 = Y.unsqueeze(0)
118
+ X2 = X1.repeat(1, Y.shape[0], 1)
119
+ Y2 = Y1.repeat(X.shape[0], 1, 1)
120
+ Z = torch.cat([X2, Y2], -1)
121
+ Z = Z.view(-1, Z.shape[-1])
122
+ return Z
123
+
124
+
125
+ def toggle_parameters(model, requires_grad):
126
+ """
127
+ Set all weights to requires_grad or not.
128
+ """
129
+ for param in model.parameters():
130
+ param.requires_grad = requires_grad
131
+
132
+
133
+ def detach_tensor(ten):
134
+ """This function move tensor to cpu and convert to numpy"""
135
+ if isinstance(ten, torch.Tensor):
136
+ return ten.cpu().detach().numpy()
137
+ return ten
138
+
139
+
140
+ def count_model_parameters(model):
141
+ """
142
+ Return the amount of parameters that requries gradients.
143
+ """
144
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
145
+
146
+
147
+ def reset_all_seeds(seed):
148
+ """Reset all seeds for reproduciability."""
149
+ random.seed(seed)
150
+ torch.manual_seed(seed)
151
+ np.random.seed(seed)
152
+
153
+
154
+ def get_activation(name):
155
+ """This function return an activation constructor by name."""
156
+ if name == "tanh":
157
+ return nn.Tanh()
158
+ elif name == "sigmoid":
159
+ return nn.Sigmoid()
160
+ elif name == "relu":
161
+ return nn.ReLU()
162
+ elif name == "selu":
163
+ return nn.SELU()
164
+ elif name == "relu6":
165
+ return nn.ReLU6()
166
+ elif name == "softplus":
167
+ return nn.Softplus()
168
+ elif name == "softshrink":
169
+ return nn.Softshrink()
170
+ else:
171
+ print("Undefined activation: %s" % (name))
172
+ assert False
173
+
174
+
175
+ def stack_ll_tensors(tensor_list_list):
176
+ """
177
+ Recursively stack a list of lists of lists .. whose elements are tensors with the same shape
178
+ """
179
+ if isinstance(tensor_list_list, torch.Tensor):
180
+ return tensor_list_list
181
+ assert isinstance(tensor_list_list, list)
182
+ if isinstance(tensor_list_list[0], torch.Tensor):
183
+ return torch.stack(tensor_list_list)
184
+
185
+ stacked_tensor = []
186
+ for tensor_list in tensor_list_list:
187
+ stacked_tensor.append(stack_ll_tensors(tensor_list))
188
+ stacked_tensor = torch.stack(stacked_tensor)
189
+ return stacked_tensor
190
+
191
+
192
+ def get_optim(name):
193
+ """This function return an optimizer constructor by name."""
194
+ if name == "adam":
195
+ return optim.Adam
196
+ elif name == "rmsprop":
197
+ return optim.RMSprop
198
+ elif name == "sgd":
199
+ return optim.SGD
200
+ else:
201
+ print("Undefined optim: %s" % (name))
202
+ assert False
203
+
204
+
205
+ def decay_lr(optimizer, gamma):
206
+ """
207
+ Decay the learning rate by gamma
208
+ """
209
+ assert isinstance(gamma, float)
210
+ assert 0 <= gamma and gamma <= 1.0
211
+ for param_group in optimizer.param_groups:
212
+ param_group["lr"] *= gamma
common/transforms.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ import common.data_utils as data_utils
5
+ from common.np_utils import permute_np
6
+
7
+ """
8
+ Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula
9
+ Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
10
+ """
11
+
12
+
13
+ def to_xy(x_homo):
14
+ assert isinstance(x_homo, (torch.FloatTensor, torch.cuda.FloatTensor))
15
+ assert x_homo.shape[1] == 3
16
+ assert len(x_homo.shape) == 2
17
+ batch_size = x_homo.shape[0]
18
+ x = torch.ones(batch_size, 2, device=x_homo.device)
19
+ x = x_homo[:, :2] / x_homo[:, 2:3]
20
+ return x
21
+
22
+
23
+ def to_xyz(x_homo):
24
+ assert isinstance(x_homo, (torch.FloatTensor, torch.cuda.FloatTensor))
25
+ assert x_homo.shape[1] == 4
26
+ assert len(x_homo.shape) == 2
27
+ batch_size = x_homo.shape[0]
28
+ x = torch.ones(batch_size, 3, device=x_homo.device)
29
+ x = x_homo[:, :3] / x_homo[:, 3:4]
30
+ return x
31
+
32
+
33
+ def to_homo(x):
34
+ assert isinstance(x, (torch.FloatTensor, torch.cuda.FloatTensor))
35
+ assert x.shape[1] == 3
36
+ assert len(x.shape) == 2
37
+ batch_size = x.shape[0]
38
+ x_homo = torch.ones(batch_size, 4, device=x.device)
39
+ x_homo[:, :3] = x.clone()
40
+ return x_homo
41
+
42
+
43
+ def to_homo_batch(x):
44
+ assert isinstance(x, (torch.FloatTensor, torch.cuda.FloatTensor))
45
+ assert x.shape[2] == 3
46
+ assert len(x.shape) == 3
47
+ batch_size = x.shape[0]
48
+ num_pts = x.shape[1]
49
+ x_homo = torch.ones(batch_size, num_pts, 4, device=x.device)
50
+ x_homo[:, :, :3] = x.clone()
51
+ return x_homo
52
+
53
+
54
+ def to_xyz_batch(x_homo):
55
+ """
56
+ Input: (B, N, 4)
57
+ Ouput: (B, N, 3)
58
+ """
59
+ assert isinstance(x_homo, (torch.FloatTensor, torch.cuda.FloatTensor))
60
+ assert x_homo.shape[2] == 4
61
+ assert len(x_homo.shape) == 3
62
+ batch_size = x_homo.shape[0]
63
+ num_pts = x_homo.shape[1]
64
+ x = torch.ones(batch_size, num_pts, 3, device=x_homo.device)
65
+ x = x_homo[:, :, :3] / x_homo[:, :, 3:4]
66
+ return x
67
+
68
+
69
+ def to_xy_batch(x_homo):
70
+ assert isinstance(x_homo, (torch.FloatTensor, torch.cuda.FloatTensor))
71
+ assert x_homo.shape[2] == 3
72
+ assert len(x_homo.shape) == 3
73
+ batch_size = x_homo.shape[0]
74
+ num_pts = x_homo.shape[1]
75
+ x = torch.ones(batch_size, num_pts, 2, device=x_homo.device)
76
+ x = x_homo[:, :, :2] / x_homo[:, :, 2:3]
77
+ return x
78
+
79
+
80
+ # VR Distortion Correction Using Vertex Displacement
81
+ # https://stackoverflow.com/questions/44489686/camera-lens-distortion-in-opengl
82
+ def distort_pts3d_all(_pts_cam, dist_coeffs):
83
+ # egocentric cameras commonly has heavy distortion
84
+ # this function transform points in the undistorted camera coord
85
+ # to distorted camera coord such that the 2d projection can match the pixels.
86
+ pts_cam = _pts_cam.clone().double()
87
+ z = pts_cam[:, :, 2]
88
+
89
+ z_inv = 1 / z
90
+
91
+ x1 = pts_cam[:, :, 0] * z_inv
92
+ y1 = pts_cam[:, :, 1] * z_inv
93
+
94
+ # precalculations
95
+ x1_2 = x1 * x1
96
+ y1_2 = y1 * y1
97
+ x1_y1 = x1 * y1
98
+ r2 = x1_2 + y1_2
99
+ r4 = r2 * r2
100
+ r6 = r4 * r2
101
+
102
+ r_dist = (1 + dist_coeffs[0] * r2 + dist_coeffs[1] * r4 + dist_coeffs[4] * r6) / (
103
+ 1 + dist_coeffs[5] * r2 + dist_coeffs[6] * r4 + dist_coeffs[7] * r6
104
+ )
105
+
106
+ # full (rational + tangential) distortion
107
+ x2 = x1 * r_dist + 2 * dist_coeffs[2] * x1_y1 + dist_coeffs[3] * (r2 + 2 * x1_2)
108
+ y2 = y1 * r_dist + 2 * dist_coeffs[3] * x1_y1 + dist_coeffs[2] * (r2 + 2 * y1_2)
109
+ # denormalize for projection (which is a linear operation)
110
+ cam_pts_dist = torch.stack([x2 * z, y2 * z, z], dim=2).float()
111
+ return cam_pts_dist
112
+
113
+
114
+ def rigid_tf_torch_batch(points, R, T):
115
+ """
116
+ Performs rigid transformation to incoming points but batched
117
+ Q = (points*R.T) + T
118
+ points: (batch, num, 3)
119
+ R: (batch, 3, 3)
120
+ T: (batch, 3, 1)
121
+ out: (batch, num, 3)
122
+ """
123
+ points_out = torch.bmm(R, points.permute(0, 2, 1)) + T
124
+ points_out = points_out.permute(0, 2, 1)
125
+ return points_out
126
+
127
+
128
+ def solve_rigid_tf_np(A: np.ndarray, B: np.ndarray):
129
+ """
130
+ “Least-Squares Fitting of Two 3-D Point Sets”, Arun, K. S. , May 1987
131
+ Input: expects Nx3 matrix of points
132
+ Returns R,t
133
+ R = 3x3 rotation matrix
134
+ t = 3x1 column vector
135
+
136
+ This function should be a fix for compute_rigid_tf when the det == -1
137
+ """
138
+
139
+ assert A.shape == B.shape
140
+ A = A.T
141
+ B = B.T
142
+
143
+ num_rows, num_cols = A.shape
144
+ if num_rows != 3:
145
+ raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
146
+
147
+ num_rows, num_cols = B.shape
148
+ if num_rows != 3:
149
+ raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
150
+
151
+ # find mean column wise
152
+ centroid_A = np.mean(A, axis=1)
153
+ centroid_B = np.mean(B, axis=1)
154
+
155
+ # ensure centroids are 3x1
156
+ centroid_A = centroid_A.reshape(-1, 1)
157
+ centroid_B = centroid_B.reshape(-1, 1)
158
+
159
+ # subtract mean
160
+ Am = A - centroid_A
161
+ Bm = B - centroid_B
162
+
163
+ H = Am @ np.transpose(Bm)
164
+
165
+ # find rotation
166
+ U, S, Vt = np.linalg.svd(H)
167
+ R = Vt.T @ U.T
168
+
169
+ # special reflection case
170
+ if np.linalg.det(R) < 0:
171
+ Vt[2, :] *= -1
172
+ R = Vt.T @ U.T
173
+
174
+ t = -R @ centroid_A + centroid_B
175
+
176
+ return R, t
177
+
178
+
179
+ def batch_solve_rigid_tf(A, B):
180
+ """
181
+ “Least-Squares Fitting of Two 3-D Point Sets”, Arun, K. S. , May 1987
182
+ Input: expects BxNx3 matrix of points
183
+ Returns R,t
184
+ R = Bx3x3 rotation matrix
185
+ t = Bx3x1 column vector
186
+ """
187
+
188
+ assert A.shape == B.shape
189
+ dev = A.device
190
+ A = A.cpu().numpy()
191
+ B = B.cpu().numpy()
192
+ A = permute_np(A, (0, 2, 1))
193
+ B = permute_np(B, (0, 2, 1))
194
+
195
+ batch, num_rows, num_cols = A.shape
196
+ if num_rows != 3:
197
+ raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
198
+
199
+ _, num_rows, num_cols = B.shape
200
+ if num_rows != 3:
201
+ raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
202
+
203
+ # find mean column wise
204
+ centroid_A = np.mean(A, axis=2)
205
+ centroid_B = np.mean(B, axis=2)
206
+
207
+ # ensure centroids are 3x1
208
+ centroid_A = centroid_A.reshape(batch, -1, 1)
209
+ centroid_B = centroid_B.reshape(batch, -1, 1)
210
+
211
+ # subtract mean
212
+ Am = A - centroid_A
213
+ Bm = B - centroid_B
214
+
215
+ H = np.matmul(Am, permute_np(Bm, (0, 2, 1)))
216
+
217
+ # find rotation
218
+ U, S, Vt = np.linalg.svd(H)
219
+ R = np.matmul(permute_np(Vt, (0, 2, 1)), permute_np(U, (0, 2, 1)))
220
+
221
+ # special reflection case
222
+ neg_idx = np.linalg.det(R) < 0
223
+ if neg_idx.sum() > 0:
224
+ raise Exception(
225
+ f"some rotation matrices are not orthogonal; make sure implementation is correct for such case: {neg_idx}"
226
+ )
227
+ Vt[neg_idx, 2, :] *= -1
228
+ R[neg_idx, :, :] = np.matmul(
229
+ permute_np(Vt[neg_idx], (0, 2, 1)), permute_np(U[neg_idx], (0, 2, 1))
230
+ )
231
+
232
+ t = np.matmul(-R, centroid_A) + centroid_B
233
+
234
+ R = torch.FloatTensor(R).to(dev)
235
+ t = torch.FloatTensor(t).to(dev)
236
+ return R, t
237
+
238
+
239
+ def rigid_tf_np(points, R, T):
240
+ """
241
+ Performs rigid transformation to incoming points
242
+ Q = (points*R.T) + T
243
+ points: (num, 3)
244
+ R: (3, 3)
245
+ T: (1, 3)
246
+
247
+ out: (num, 3)
248
+ """
249
+
250
+ assert isinstance(points, np.ndarray)
251
+ assert isinstance(R, np.ndarray)
252
+ assert isinstance(T, np.ndarray)
253
+ assert len(points.shape) == 2
254
+ assert points.shape[1] == 3
255
+ assert R.shape == (3, 3)
256
+ assert T.shape == (1, 3)
257
+ points_new = np.matmul(R, points.T).T + T
258
+ return points_new
259
+
260
+
261
+ def transform_points(world2cam_mat, pts):
262
+ """
263
+ Map points from one coord to another based on the 4x4 matrix.
264
+ e.g., map points from world to camera coord.
265
+ pts: (N, 3), in METERS!!
266
+ world2cam_mat: (4, 4)
267
+ Output: points in cam coord (N, 3)
268
+ We follow this convention:
269
+ | R T | |pt|
270
+ | 0 1 | * | 1|
271
+ i.e. we rotate first then translate as T is the camera translation not position.
272
+ """
273
+ assert isinstance(pts, (torch.FloatTensor, torch.cuda.FloatTensor))
274
+ assert isinstance(world2cam_mat, (torch.FloatTensor, torch.cuda.FloatTensor))
275
+ assert world2cam_mat.shape == (4, 4)
276
+ assert len(pts.shape) == 2
277
+ assert pts.shape[1] == 3
278
+ pts_homo = to_homo(pts)
279
+
280
+ # mocap to cam
281
+ pts_cam_homo = torch.matmul(world2cam_mat, pts_homo.T).T
282
+ pts_cam = to_xyz(pts_cam_homo)
283
+
284
+ assert pts_cam.shape[1] == 3
285
+ return pts_cam
286
+
287
+
288
+ def transform_points_batch(world2cam_mat, pts):
289
+ """
290
+ Map points from one coord to another based on the 4x4 matrix.
291
+ e.g., map points from world to camera coord.
292
+ pts: (B, N, 3), in METERS!!
293
+ world2cam_mat: (B, 4, 4)
294
+ Output: points in cam coord (B, N, 3)
295
+ We follow this convention:
296
+ | R T | |pt|
297
+ | 0 1 | * | 1|
298
+ i.e. we rotate first then translate as T is the camera translation not position.
299
+ """
300
+ assert isinstance(pts, (torch.FloatTensor, torch.cuda.FloatTensor))
301
+ assert isinstance(world2cam_mat, (torch.FloatTensor, torch.cuda.FloatTensor))
302
+ assert world2cam_mat.shape[1:] == (4, 4)
303
+ assert len(pts.shape) == 3
304
+ assert pts.shape[2] == 3
305
+ batch_size = pts.shape[0]
306
+ pts_homo = to_homo_batch(pts)
307
+
308
+ # mocap to cam
309
+ pts_cam_homo = torch.bmm(world2cam_mat, pts_homo.permute(0, 2, 1)).permute(0, 2, 1)
310
+ pts_cam = to_xyz_batch(pts_cam_homo)
311
+
312
+ assert pts_cam.shape[2] == 3
313
+ return pts_cam
314
+
315
+
316
+ def project2d_batch(K, pts_cam):
317
+ """
318
+ K: (B, 3, 3)
319
+ pts_cam: (B, N, 3)
320
+ """
321
+
322
+ assert isinstance(K, (torch.FloatTensor, torch.cuda.FloatTensor))
323
+ assert isinstance(pts_cam, (torch.FloatTensor, torch.cuda.FloatTensor))
324
+ assert K.shape[1:] == (3, 3)
325
+ assert pts_cam.shape[2] == 3
326
+ assert len(pts_cam.shape) == 3
327
+ pts2d_homo = torch.bmm(K, pts_cam.permute(0, 2, 1)).permute(0, 2, 1)
328
+ pts2d = to_xy_batch(pts2d_homo)
329
+ return pts2d
330
+
331
+
332
+ def project2d_norm_batch(K, pts_cam, patch_width):
333
+ """
334
+ K: (B, 3, 3)
335
+ pts_cam: (B, N, 3)
336
+ """
337
+
338
+ assert isinstance(K, (torch.FloatTensor, torch.cuda.FloatTensor))
339
+ assert isinstance(pts_cam, (torch.FloatTensor, torch.cuda.FloatTensor))
340
+ assert K.shape[1:] == (3, 3)
341
+ assert pts_cam.shape[2] == 3
342
+ assert len(pts_cam.shape) == 3
343
+ v2d = project2d_batch(K, pts_cam)
344
+ v2d_norm = data_utils.normalize_kp2d(v2d, patch_width)
345
+ return v2d_norm
346
+
347
+
348
+ def project2d(K, pts_cam):
349
+ assert isinstance(K, (torch.FloatTensor, torch.cuda.FloatTensor))
350
+ assert isinstance(pts_cam, (torch.FloatTensor, torch.cuda.FloatTensor))
351
+ assert K.shape == (3, 3)
352
+ assert pts_cam.shape[1] == 3
353
+ assert len(pts_cam.shape) == 2
354
+ pts2d_homo = torch.matmul(K, pts_cam.T).T
355
+ pts2d = to_xy(pts2d_homo)
356
+ return pts2d
common/viewer.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as op
3
+ import re
4
+ from abc import abstractmethod
5
+
6
+ import matplotlib.cm as cm
7
+ import numpy as np
8
+ from aitviewer.headless import HeadlessRenderer
9
+ from aitviewer.renderables.billboard import Billboard
10
+ from aitviewer.renderables.meshes import Meshes
11
+ from aitviewer.scene.camera import OpenCVCamera
12
+ from aitviewer.scene.material import Material
13
+ from aitviewer.utils.so3 import aa2rot_numpy
14
+ from aitviewer.viewer import Viewer
15
+ from easydict import EasyDict as edict
16
+ from loguru import logger
17
+ from PIL import Image
18
+ from tqdm import tqdm
19
+
20
+ OBJ_ID = 100
21
+ SMPLX_ID = 150
22
+ LEFT_ID = 200
23
+ RIGHT_ID = 250
24
+ SEGM_IDS = {"object": OBJ_ID, "smplx": SMPLX_ID, "left": LEFT_ID, "right": RIGHT_ID}
25
+
26
+ cmap = cm.get_cmap("plasma")
27
+ materials = {
28
+ "none": None,
29
+ "white": Material(color=(1.0, 1.0, 1.0, 1.0), ambient=0.2),
30
+ "red": Material(color=(0.969, 0.106, 0.059, 1.0), ambient=0.2),
31
+ "blue": Material(color=(0.0, 0.0, 1.0, 1.0), ambient=0.2),
32
+ "green": Material(color=(1.0, 0.0, 0.0, 1.0), ambient=0.2),
33
+ "cyan": Material(color=(0.051, 0.659, 0.051, 1.0), ambient=0.2),
34
+ "light-blue": Material(color=(0.588, 0.5647, 0.9725, 1.0), ambient=0.2),
35
+ "cyan-light": Material(color=(0.051, 0.659, 0.051, 1.0), ambient=0.2),
36
+ "dark-light": Material(color=(0.404, 0.278, 0.278, 1.0), ambient=0.2),
37
+ "rice": Material(color=(0.922, 0.922, 0.102, 1.0), ambient=0.2),
38
+ }
39
+
40
+
41
+ class ViewerData(edict):
42
+ """
43
+ Interface to standardize viewer data.
44
+ """
45
+
46
+ def __init__(self, Rt, K, cols, rows, imgnames=None):
47
+ self.imgnames = imgnames
48
+ self.Rt = Rt
49
+ self.K = K
50
+ self.num_frames = Rt.shape[0]
51
+ self.cols = cols
52
+ self.rows = rows
53
+ self.validate_format()
54
+
55
+ def validate_format(self):
56
+ assert len(self.Rt.shape) == 3
57
+ assert self.Rt.shape[0] == self.num_frames
58
+ assert self.Rt.shape[1] == 3
59
+ assert self.Rt.shape[2] == 4
60
+
61
+ assert len(self.K.shape) == 2
62
+ assert self.K.shape[0] == 3
63
+ assert self.K.shape[1] == 3
64
+ if self.imgnames is not None:
65
+ assert self.num_frames == len(self.imgnames)
66
+ assert self.num_frames > 0
67
+ im_p = self.imgnames[0]
68
+ assert op.exists(im_p), f"Image path {im_p} does not exist"
69
+
70
+
71
+ class ARCTICViewer:
72
+ def __init__(
73
+ self,
74
+ render_types=["rgb", "depth", "mask"],
75
+ interactive=True,
76
+ size=(2024, 2024),
77
+ ):
78
+ if not interactive:
79
+ v = HeadlessRenderer()
80
+ else:
81
+ v = Viewer(size=size)
82
+
83
+ self.v = v
84
+ self.interactive = interactive
85
+ # self.layers = layers
86
+ self.render_types = render_types
87
+
88
+ def view_interactive(self):
89
+ self.v.run()
90
+
91
+ def view_fn_headless(self, num_iter, out_folder):
92
+ v = self.v
93
+
94
+ v._init_scene()
95
+
96
+ logger.info("Rendering to video")
97
+ if "video" in self.render_types:
98
+ vid_p = op.join(out_folder, "video.mp4")
99
+ v.save_video(video_dir=vid_p)
100
+
101
+ pbar = tqdm(range(num_iter))
102
+ for fidx in pbar:
103
+ out_rgb = op.join(out_folder, "images", f"rgb/{fidx:04d}.png")
104
+ out_mask = op.join(out_folder, "images", f"mask/{fidx:04d}.png")
105
+ out_depth = op.join(out_folder, "images", f"depth/{fidx:04d}.npy")
106
+
107
+ # render RGB, depth, segmentation masks
108
+ if "rgb" in self.render_types:
109
+ v.export_frame(out_rgb)
110
+ if "depth" in self.render_types:
111
+ os.makedirs(op.dirname(out_depth), exist_ok=True)
112
+ render_depth(v, out_depth)
113
+ if "mask" in self.render_types:
114
+ os.makedirs(op.dirname(out_mask), exist_ok=True)
115
+ render_mask(v, out_mask)
116
+ v.scene.next_frame()
117
+ logger.info(f"Exported to {out_folder}")
118
+
119
+ @abstractmethod
120
+ def load_data(self):
121
+ pass
122
+
123
+ def check_format(self, batch):
124
+ meshes_all, data = batch
125
+ assert isinstance(meshes_all, dict)
126
+ assert len(meshes_all) > 0
127
+ for mesh in meshes_all.values():
128
+ assert isinstance(mesh, Meshes)
129
+ assert isinstance(data, ViewerData)
130
+
131
+ def render_seq(self, batch, out_folder="./render_out"):
132
+ meshes_all, data = batch
133
+ self.setup_viewer(data)
134
+ for mesh in meshes_all.values():
135
+ self.v.scene.add(mesh)
136
+ if self.interactive:
137
+ self.view_interactive()
138
+ else:
139
+ num_iter = data["num_frames"]
140
+ self.view_fn_headless(num_iter, out_folder)
141
+
142
+ def setup_viewer(self, data):
143
+ v = self.v
144
+ fps = 30
145
+ if "imgnames" in data:
146
+ setup_billboard(data, v)
147
+
148
+ # camera.show_path()
149
+ v.run_animations = True # autoplay
150
+ v.run_animations = False # autoplay
151
+ v.playback_fps = fps
152
+ v.scene.fps = fps
153
+ v.scene.origin.enabled = False
154
+ v.scene.floor.enabled = False
155
+ v.auto_set_floor = False
156
+ v.scene.floor.position[1] = -3
157
+ # v.scene.camera.position = np.array((0.0, 0.0, 0))
158
+ self.v = v
159
+
160
+
161
+ def dist2vc(dist_ro, dist_lo, dist_o, _cmap, tf_fn=None):
162
+ if tf_fn is not None:
163
+ exp_map = tf_fn
164
+ else:
165
+ exp_map = small_exp_map
166
+ dist_ro = exp_map(dist_ro)
167
+ dist_lo = exp_map(dist_lo)
168
+ dist_o = exp_map(dist_o)
169
+
170
+ vc_ro = _cmap(dist_ro)
171
+ vc_lo = _cmap(dist_lo)
172
+ vc_o = _cmap(dist_o)
173
+ return vc_ro, vc_lo, vc_o
174
+
175
+
176
+ def small_exp_map(_dist):
177
+ dist = np.copy(_dist)
178
+ # dist = 1.0 - np.clip(dist, 0, 0.1) / 0.1
179
+ dist = np.exp(-20.0 * dist)
180
+ return dist
181
+
182
+
183
+ def construct_viewer_meshes(data, draw_edges=False, flat_shading=True):
184
+ rotation_flip = aa2rot_numpy(np.array([1, 0, 0]) * np.pi)
185
+ meshes = {}
186
+ for key, val in data.items():
187
+ if "object" in key:
188
+ flat_shading = False
189
+ else:
190
+ flat_shading = flat_shading
191
+ v3d = val["v3d"]
192
+ meshes[key] = Meshes(
193
+ v3d,
194
+ val["f3d"],
195
+ vertex_colors=val["vc"],
196
+ name=val["name"],
197
+ flat_shading=flat_shading,
198
+ draw_edges=draw_edges,
199
+ material=materials[val["color"]],
200
+ rotation=rotation_flip,
201
+ )
202
+ return meshes
203
+
204
+
205
+ def setup_viewer(
206
+ v, shared_folder_p, video, images_path, data, flag, seq_name, side_angle
207
+ ):
208
+ fps = 10
209
+ cols, rows = 224, 224
210
+ focal = 1000.0
211
+
212
+ # setup image paths
213
+ regex = re.compile(r"(\d*)$")
214
+
215
+ def sort_key(x):
216
+ name = os.path.splitext(x)[0]
217
+ return int(regex.search(name).group(0))
218
+
219
+ # setup billboard
220
+ images_path = op.join(shared_folder_p, "images")
221
+ images_paths = [
222
+ os.path.join(images_path, f)
223
+ for f in sorted(os.listdir(images_path), key=sort_key)
224
+ ]
225
+ assert len(images_paths) > 0
226
+
227
+ cam_t = data[f"{flag}.object.cam_t"]
228
+ num_frames = min(cam_t.shape[0], len(images_paths))
229
+ cam_t = cam_t[:num_frames]
230
+ # setup camera
231
+ K = np.array([[focal, 0, rows / 2.0], [0, focal, cols / 2.0], [0, 0, 1]])
232
+ Rt = np.zeros((num_frames, 3, 4))
233
+ Rt[:, :, 3] = cam_t
234
+ Rt[:, :3, :3] = np.eye(3)
235
+ Rt[:, 1:3, :3] *= -1.0
236
+
237
+ camera = OpenCVCamera(K, Rt, cols, rows, viewer=v)
238
+ if side_angle is None:
239
+ billboard = Billboard.from_camera_and_distance(
240
+ camera, 10.0, cols, rows, images_paths
241
+ )
242
+ v.scene.add(billboard)
243
+ v.scene.add(camera)
244
+ v.run_animations = True # autoplay
245
+ v.playback_fps = fps
246
+ v.scene.fps = fps
247
+ v.scene.origin.enabled = False
248
+ v.scene.floor.enabled = False
249
+ v.auto_set_floor = False
250
+ v.scene.floor.position[1] = -3
251
+ v.set_temp_camera(camera)
252
+ # v.scene.camera.position = np.array((0.0, 0.0, 0))
253
+ return v
254
+
255
+
256
+ def render_depth(v, depth_p):
257
+ depth = np.array(v.get_depth()).astype(np.float16)
258
+ np.save(depth_p, depth)
259
+
260
+
261
+ def render_mask(v, mask_p):
262
+ nodes_uid = {node.name: node.uid for node in v.scene.collect_nodes()}
263
+ my_cmap = {
264
+ uid: [SEGM_IDS[name], SEGM_IDS[name], SEGM_IDS[name]]
265
+ for name, uid in nodes_uid.items()
266
+ if name in SEGM_IDS.keys()
267
+ }
268
+ mask = np.array(v.get_mask(color_map=my_cmap)).astype(np.uint8)
269
+ mask = Image.fromarray(mask)
270
+ mask.save(mask_p)
271
+
272
+
273
+ def setup_billboard(data, v):
274
+ images_paths = data.imgnames
275
+ K = data.K
276
+ Rt = data.Rt
277
+ rows = data.rows
278
+ cols = data.cols
279
+ camera = OpenCVCamera(K, Rt, cols, rows, viewer=v)
280
+ if images_paths is not None:
281
+ billboard = Billboard.from_camera_and_distance(
282
+ camera, 10.0, cols, rows, images_paths
283
+ )
284
+ v.scene.add(billboard)
285
+ v.scene.add(camera)
286
+ v.scene.camera.load_cam()
287
+ v.set_temp_camera(camera)
common/vis_utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.cm as cm
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ # connection between the 8 points of 3d bbox
7
+ BONES_3D_BBOX = [
8
+ (0, 1),
9
+ (1, 2),
10
+ (2, 3),
11
+ (3, 0),
12
+ (0, 4),
13
+ (1, 5),
14
+ (2, 6),
15
+ (3, 7),
16
+ (4, 5),
17
+ (5, 6),
18
+ (6, 7),
19
+ (7, 4),
20
+ ]
21
+
22
+
23
+ def plot_2d_bbox(bbox_2d, bones, color, ax):
24
+ if ax is None:
25
+ axx = plt
26
+ else:
27
+ axx = ax
28
+ colors = cm.rainbow(np.linspace(0, 1, len(bbox_2d)))
29
+ for pt, c in zip(bbox_2d, colors):
30
+ axx.scatter(pt[0], pt[1], color=c, s=50)
31
+
32
+ if bones is None:
33
+ bones = BONES_3D_BBOX
34
+ for bone in bones:
35
+ sidx, eidx = bone
36
+ # bottom of bbox is white
37
+ if min(sidx, eidx) >= 4:
38
+ color = "w"
39
+ axx.plot(
40
+ [bbox_2d[sidx][0], bbox_2d[eidx][0]],
41
+ [bbox_2d[sidx][1], bbox_2d[eidx][1]],
42
+ color,
43
+ )
44
+ return axx
45
+
46
+
47
+ # http://www.icare.univ-lille1.fr/tutorials/convert_a_matplotlib_figure
48
+ def fig2data(fig):
49
+ """
50
+ @brief Convert a Matplotlib figure to a 4D
51
+ numpy array with RGBA channels and return it
52
+ @param fig a matplotlib figure
53
+ @return a numpy 3D array of RGBA values
54
+ """
55
+ # draw the renderer
56
+ fig.canvas.draw()
57
+
58
+ # Get the RGBA buffer from the figure
59
+ w, h = fig.canvas.get_width_height()
60
+ buf = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
61
+ buf.shape = (w, h, 4)
62
+
63
+ # canvas.tostring_argb give pixmap in ARGB mode.
64
+ # Roll the ALPHA channel to have it in RGBA mode
65
+ buf = np.roll(buf, 3, axis=2)
66
+ return buf
67
+
68
+
69
+ # http://www.icare.univ-lille1.fr/tutorials/convert_a_matplotlib_figure
70
+ def fig2img(fig):
71
+ """
72
+ @brief Convert a Matplotlib figure to a PIL Image
73
+ in RGBA format and return it
74
+ @param fig a matplotlib figure
75
+ @return a Python Imaging Library ( PIL ) image
76
+ """
77
+ # put the figure pixmap into a numpy array
78
+ buf = fig2data(fig)
79
+ w, h, _ = buf.shape
80
+ return Image.frombytes("RGBA", (w, h), buf.tobytes())
81
+
82
+
83
+ def concat_pil_images(images):
84
+ """
85
+ Put a list of PIL images next to each other
86
+ """
87
+ assert isinstance(images, list)
88
+ widths, heights = zip(*(i.size for i in images))
89
+
90
+ total_width = sum(widths)
91
+ max_height = max(heights)
92
+
93
+ new_im = Image.new("RGB", (total_width, max_height))
94
+
95
+ x_offset = 0
96
+ for im in images:
97
+ new_im.paste(im, (x_offset, 0))
98
+ x_offset += im.size[0]
99
+ return new_im
100
+
101
+
102
+ def stack_pil_images(images):
103
+ """
104
+ Stack a list of PIL images next to each other
105
+ """
106
+ assert isinstance(images, list)
107
+ widths, heights = zip(*(i.size for i in images))
108
+
109
+ total_height = sum(heights)
110
+ max_width = max(widths)
111
+
112
+ new_im = Image.new("RGB", (max_width, total_height))
113
+
114
+ y_offset = 0
115
+ for im in images:
116
+ new_im.paste(im, (0, y_offset))
117
+ y_offset += im.size[1]
118
+ return new_im
119
+
120
+
121
+ def im_list_to_plt(image_list, figsize, title_list=None):
122
+ fig, axes = plt.subplots(nrows=1, ncols=len(image_list), figsize=figsize)
123
+ for idx, (ax, im) in enumerate(zip(axes, image_list)):
124
+ ax.imshow(im)
125
+ ax.set_title(title_list[idx])
126
+ fig.tight_layout()
127
+ im = fig2img(fig)
128
+ plt.close()
129
+ return im
common/xdict.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ import common.thing as thing
5
+
6
+
7
+ def _print_stat(key, thing):
8
+ """
9
+ Helper function for printing statistics about a key-value pair in an xdict.
10
+ """
11
+ mytype = type(thing)
12
+ if isinstance(thing, (list, tuple)):
13
+ print("{:<20}: {:<30}\t{:}".format(key, len(thing), mytype))
14
+ elif isinstance(thing, (torch.Tensor)):
15
+ dev = thing.device
16
+ shape = str(thing.shape).replace(" ", "")
17
+ print("{:<20}: {:<30}\t{:}\t{}".format(key, shape, mytype, dev))
18
+ elif isinstance(thing, (np.ndarray)):
19
+ dev = ""
20
+ shape = str(thing.shape).replace(" ", "")
21
+ print("{:<20}: {:<30}\t{:}".format(key, shape, mytype))
22
+ else:
23
+ print("{:<20}: {:}".format(key, mytype))
24
+
25
+
26
+ class xdict(dict):
27
+ """
28
+ A subclass of Python's built-in dict class, which provides additional methods for manipulating and operating on dictionaries.
29
+ """
30
+
31
+ def __init__(self, mydict=None):
32
+ """
33
+ Constructor for the xdict class. Creates a new xdict object and optionally initializes it with key-value pairs from the provided dictionary mydict. If mydict is not provided, an empty xdict is created.
34
+ """
35
+ if mydict is None:
36
+ return
37
+
38
+ for k, v in mydict.items():
39
+ super().__setitem__(k, v)
40
+
41
+ def subset(self, keys):
42
+ """
43
+ Returns a new xdict object containing only the key-value pairs with keys in the provided list 'keys'.
44
+ """
45
+ out_dict = {}
46
+ for k in keys:
47
+ out_dict[k] = self[k]
48
+ return xdict(out_dict)
49
+
50
+ def __setitem__(self, key, val):
51
+ """
52
+ Overrides the dict.__setitem__ method to raise an assertion error if a key already exists.
53
+ """
54
+ assert key not in self.keys(), f"Key already exists {key}"
55
+ super().__setitem__(key, val)
56
+
57
+ def search(self, keyword, replace_to=None):
58
+ """
59
+ Returns a new xdict object containing only the key-value pairs with keys that contain the provided keyword.
60
+ """
61
+ out_dict = {}
62
+ for k in self.keys():
63
+ if keyword in k:
64
+ if replace_to is None:
65
+ out_dict[k] = self[k]
66
+ else:
67
+ out_dict[k.replace(keyword, replace_to)] = self[k]
68
+ return xdict(out_dict)
69
+
70
+ def rm(self, keyword, keep_list=[], verbose=False):
71
+ """
72
+ Returns a new xdict object with keys that contain keyword removed. Keys in keep_list are excluded from the removal.
73
+ """
74
+ out_dict = {}
75
+ for k in self.keys():
76
+ if keyword not in k or k in keep_list:
77
+ out_dict[k] = self[k]
78
+ else:
79
+ if verbose:
80
+ print(f"Removing: {k}")
81
+ return xdict(out_dict)
82
+
83
+ def overwrite(self, k, v):
84
+ """
85
+ The original assignment operation of Python dict
86
+ """
87
+ super().__setitem__(k, v)
88
+
89
+ def merge(self, dict2):
90
+ """
91
+ Same as dict.update(), but raises an assertion error if there are duplicate keys between the two dictionaries.
92
+
93
+ Args:
94
+ dict2 (dict or xdict): The dictionary or xdict instance to merge with.
95
+
96
+ Raises:
97
+ AssertionError: If dict2 is not a dictionary or xdict instance.
98
+ AssertionError: If there are duplicate keys between the two instances.
99
+ """
100
+ assert isinstance(dict2, (dict, xdict))
101
+ mykeys = set(self.keys())
102
+ intersect = mykeys.intersection(set(dict2.keys()))
103
+ assert len(intersect) == 0, f"Merge failed: duplicate keys ({intersect})"
104
+ self.update(dict2)
105
+
106
+ def mul(self, scalar):
107
+ """
108
+ Multiplies each value (could be tensor, np.array, list) in the xdict instance by the provided scalar.
109
+
110
+ Args:
111
+ scalar (float): The scalar to multiply the values by.
112
+
113
+ Raises:
114
+ AssertionError: If scalar is not a float.
115
+ """
116
+ if isinstance(scalar, int):
117
+ scalar = 1.0 * scalar
118
+ assert isinstance(scalar, float)
119
+ out_dict = {}
120
+ for k in self.keys():
121
+ if isinstance(self[k], list):
122
+ out_dict[k] = [v * scalar for v in self[k]]
123
+ else:
124
+ out_dict[k] = self[k] * scalar
125
+ return xdict(out_dict)
126
+
127
+ def prefix(self, text):
128
+ """
129
+ Adds a prefix to each key in the xdict instance.
130
+
131
+ Args:
132
+ text (str): The prefix to add.
133
+
134
+ Returns:
135
+ xdict: The xdict instance with the added prefix.
136
+ """
137
+ out_dict = {}
138
+ for k in self.keys():
139
+ out_dict[text + k] = self[k]
140
+ return xdict(out_dict)
141
+
142
+ def replace_keys(self, str_src, str_tar):
143
+ """
144
+ Replaces a substring in all keys of the xdict instance.
145
+
146
+ Args:
147
+ str_src (str): The substring to replace.
148
+ str_tar (str): The replacement string.
149
+
150
+ Returns:
151
+ xdict: The xdict instance with the replaced keys.
152
+ """
153
+ out_dict = {}
154
+ for k in self.keys():
155
+ old_key = k
156
+ new_key = old_key.replace(str_src, str_tar)
157
+ out_dict[new_key] = self[k]
158
+ return xdict(out_dict)
159
+
160
+ def postfix(self, text):
161
+ """
162
+ Adds a postfix to each key in the xdict instance.
163
+
164
+ Args:
165
+ text (str): The postfix to add.
166
+
167
+ Returns:
168
+ xdict: The xdict instance with the added postfix.
169
+ """
170
+ out_dict = {}
171
+ for k in self.keys():
172
+ out_dict[k + text] = self[k]
173
+ return xdict(out_dict)
174
+
175
+ def sorted_keys(self):
176
+ """
177
+ Returns a sorted list of the keys in the xdict instance.
178
+
179
+ Returns:
180
+ list: A sorted list of keys in the xdict instance.
181
+ """
182
+ return sorted(list(self.keys()))
183
+
184
+ def to(self, dev):
185
+ """
186
+ Moves the xdict instance to a specific device.
187
+
188
+ Args:
189
+ dev (torch.device): The device to move the instance to.
190
+
191
+ Returns:
192
+ xdict: The xdict instance moved to the specified device.
193
+ """
194
+ if dev is None:
195
+ return self
196
+ raw_dict = dict(self)
197
+ return xdict(thing.thing2dev(raw_dict, dev))
198
+
199
+ def to_torch(self):
200
+ """
201
+ Converts elements in the xdict to Torch tensors and returns a new xdict.
202
+
203
+ Returns:
204
+ xdict: A new xdict with Torch tensors as values.
205
+ """
206
+ return xdict(thing.thing2torch(self))
207
+
208
+ def to_np(self):
209
+ """
210
+ Converts elements in the xdict to numpy arrays and returns a new xdict.
211
+
212
+ Returns:
213
+ xdict: A new xdict with numpy arrays as values.
214
+ """
215
+ return xdict(thing.thing2np(self))
216
+
217
+ def tolist(self):
218
+ """
219
+ Converts elements in the xdict to Python lists and returns a new xdict.
220
+
221
+ Returns:
222
+ xdict: A new xdict with Python lists as values.
223
+ """
224
+ return xdict(thing.thing2list(self))
225
+
226
+ def print_stat(self):
227
+ """
228
+ Prints statistics for each item in the xdict.
229
+ """
230
+ for k, v in self.items():
231
+ _print_stat(k, v)
232
+
233
+ def detach(self):
234
+ """
235
+ Detaches all Torch tensors in the xdict from the computational graph and moves them to the CPU.
236
+ Non-tensor objects are ignored.
237
+
238
+ Returns:
239
+ xdict: A new xdict with detached Torch tensors as values.
240
+ """
241
+ return xdict(thing.detach_thing(self))
242
+
243
+ def has_invalid(self):
244
+ """
245
+ Checks if any of the Torch tensors in the xdict contain NaN or Inf values.
246
+
247
+ Returns:
248
+ bool: True if at least one tensor contains NaN or Inf values, False otherwise.
249
+ """
250
+ for k, v in self.items():
251
+ if isinstance(v, torch.Tensor):
252
+ if torch.isnan(v).any():
253
+ print(f"{k} contains nan values")
254
+ return True
255
+ if torch.isinf(v).any():
256
+ print(f"{k} contains inf values")
257
+ return True
258
+ return False
259
+
260
+ def apply(self, operation, criterion=None):
261
+ """
262
+ Applies an operation to the values in the xdict, based on an optional criterion.
263
+
264
+ Args:
265
+ operation (callable): A callable object that takes a single argument and returns a value.
266
+ criterion (callable, optional): A callable object that takes two arguments (key and value) and returns a boolean.
267
+
268
+ Returns:
269
+ xdict: A new xdict with the same keys as the original, but with the values modified by the operation.
270
+ """
271
+ out = {}
272
+ for k, v in self.items():
273
+ if criterion is None or criterion(k, v):
274
+ out[k] = operation(v)
275
+ return xdict(out)
276
+
277
+ def save(self, path, dev=None, verbose=True):
278
+ """
279
+ Saves the xdict to disk as a Torch tensor.
280
+
281
+ Args:
282
+ path (str): The path to save the xdict.
283
+ dev (torch.device, optional): The device to use for saving the tensor (default is CPU).
284
+ verbose (bool, optional): Whether to print a message indicating that the xdict has been saved (default is True).
285
+ """
286
+ if verbose:
287
+ print(f"Saving to {path}")
288
+ torch.save(self.to(dev), path)
data_loaders/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data_loaders/__pycache__/get_data.cpython-38.pyc ADDED
Binary file (4.42 kB). View file
 
data_loaders/__pycache__/tensors.cpython-38.pyc ADDED
Binary file (6.98 kB). View file
 
data_loaders/get_data.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from data_loaders.tensors import collate as all_collate
3
+ from data_loaders.tensors import t2m_collate, motion_ours_collate, motion_ours_singe_seq_collate, motion_ours_obj_base_rel_dist_collate
4
+ # from data_loaders.humanml.data.dataset import HumanML3D
5
+ import torch
6
+
7
+ def get_dataset_class(name, args=None):
8
+ if name == "amass":
9
+ from .amass import AMASS
10
+ return AMASS
11
+ elif name == "uestc":
12
+ from .a2m.uestc import UESTC
13
+ return UESTC
14
+ elif name == "humanact12":
15
+ from .a2m.humanact12poses import HumanAct12Poses
16
+ return HumanAct12Poses ## to pose ##
17
+ elif name == "humanml":
18
+ from data_loaders.humanml.data.dataset import HumanML3D
19
+ return HumanML3D
20
+ elif name == "kit":
21
+ from data_loaders.humanml.data.dataset import KIT
22
+ return KIT
23
+ elif name == "motion_ours": # motion ours
24
+ if len(args.single_seq_path) > 0 and not args.use_predicted_infos and not args.use_interpolated_infos:
25
+ print(f"Using single frame dataset for evaluation purpose...")
26
+ # from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V16
27
+ if args.rep_type == "obj_base_rel_dist":
28
+ from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V17 as my_data
29
+ elif args.rep_type == "ambient_obj_base_rel_dist":
30
+ from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V18 as my_data
31
+ elif args.rep_type in[ "obj_base_rel_dist_we", "obj_base_rel_dist_we_wj", "obj_base_rel_dist_we_wj_latents"]:
32
+ if args.use_arctic and args.use_pose_pred:
33
+ from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_Arctic_from_Pred as my_data
34
+ elif args.use_hho:
35
+ from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_HHO as my_data
36
+ elif args.use_arctic:
37
+ from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_Arctic as my_data
38
+ elif len(args.cad_model_fn) > 0:
39
+ from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_Ours as my_data
40
+ elif len(args.predicted_info_fn) > 0:
41
+ from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_From_Evaluated_Info as my_data
42
+ else:
43
+ from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19 as my_data
44
+ else:
45
+ from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V16 as my_data
46
+ return my_data
47
+ else:
48
+ if args.rep_type == "obj_base_rel_dist":
49
+ from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V17 as my_data
50
+ elif args.rep_type == "ambient_obj_base_rel_dist":
51
+ from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V18 as my_data
52
+ elif args.rep_type in ["obj_base_rel_dist_we", "obj_base_rel_dist_we_wj", "obj_base_rel_dist_we_wj_latents"]:
53
+ if args.use_arctic:
54
+ from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V19_ARCTIC as my_data
55
+ elif args.use_vox_data: # use vox data here #
56
+ from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V20 as my_data
57
+ elif args.use_predicted_infos: # train with predicted infos for test tim adaptation #
58
+ from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V21 as my_data
59
+ elif args.use_interpolated_infos:
60
+ # GRAB_Dataset_V22
61
+ from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V22 as my_data
62
+ else:
63
+ from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V19 as my_data
64
+ else:
65
+ from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V16 as my_data
66
+ return my_data
67
+ # from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V16
68
+ # return GRAB_Dataset_V16
69
+ else:
70
+ raise ValueError(f'Unsupported dataset name [{name}]')
71
+
72
+ def get_collate_fn(name, hml_mode='train', args=None):
73
+ print(f"name: {name}, hml_mode: {hml_mode}")
74
+ if hml_mode == 'gt':
75
+ from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate
76
+ return t2m_eval_collate
77
+ if name in ["humanml", "kit"]:
78
+ return t2m_collate
79
+ elif name in ["motion_ours"]:
80
+ ## === single seq path === ##
81
+ print(f"single_seq_path: {args.single_seq_path}, rep_type: {args.rep_type}")
82
+ # motion_ours_obj_base_rel_dist_collate
83
+ ## rep_type of the obj_base_pts rel_dist; ambient obj base rel dist ##
84
+ if args.rep_type in ["obj_base_rel_dist", "ambient_obj_base_rel_dist", "obj_base_rel_dist_we", "obj_base_rel_dist_we_wj", "obj_base_rel_dist_we_wj_latents"]:
85
+ return motion_ours_obj_base_rel_dist_collate
86
+ else: # single_seq_path #
87
+ if len(args.single_seq_path) > 0:
88
+ return motion_ours_singe_seq_collate
89
+ else:
90
+ return motion_ours_collate
91
+ # if len(args.single_seq_path) > 0:
92
+ # return motion_ours_singe_seq_collate
93
+ # else:
94
+ # if args.rep_type == "obj_base_rel_dist":
95
+ # return motion_ours_obj_base_rel_dist_collate
96
+ # else:
97
+ # return motion_ours_collate
98
+ else:
99
+ return all_collate
100
+
101
+ ## get dataset and datasset ###
102
+ def get_dataset(name, num_frames, split='train', hml_mode='train', args=None):
103
+ DATA = get_dataset_class(name, args=args)
104
+ if name in ["humanml", "kit"]:
105
+ dataset = DATA(split=split, num_frames=num_frames, mode=hml_mode)
106
+ elif name in ["motion_ours"]:
107
+ # humanml_datawarper = HumanML3D(split=split, num_frames=num_frames, mode=hml_mode, load_vectorizer=True)
108
+ # w_vectorizer = humanml_datawarper.w_vectorizer
109
+
110
+ w_vectorizer = None
111
+ # split = "val" ## add split, split here --> split --> split and split ##
112
+ data_path = "/data1/sim/GRAB_processed"
113
+ # split, w_vectorizer, window_size=30, step_size=15, num_points=8000, args=None
114
+ window_size = args.window_size
115
+ # split= "val"
116
+ dataset = DATA(data_path, split=split, w_vectorizer=w_vectorizer, window_size=window_size, step_size=15, num_points=8000, args=args)
117
+ else:
118
+ dataset = DATA(split=split, num_frames=num_frames)
119
+ return dataset
120
+
121
+
122
+ def get_dataset_only(name, batch_size, num_frames, split='train', hml_mode='train', args=None):
123
+ dataset = get_dataset(name, num_frames, split, hml_mode, args=args)
124
+ return dataset
125
+
126
+ # python -m train.train_mdm --save_dir save/my_humanml_trans_enc_512 --dataset motion_ours
127
+ def get_dataset_loader(name, batch_size, num_frames, split='train', hml_mode='train', args=None):
128
+ dataset = get_dataset(name, num_frames, split, hml_mode, args=args)
129
+ collate = get_collate_fn(name, hml_mode, args=args)
130
+
131
+ if args is not None and name in ["motion_ours"] and len(args.single_seq_path) > 0:
132
+ shuffle_loader = False
133
+ drop_last = False
134
+ else:
135
+ shuffle_loader = True
136
+ drop_last = True
137
+
138
+ num_workers = 8 ## get data; get data loader ##
139
+ num_workers = 16 # num_workers # ## num_workders #
140
+ ### ==== create dataloader here ==== ###
141
+ ### ==== create dataloader here ==== ###
142
+ loader = DataLoader( # tag for each sequence
143
+ dataset, batch_size=batch_size, shuffle=shuffle_loader,
144
+ num_workers=num_workers, drop_last=drop_last, collate_fn=collate
145
+ )
146
+
147
+ return loader
148
+
149
+
150
+ # python -m train.train_mdm --save_dir save/my_humanml_trans_enc_512 --dataset motion_ours
151
+ def get_dataset_loader_dist(name, batch_size, num_frames, split='train', hml_mode='train', args=None):
152
+ dataset = get_dataset(name, num_frames, split, hml_mode, args=args)
153
+ collate = get_collate_fn(name, hml_mode, args=args)
154
+
155
+ if args is not None and name in ["motion_ours"] and len(args.single_seq_path) > 0:
156
+ # shuffle_loader = False
157
+ drop_last = False
158
+ else:
159
+ # shuffle_loader = True
160
+ drop_last = True
161
+
162
+ num_workers = 8 ## get data; get data loader ##
163
+ num_workers = 16 # num_workers # ## num_workders #
164
+ ### ==== create dataloader here ==== ###
165
+ ### ==== create dataloader here ==== ###
166
+
167
+ ''' dist sampler and loader '''
168
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset)
169
+ loader = DataLoader(dataset, batch_size=batch_size,
170
+ sampler=sampler, num_workers=num_workers, drop_last=drop_last, collate_fn=collate)
171
+
172
+
173
+ # loader = DataLoader( # tag for each sequence
174
+ # dataset, batch_size=batch_size, shuffle=shuffle_loader,
175
+ # num_workers=num_workers, drop_last=drop_last, collate_fn=collate
176
+ # )
177
+
178
+ return loader
data_loaders/humanml/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data_loaders/humanml/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This code is based on https://github.com/EricGuo5513/text-to-motion.git
data_loaders/humanml/common/__pycache__/quaternion.cpython-38.pyc ADDED
Binary file (11.6 kB). View file
 
data_loaders/humanml/common/__pycache__/skeleton.cpython-38.pyc ADDED
Binary file (6.15 kB). View file
 
data_loaders/humanml/common/quaternion.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+ import numpy as np
10
+
11
+ _EPS4 = np.finfo(float).eps * 4.0
12
+
13
+ _FLOAT_EPS = np.finfo(np.float).eps
14
+
15
+ # PyTorch-backed implementations
16
+ def qinv(q):
17
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
18
+ mask = torch.ones_like(q)
19
+ mask[..., 1:] = -mask[..., 1:]
20
+ return q * mask
21
+
22
+
23
+ def qinv_np(q):
24
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25
+ return qinv(torch.from_numpy(q).float()).numpy()
26
+
27
+
28
+ def qnormalize(q):
29
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
30
+ return q / torch.norm(q, dim=-1, keepdim=True)
31
+
32
+
33
+ def qmul(q, r):
34
+ """
35
+ Multiply quaternion(s) q with quaternion(s) r.
36
+ Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
37
+ Returns q*r as a tensor of shape (*, 4).
38
+ """
39
+ assert q.shape[-1] == 4
40
+ assert r.shape[-1] == 4
41
+
42
+ original_shape = q.shape
43
+
44
+ # Compute outer product
45
+ terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
46
+
47
+ w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
48
+ x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
49
+ y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
50
+ z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
51
+ return torch.stack((w, x, y, z), dim=1).view(original_shape)
52
+
53
+
54
+ def qrot(q, v):
55
+ """
56
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
57
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
58
+ where * denotes any number of dimensions.
59
+ Returns a tensor of shape (*, 3).
60
+ """
61
+ assert q.shape[-1] == 4
62
+ assert v.shape[-1] == 3
63
+ assert q.shape[:-1] == v.shape[:-1]
64
+
65
+ original_shape = list(v.shape)
66
+ # print(q.shape)
67
+ q = q.contiguous().view(-1, 4)
68
+ v = v.contiguous().view(-1, 3)
69
+
70
+ qvec = q[:, 1:]
71
+ uv = torch.cross(qvec, v, dim=1)
72
+ uuv = torch.cross(qvec, uv, dim=1)
73
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
74
+
75
+
76
+ def qeuler(q, order, epsilon=0, deg=True):
77
+ """
78
+ Convert quaternion(s) q to Euler angles.
79
+ Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
80
+ Returns a tensor of shape (*, 3).
81
+ """
82
+ assert q.shape[-1] == 4
83
+
84
+ original_shape = list(q.shape)
85
+ original_shape[-1] = 3
86
+ q = q.view(-1, 4)
87
+
88
+ q0 = q[:, 0]
89
+ q1 = q[:, 1]
90
+ q2 = q[:, 2]
91
+ q3 = q[:, 3]
92
+
93
+ if order == 'xyz':
94
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
95
+ y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
96
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
97
+ elif order == 'yzx':
98
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
99
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
100
+ z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
101
+ elif order == 'zxy':
102
+ x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
103
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
104
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
105
+ elif order == 'xzy':
106
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
107
+ y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
108
+ z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
109
+ elif order == 'yxz':
110
+ x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
111
+ y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
112
+ z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
113
+ elif order == 'zyx':
114
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
115
+ y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
116
+ z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
117
+ else:
118
+ raise
119
+
120
+ if deg:
121
+ return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
122
+ else:
123
+ return torch.stack((x, y, z), dim=1).view(original_shape)
124
+
125
+
126
+ # Numpy-backed implementations
127
+
128
+ def qmul_np(q, r):
129
+ q = torch.from_numpy(q).contiguous().float()
130
+ r = torch.from_numpy(r).contiguous().float()
131
+ return qmul(q, r).numpy()
132
+
133
+
134
+ def qrot_np(q, v):
135
+ q = torch.from_numpy(q).contiguous().float()
136
+ v = torch.from_numpy(v).contiguous().float()
137
+ return qrot(q, v).numpy()
138
+
139
+
140
+ def qeuler_np(q, order, epsilon=0, use_gpu=False):
141
+ if use_gpu:
142
+ q = torch.from_numpy(q).cuda().float()
143
+ return qeuler(q, order, epsilon).cpu().numpy()
144
+ else:
145
+ q = torch.from_numpy(q).contiguous().float()
146
+ return qeuler(q, order, epsilon).numpy()
147
+
148
+
149
+ def qfix(q):
150
+ """
151
+ Enforce quaternion continuity across the time dimension by selecting
152
+ the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
153
+ between two consecutive frames.
154
+
155
+ Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
156
+ Returns a tensor of the same shape.
157
+ """
158
+ assert len(q.shape) == 3
159
+ assert q.shape[-1] == 4
160
+
161
+ result = q.copy()
162
+ dot_products = np.sum(q[1:] * q[:-1], axis=2)
163
+ mask = dot_products < 0
164
+ mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
165
+ result[1:][mask] *= -1
166
+ return result
167
+
168
+
169
+ def euler2quat(e, order, deg=True):
170
+ """
171
+ Convert Euler angles to quaternions.
172
+ """
173
+ assert e.shape[-1] == 3
174
+
175
+ original_shape = list(e.shape)
176
+ original_shape[-1] = 4
177
+
178
+ e = e.view(-1, 3)
179
+
180
+ ## if euler angles in degrees
181
+ if deg:
182
+ e = e * np.pi / 180.
183
+
184
+ x = e[:, 0]
185
+ y = e[:, 1]
186
+ z = e[:, 2]
187
+
188
+ rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
189
+ ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
190
+ rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
191
+
192
+ result = None
193
+ for coord in order:
194
+ if coord == 'x':
195
+ r = rx
196
+ elif coord == 'y':
197
+ r = ry
198
+ elif coord == 'z':
199
+ r = rz
200
+ else:
201
+ raise
202
+ if result is None:
203
+ result = r
204
+ else:
205
+ result = qmul(result, r)
206
+
207
+ # Reverse antipodal representation to have a non-negative "w"
208
+ if order in ['xyz', 'yzx', 'zxy']:
209
+ result *= -1
210
+
211
+ return result.view(original_shape)
212
+
213
+
214
+ def expmap_to_quaternion(e):
215
+ """
216
+ Convert axis-angle rotations (aka exponential maps) to quaternions.
217
+ Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
218
+ Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
219
+ Returns a tensor of shape (*, 4).
220
+ """
221
+ assert e.shape[-1] == 3
222
+
223
+ original_shape = list(e.shape)
224
+ original_shape[-1] = 4
225
+ e = e.reshape(-1, 3)
226
+
227
+ theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
228
+ w = np.cos(0.5 * theta).reshape(-1, 1)
229
+ xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
230
+ return np.concatenate((w, xyz), axis=1).reshape(original_shape)
231
+
232
+
233
+ def euler_to_quaternion(e, order):
234
+ """
235
+ Convert Euler angles to quaternions.
236
+ """
237
+ assert e.shape[-1] == 3
238
+
239
+ original_shape = list(e.shape)
240
+ original_shape[-1] = 4
241
+
242
+ e = e.reshape(-1, 3)
243
+
244
+ x = e[:, 0]
245
+ y = e[:, 1]
246
+ z = e[:, 2]
247
+
248
+ rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
249
+ ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
250
+ rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
251
+
252
+ result = None
253
+ for coord in order:
254
+ if coord == 'x':
255
+ r = rx
256
+ elif coord == 'y':
257
+ r = ry
258
+ elif coord == 'z':
259
+ r = rz
260
+ else:
261
+ raise
262
+ if result is None:
263
+ result = r
264
+ else:
265
+ result = qmul_np(result, r)
266
+
267
+ # Reverse antipodal representation to have a non-negative "w"
268
+ if order in ['xyz', 'yzx', 'zxy']:
269
+ result *= -1
270
+
271
+ return result.reshape(original_shape)
272
+
273
+
274
+ def quaternion_to_matrix(quaternions):
275
+ """
276
+ Convert rotations given as quaternions to rotation matrices.
277
+ Args:
278
+ quaternions: quaternions with real part first,
279
+ as tensor of shape (..., 4).
280
+ Returns:
281
+ Rotation matrices as tensor of shape (..., 3, 3).
282
+ """
283
+ r, i, j, k = torch.unbind(quaternions, -1)
284
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
285
+
286
+ o = torch.stack(
287
+ (
288
+ 1 - two_s * (j * j + k * k),
289
+ two_s * (i * j - k * r),
290
+ two_s * (i * k + j * r),
291
+ two_s * (i * j + k * r),
292
+ 1 - two_s * (i * i + k * k),
293
+ two_s * (j * k - i * r),
294
+ two_s * (i * k - j * r),
295
+ two_s * (j * k + i * r),
296
+ 1 - two_s * (i * i + j * j),
297
+ ),
298
+ -1,
299
+ )
300
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
301
+
302
+
303
+ def quaternion_to_matrix_np(quaternions):
304
+ q = torch.from_numpy(quaternions).contiguous().float()
305
+ return quaternion_to_matrix(q).numpy()
306
+
307
+
308
+ def quaternion_to_cont6d_np(quaternions):
309
+ rotation_mat = quaternion_to_matrix_np(quaternions)
310
+ cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
311
+ return cont_6d
312
+
313
+
314
+ def quaternion_to_cont6d(quaternions):
315
+ rotation_mat = quaternion_to_matrix(quaternions)
316
+ cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
317
+ return cont_6d
318
+
319
+
320
+ def cont6d_to_matrix(cont6d):
321
+ assert cont6d.shape[-1] == 6, "The last dimension must be 6"
322
+ x_raw = cont6d[..., 0:3]
323
+ y_raw = cont6d[..., 3:6]
324
+
325
+ x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
326
+ z = torch.cross(x, y_raw, dim=-1)
327
+ z = z / torch.norm(z, dim=-1, keepdim=True)
328
+
329
+ y = torch.cross(z, x, dim=-1)
330
+
331
+ x = x[..., None]
332
+ y = y[..., None]
333
+ z = z[..., None]
334
+
335
+ mat = torch.cat([x, y, z], dim=-1)
336
+ return mat
337
+
338
+
339
+ def cont6d_to_matrix_np(cont6d):
340
+ q = torch.from_numpy(cont6d).contiguous().float()
341
+ return cont6d_to_matrix(q).numpy()
342
+
343
+
344
+ def qpow(q0, t, dtype=torch.float):
345
+ ''' q0 : tensor of quaternions
346
+ t: tensor of powers
347
+ '''
348
+ q0 = qnormalize(q0)
349
+ theta0 = torch.acos(q0[..., 0])
350
+
351
+ ## if theta0 is close to zero, add epsilon to avoid NaNs
352
+ mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
353
+ theta0 = (1 - mask) * theta0 + mask * 10e-10
354
+ v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
355
+
356
+ if isinstance(t, torch.Tensor):
357
+ q = torch.zeros(t.shape + q0.shape)
358
+ theta = t.view(-1, 1) * theta0.view(1, -1)
359
+ else: ## if t is a number
360
+ q = torch.zeros(q0.shape)
361
+ theta = t * theta0
362
+
363
+ q[..., 0] = torch.cos(theta)
364
+ q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
365
+
366
+ return q.to(dtype)
367
+
368
+
369
+ def qslerp(q0, q1, t):
370
+ '''
371
+ q0: starting quaternion
372
+ q1: ending quaternion
373
+ t: array of points along the way
374
+
375
+ Returns:
376
+ Tensor of Slerps: t.shape + q0.shape
377
+ '''
378
+
379
+ q0 = qnormalize(q0)
380
+ q1 = qnormalize(q1)
381
+ q_ = qpow(qmul(q1, qinv(q0)), t)
382
+
383
+ return qmul(q_,
384
+ q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
385
+
386
+
387
+ def qbetween(v0, v1):
388
+ '''
389
+ find the quaternion used to rotate v0 to v1
390
+ '''
391
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
392
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
393
+
394
+ v = torch.cross(v0, v1)
395
+ w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
396
+ keepdim=True)
397
+ return qnormalize(torch.cat([w, v], dim=-1))
398
+
399
+
400
+ def qbetween_np(v0, v1):
401
+ '''
402
+ find the quaternion used to rotate v0 to v1
403
+ '''
404
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
405
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
406
+
407
+ v0 = torch.from_numpy(v0).float()
408
+ v1 = torch.from_numpy(v1).float()
409
+ return qbetween(v0, v1).numpy()
410
+
411
+
412
+ def lerp(p0, p1, t):
413
+ if not isinstance(t, torch.Tensor):
414
+ t = torch.Tensor([t])
415
+
416
+ new_shape = t.shape + p0.shape
417
+ new_view_t = t.shape + torch.Size([1] * len(p0.shape))
418
+ new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
419
+ p0 = p0.view(new_view_p).expand(new_shape)
420
+ p1 = p1.view(new_view_p).expand(new_shape)
421
+ t = t.view(new_view_t).expand(new_shape)
422
+
423
+ return p0 + t * (p1 - p0)
data_loaders/humanml/common/skeleton.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_loaders.humanml.common.quaternion import *
2
+ import scipy.ndimage.filters as filters
3
+
4
+ class Skeleton(object):
5
+ def __init__(self, offset, kinematic_tree, device):
6
+ self.device = device
7
+ self._raw_offset_np = offset.numpy()
8
+ self._raw_offset = offset.clone().detach().to(device).float()
9
+ self._kinematic_tree = kinematic_tree
10
+ self._offset = None
11
+ self._parents = [0] * len(self._raw_offset)
12
+ self._parents[0] = -1
13
+ for chain in self._kinematic_tree:
14
+ for j in range(1, len(chain)):
15
+ self._parents[chain[j]] = chain[j-1]
16
+
17
+ def njoints(self):
18
+ return len(self._raw_offset)
19
+
20
+ def offset(self):
21
+ return self._offset
22
+
23
+ def set_offset(self, offsets):
24
+ self._offset = offsets.clone().detach().to(self.device).float()
25
+
26
+ def kinematic_tree(self):
27
+ return self._kinematic_tree
28
+
29
+ def parents(self):
30
+ return self._parents
31
+
32
+ # joints (batch_size, joints_num, 3)
33
+ def get_offsets_joints_batch(self, joints):
34
+ assert len(joints.shape) == 3
35
+ _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
36
+ for i in range(1, self._raw_offset.shape[0]):
37
+ _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
38
+
39
+ self._offset = _offsets.detach()
40
+ return _offsets
41
+
42
+ # joints (joints_num, 3)
43
+ def get_offsets_joints(self, joints):
44
+ assert len(joints.shape) == 2
45
+ _offsets = self._raw_offset.clone()
46
+ for i in range(1, self._raw_offset.shape[0]):
47
+ # print(joints.shape)
48
+ _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
49
+
50
+ self._offset = _offsets.detach()
51
+ return _offsets
52
+
53
+ # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
54
+ # joints (batch_size, joints_num, 3)
55
+ def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):
56
+ assert len(face_joint_idx) == 4
57
+ '''Get Forward Direction'''
58
+ l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
59
+ across1 = joints[:, r_hip] - joints[:, l_hip]
60
+ across2 = joints[:, sdr_r] - joints[:, sdr_l]
61
+ across = across1 + across2
62
+ across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis]
63
+ # print(across1.shape, across2.shape)
64
+
65
+ # forward (batch_size, 3)
66
+ forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
67
+ if smooth_forward:
68
+ forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
69
+ # forward (batch_size, 3)
70
+ forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis]
71
+
72
+ '''Get Root Rotation'''
73
+ target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
74
+ root_quat = qbetween_np(forward, target)
75
+
76
+ '''Inverse Kinematics'''
77
+ # quat_params (batch_size, joints_num, 4)
78
+ # print(joints.shape[:-1])
79
+ quat_params = np.zeros(joints.shape[:-1] + (4,))
80
+ # print(quat_params.shape)
81
+ root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])
82
+ quat_params[:, 0] = root_quat
83
+ # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]])
84
+ for chain in self._kinematic_tree:
85
+ R = root_quat
86
+ for j in range(len(chain) - 1):
87
+ # (batch, 3)
88
+ u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)
89
+ # print(u.shape)
90
+ # (batch, 3)
91
+ v = joints[:, chain[j+1]] - joints[:, chain[j]]
92
+ v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis]
93
+ # print(u.shape, v.shape)
94
+ rot_u_v = qbetween_np(u, v)
95
+
96
+ R_loc = qmul_np(qinv_np(R), rot_u_v)
97
+
98
+ quat_params[:,chain[j + 1], :] = R_loc
99
+ R = qmul_np(R, R_loc)
100
+
101
+ return quat_params
102
+
103
+ # Be sure root joint is at the beginning of kinematic chains
104
+ def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
105
+ # quat_params (batch_size, joints_num, 4)
106
+ # joints (batch_size, joints_num, 3)
107
+ # root_pos (batch_size, 3)
108
+ if skel_joints is not None:
109
+ offsets = self.get_offsets_joints_batch(skel_joints)
110
+ if len(self._offset.shape) == 2:
111
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
112
+ joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
113
+ joints[:, 0] = root_pos
114
+ for chain in self._kinematic_tree:
115
+ if do_root_R:
116
+ R = quat_params[:, 0]
117
+ else:
118
+ R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
119
+ for i in range(1, len(chain)):
120
+ R = qmul(R, quat_params[:, chain[i]])
121
+ offset_vec = offsets[:, chain[i]]
122
+ joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
123
+ return joints
124
+
125
+ # Be sure root joint is at the beginning of kinematic chains
126
+ def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
127
+ # quat_params (batch_size, joints_num, 4)
128
+ # joints (batch_size, joints_num, 3)
129
+ # root_pos (batch_size, 3)
130
+ if skel_joints is not None:
131
+ skel_joints = torch.from_numpy(skel_joints)
132
+ offsets = self.get_offsets_joints_batch(skel_joints)
133
+ if len(self._offset.shape) == 2:
134
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
135
+ offsets = offsets.numpy()
136
+ joints = np.zeros(quat_params.shape[:-1] + (3,))
137
+ joints[:, 0] = root_pos
138
+ for chain in self._kinematic_tree:
139
+ if do_root_R:
140
+ R = quat_params[:, 0]
141
+ else:
142
+ R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
143
+ for i in range(1, len(chain)):
144
+ R = qmul_np(R, quat_params[:, chain[i]])
145
+ offset_vec = offsets[:, chain[i]]
146
+ joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
147
+ return joints
148
+
149
+ def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
150
+ # cont6d_params (batch_size, joints_num, 6)
151
+ # joints (batch_size, joints_num, 3)
152
+ # root_pos (batch_size, 3)
153
+ if skel_joints is not None:
154
+ skel_joints = torch.from_numpy(skel_joints)
155
+ offsets = self.get_offsets_joints_batch(skel_joints)
156
+ if len(self._offset.shape) == 2:
157
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
158
+ offsets = offsets.numpy()
159
+ joints = np.zeros(cont6d_params.shape[:-1] + (3,))
160
+ joints[:, 0] = root_pos
161
+ for chain in self._kinematic_tree:
162
+ if do_root_R:
163
+ matR = cont6d_to_matrix_np(cont6d_params[:, 0])
164
+ else:
165
+ matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
166
+ for i in range(1, len(chain)):
167
+ matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
168
+ offset_vec = offsets[:, chain[i]][..., np.newaxis]
169
+ # print(matR.shape, offset_vec.shape)
170
+ joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
171
+ return joints
172
+
173
+ def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
174
+ # cont6d_params (batch_size, joints_num, 6)
175
+ # joints (batch_size, joints_num, 3)
176
+ # root_pos (batch_size, 3)
177
+ if skel_joints is not None:
178
+ # skel_joints = torch.from_numpy(skel_joints)
179
+ offsets = self.get_offsets_joints_batch(skel_joints)
180
+ if len(self._offset.shape) == 2:
181
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
182
+ joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
183
+ joints[..., 0, :] = root_pos
184
+ for chain in self._kinematic_tree:
185
+ if do_root_R:
186
+ matR = cont6d_to_matrix(cont6d_params[:, 0])
187
+ else:
188
+ matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
189
+ for i in range(1, len(chain)):
190
+ matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
191
+ offset_vec = offsets[:, chain[i]].unsqueeze(-1)
192
+ # print(matR.shape, offset_vec.shape)
193
+ joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
194
+ return joints
195
+
196
+
197
+
198
+
199
+
data_loaders/humanml/data/__init__.py ADDED
File without changes
data_loaders/humanml/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (172 Bytes). View file
 
data_loaders/humanml/data/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (19.1 kB). View file
 
data_loaders/humanml/data/__pycache__/dataset_ours.cpython-38.pyc ADDED
Binary file (73.1 kB). View file
 
data_loaders/humanml/data/__pycache__/dataset_ours_single_seq.cpython-38.pyc ADDED
Binary file (87.9 kB). View file
 
data_loaders/humanml/data/__pycache__/utils.cpython-38.pyc ADDED
Binary file (15.3 kB). View file
 
data_loaders/humanml/data/dataset.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ import numpy as np
4
+ import os
5
+ from os.path import join as pjoin
6
+ import random
7
+ import codecs as cs
8
+ from tqdm import tqdm
9
+ import spacy
10
+
11
+ from torch.utils.data._utils.collate import default_collate
12
+ from data_loaders.humanml.utils.word_vectorizer import WordVectorizer
13
+ from data_loaders.humanml.utils.get_opt import get_opt
14
+
15
+ # import spacy
16
+
17
+ def collate_fn(batch):
18
+ batch.sort(key=lambda x: x[3], reverse=True)
19
+ return default_collate(batch)
20
+
21
+
22
+ '''For use of training text-2-motion generative model'''
23
+ class Text2MotionDataset(data.Dataset):
24
+ def __init__(self, opt, mean, std, split_file, w_vectorizer):
25
+ self.opt = opt
26
+ self.w_vectorizer = w_vectorizer
27
+ self.max_length = 20
28
+ self.pointer = 0
29
+ min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
30
+
31
+ joints_num = opt.joints_num
32
+
33
+ data_dict = {}
34
+ id_list = []
35
+ with cs.open(split_file, 'r') as f:
36
+ for line in f.readlines():
37
+ id_list.append(line.strip())
38
+
39
+ new_name_list = []
40
+ length_list = []
41
+ for name in tqdm(id_list):
42
+ try:
43
+ motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
44
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
45
+ continue
46
+ text_data = []
47
+ flag = False
48
+ with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
49
+ for line in f.readlines():
50
+ text_dict = {}
51
+ line_split = line.strip().split('#')
52
+ caption = line_split[0]
53
+ tokens = line_split[1].split(' ')
54
+ f_tag = float(line_split[2])
55
+ to_tag = float(line_split[3])
56
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
57
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
58
+
59
+ text_dict['caption'] = caption
60
+ text_dict['tokens'] = tokens
61
+ if f_tag == 0.0 and to_tag == 0.0:
62
+ flag = True
63
+ text_data.append(text_dict)
64
+ else:
65
+ try:
66
+ n_motion = motion[int(f_tag*20) : int(to_tag*20)]
67
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
68
+ continue
69
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
70
+ while new_name in data_dict:
71
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
72
+ data_dict[new_name] = {'motion': n_motion,
73
+ 'length': len(n_motion),
74
+ 'text':[text_dict]}
75
+ new_name_list.append(new_name)
76
+ length_list.append(len(n_motion))
77
+ except:
78
+ print(line_split)
79
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
80
+ # break
81
+
82
+ if flag:
83
+ data_dict[name] = {'motion': motion,
84
+ 'length': len(motion),
85
+ 'text':text_data}
86
+ new_name_list.append(name)
87
+ length_list.append(len(motion))
88
+ except:
89
+ # Some motion may not exist in KIT dataset
90
+ pass
91
+
92
+
93
+ name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
94
+
95
+ if opt.is_train:
96
+ # root_rot_velocity (B, seq_len, 1)
97
+ std[0:1] = std[0:1] / opt.feat_bias
98
+ # root_linear_velocity (B, seq_len, 2)
99
+ std[1:3] = std[1:3] / opt.feat_bias
100
+ # root_y (B, seq_len, 1)
101
+ std[3:4] = std[3:4] / opt.feat_bias
102
+ # ric_data (B, seq_len, (joint_num - 1)*3)
103
+ std[4: 4 + (joints_num - 1) * 3] = std[4: 4 + (joints_num - 1) * 3] / 1.0
104
+ # rot_data (B, seq_len, (joint_num - 1)*6)
105
+ std[4 + (joints_num - 1) * 3: 4 + (joints_num - 1) * 9] = std[4 + (joints_num - 1) * 3: 4 + (
106
+ joints_num - 1) * 9] / 1.0
107
+ # local_velocity (B, seq_len, joint_num*3)
108
+ std[4 + (joints_num - 1) * 9: 4 + (joints_num - 1) * 9 + joints_num * 3] = std[
109
+ 4 + (joints_num - 1) * 9: 4 + (
110
+ joints_num - 1) * 9 + joints_num * 3] / 1.0
111
+ # foot contact (B, seq_len, 4)
112
+ std[4 + (joints_num - 1) * 9 + joints_num * 3:] = std[
113
+ 4 + (joints_num - 1) * 9 + joints_num * 3:] / opt.feat_bias
114
+
115
+ assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1]
116
+ np.save(pjoin(opt.meta_dir, 'mean.npy'), mean)
117
+ np.save(pjoin(opt.meta_dir, 'std.npy'), std)
118
+
119
+ self.mean = mean
120
+ self.std = std
121
+ self.length_arr = np.array(length_list)
122
+ self.data_dict = data_dict
123
+ self.name_list = name_list
124
+ self.reset_max_len(self.max_length)
125
+
126
+ def reset_max_len(self, length):
127
+ assert length <= self.opt.max_motion_length
128
+ self.pointer = np.searchsorted(self.length_arr, length)
129
+ print("Pointer Pointing at %d"%self.pointer)
130
+ self.max_length = length
131
+
132
+ def inv_transform(self, data):
133
+ return data * self.std + self.mean
134
+
135
+ def __len__(self):
136
+ return len(self.data_dict) - self.pointer
137
+
138
+ def __getitem__(self, item):
139
+ idx = self.pointer + item
140
+ data = self.data_dict[self.name_list[idx]]
141
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
142
+ # Randomly select a caption
143
+ text_data = random.choice(text_list)
144
+ caption, tokens = text_data['caption'], text_data['tokens']
145
+
146
+ if len(tokens) < self.opt.max_text_len:
147
+ # pad with "unk"
148
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
149
+ sent_len = len(tokens)
150
+ tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
151
+ else:
152
+ # crop
153
+ tokens = tokens[:self.opt.max_text_len]
154
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
155
+ sent_len = len(tokens)
156
+ pos_one_hots = []
157
+ word_embeddings = []
158
+ for token in tokens:
159
+ word_emb, pos_oh = self.w_vectorizer[token]
160
+ pos_one_hots.append(pos_oh[None, :])
161
+ word_embeddings.append(word_emb[None, :])
162
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
163
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
164
+
165
+ len_gap = (m_length - self.max_length) // self.opt.unit_length
166
+
167
+ if self.opt.is_train:
168
+ if m_length != self.max_length:
169
+ # print("Motion original length:%d_%d"%(m_length, len(motion)))
170
+ if self.opt.unit_length < 10:
171
+ coin2 = np.random.choice(['single', 'single', 'double'])
172
+ else:
173
+ coin2 = 'single'
174
+ if len_gap == 0 or (len_gap == 1 and coin2 == 'double'):
175
+ m_length = self.max_length
176
+ idx = random.randint(0, m_length - self.max_length)
177
+ motion = motion[idx:idx+self.max_length]
178
+ else:
179
+ if coin2 == 'single':
180
+ n_m_length = self.max_length + self.opt.unit_length * len_gap
181
+ else:
182
+ n_m_length = self.max_length + self.opt.unit_length * (len_gap - 1)
183
+ idx = random.randint(0, m_length - n_m_length)
184
+ motion = motion[idx:idx + self.max_length]
185
+ m_length = n_m_length
186
+ # print(len_gap, idx, coin2)
187
+ else:
188
+ if self.opt.unit_length < 10:
189
+ coin2 = np.random.choice(['single', 'single', 'double'])
190
+ else:
191
+ coin2 = 'single'
192
+
193
+ if coin2 == 'double':
194
+ m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length
195
+ elif coin2 == 'single':
196
+ m_length = (m_length // self.opt.unit_length) * self.opt.unit_length
197
+ idx = random.randint(0, len(motion) - m_length)
198
+ motion = motion[idx:idx+m_length]
199
+
200
+ "Z Normalization"
201
+ motion = (motion - self.mean) / self.std
202
+
203
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length
204
+
205
+
206
+ '''For use of training text motion matching model, and evaluations'''
207
+ ## text2motions dataset v2 ##
208
+ class Text2MotionDatasetV2(data.Dataset): # text2motion dataset
209
+ def __init__(self, opt, mean, std, split_file, w_vectorizer):
210
+ self.opt = opt
211
+ self.w_vectorizer = w_vectorizer
212
+ self.max_length = 20
213
+ self.pointer = 0
214
+ self.max_motion_length = opt.max_motion_length
215
+ min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
216
+
217
+ data_dict = {}
218
+ id_list = []
219
+ with cs.open(split_file, 'r') as f:
220
+ for line in f.readlines():
221
+ id_list.append(line.strip()) ## id list ##
222
+ # id_list = id_list[:200]
223
+
224
+ new_name_list = []
225
+ length_list = []
226
+ for name in tqdm(id_list):
227
+ try:
228
+ ## motion_dir ##
229
+ motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
230
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
231
+ continue
232
+ text_data = []
233
+ flag = False
234
+ ## motionnn
235
+ with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
236
+ for line in f.readlines():
237
+ text_dict = {}
238
+ line_split = line.strip().split('#')
239
+ caption = line_split[0]
240
+ tokens = line_split[1].split(' ')
241
+ f_tag = float(line_split[2])
242
+ to_tag = float(line_split[3])
243
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
244
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
245
+
246
+ text_dict['caption'] = caption ## caption, motion, ##
247
+ text_dict['tokens'] = tokens
248
+ if f_tag == 0.0 and to_tag == 0.0:
249
+ flag = True
250
+ text_data.append(text_dict)
251
+ else:
252
+ try:
253
+ n_motion = motion[int(f_tag*20) : int(to_tag*20)]
254
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
255
+ continue
256
+ # new name for indexing #
257
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
258
+ while new_name in data_dict:
259
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
260
+ data_dict[new_name] = {'motion': n_motion,
261
+ 'length': len(n_motion), ## length of motion ##
262
+ 'text':[text_dict]}
263
+ new_name_list.append(new_name)
264
+ length_list.append(len(n_motion))
265
+ except:
266
+ print(line_split)
267
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
268
+ # break
269
+
270
+ if flag:
271
+ ## motion, lenght, text ##
272
+ data_dict[name] = {'motion': motion, ## motion, length of the motion, text data
273
+ 'length': len(motion), ## motion, lenght, text ##
274
+ 'text': text_data}
275
+ new_name_list.append(name)
276
+ length_list.append(len(motion))
277
+ except:
278
+ pass
279
+
280
+ name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
281
+
282
+ self.mean = mean
283
+ self.std = std
284
+ self.length_arr = np.array(length_list)
285
+ self.data_dict = data_dict
286
+ self.name_list = name_list
287
+ self.reset_max_len(self.max_length)
288
+
289
+ def reset_max_len(self, length):
290
+ assert length <= self.max_motion_length
291
+ self.pointer = np.searchsorted(self.length_arr, length)
292
+ print("Pointer Pointing at %d"%self.pointer)
293
+ self.max_length = length
294
+
295
+ def inv_transform(self, data):
296
+ return data * self.std + self.mean
297
+
298
+ def __len__(self):
299
+ return len(self.data_dict) - self.pointer
300
+
301
+ def __getitem__(self, item):
302
+ idx = self.pointer + item
303
+ data = self.data_dict[self.name_list[idx]] # data
304
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
305
+ # Randomly select a caption
306
+ text_data = random.choice(text_list)
307
+ caption, tokens = text_data['caption'], text_data['tokens']
308
+
309
+ if len(tokens) < self.opt.max_text_len:
310
+ # pad with "unk"
311
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
312
+ sent_len = len(tokens)
313
+ tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
314
+ else:
315
+ # crop
316
+ tokens = tokens[:self.opt.max_text_len]
317
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
318
+ sent_len = len(tokens)
319
+ pos_one_hots = [] ## pose one hots ##
320
+ word_embeddings = []
321
+ for token in tokens:
322
+ word_emb, pos_oh = self.w_vectorizer[token]
323
+ pos_one_hots.append(pos_oh[None, :])
324
+ word_embeddings.append(word_emb[None, :])
325
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
326
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
327
+
328
+ # Crop the motions in to times of 4, and introduce small variations
329
+ if self.opt.unit_length < 10:
330
+ coin2 = np.random.choice(['single', 'single', 'double'])
331
+ else:
332
+ coin2 = 'single'
333
+
334
+ if coin2 == 'double':
335
+ m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length
336
+ elif coin2 == 'single':
337
+ m_length = (m_length // self.opt.unit_length) * self.opt.unit_length
338
+ idx = random.randint(0, len(motion) - m_length)
339
+ motion = motion[idx:idx+m_length]
340
+
341
+ "Z Normalization"
342
+ motion = (motion - self.mean) / self.std
343
+
344
+ if m_length < self.max_motion_length:
345
+ motion = np.concatenate([motion, # positions # right? #
346
+ np.zeros((self.max_motion_length - m_length, motion.shape[1]))
347
+ ], axis=0)
348
+ # print(word_embeddings.shape, motion.shape)
349
+ # print(tokens)
350
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
351
+
352
+
353
+ ## and
354
+ '''For use of training baseline'''
355
+ class Text2MotionDatasetBaseline(data.Dataset):
356
+ def __init__(self, opt, mean, std, split_file, w_vectorizer):
357
+ self.opt = opt
358
+ self.w_vectorizer = w_vectorizer
359
+ self.max_length = 20
360
+ self.pointer = 0
361
+ self.max_motion_length = opt.max_motion_length
362
+ min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
363
+
364
+ data_dict = {}
365
+ id_list = []
366
+ with cs.open(split_file, 'r') as f:
367
+ for line in f.readlines():
368
+ id_list.append(line.strip())
369
+ # id_list = id_list[:200]
370
+
371
+ new_name_list = []
372
+ length_list = []
373
+ for name in tqdm(id_list):
374
+ try:
375
+ motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
376
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
377
+ continue
378
+ text_data = []
379
+ flag = False
380
+ with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
381
+ for line in f.readlines():
382
+ text_dict = {}
383
+ line_split = line.strip().split('#')
384
+ caption = line_split[0]
385
+ tokens = line_split[1].split(' ')
386
+ f_tag = float(line_split[2])
387
+ to_tag = float(line_split[3])
388
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
389
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
390
+
391
+ text_dict['caption'] = caption
392
+ text_dict['tokens'] = tokens
393
+ if f_tag == 0.0 and to_tag == 0.0:
394
+ flag = True
395
+ text_data.append(text_dict)
396
+ else:
397
+ try:
398
+ n_motion = motion[int(f_tag*20) : int(to_tag*20)]
399
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
400
+ continue
401
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
402
+ while new_name in data_dict:
403
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
404
+ data_dict[new_name] = {'motion': n_motion,
405
+ 'length': len(n_motion),
406
+ 'text':[text_dict]}
407
+ new_name_list.append(new_name)
408
+ length_list.append(len(n_motion))
409
+ except:
410
+ print(line_split)
411
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
412
+ # break
413
+
414
+ if flag:
415
+ data_dict[name] = {'motion': motion,
416
+ 'length': len(motion),
417
+ 'text': text_data}
418
+ new_name_list.append(name)
419
+ length_list.append(len(motion))
420
+ except:
421
+ pass
422
+
423
+ name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
424
+
425
+ self.mean = mean
426
+ self.std = std
427
+ self.length_arr = np.array(length_list)
428
+ self.data_dict = data_dict
429
+ self.name_list = name_list
430
+ self.reset_max_len(self.max_length)
431
+
432
+ def reset_max_len(self, length):
433
+ assert length <= self.max_motion_length
434
+ self.pointer = np.searchsorted(self.length_arr, length)
435
+ print("Pointer Pointing at %d"%self.pointer)
436
+ self.max_length = length
437
+
438
+ def inv_transform(self, data):
439
+ return data * self.std + self.mean
440
+
441
+ def __len__(self):
442
+ return len(self.data_dict) - self.pointer
443
+
444
+ def __getitem__(self, item):
445
+ idx = self.pointer + item
446
+ data = self.data_dict[self.name_list[idx]]
447
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
448
+ # Randomly select a caption
449
+ text_data = random.choice(text_list)
450
+ caption, tokens = text_data['caption'], text_data['tokens']
451
+
452
+ if len(tokens) < self.opt.max_text_len:
453
+ # pad with "unk"
454
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
455
+ sent_len = len(tokens)
456
+ tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
457
+ else:
458
+ # crop
459
+ tokens = tokens[:self.opt.max_text_len]
460
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
461
+ sent_len = len(tokens)
462
+ pos_one_hots = []
463
+ word_embeddings = []
464
+ for token in tokens:
465
+ word_emb, pos_oh = self.w_vectorizer[token]
466
+ pos_one_hots.append(pos_oh[None, :])
467
+ word_embeddings.append(word_emb[None, :])
468
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
469
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
470
+
471
+ len_gap = (m_length - self.max_length) // self.opt.unit_length
472
+
473
+ if m_length != self.max_length:
474
+ # print("Motion original length:%d_%d"%(m_length, len(motion)))
475
+ if self.opt.unit_length < 10:
476
+ coin2 = np.random.choice(['single', 'single', 'double'])
477
+ else:
478
+ coin2 = 'single'
479
+ if len_gap == 0 or (len_gap == 1 and coin2 == 'double'):
480
+ m_length = self.max_length
481
+ s_idx = random.randint(0, m_length - self.max_length)
482
+ else:
483
+ if coin2 == 'single':
484
+ n_m_length = self.max_length + self.opt.unit_length * len_gap
485
+ else:
486
+ n_m_length = self.max_length + self.opt.unit_length * (len_gap - 1)
487
+ s_idx = random.randint(0, m_length - n_m_length)
488
+ m_length = n_m_length
489
+ else:
490
+ s_idx = 0
491
+
492
+ src_motion = motion[s_idx: s_idx + m_length]
493
+ tgt_motion = motion[s_idx: s_idx + self.max_length]
494
+
495
+ "Z Normalization"
496
+ src_motion = (src_motion - self.mean) / self.std
497
+ tgt_motion = (tgt_motion - self.mean) / self.std
498
+
499
+ if m_length < self.max_motion_length:
500
+ src_motion = np.concatenate([src_motion,
501
+ np.zeros((self.max_motion_length - m_length, motion.shape[1]))
502
+ ], axis=0)
503
+ # print(m_length, src_motion.shape, tgt_motion.shape)
504
+ # print(word_embeddings.shape, motion.shape)
505
+ # print(tokens)
506
+ return word_embeddings, caption, sent_len, src_motion, tgt_motion, m_length
507
+
508
+
509
+ class MotionDatasetV2(data.Dataset):
510
+ def __init__(self, opt, mean, std, split_file):
511
+ self.opt = opt
512
+ joints_num = opt.joints_num
513
+
514
+ self.data = []
515
+ self.lengths = []
516
+ id_list = []
517
+ with cs.open(split_file, 'r') as f:
518
+ for line in f.readlines():
519
+ id_list.append(line.strip())
520
+
521
+ for name in tqdm(id_list):
522
+ try:
523
+ motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
524
+ if motion.shape[0] < opt.window_size:
525
+ continue
526
+ self.lengths.append(motion.shape[0] - opt.window_size)
527
+ self.data.append(motion)
528
+ except:
529
+ # Some motion may not exist in KIT dataset
530
+ pass
531
+
532
+ self.cumsum = np.cumsum([0] + self.lengths)
533
+
534
+ if opt.is_train:
535
+ # root_rot_velocity (B, seq_len, 1)
536
+ std[0:1] = std[0:1] / opt.feat_bias
537
+ # root_linear_velocity (B, seq_len, 2)
538
+ std[1:3] = std[1:3] / opt.feat_bias
539
+ # root_y (B, seq_len, 1)
540
+ std[3:4] = std[3:4] / opt.feat_bias
541
+ # ric_data (B, seq_len, (joint_num - 1)*3)
542
+ std[4: 4 + (joints_num - 1) * 3] = std[4: 4 + (joints_num - 1) * 3] / 1.0
543
+ # rot_data (B, seq_len, (joint_num - 1)*6)
544
+ std[4 + (joints_num - 1) * 3: 4 + (joints_num - 1) * 9] = std[4 + (joints_num - 1) * 3: 4 + (
545
+ joints_num - 1) * 9] / 1.0
546
+ # local_velocity (B, seq_len, joint_num*3)
547
+ std[4 + (joints_num - 1) * 9: 4 + (joints_num - 1) * 9 + joints_num * 3] = std[
548
+ 4 + (joints_num - 1) * 9: 4 + (
549
+ joints_num - 1) * 9 + joints_num * 3] / 1.0
550
+ # foot contact (B, seq_len, 4)
551
+ std[4 + (joints_num - 1) * 9 + joints_num * 3:] = std[
552
+ 4 + (joints_num - 1) * 9 + joints_num * 3:] / opt.feat_bias
553
+
554
+ assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1]
555
+ np.save(pjoin(opt.meta_dir, 'mean.npy'), mean)
556
+ np.save(pjoin(opt.meta_dir, 'std.npy'), std)
557
+
558
+ self.mean = mean
559
+ self.std = std
560
+ print("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1]))
561
+
562
+ def inv_transform(self, data):
563
+ return data * self.std + self.mean
564
+
565
+ def __len__(self):
566
+ return self.cumsum[-1]
567
+
568
+ def __getitem__(self, item):
569
+ if item != 0:
570
+ motion_id = np.searchsorted(self.cumsum, item) - 1
571
+ idx = item - self.cumsum[motion_id] - 1
572
+ else:
573
+ motion_id = 0
574
+ idx = 0
575
+ # idx + j
576
+ motion = self.data[motion_id][idx:idx+self.opt.window_size]
577
+ "Z Normalization"
578
+ motion = (motion - self.mean) / self.std
579
+
580
+ return motion
581
+
582
+
583
+ class RawTextDataset(data.Dataset):
584
+ def __init__(self, opt, mean, std, text_file, w_vectorizer):
585
+ self.mean = mean
586
+ self.std = std
587
+ self.opt = opt
588
+ self.data_dict = []
589
+ self.nlp = spacy.load('en_core_web_sm')
590
+
591
+ with cs.open(text_file) as f:
592
+ for line in f.readlines():
593
+ word_list, pos_list = self.process_text(line.strip())
594
+ tokens = ['%s/%s'%(word_list[i], pos_list[i]) for i in range(len(word_list))]
595
+ self.data_dict.append({'caption':line.strip(), "tokens":tokens})
596
+
597
+ self.w_vectorizer = w_vectorizer
598
+ print("Total number of descriptions {}".format(len(self.data_dict)))
599
+
600
+
601
+ def process_text(self, sentence):
602
+ sentence = sentence.replace('-', '')
603
+ doc = self.nlp(sentence)
604
+ word_list = []
605
+ pos_list = []
606
+ for token in doc:
607
+ word = token.text
608
+ if not word.isalpha():
609
+ continue
610
+ if (token.pos_ == 'NOUN' or token.pos_ == 'VERB') and (word != 'left'):
611
+ word_list.append(token.lemma_)
612
+ else:
613
+ word_list.append(word)
614
+ pos_list.append(token.pos_)
615
+ return word_list, pos_list
616
+
617
+ def inv_transform(self, data):
618
+ return data * self.std + self.mean
619
+
620
+ def __len__(self):
621
+ return len(self.data_dict)
622
+
623
+ def __getitem__(self, item):
624
+ data = self.data_dict[item]
625
+ caption, tokens = data['caption'], data['tokens']
626
+
627
+ if len(tokens) < self.opt.max_text_len:
628
+ # pad with "unk"
629
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
630
+ sent_len = len(tokens)
631
+ tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
632
+ else:
633
+ # crop
634
+ tokens = tokens[:self.opt.max_text_len]
635
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
636
+ sent_len = len(tokens)
637
+ pos_one_hots = []
638
+ word_embeddings = []
639
+ for token in tokens:
640
+ word_emb, pos_oh = self.w_vectorizer[token]
641
+ pos_one_hots.append(pos_oh[None, :])
642
+ word_embeddings.append(word_emb[None, :])
643
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
644
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
645
+
646
+ return word_embeddings, pos_one_hots, caption, sent_len
647
+
648
+ class TextOnlyDataset(data.Dataset):
649
+ def __init__(self, opt, mean, std, split_file):
650
+ self.mean = mean
651
+ self.std = std
652
+ self.opt = opt
653
+ self.data_dict = []
654
+ self.max_length = 20
655
+ self.pointer = 0
656
+ self.fixed_length = 120
657
+
658
+
659
+ data_dict = {}
660
+ id_list = []
661
+ with cs.open(split_file, 'r') as f:
662
+ for line in f.readlines():
663
+ id_list.append(line.strip())
664
+ # id_list = id_list[:200]
665
+
666
+ new_name_list = []
667
+ length_list = []
668
+ for name in tqdm(id_list):
669
+ try:
670
+ text_data = []
671
+ flag = False
672
+ with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
673
+ for line in f.readlines():
674
+ text_dict = {}
675
+ line_split = line.strip().split('#')
676
+ caption = line_split[0]
677
+ tokens = line_split[1].split(' ')
678
+ f_tag = float(line_split[2])
679
+ to_tag = float(line_split[3])
680
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
681
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
682
+
683
+ text_dict['caption'] = caption
684
+ text_dict['tokens'] = tokens
685
+ if f_tag == 0.0 and to_tag == 0.0:
686
+ flag = True
687
+ text_data.append(text_dict)
688
+ else:
689
+ try:
690
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
691
+ while new_name in data_dict:
692
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
693
+ data_dict[new_name] = {'text':[text_dict]}
694
+ new_name_list.append(new_name)
695
+ except:
696
+ print(line_split)
697
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
698
+ # break
699
+
700
+ if flag:
701
+ data_dict[name] = {'text': text_data}
702
+ new_name_list.append(name)
703
+ except:
704
+ pass
705
+
706
+ self.length_arr = np.array(length_list)
707
+ self.data_dict = data_dict
708
+ self.name_list = new_name_list
709
+
710
+ def inv_transform(self, data):
711
+ return data * self.std + self.mean
712
+
713
+ def __len__(self):
714
+ return len(self.data_dict)
715
+
716
+ def __getitem__(self, item):
717
+ idx = self.pointer + item
718
+ data = self.data_dict[self.name_list[idx]]
719
+ text_list = data['text']
720
+
721
+ # Randomly select a caption
722
+ text_data = random.choice(text_list)
723
+ caption, tokens = text_data['caption'], text_data['tokens']
724
+ return None, None, caption, None, np.array([0]), self.fixed_length, None
725
+ # fixed_length can be set from outside before sampling
726
+
727
+ ## t2m original dataset
728
+ # A wrapper class for t2m original dataset for MDM purposes
729
+ # humanml 3D
730
+ class HumanML3D(data.Dataset): ## humanml dataset ## ## human ml dataset text2motion ##
731
+ def __init__(self, mode, datapath='./dataset/humanml_opt.txt', split="train", load_vectorizer=False, **kwargs):
732
+ self.mode = mode
733
+
734
+ self.dataset_name = 't2m'
735
+ self.dataname = 't2m'
736
+
737
+ ### humanml3d --> humanml3d,
738
+ # Configurations of T2M dataset and KIT dataset is almost the same
739
+ abs_base_path = f'.'
740
+ dataset_opt_path = pjoin(abs_base_path, datapath) ## pjoin, pjoin, getopt, # abs
741
+ device = None # torch.device('cuda:4') # This param is not in use in this context
742
+ opt = get_opt(dataset_opt_path, device)
743
+ opt.meta_dir = pjoin(abs_base_path, opt.meta_dir)
744
+ opt.motion_dir = pjoin(abs_base_path, opt.motion_dir)
745
+ opt.text_dir = pjoin(abs_base_path, opt.text_dir)
746
+ opt.model_dir = pjoin(abs_base_path, opt.model_dir)
747
+ opt.checkpoints_dir = pjoin(abs_base_path, opt.checkpoints_dir)
748
+ opt.data_root = pjoin(abs_base_path, opt.data_root) ## data_root --> data root;
749
+ opt.save_root = pjoin(abs_base_path, opt.save_root)
750
+ opt.meta_dir = './dataset'
751
+ self.opt = opt
752
+ print('Loading dataset %s ...' % opt.dataset_name)
753
+
754
+ if mode == 'gt':
755
+ # used by T2M models (including evaluators)
756
+ self.mean = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_mean.npy'))
757
+ self.std = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_std.npy'))
758
+ elif mode in ['train', 'eval', 'text_only']:
759
+ # used by our models
760
+ self.mean = np.load(pjoin(opt.data_root, 'Mean.npy'))
761
+ self.std = np.load(pjoin(opt.data_root, 'Std.npy'))
762
+
763
+ if mode == 'eval':
764
+ # used by T2M models (including evaluators)
765
+ # this is to translate their norms to ours
766
+ self.mean_for_eval = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_mean.npy'))
767
+ self.std_for_eval = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_std.npy'))
768
+ print(f"dataset_name: {opt.dataset_name}")
769
+ if load_vectorizer:
770
+ self.split_file = pjoin(opt.data_root, f'train.txt')
771
+ else:
772
+ self.split_file = pjoin(opt.data_root, f'{split}.txt')
773
+ if mode == 'text_only' and (not load_vectorizer):
774
+ self.t2m_dataset = TextOnlyDataset(self.opt, self.mean, self.std, self.split_file)
775
+ else:
776
+ self.w_vectorizer = WordVectorizer(pjoin(abs_base_path, 'glove'), 'our_vab')
777
+ ### text to
778
+ self.t2m_dataset = Text2MotionDatasetV2(self.opt, self.mean, self.std, self.split_file, self.w_vectorizer)
779
+ self.num_actions = 1 # dummy placeholder
780
+
781
+ # assert len(self.t2m_dataset) > 1, 'You loaded an empty dataset, ' \
782
+ # 'it is probably because your data dir has only texts and no motions.\n' \
783
+ # 'To train and evaluate MDM you should get the FULL data as described ' \
784
+ # 'in the README file.'
785
+
786
+ def __getitem__(self, item):
787
+ return self.t2m_dataset.__getitem__(item)
788
+
789
+ def __len__(self):
790
+ return self.t2m_dataset.__len__()
791
+
792
+ # A wrapper class for t2m original dataset for MDM purposes
793
+ class KIT(HumanML3D):
794
+ def __init__(self, mode, datapath='./dataset/kit_opt.txt', split="train", **kwargs):
795
+ super(KIT, self).__init__(mode, datapath, split, **kwargs)
data_loaders/humanml/data/dataset_ours.py ADDED
The diff for this file is too large to render. See raw diff
 
data_loaders/humanml/data/dataset_ours_single_seq.py ADDED
The diff for this file is too large to render. See raw diff
 
data_loaders/humanml/data/utils.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import time
4
+ from scipy.spatial.transform import Rotation as R
5
+
6
+ try:
7
+ from torch_cluster import fps
8
+ except:
9
+ pass
10
+ from collections import OrderedDict
11
+ import os, argparse, copy, json
12
+ import math
13
+
14
+ def sample_pcd_from_mesh(vertices, triangles, npoints=512):
15
+ arears = []
16
+ for i in range(triangles.shape[0]):
17
+ v_a, v_b, v_c = int(triangles[i, 0].item()), int(triangles[i, 1].item()), int(triangles[i, 2].item())
18
+ v_a, v_b, v_c = vertices[v_a], vertices[v_b], vertices[v_c]
19
+ ab, ac = v_b - v_a, v_c - v_a
20
+ cos_ab_ac = (np.sum(ab * ac) / np.clip(np.sqrt(np.sum(ab ** 2)) * np.sqrt(np.sum(ac ** 2)), a_min=1e-9, a_max=9999999.0)).item()
21
+ sin_ab_ac = math.sqrt(1. - cos_ab_ac ** 2)
22
+ cur_area = 0.5 * sin_ab_ac * np.sqrt(np.sum(ab ** 2)).item() * np.sqrt(np.sum(ac ** 2)).item()
23
+ arears.append(cur_area)
24
+ tot_area = sum(arears)
25
+
26
+ sampled_pcts = []
27
+ tot_indices = []
28
+ tot_factors = []
29
+ for i in range(triangles.shape[0]):
30
+
31
+ v_a, v_b, v_c = int(triangles[i, 0].item()), int(triangles[i, 1].item()), int(
32
+ triangles[i, 2].item())
33
+ v_a, v_b, v_c = vertices[v_a], vertices[v_b], vertices[v_c]
34
+ # ab, ac = v_b - v_a, v_c - v_a
35
+ # cur_sampled_pts = int(npoints * (arears[i] / tot_area))
36
+ cur_sampled_pts = math.ceil(npoints * (arears[i] / tot_area))
37
+ # if cur_sampled_pts == 0:
38
+
39
+ cur_sampled_pts = int(arears[i] * npoints)
40
+ cur_sampled_pts = 1 if cur_sampled_pts == 0 else cur_sampled_pts
41
+
42
+ tmp_x, tmp_y = np.random.uniform(0, 1., (cur_sampled_pts,)).tolist(), np.random.uniform(0., 1., (cur_sampled_pts,)).tolist()
43
+
44
+ for xx, yy in zip(tmp_x, tmp_y):
45
+ sqrt_xx, sqrt_yy = math.sqrt(xx), math.sqrt(yy)
46
+ aa = 1. - sqrt_xx
47
+ bb = sqrt_xx * (1. - yy)
48
+ cc = yy * sqrt_xx
49
+ cur_pos = v_a * aa + v_b * bb + v_c * cc
50
+ sampled_pcts.append(cur_pos)
51
+
52
+ tot_indices.append(triangles[i]) # tot_indices for triangles # # vertices indices
53
+ tot_factors.append([aa, bb, cc])
54
+
55
+ tot_indices = np.array(tot_indices, dtype=np.long)
56
+ tot_factors = np.array(tot_factors, dtype=np.float32)
57
+
58
+ sampled_ptcs = np.array(sampled_pcts)
59
+ print("sampled points from surface:", sampled_ptcs.shape)
60
+ # sampled_pcts = np.concatenate([sampled_pcts, vertices], axis=0)
61
+ return sampled_ptcs, tot_indices, tot_factors
62
+
63
+
64
+ def read_obj_file_ours(obj_fn, sub_one=False):
65
+ vertices = []
66
+ faces = []
67
+ with open(obj_fn, "r") as rf:
68
+ for line in rf:
69
+ items = line.strip().split(" ")
70
+ if items[0] == 'v':
71
+ cur_verts = items[1:]
72
+ cur_verts = [float(vv) for vv in cur_verts]
73
+ vertices.append(cur_verts)
74
+ elif items[0] == 'f':
75
+ cur_faces = items[1:] # faces
76
+ cur_face_idxes = []
77
+ for cur_f in cur_faces:
78
+ try:
79
+ cur_f_idx = int(cur_f.split("/")[0])
80
+ except:
81
+ cur_f_idx = int(cur_f.split("//")[0])
82
+ cur_face_idxes.append(cur_f_idx if not sub_one else cur_f_idx - 1)
83
+ faces.append(cur_face_idxes)
84
+ rf.close()
85
+ vertices = np.array(vertices, dtype=np.float)
86
+ return vertices, faces
87
+
88
+ def clamp_gradient(model, clip):
89
+ for p in model.parameters():
90
+ torch.nn.utils.clip_grad_value_(p, clip)
91
+
92
+ def clamp_gradient_norm(model, max_norm, norm_type=2):
93
+ for p in model.parameters():
94
+ torch.nn.utils.clip_grad_norm_(p, max_norm, norm_type=2)
95
+
96
+
97
+ def save_network(net, directory, network_label, epoch_label=None, **kwargs):
98
+ """
99
+ save model to directory with name {network_label}_{epoch_label}.pth
100
+ Args:
101
+ net: pytorch model
102
+ directory: output directory
103
+ network_label: str
104
+ epoch_label: convertible to str
105
+ kwargs: additional value to be included
106
+ """
107
+ save_filename = "_".join((network_label, str(epoch_label))) + ".pth"
108
+ save_path = os.path.join(directory, save_filename)
109
+ merge_states = OrderedDict()
110
+ merge_states["states"] = net.cpu().state_dict()
111
+ for k in kwargs:
112
+ merge_states[k] = kwargs[k]
113
+ torch.save(merge_states, save_path)
114
+ net = net.cuda()
115
+
116
+
117
+ def load_network(net, path):
118
+ """
119
+ load network parameters whose name exists in the pth file.
120
+ return:
121
+ INT trained step
122
+ """
123
+ # warnings.DeprecationWarning("load_network is deprecated. Use module.load_state_dict(strict=False) instead.")
124
+ if isinstance(path, str):
125
+ logger.info("loading network from {}".format(path))
126
+ if path[-3:] == "pth":
127
+ loaded_state = torch.load(path)
128
+ if "states" in loaded_state:
129
+ loaded_state = loaded_state["states"]
130
+ else:
131
+ loaded_state = np.load(path).item()
132
+ if "states" in loaded_state:
133
+ loaded_state = loaded_state["states"]
134
+ elif isinstance(path, dict):
135
+ loaded_state = path
136
+
137
+ network = net.module if isinstance(
138
+ net, torch.nn.DataParallel) else net
139
+
140
+ missingkeys, unexpectedkeys = network.load_state_dict(loaded_state, strict=False)
141
+ if len(missingkeys)>0:
142
+ logger.warn("load_network {} missing keys".format(len(missingkeys)), "\n".join(missingkeys))
143
+ if len(unexpectedkeys)>0:
144
+ logger.warn("load_network {} unexpected keys".format(len(unexpectedkeys)), "\n".join(unexpectedkeys))
145
+
146
+
147
+
148
+ def weights_init(m):
149
+ """
150
+ initialize the weighs of the network for Convolutional layers and batchnorm layers
151
+ """
152
+ if isinstance(m, (torch.nn.modules.conv._ConvNd, torch.nn.Linear)):
153
+ torch.nn.init.xavier_uniform_(m.weight)
154
+ if m.bias is not None:
155
+ torch.nn.init.constant_(m.bias, 0.0)
156
+ elif isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
157
+ torch.nn.init.constant_(m.bias, 0.0)
158
+ torch.nn.init.constant_(m.weight, 1.0)
159
+
160
+ def seal(mesh_to_seal):
161
+ circle_v_id = np.array([108, 79, 78, 121, 214, 215, 279, 239, 234, 92, 38, 122, 118, 117, 119, 120], dtype = np.int32)
162
+ center = (mesh_to_seal.v[circle_v_id, :]).mean(0)
163
+
164
+ sealed_mesh = copy.copy(mesh_to_seal)
165
+ sealed_mesh.v = np.vstack([mesh_to_seal.v, center])
166
+ center_v_id = sealed_mesh.v.shape[0] - 1
167
+
168
+ for i in range(circle_v_id.shape[0]):
169
+ new_faces = [circle_v_id[i-1], circle_v_id[i], center_v_id]
170
+ sealed_mesh.f = np.vstack([sealed_mesh.f, new_faces])
171
+ return sealed_mesh
172
+
173
+ def read_pos_fr_txt(txt_fn):
174
+ pos_data = []
175
+ with open(txt_fn, "r") as rf:
176
+ for line in rf:
177
+ cur_pos = line.strip().split(" ")
178
+ cur_pos = [float(p) for p in cur_pos]
179
+ pos_data.append(cur_pos)
180
+ rf.close()
181
+ pos_data = np.array(pos_data, dtype=np.float32)
182
+ print(f"pos_data: {pos_data.shape}")
183
+ return pos_data
184
+
185
+ def read_field_data_fr_txt(field_fn):
186
+ field_data = []
187
+ with open(field_fn, "r") as rf:
188
+ for line in rf:
189
+ cur_field = line.strip().split(" ")
190
+ cur_field = [float(p) for p in cur_field]
191
+ field_data.append(cur_field)
192
+ rf.close()
193
+ field_data = np.array(field_data, dtype=np.float32)
194
+ print(f"filed_data: {field_data.shape}")
195
+ return field_data
196
+
197
+ def farthest_point_sampling(pos: torch.FloatTensor, n_sampling: int):
198
+ bz, N = pos.size(0), pos.size(1)
199
+ feat_dim = pos.size(-1)
200
+ device = pos.device
201
+ sampling_ratio = float(n_sampling / N)
202
+ pos_float = pos.float()
203
+
204
+ batch = torch.arange(bz, dtype=torch.long).view(bz, 1).to(device)
205
+ mult_one = torch.ones((N,), dtype=torch.long).view(1, N).to(device)
206
+
207
+ batch = batch * mult_one
208
+ batch = batch.view(-1)
209
+ pos_float = pos_float.contiguous().view(-1, feat_dim).contiguous() # (bz x N, 3)
210
+ # sampling_ratio = torch.tensor([sampling_ratio for _ in range(bz)], dtype=torch.float).to(device)
211
+ # batch = torch.zeros((N, ), dtype=torch.long, device=device)
212
+ sampled_idx = fps(pos_float, batch, ratio=sampling_ratio, random_start=False)
213
+ # shape of sampled_idx?
214
+ return sampled_idx
215
+
216
+
217
+ def batched_index_select_ours(values, indices, dim = 1):
218
+ value_dims = values.shape[(dim + 1):]
219
+ values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
220
+ indices = indices[(..., *((None,) * len(value_dims)))]
221
+ indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
222
+ value_expand_len = len(indices_shape) - (dim + 1)
223
+ values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
224
+
225
+ value_expand_shape = [-1] * len(values.shape)
226
+ expand_slice = slice(dim, (dim + value_expand_len))
227
+ value_expand_shape[expand_slice] = indices.shape[expand_slice]
228
+ values = values.expand(*value_expand_shape)
229
+
230
+ dim += value_expand_len
231
+ return values.gather(dim, indices)
232
+
233
+ def compute_nearest(query, verts):
234
+ # query: bsz x nn_q x 3
235
+ # verts: bsz x nn_q x 3
236
+ dists = torch.sum((query.unsqueeze(2) - verts.unsqueeze(1)) ** 2, dim=-1)
237
+ minn_dists, minn_dists_idx = torch.min(dists, dim=-1) # bsz x nn_q
238
+ minn_pts_pos = batched_index_select_ours(values=verts, indices=minn_dists_idx, dim=1)
239
+ minn_pts_pos = minn_pts_pos.unsqueeze(2)
240
+ minn_dists_idx = minn_dists_idx.unsqueeze(2)
241
+ return minn_dists, minn_dists_idx, minn_pts_pos
242
+
243
+
244
+ def batched_index_select(t, dim, inds):
245
+ """
246
+ Helper function to extract batch-varying indicies along array
247
+ :param t: array to select from
248
+ :param dim: dimension to select along
249
+ :param inds: batch-vary indicies
250
+ :return:
251
+ """
252
+ dummy = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), t.size(2))
253
+ out = t.gather(dim, dummy) # b x e x f
254
+ return out
255
+
256
+
257
+ def batched_get_rot_mtx_fr_vecs(normal_vecs):
258
+ # normal_vecs: nn_pts x 3 #
259
+ #
260
+ normal_vecs = normal_vecs / torch.clamp(torch.norm(normal_vecs, p=2, dim=-1, keepdim=True), min=1e-5)
261
+ sin_theta = normal_vecs[..., 0]
262
+ cos_theta = torch.sqrt(1. - sin_theta ** 2)
263
+ sin_phi = normal_vecs[..., 1] / torch.clamp(cos_theta, min=1e-5)
264
+ # cos_phi = torch.sqrt(1. - sin_phi ** 2)
265
+ cos_phi = normal_vecs[..., 2] / torch.clamp(cos_theta, min=1e-5)
266
+
267
+ sin_phi[cos_theta < 1e-5] = 1.
268
+ cos_phi[cos_theta < 1e-5] = 0.
269
+
270
+ #
271
+ y_rot_mtx = torch.stack(
272
+ [
273
+ torch.stack([cos_theta, torch.zeros_like(cos_theta), -sin_theta], dim=-1),
274
+ torch.stack([torch.zeros_like(cos_theta), torch.ones_like(cos_theta), torch.zeros_like(cos_theta)], dim=-1),
275
+ torch.stack([sin_theta, torch.zeros_like(cos_theta), cos_theta], dim=-1)
276
+ ], dim=-1
277
+ )
278
+ x_rot_mtx = torch.stack(
279
+ [
280
+ torch.stack([torch.ones_like(cos_theta), torch.zeros_like(cos_theta), torch.zeros_like(cos_theta)], dim=-1),
281
+ torch.stack([torch.zeros_like(cos_phi), cos_phi, -sin_phi], dim=-1),
282
+ torch.stack([torch.zeros_like(cos_phi), sin_phi, cos_phi], dim=-1)
283
+ ], dim=-1
284
+ )
285
+ rot_mtx = torch.matmul(x_rot_mtx, y_rot_mtx)
286
+ return rot_mtx
287
+
288
+
289
+ def batched_get_rot_mtx_fr_vecs_v2(normal_vecs):
290
+ # normal_vecs: nn_pts x 3 #
291
+ #
292
+ normal_vecs = normal_vecs / torch.clamp(torch.norm(normal_vecs, p=2, dim=-1, keepdim=True), min=1e-5)
293
+ sin_theta = normal_vecs[..., 0]
294
+ cos_theta = torch.sqrt(1. - sin_theta ** 2)
295
+ sin_phi = normal_vecs[..., 1] / torch.clamp(cos_theta, min=1e-5)
296
+ # cos_phi = torch.sqrt(1. - sin_phi ** 2)
297
+ cos_phi = normal_vecs[..., 2] / torch.clamp(cos_theta, min=1e-5)
298
+
299
+ sin_phi[cos_theta < 1e-5] = 1.
300
+ cos_phi[cos_theta < 1e-5] = 0.
301
+
302
+ # o: nn_pts x 3 #
303
+ o = torch.stack(
304
+ [torch.zeros_like(cos_phi), cos_phi, -sin_phi], dim=-1
305
+ )
306
+ nxo = torch.cross(o, normal_vecs)
307
+ # rot_mtx: nn_pts x 3 x 3 #
308
+ rot_mtx = torch.stack(
309
+ [nxo, o, normal_vecs], dim=-1
310
+ )
311
+ return rot_mtx
312
+
313
+
314
+ def batched_get_orientation_matrices(rot_vec):
315
+ rot_matrices = []
316
+ for i_w in range(rot_vec.shape[0]):
317
+ cur_rot_vec = rot_vec[i_w]
318
+ cur_rot_mtx = R.from_rotvec(cur_rot_vec).as_matrix()
319
+ rot_matrices.append(cur_rot_mtx)
320
+ rot_matrices = np.stack(rot_matrices, axis=0)
321
+ return rot_matrices
322
+
323
+ def batched_get_minn_dist_corresponding_pts(tips, obj_pcs):
324
+ dist_tips_to_obj_pc_minn_idx = np.argmin(
325
+ ((tips.reshape(tips.shape[0], tips.shape[1], 1, 3) - obj_pcs.reshape(obj_pcs.shape[0], 1, obj_pcs.shape[1], 3)) ** 2).sum(axis=-1), axis=-1
326
+ )
327
+ obj_pcs_th = torch.from_numpy(obj_pcs).float()
328
+ dist_tips_to_obj_pc_minn_idx_th = torch.from_numpy(dist_tips_to_obj_pc_minn_idx).long()
329
+ nearest_pc_th = batched_index_select(obj_pcs_th, 1, dist_tips_to_obj_pc_minn_idx_th)
330
+ return nearest_pc_th, dist_tips_to_obj_pc_minn_idx_th
331
+
332
+ def get_affinity_fr_dist(dist, s=0.02):
333
+ ### affinity scores ###
334
+ k = 0.5 * torch.cos(torch.pi / s * torch.abs(dist)) + 0.5
335
+ return k
336
+
337
+ def batched_reverse_transform(rot, transl, t_pc, trans=True):
338
+ # t_pc: ws x nn_obj x 3
339
+ # rot; ws x 3 x 3
340
+ # transl: ws x 1 x 3
341
+ if trans:
342
+ reverse_trans_pc = t_pc - transl
343
+ else:
344
+ reverse_trans_pc = t_pc
345
+ reverse_trans_pc = np.matmul(np.transpose(rot, (0, 2, 1)), np.transpose(reverse_trans_pc, (0, 2, 1)))
346
+ reverse_trans_pc = np.transpose(reverse_trans_pc, (0, 2, 1))
347
+ return reverse_trans_pc
348
+
349
+
350
+ def capsule_sdf(mesh_verts, mesh_normals, query_points, query_normals, caps_rad, caps_top, caps_bot, foreach_on_mesh):
351
+ # if caps on hand: mesh_verts = hand vert
352
+ """
353
+ Find the SDF of query points to mesh verts
354
+ Capsule SDF formulation from https://iquilezles.org/www/articles/distfunctions/distfunctions.htm
355
+
356
+ :param mesh_verts: (batch, V, 3)
357
+ :param mesh_normals: (batch, V, 3)
358
+ :param query_points: (batch, Q, 3)
359
+ :param caps_rad: scalar, radius of capsules
360
+ :param caps_top: scalar, distance from mesh to top of capsule
361
+ :param caps_bot: scalar, distance from mesh to bottom of capsule
362
+ :param foreach_on_mesh: boolean, foreach point on mesh find closest query (V), or foreach query find closest mesh (Q)
363
+ :return: normalized sdsf + 1 (batch, V or Q)
364
+ """
365
+ # TODO implement normal check?
366
+ if foreach_on_mesh: # Foreach mesh vert, find closest query point
367
+ # knn_dists, nearest_idx, nearest_pos = pytorch3d.ops.knn_points(mesh_verts, query_points, K=1, return_nn=True) # TODO should attract capsule middle?
368
+ # knn_dists, nearest_idx, nearest_pos = compute_nearest(query_points, mesh_verts)
369
+ knn_dists, nearest_idx, nearest_pos = compute_nearest(mesh_verts, query_points)
370
+
371
+ capsule_tops = mesh_verts + mesh_normals * caps_top
372
+ capsule_bots = mesh_verts + mesh_normals * caps_bot
373
+ delta_top = nearest_pos[:, :, 0, :] - capsule_tops
374
+ normal_dot = torch.sum(mesh_normals * batched_index_select(query_normals, 1, nearest_idx.squeeze(2)), dim=2)
375
+
376
+ rt_nearest_verts = mesh_verts
377
+ rt_nearest_normals = mesh_normals
378
+
379
+ else: # Foreach query vert, find closest mesh point
380
+ # knn_dists, nearest_idx, nearest_pos = pytorch3d.ops.knn_points(query_points, mesh_verts, K=1, return_nn=True) # TODO should attract capsule middle?
381
+ st_time = time.time()
382
+ knn_dists, nearest_idx, nearest_pos = compute_nearest(query_points, mesh_verts)
383
+ ed_time = time.time()
384
+ # print(f"Time for computing nearest: {ed_time - st_time}")
385
+
386
+ closest_mesh_verts = batched_index_select(mesh_verts, 1, nearest_idx.squeeze(2)) # Shape (batch, V, 3)
387
+ closest_mesh_normals = batched_index_select(mesh_normals, 1, nearest_idx.squeeze(2)) # Shape (batch, V, 3)
388
+
389
+ capsule_tops = closest_mesh_verts + closest_mesh_normals * caps_top # Coordinates of the top focii of the capsules (batch, V, 3)
390
+ capsule_bots = closest_mesh_verts + closest_mesh_normals * caps_bot
391
+ delta_top = query_points - capsule_tops
392
+ # normal_dot = torch.sum(query_normals * closest_mesh_normals, dim=2)
393
+ normal_dot = None
394
+
395
+ rt_nearest_verts = closest_mesh_verts
396
+ rt_nearest_normals = closest_mesh_normals
397
+
398
+ # (top -> bot) #!!#
399
+ bot_to_top = capsule_bots - capsule_tops # Vector from capsule bottom to top
400
+ along_axis = torch.sum(delta_top * bot_to_top, dim=2) # Dot product
401
+ top_to_bot_square = torch.sum(bot_to_top * bot_to_top, dim=2)
402
+
403
+ # print(f"top_to_bot_square: {top_to_bot_square[..., :10]}")
404
+ h = torch.clamp(along_axis / top_to_bot_square, 0, 1) # Could avoid NaNs with offset in division here
405
+ dist_to_axis = torch.norm(delta_top - bot_to_top * h.unsqueeze(2), dim=2) # Distance to capsule centerline
406
+
407
+ # two endpoints; edge of the capsule #
408
+ return dist_to_axis / caps_rad, normal_dot, rt_nearest_verts, rt_nearest_normals # (Normalized SDF)+1 0 on endpoint, 1 on edge of capsule
409
+
410
+
411
+
412
+ def reparameterize_gaussian(mean, logvar):
413
+ std = torch.exp(0.5 * logvar) ### std and eps -->
414
+ eps = torch.randn(std.size()).to(mean.device)
415
+ return mean + std * eps
416
+
417
+
418
+ def gaussian_entropy(logvar):
419
+ const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2))
420
+ ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const
421
+ return ent
422
+
423
+
424
+ def standard_normal_logprob(z): # feature dim
425
+ dim = z.size(-1)
426
+ log_z = -0.5 * dim * np.log(2 * np.pi)
427
+ return log_z - z.pow(2) / 2
428
+
429
+
430
+ def truncated_normal_(tensor, mean=0, std=1, trunc_std=2):
431
+ """
432
+ Taken from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
433
+ """
434
+ size = tensor.shape
435
+ tmp = tensor.new_empty(size + (4,)).normal_()
436
+ valid = (tmp < trunc_std) & (tmp > -trunc_std)
437
+ ind = valid.max(-1, keepdim=True)[1]
438
+ tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
439
+ tensor.data.mul_(std).add_(mean)
440
+ return tensor
441
+
442
+
443
+ def makepath(desired_path, isfile = False):
444
+ '''
445
+ if the path does not exist make it
446
+ :param desired_path: can be path to a file or a folder name
447
+ :return:
448
+ '''
449
+ import os
450
+ if isfile:
451
+ if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path))
452
+ else:
453
+ if not os.path.exists(desired_path): os.makedirs(desired_path)
454
+ return desired_path
455
+
456
+
457
+ def batch_gather(arr, ind):
458
+ """
459
+ :param arr: B x N x D
460
+ :param ind: B x M
461
+ :return: B x M x D
462
+ """
463
+ dummy = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), arr.size(2))
464
+ out = torch.gather(arr, 1, dummy)
465
+ return out
466
+
467
+
468
+ def random_rotate_np(x):
469
+ aa = np.random.randn(3)
470
+ theta = np.sqrt(np.sum(aa**2))
471
+ k = aa / np.maximum(theta, 1e-6)
472
+ K = np.array([[0, -k[2], k[1]],
473
+ [k[2], 0, -k[0]],
474
+ [-k[1], k[0], 0]])
475
+ R = np.eye(3) + np.sin(theta)*K + (1-np.cos(theta))*np.matmul(K, K)
476
+ R = R.astype(np.float32)
477
+ return np.matmul(x, R), R
478
+
479
+
480
+ def rotate_x(x, rad):
481
+ rad = -rad
482
+ rotmat = np.array([
483
+ [1, 0, 0],
484
+ [0, np.cos(rad), -np.sin(rad)],
485
+ [0, np.sin(rad), np.cos(rad)]
486
+ ])
487
+ return np.dot(x, rotmat)
488
+
489
+ def rotate_y(x, rad):
490
+ rad = -rad
491
+ rotmat = np.array([
492
+ [np.cos(rad), 0, np.sin(rad)],
493
+ [0, 1, 0],
494
+ [-np.sin(rad), 0, np.cos(rad)]
495
+ ])
496
+ return np.dot(x, rotmat)
497
+
498
+ def rotate_z(x, rad):
499
+ rad = -rad
500
+ rotmat = np.array([
501
+ [np.cos(rad), -np.sin(rad), 0],
502
+ [np.sin(rad), np.cos(rad), 0],
503
+ [0, 0, 1]
504
+ ])
505
+ return np.dot(x, rotmat)
506
+
507
+
data_loaders/humanml/motion_loaders/__init__.py ADDED
File without changes
data_loaders/humanml/motion_loaders/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (182 Bytes). View file