brjathu commited on
Commit
29a229f
β€’
1 Parent(s): 2cbaea4

Adding HF files

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitignore +145 -0
  2. README.md +4 -4
  3. README_github.md +74 -0
  4. app.py +149 -0
  5. app_mesh.py +141 -0
  6. clean_ckpt.ipynb +93 -0
  7. demo.py +130 -0
  8. environment.yml +20 -0
  9. fetch_data.sh +2 -0
  10. hmr2/__init__.py +0 -0
  11. hmr2/configs/__init__.py +88 -0
  12. hmr2/datasets/__init__.py +0 -0
  13. hmr2/datasets/utils.py +999 -0
  14. hmr2/datasets/vitdet_dataset.py +89 -0
  15. hmr2/models/__init__.py +3 -0
  16. hmr2/models/backbones/__init__.py +7 -0
  17. hmr2/models/backbones/vit.py +348 -0
  18. hmr2/models/backbones/vit_vitpose.py +17 -0
  19. hmr2/models/components/__init__.py +0 -0
  20. hmr2/models/components/pose_transformer.py +358 -0
  21. hmr2/models/components/t_cond_mlp.py +199 -0
  22. hmr2/models/discriminator.py +99 -0
  23. hmr2/models/heads/__init__.py +1 -0
  24. hmr2/models/heads/smpl_head.py +111 -0
  25. hmr2/models/hmr2.py +363 -0
  26. hmr2/models/losses.py +92 -0
  27. hmr2/models/smpl_wrapper.py +41 -0
  28. hmr2/utils/__init__.py +25 -0
  29. hmr2/utils/geometry.py +102 -0
  30. hmr2/utils/mesh_renderer.py +149 -0
  31. hmr2/utils/pose_utils.py +306 -0
  32. hmr2/utils/render_openpose.py +149 -0
  33. hmr2/utils/renderer.py +396 -0
  34. hmr2/utils/skeleton_renderer.py +122 -0
  35. hmr2/utils/texture_utils.py +85 -0
  36. hmr2/utils/utils_detectron2.py +93 -0
  37. requirements.txt +29 -0
  38. setup.py +8 -0
  39. vendor/detectron2/.circleci/config.yml +271 -0
  40. vendor/detectron2/.circleci/import-tests.sh +16 -0
  41. vendor/detectron2/.clang-format +85 -0
  42. vendor/detectron2/.flake8 +15 -0
  43. vendor/detectron2/.github/CODE_OF_CONDUCT.md +5 -0
  44. vendor/detectron2/.github/CONTRIBUTING.md +68 -0
  45. vendor/detectron2/.github/Detectron2-Logo-Horz.svg +1 -0
  46. vendor/detectron2/.github/ISSUE_TEMPLATE.md +5 -0
  47. vendor/detectron2/.github/ISSUE_TEMPLATE/bugs.md +38 -0
  48. vendor/detectron2/.github/ISSUE_TEMPLATE/config.yml +17 -0
  49. vendor/detectron2/.github/ISSUE_TEMPLATE/documentation.md +14 -0
  50. vendor/detectron2/.github/ISSUE_TEMPLATE/feature-request.md +31 -0
.gitignore ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Specific
2
+ /logs*/
3
+ /results/
4
+ /sandbox/
5
+ *.lock
6
+ *.pt
7
+ *.npy
8
+ /example_data/downloaded*
9
+ *.tar
10
+ *.tar.gz
11
+ /discord_sandbox/
12
+ /demo_out/
13
+ token_channel.csv
14
+
15
+ # Byte-compiled / optimized / DLL files
16
+ __pycache__/
17
+ *.py[cod]
18
+ *$py.class
19
+ logs/
20
+ # C extensions
21
+ *.so
22
+
23
+ # Distribution / packaging
24
+ .Python
25
+ build/
26
+ develop-eggs/
27
+ dist/
28
+ downloads/
29
+ eggs/
30
+ .eggs/
31
+ lib/
32
+ lib64/
33
+ parts/
34
+ sdist/
35
+ var/
36
+ wheels/
37
+ pip-wheel-metadata/
38
+ share/python-wheels/
39
+ *.egg-info/
40
+ .installed.cfg
41
+ *.egg
42
+ MANIFEST
43
+
44
+ # PyInstaller
45
+ # Usually these files are written by a python script from a template
46
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
47
+ *.manifest
48
+ *.spec
49
+
50
+ # Installer logs
51
+ pip-log.txt
52
+ pip-delete-this-directory.txt
53
+
54
+ # Unit test / coverage reports
55
+ htmlcov/
56
+ .tox/
57
+ .nox/
58
+ .coverage
59
+ .coverage.*
60
+ .cache
61
+ nosetests.xml
62
+ coverage.xml
63
+ *.cover
64
+ *.py,cover
65
+ .hypothesis/
66
+ .pytest_cache/
67
+
68
+ # Translations
69
+ *.mo
70
+ *.pot
71
+
72
+ # Django stuff:
73
+ *.log
74
+ local_settings.py
75
+ db.sqlite3
76
+ db.sqlite3-journal
77
+
78
+ # Flask stuff:
79
+ instance/
80
+ .webassets-cache
81
+
82
+ # Scrapy stuff:
83
+ .scrapy
84
+
85
+ # Sphinx documentation
86
+ docs/_build/
87
+
88
+ # PyBuilder
89
+ target/
90
+
91
+ # Jupyter Notebook
92
+ .ipynb_checkpoints
93
+
94
+ # IPython
95
+ profile_default/
96
+ ipython_config.py
97
+
98
+ # pyenv
99
+ .python-version
100
+
101
+ # pipenv
102
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
104
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
105
+ # install all needed dependencies.
106
+ #Pipfile.lock
107
+
108
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
109
+ __pypackages__/
110
+
111
+ # Celery stuff
112
+ celerybeat-schedule
113
+ celerybeat.pid
114
+
115
+ # SageMath parsed files
116
+ *.sage.py
117
+
118
+ # Environments
119
+ .env
120
+ .venv
121
+ env/
122
+ venv/
123
+ ENV/
124
+ env.bak/
125
+ venv.bak/
126
+
127
+ # Spyder project settings
128
+ .spyderproject
129
+ .spyproject
130
+
131
+ # Rope project settings
132
+ .ropeproject
133
+
134
+ # mkdocs documentation
135
+ /site
136
+
137
+ # mypy
138
+ .mypy_cache/
139
+ .dmypy.json
140
+ dmypy.json
141
+
142
+ # Pyre type checker
143
+ .pyre/
144
+ /checkpoints/
145
+ /data/
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: HMR
3
- emoji: πŸ“ˆ
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.32.0
8
  app_file: app.py
 
1
  ---
2
+ title: HMR2.0
3
+ emoji: πŸ”₯
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.32.0
8
  app_file: app.py
README_github.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 4DHumans: Reconstructing and Tracking Humans with Transformers
2
+ Code repository for the paper:
3
+ **Humans in 4D: Reconstructing and Tracking Humans with Transformers**
4
+ [Shubham Goel](https://people.eecs.berkeley.edu/~shubham-goel/), [Georgios Pavlakos](https://geopavlakos.github.io/), [Jathushan Rajasegaran](http://people.eecs.berkeley.edu/~jathushan/), [Angjoo Kanazawa](https://people.eecs.berkeley.edu/~kanazawa/)<sup>\*</sup>, [Jitendra Malik](http://people.eecs.berkeley.edu/~malik/)<sup>\*</sup>
5
+ arXiv preprint 2023
6
+ [[paper]()] [[project page](https://shubham-goel.github.io/4dhumans/)] [[hugging faces space]()]
7
+
8
+ ![teaser](assets/teaser.png)
9
+
10
+ ## Download dependencies
11
+ Our demo code depends on [detectron2](https://github.com/facebookresearch/detectron2) to detect humans.
12
+ To automatically download this dependency, clone this repo using `--recursive`, or run `git submodule update --init` if you've already cloned the repository. You should see the detectron2 source code at `vendor/detectron2`.
13
+ ```bash
14
+ git clone https://github.com/shubham-goel/4D-Humans.git --recursive
15
+ # OR
16
+ git clone https://github.com/shubham-goel/4D-Humans.git
17
+ cd 4D-Humans
18
+ git submodule update --init
19
+ ```
20
+
21
+ ## Installation
22
+ We recommend creating a clean [conda](https://docs.conda.io/) environment and installing all dependencies, as follows:
23
+ ```bash
24
+ conda env create -f environment.yml
25
+ ```
26
+
27
+ After the installation is complete you can activate the conda environment by running:
28
+ ```
29
+ conda activate 4D-humans
30
+ ```
31
+
32
+ ## Download checkpoints and SMPL models
33
+ To download the checkpoints and SMPL models, run
34
+ ```bash
35
+ ./fetch_data.sh
36
+ ```
37
+
38
+ ## Run demo on images
39
+ You may now run our demo to 3D reconstruct humans in images using the following command, which will run ViTDet and HMR2.0 on all images in the specified `--img_folder` and save renderings of the reconstructions in `--out_folder`. You can also use the `--side_view` flag to additionally render the side view of the reconstructed mesh. `--batch_size` batches the images together for faster processing.
40
+ ```bash
41
+ python demo.py \
42
+ --img_folder example_data/images \
43
+ --out_folder demo_out \
44
+ --batch_size=48 --side_view
45
+ ```
46
+
47
+ ## Run demo on videos
48
+ Coming soon.
49
+
50
+ ## Training and evaluation
51
+ Cmoing soon.
52
+
53
+ ## Acknowledgements
54
+ Parts of the code are taken or adapted from the following repos:
55
+ - [ProHMR](https://github.com/nkolot/ProHMR)
56
+ - [SPIN](https://github.com/nkolot/SPIN)
57
+ - [SMPLify-X](https://github.com/vchoutas/smplify-x)
58
+ - [HMR](https://github.com/akanazawa/hmr)
59
+ - [ViTPose](https://github.com/ViTAE-Transformer/ViTPose)
60
+ - [Detectron2](https://github.com/facebookresearch/detectron2)
61
+
62
+ Additionally, we thank [StabilityAI](https://stability.ai/) for a generous compute grant that enabled this work.
63
+
64
+ ## Citing
65
+ If you find this code useful for your research, please consider citing the following paper:
66
+
67
+ ```
68
+ @article{4DHUMANS,
69
+ title={Humans in 4{D}: Reconstructing and Tracking Humans with Transformers},
70
+ author={Goel, Shubham and Pavlakos, Georgios and Rajasegaran, Jathushan and Kanazawa, Angjoo and Malik, Jitendra},
71
+ journal={arXiv preprint},
72
+ year={2023}
73
+ }
74
+ ```
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import cv2
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+
11
+ from hmr2.configs import get_config
12
+ from hmr2.datasets.vitdet_dataset import (DEFAULT_MEAN, DEFAULT_STD,
13
+ ViTDetDataset)
14
+ from hmr2.models import HMR2
15
+ from hmr2.utils import recursive_to
16
+ from hmr2.utils.renderer import Renderer, cam_crop_to_full
17
+
18
+ # Setup HMR2.0 model
19
+ LIGHT_BLUE=(0.65098039, 0.74117647, 0.85882353)
20
+ DEFAULT_CHECKPOINT='logs/train/multiruns/hmr2/0/checkpoints/epoch=35-step=1000000.ckpt'
21
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
22
+ model_cfg = str(Path(DEFAULT_CHECKPOINT).parent.parent / 'model_config.yaml')
23
+ model_cfg = get_config(model_cfg)
24
+ model = HMR2.load_from_checkpoint(DEFAULT_CHECKPOINT, strict=False, cfg=model_cfg).to(device)
25
+ model.eval()
26
+
27
+
28
+ # Load detector
29
+ from detectron2.config import LazyConfig
30
+
31
+ from hmr2.utils.utils_detectron2 import DefaultPredictor_Lazy
32
+
33
+ detectron2_cfg = LazyConfig.load(f"vendor/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_h_75ep.py")
34
+ detectron2_cfg.train.init_checkpoint = "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl"
35
+ for i in range(3):
36
+ detectron2_cfg.model.roi_heads.box_predictors[i].test_score_thresh = 0.25
37
+ detector = DefaultPredictor_Lazy(detectron2_cfg)
38
+
39
+ # Setup the renderer
40
+ renderer = Renderer(model_cfg, faces=model.smpl.faces)
41
+
42
+
43
+ import numpy as np
44
+
45
+
46
+ def infer(in_pil_img, in_threshold=0.8, out_pil_img=None):
47
+
48
+ open_cv_image = np.array(in_pil_img)
49
+ # Convert RGB to BGR
50
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
51
+ print("EEEEE", open_cv_image.shape)
52
+ det_out = detector(open_cv_image)
53
+ det_instances = det_out['instances']
54
+ valid_idx = (det_instances.pred_classes==0) & (det_instances.scores > in_threshold)
55
+ boxes=det_instances.pred_boxes.tensor[valid_idx].cpu().numpy()
56
+
57
+ # Run HMR2.0 on all detected humans
58
+ dataset = ViTDetDataset(model_cfg, open_cv_image, boxes)
59
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=False, num_workers=0)
60
+
61
+ all_verts = []
62
+ all_cam_t = []
63
+
64
+ for batch in dataloader:
65
+ batch = recursive_to(batch, device)
66
+ with torch.no_grad():
67
+ out = model(batch)
68
+
69
+ pred_cam = out['pred_cam']
70
+ box_center = batch["box_center"].float()
71
+ box_size = batch["box_size"].float()
72
+ img_size = batch["img_size"].float()
73
+ render_size = img_size
74
+ pred_cam_t = cam_crop_to_full(pred_cam, box_center, box_size, render_size).detach().cpu().numpy()
75
+
76
+ # Render the result
77
+ batch_size = batch['img'].shape[0]
78
+ for n in range(batch_size):
79
+ # Get filename from path img_path
80
+ # img_fn, _ = os.path.splitext(os.path.basename(img_path))
81
+ person_id = int(batch['personid'][n])
82
+ white_img = (torch.ones_like(batch['img'][n]).cpu() - DEFAULT_MEAN[:,None,None]/255) / (DEFAULT_STD[:,None,None]/255)
83
+ input_patch = batch['img'][n].cpu() * (DEFAULT_STD[:,None,None]/255) + (DEFAULT_MEAN[:,None,None]/255)
84
+ input_patch = input_patch.permute(1,2,0).numpy()
85
+
86
+ regression_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(),
87
+ out['pred_cam_t'][n].detach().cpu().numpy(),
88
+ batch['img'][n],
89
+ mesh_base_color=LIGHT_BLUE,
90
+ scene_bg_color=(1, 1, 1),
91
+ )
92
+
93
+
94
+ verts = out['pred_vertices'][n].detach().cpu().numpy()
95
+ cam_t = pred_cam_t[n]
96
+
97
+ all_verts.append(verts)
98
+ all_cam_t.append(cam_t)
99
+
100
+
101
+ # Render front view
102
+ if len(all_verts) > 0:
103
+ misc_args = dict(
104
+ mesh_base_color=LIGHT_BLUE,
105
+ scene_bg_color=(1, 1, 1),
106
+ )
107
+ cam_view = renderer.render_rgba_multiple(all_verts, cam_t=all_cam_t, render_res=render_size[n], **misc_args)
108
+
109
+ # Overlay image
110
+ input_img = open_cv_image.astype(np.float32)[:,:,::-1]/255.0
111
+ input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
112
+ input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
113
+
114
+ # convert to PIL image
115
+ out_pil_img = Image.fromarray((input_img_overlay*255).astype(np.uint8))
116
+
117
+ return out_pil_img
118
+ else:
119
+ return None
120
+
121
+
122
+ with gr.Blocks(title="4DHumans", css=".gradio-container") as demo:
123
+
124
+ gr.HTML("""<div style="font-weight:bold; text-align:center; color:royalblue;">HMR 2.0</div>""")
125
+
126
+ with gr.Row():
127
+ input_image = gr.Image(label="Input image", type="pil", width=300, height=300, fixed_size=True)
128
+ output_image = gr.Image(label="Reconstructions", type="pil", width=300, height=300, fixed_size=True)
129
+
130
+ gr.HTML("""<br/>""")
131
+
132
+ with gr.Row():
133
+ threshold = gr.Slider(0, 1.0, value=0.8, label='Detection Threshold')
134
+ send_btn = gr.Button("Infer")
135
+ send_btn.click(fn=infer, inputs=[input_image, threshold], outputs=[output_image])
136
+
137
+ # gr.Examples(['samples/img1.jpg', 'samples/img2.png', 'samples/img3.jpg', 'samples/img4.jpg'], inputs=input_image)
138
+
139
+ gr.HTML("""</ul>""")
140
+
141
+
142
+
143
+ #demo.queue()
144
+ demo.launch(debug=True)
145
+
146
+
147
+
148
+
149
+ ### EOF ###
app_mesh.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import cv2
6
+ import gradio as gr
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ import trimesh
11
+ import tempfile
12
+
13
+ from hmr2.configs import get_config
14
+ from hmr2.datasets.vitdet_dataset import (DEFAULT_MEAN, DEFAULT_STD,
15
+ ViTDetDataset)
16
+ from hmr2.models import HMR2
17
+ from hmr2.utils import recursive_to
18
+ from hmr2.utils.renderer import Renderer, cam_crop_to_full
19
+
20
+ # Setup HMR2.0 model
21
+ LIGHT_BLUE=(0.65098039, 0.74117647, 0.85882353)
22
+ DEFAULT_CHECKPOINT='logs/train/multiruns/hmr2/0/checkpoints/epoch=35-step=1000000.ckpt'
23
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
24
+ model_cfg = str(Path(DEFAULT_CHECKPOINT).parent.parent / 'model_config.yaml')
25
+ model_cfg = get_config(model_cfg)
26
+ model = HMR2.load_from_checkpoint(DEFAULT_CHECKPOINT, strict=False, cfg=model_cfg).to(device)
27
+ model.eval()
28
+
29
+
30
+ # Load detector
31
+ from detectron2.config import LazyConfig
32
+
33
+ from hmr2.utils.utils_detectron2 import DefaultPredictor_Lazy
34
+
35
+ detectron2_cfg = LazyConfig.load(f"vendor/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_h_75ep.py")
36
+ detectron2_cfg.train.init_checkpoint = "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl"
37
+ for i in range(3):
38
+ detectron2_cfg.model.roi_heads.box_predictors[i].test_score_thresh = 0.25
39
+ detector = DefaultPredictor_Lazy(detectron2_cfg)
40
+
41
+ # Setup the renderer
42
+ renderer = Renderer(model_cfg, faces=model.smpl.faces)
43
+
44
+
45
+ import numpy as np
46
+
47
+
48
+ def infer(in_pil_img, in_threshold=0.8):
49
+
50
+ open_cv_image = np.array(in_pil_img)
51
+ # Convert RGB to BGR
52
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
53
+ print("EEEEE", open_cv_image.shape)
54
+ det_out = detector(open_cv_image)
55
+ det_instances = det_out['instances']
56
+ valid_idx = (det_instances.pred_classes==0) & (det_instances.scores > in_threshold)
57
+ boxes=det_instances.pred_boxes.tensor[valid_idx].cpu().numpy()
58
+
59
+ # Run HMR2.0 on all detected humans
60
+ dataset = ViTDetDataset(model_cfg, open_cv_image, boxes)
61
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=False, num_workers=0)
62
+
63
+ all_verts = []
64
+ all_cam_t = []
65
+
66
+ for batch in dataloader:
67
+ batch = recursive_to(batch, device)
68
+ with torch.no_grad():
69
+ out = model(batch)
70
+
71
+ pred_cam = out['pred_cam']
72
+ box_center = batch["box_center"].float()
73
+ box_size = batch["box_size"].float()
74
+ img_size = batch["img_size"].float()
75
+ render_size = img_size
76
+ pred_cam_t = cam_crop_to_full(pred_cam, box_center, box_size, render_size, focal_length=img_size.mean()*2).detach().cpu().numpy()
77
+
78
+ # Render the result
79
+ batch_size = batch['img'].shape[0]
80
+ for n in range(batch_size):
81
+ # Get filename from path img_path
82
+ # img_fn, _ = os.path.splitext(os.path.basename(img_path))
83
+ person_id = int(batch['personid'][n])
84
+ white_img = (torch.ones_like(batch['img'][n]).cpu() - DEFAULT_MEAN[:,None,None]/255) / (DEFAULT_STD[:,None,None]/255)
85
+ input_patch = batch['img'][n].cpu() * (DEFAULT_STD[:,None,None]/255) + (DEFAULT_MEAN[:,None,None]/255)
86
+ input_patch = input_patch.permute(1,2,0).numpy()
87
+
88
+ regression_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(),
89
+ out['pred_cam_t'][n].detach().cpu().numpy(),
90
+ batch['img'][n],
91
+ mesh_base_color=LIGHT_BLUE,
92
+ scene_bg_color=(1, 1, 1),
93
+ )
94
+
95
+
96
+ verts = out['pred_vertices'][n].detach().cpu().numpy()
97
+ cam_t = pred_cam_t[n]
98
+
99
+ all_verts.append(verts)
100
+ all_cam_t.append(cam_t)
101
+
102
+ # Return mesh path
103
+ trimeshes = [renderer.vertices_to_trimesh(vvv, ttt.copy(), LIGHT_BLUE) for vvv,ttt in zip(all_verts, all_cam_t)]
104
+
105
+ # Join meshes
106
+ mesh = trimesh.util.concatenate(trimeshes)
107
+
108
+ # Save mesh to file
109
+ temp_name = next(tempfile._get_candidate_names()) + '.obj'
110
+ trimesh.exchange.export.export_mesh(mesh, temp_name)
111
+ return temp_name
112
+
113
+
114
+ with gr.Blocks(title="4DHumans", css=".gradio-container") as demo:
115
+
116
+ gr.HTML("""<div style="font-weight:bold; text-align:center; color:royalblue;">HMR 2.0</div>""")
117
+
118
+ with gr.Row():
119
+ input_image = gr.Image(label="Input image", type="pil", width=300, height=300, fixed_size=True)
120
+ output_model = gr.Model3D(label="Reconstructions", width=300, height=300, fixed_size=True, clear_color=[0.0, 0.0, 0.0, 0.0])
121
+
122
+ gr.HTML("""<br/>""")
123
+
124
+ with gr.Row():
125
+ threshold = gr.Slider(0, 1.0, value=0.8, label='Detection Threshold')
126
+ send_btn = gr.Button("Infer")
127
+ send_btn.click(fn=infer, inputs=[input_image, threshold], outputs=[output_model])
128
+
129
+ # gr.Examples(['samples/img1.jpg', 'samples/img2.png', 'samples/img3.jpg', 'samples/img4.jpg'], inputs=input_image)
130
+
131
+ gr.HTML("""</ul>""")
132
+
133
+
134
+
135
+ #demo.queue()
136
+ demo.launch(debug=True)
137
+
138
+
139
+
140
+
141
+ ### EOF ###
clean_ckpt.ipynb ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "ckpt_path = 'logs/train/multiruns/hmr2/0/checkpoints/epoch=35-step=1000000.ckpt'"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "# Load ckpt\n",
20
+ "ckpt = torch.load(ckpt_path, map_location='cpu')"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "ckpt.keys()\n"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "ckpt['loops']"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "# Delete optimizer_states\n",
48
+ "del ckpt['optimizer_states']\n",
49
+ "del ckpt['callbacks']\n",
50
+ "del ckpt['loops']"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "# Save new ckpt\n",
60
+ "torch.save(ckpt, ckpt_path)"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": []
69
+ }
70
+ ],
71
+ "metadata": {
72
+ "kernelspec": {
73
+ "display_name": "4D-humans",
74
+ "language": "python",
75
+ "name": "python3"
76
+ },
77
+ "language_info": {
78
+ "codemirror_mode": {
79
+ "name": "ipython",
80
+ "version": 3
81
+ },
82
+ "file_extension": ".py",
83
+ "mimetype": "text/x-python",
84
+ "name": "python",
85
+ "nbconvert_exporter": "python",
86
+ "pygments_lexer": "ipython3",
87
+ "version": "3.10.6"
88
+ },
89
+ "orig_nbformat": 4
90
+ },
91
+ "nbformat": 4,
92
+ "nbformat_minor": 2
93
+ }
demo.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ import argparse
4
+ import os
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from hmr2.configs import get_config
9
+ from hmr2.models import HMR2
10
+ from hmr2.utils import recursive_to
11
+ from hmr2.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
12
+ from hmr2.utils.renderer import Renderer, cam_crop_to_full
13
+
14
+ LIGHT_BLUE=(0.65098039, 0.74117647, 0.85882353)
15
+ # DEFAULT_CHECKPOINT='logs/train/multiruns/20b1_mix11_a1/0/checkpoints/epoch=30-step=1000000.ckpt'
16
+ DEFAULT_CHECKPOINT='logs/train/multiruns/hmr2/0/checkpoints/epoch=35-step=1000000.ckpt'
17
+ parser = argparse.ArgumentParser(description='HMR2 demo code')
18
+ parser.add_argument('--checkpoint', type=str, default=DEFAULT_CHECKPOINT, help='Path to pretrained model checkpoint')
19
+ parser.add_argument('--img_folder', type=str, default='example_data/images', help='Folder with input images')
20
+ parser.add_argument('--out_folder', type=str, default='demo_out', help='Output folder to save rendered results')
21
+ parser.add_argument('--side_view', dest='side_view', action='store_true', default=False, help='If set, render side view also')
22
+ parser.add_argument('--batch_size', type=int, default=1, help='Batch size for inference/fitting')
23
+
24
+ args = parser.parse_args()
25
+
26
+ # Setup HMR2.0 model
27
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
28
+ model_cfg = str(Path(args.checkpoint).parent.parent / 'model_config.yaml')
29
+ model_cfg = get_config(model_cfg)
30
+ model = HMR2.load_from_checkpoint(args.checkpoint, strict=False, cfg=model_cfg).to(device)
31
+ model.eval()
32
+
33
+ # Load detector
34
+ from detectron2.config import LazyConfig
35
+ from hmr2.utils.utils_detectron2 import DefaultPredictor_Lazy
36
+ detectron2_cfg = LazyConfig.load(f"vendor/detectron2/projects/ViTDet/configs/COCO/cascade_mask_rcnn_vitdet_h_75ep.py")
37
+ detectron2_cfg.train.init_checkpoint = "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl"
38
+ for i in range(3):
39
+ detectron2_cfg.model.roi_heads.box_predictors[i].test_score_thresh = 0.25
40
+ detector = DefaultPredictor_Lazy(detectron2_cfg)
41
+
42
+ # Setup the renderer
43
+ renderer = Renderer(model_cfg, faces=model.smpl.faces)
44
+
45
+ # Make output directory if it does not exist
46
+ os.makedirs(args.out_folder, exist_ok=True)
47
+
48
+ # Iterate over all images in folder
49
+ for img_path in Path(args.img_folder).glob('*.png'):
50
+ img_cv2 = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
51
+
52
+ # Detect humans in image
53
+ det_out = detector(img_cv2)
54
+
55
+ det_instances = det_out['instances']
56
+ valid_idx = (det_instances.pred_classes==0) & (det_instances.scores > 0.5)
57
+ boxes=det_instances.pred_boxes.tensor[valid_idx].cpu().numpy()
58
+
59
+ # Run HMR2.0 on all detected humans
60
+ dataset = ViTDetDataset(model_cfg, img_cv2.copy(), boxes)
61
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=False, num_workers=0)
62
+
63
+
64
+ all_verts = []
65
+ all_cam_t = []
66
+
67
+ for batch in dataloader:
68
+ batch = recursive_to(batch, device)
69
+ with torch.no_grad():
70
+ out = model(batch)
71
+
72
+ pred_cam = out['pred_cam']
73
+ box_center = batch["box_center"].float()
74
+ box_size = batch["box_size"].float()
75
+ img_size = batch["img_size"].float()
76
+ render_size = img_size
77
+ pred_cam_t = cam_crop_to_full(pred_cam, box_center, box_size, render_size).detach().cpu().numpy()
78
+
79
+ # Render the result
80
+ batch_size = batch['img'].shape[0]
81
+ for n in range(batch_size):
82
+ # Get filename from path img_path
83
+ img_fn, _ = os.path.splitext(os.path.basename(img_path))
84
+ person_id = int(batch['personid'][n])
85
+ white_img = (torch.ones_like(batch['img'][n]).cpu() - DEFAULT_MEAN[:,None,None]/255) / (DEFAULT_STD[:,None,None]/255)
86
+ input_patch = batch['img'][n].cpu() * (DEFAULT_STD[:,None,None]/255) + (DEFAULT_MEAN[:,None,None]/255)
87
+ input_patch = input_patch.permute(1,2,0).numpy()
88
+
89
+ regression_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(),
90
+ out['pred_cam_t'][n].detach().cpu().numpy(),
91
+ batch['img'][n],
92
+ mesh_base_color=LIGHT_BLUE,
93
+ scene_bg_color=(1, 1, 1),
94
+ )
95
+
96
+ if args.side_view:
97
+ side_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(),
98
+ out['pred_cam_t'][n].detach().cpu().numpy(),
99
+ white_img,
100
+ mesh_base_color=LIGHT_BLUE,
101
+ scene_bg_color=(1, 1, 1),
102
+ side_view=True)
103
+ final_img = np.concatenate([input_patch, regression_img, side_img], axis=1)
104
+ else:
105
+ final_img = np.concatenate([input_patch, regression_img], axis=1)
106
+
107
+
108
+ verts = out['pred_vertices'][n].detach().cpu().numpy()
109
+ cam_t = pred_cam_t[n]
110
+
111
+ all_verts.append(verts)
112
+ all_cam_t.append(cam_t)
113
+
114
+ misc_args = dict(
115
+ mesh_base_color=LIGHT_BLUE,
116
+ scene_bg_color=(1, 1, 1),
117
+ )
118
+
119
+ # Render front view
120
+ if len(all_verts) > 0:
121
+ cam_view = renderer.render_rgba_multiple(all_verts, cam_t=all_cam_t, render_res=render_size[n], **misc_args)
122
+
123
+ # Overlay image
124
+ input_img = img_cv2.astype(np.float32)[:,:,::-1]/255.0
125
+ input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
126
+ input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
127
+
128
+
129
+ # cv2.imwrite(os.path.join(args.out_folder, f'{img_fn}_{person_id}.jpg'), 255*final_img[:, :, ::-1])
130
+ cv2.imwrite(os.path.join(args.out_folder, f'rend_{img_fn}.jpg'), 255*input_img_overlay[:, :, ::-1])
environment.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 4D-humans
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ dependencies:
7
+ - python=3.10
8
+ - pytorch-cuda=11.8
9
+ - torchvision
10
+ - pip
11
+ - pip:
12
+ - pytorch-lightning
13
+ - smplx==0.1.28
14
+ - pyrender
15
+ - opencv-python
16
+ - yacs
17
+ - scikit-image
18
+ - einops
19
+ - timm
20
+ - -e ./vendor/detectron2/
fetch_data.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ wget https://people.eecs.berkeley.edu/~shubham-goel/projects/4DHumans/hmr2_data.tar.gz
2
+ tar -xvzf hmr2_data.tar.gz
hmr2/__init__.py ADDED
File without changes
hmr2/configs/__init__.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+ from yacs.config import CfgNode as CN
4
+
5
+
6
+ def to_lower(x: Dict) -> Dict:
7
+ """
8
+ Convert all dictionary keys to lowercase
9
+ Args:
10
+ x (dict): Input dictionary
11
+ Returns:
12
+ dict: Output dictionary with all keys converted to lowercase
13
+ """
14
+ return {k.lower(): v for k, v in x.items()}
15
+
16
+ _C = CN(new_allowed=True)
17
+
18
+ _C.GENERAL = CN(new_allowed=True)
19
+ _C.GENERAL.RESUME = True
20
+ _C.GENERAL.TIME_TO_RUN = 3300
21
+ _C.GENERAL.VAL_STEPS = 100
22
+ _C.GENERAL.LOG_STEPS = 100
23
+ _C.GENERAL.CHECKPOINT_STEPS = 20000
24
+ _C.GENERAL.CHECKPOINT_DIR = "checkpoints"
25
+ _C.GENERAL.SUMMARY_DIR = "tensorboard"
26
+ _C.GENERAL.NUM_GPUS = 1
27
+ _C.GENERAL.NUM_WORKERS = 4
28
+ _C.GENERAL.MIXED_PRECISION = True
29
+ _C.GENERAL.ALLOW_CUDA = True
30
+ _C.GENERAL.PIN_MEMORY = False
31
+ _C.GENERAL.DISTRIBUTED = False
32
+ _C.GENERAL.LOCAL_RANK = 0
33
+ _C.GENERAL.USE_SYNCBN = False
34
+ _C.GENERAL.WORLD_SIZE = 1
35
+
36
+ _C.TRAIN = CN(new_allowed=True)
37
+ _C.TRAIN.NUM_EPOCHS = 100
38
+ _C.TRAIN.BATCH_SIZE = 32
39
+ _C.TRAIN.SHUFFLE = True
40
+ _C.TRAIN.WARMUP = False
41
+ _C.TRAIN.NORMALIZE_PER_IMAGE = False
42
+ _C.TRAIN.CLIP_GRAD = False
43
+ _C.TRAIN.CLIP_GRAD_VALUE = 1.0
44
+ _C.LOSS_WEIGHTS = CN(new_allowed=True)
45
+
46
+ _C.DATASETS = CN(new_allowed=True)
47
+
48
+ _C.MODEL = CN(new_allowed=True)
49
+ _C.MODEL.IMAGE_SIZE = 224
50
+
51
+ _C.EXTRA = CN(new_allowed=True)
52
+ _C.EXTRA.FOCAL_LENGTH = 5000
53
+
54
+ _C.DATASETS.CONFIG = CN(new_allowed=True)
55
+ _C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
56
+ _C.DATASETS.CONFIG.ROT_FACTOR = 30
57
+ _C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
58
+ _C.DATASETS.CONFIG.COLOR_SCALE = 0.2
59
+ _C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
60
+ _C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
61
+ _C.DATASETS.CONFIG.DO_FLIP = True
62
+ _C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
63
+ _C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
64
+
65
+ def default_config() -> CN:
66
+ """
67
+ Get a yacs CfgNode object with the default config values.
68
+ """
69
+ # Return a clone so that the defaults will not be altered
70
+ # This is for the "local variable" use pattern
71
+ return _C.clone()
72
+
73
+ def get_config(config_file: str, merge: bool = True) -> CN:
74
+ """
75
+ Read a config file and optionally merge it with the default config file.
76
+ Args:
77
+ config_file (str): Path to config file.
78
+ merge (bool): Whether to merge with the default config or not.
79
+ Returns:
80
+ CfgNode: Config as a yacs CfgNode object.
81
+ """
82
+ if merge:
83
+ cfg = default_config()
84
+ else:
85
+ cfg = CN(new_allowed=True)
86
+ cfg.merge_from_file(config_file)
87
+ cfg.freeze()
88
+ return cfg
hmr2/datasets/__init__.py ADDED
File without changes
hmr2/datasets/utils.py ADDED
@@ -0,0 +1,999 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parts of the code are taken or adapted from
3
+ https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py
4
+ """
5
+ import torch
6
+ import numpy as np
7
+ from skimage.transform import rotate, resize
8
+ from skimage.filters import gaussian
9
+ import random
10
+ import cv2
11
+ from typing import List, Dict, Tuple
12
+ from yacs.config import CfgNode
13
+
14
+ def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None):
15
+ """Increase the size of the bounding box to match the target shape."""
16
+ if target_aspect_ratio is None:
17
+ return input_shape
18
+
19
+ try:
20
+ w , h = input_shape
21
+ except (ValueError, TypeError):
22
+ return input_shape
23
+
24
+ w_t, h_t = target_aspect_ratio
25
+ if h / w < h_t / w_t:
26
+ h_new = max(w * h_t / w_t, h)
27
+ w_new = w
28
+ else:
29
+ h_new = h
30
+ w_new = max(h * w_t / h_t, w)
31
+ if h_new < h or w_new < w:
32
+ breakpoint()
33
+ return np.array([w_new, h_new])
34
+
35
+ def do_augmentation(aug_config: CfgNode) -> Tuple:
36
+ """
37
+ Compute random augmentation parameters.
38
+ Args:
39
+ aug_config (CfgNode): Config containing augmentation parameters.
40
+ Returns:
41
+ scale (float): Box rescaling factor.
42
+ rot (float): Random image rotation.
43
+ do_flip (bool): Whether to flip image or not.
44
+ do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT).
45
+ color_scale (List): Color rescaling factor
46
+ tx (float): Random translation along the x axis.
47
+ ty (float): Random translation along the y axis.
48
+ """
49
+
50
+ tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
51
+ ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
52
+ scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0
53
+ rot = np.clip(np.random.randn(), -2.0,
54
+ 2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0
55
+ do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE
56
+ do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE
57
+ extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0)
58
+ # extreme_crop_lvl = 0
59
+ c_up = 1.0 + aug_config.COLOR_SCALE
60
+ c_low = 1.0 - aug_config.COLOR_SCALE
61
+ color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]
62
+ return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty
63
+
64
+ def rotate_2d(pt_2d: np.array, rot_rad: float) -> np.array:
65
+ """
66
+ Rotate a 2D point on the x-y plane.
67
+ Args:
68
+ pt_2d (np.array): Input 2D point with shape (2,).
69
+ rot_rad (float): Rotation angle
70
+ Returns:
71
+ np.array: Rotated 2D point.
72
+ """
73
+ x = pt_2d[0]
74
+ y = pt_2d[1]
75
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
76
+ xx = x * cs - y * sn
77
+ yy = x * sn + y * cs
78
+ return np.array([xx, yy], dtype=np.float32)
79
+
80
+
81
+ def gen_trans_from_patch_cv(c_x: float, c_y: float,
82
+ src_width: float, src_height: float,
83
+ dst_width: float, dst_height: float,
84
+ scale: float, rot: float) -> np.array:
85
+ """
86
+ Create transformation matrix for the bounding box crop.
87
+ Args:
88
+ c_x (float): Bounding box center x coordinate in the original image.
89
+ c_y (float): Bounding box center y coordinate in the original image.
90
+ src_width (float): Bounding box width.
91
+ src_height (float): Bounding box height.
92
+ dst_width (float): Output box width.
93
+ dst_height (float): Output box height.
94
+ scale (float): Rescaling factor for the bounding box (augmentation).
95
+ rot (float): Random rotation applied to the box.
96
+ Returns:
97
+ trans (np.array): Target geometric transformation.
98
+ """
99
+ # augment size with scale
100
+ src_w = src_width * scale
101
+ src_h = src_height * scale
102
+ src_center = np.zeros(2)
103
+ src_center[0] = c_x
104
+ src_center[1] = c_y
105
+ # augment rotation
106
+ rot_rad = np.pi * rot / 180
107
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
108
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
109
+
110
+ dst_w = dst_width
111
+ dst_h = dst_height
112
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
113
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
114
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
115
+
116
+ src = np.zeros((3, 2), dtype=np.float32)
117
+ src[0, :] = src_center
118
+ src[1, :] = src_center + src_downdir
119
+ src[2, :] = src_center + src_rightdir
120
+
121
+ dst = np.zeros((3, 2), dtype=np.float32)
122
+ dst[0, :] = dst_center
123
+ dst[1, :] = dst_center + dst_downdir
124
+ dst[2, :] = dst_center + dst_rightdir
125
+
126
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
127
+
128
+ return trans
129
+
130
+
131
+ def trans_point2d(pt_2d: np.array, trans: np.array):
132
+ """
133
+ Transform a 2D point using translation matrix trans.
134
+ Args:
135
+ pt_2d (np.array): Input 2D point with shape (2,).
136
+ trans (np.array): Transformation matrix.
137
+ Returns:
138
+ np.array: Transformed 2D point.
139
+ """
140
+ src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T
141
+ dst_pt = np.dot(trans, src_pt)
142
+ return dst_pt[0:2]
143
+
144
+ def get_transform(center, scale, res, rot=0):
145
+ """Generate transformation matrix."""
146
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
147
+ h = 200 * scale
148
+ t = np.zeros((3, 3))
149
+ t[0, 0] = float(res[1]) / h
150
+ t[1, 1] = float(res[0]) / h
151
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
152
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
153
+ t[2, 2] = 1
154
+ if not rot == 0:
155
+ rot = -rot # To match direction of rotation from cropping
156
+ rot_mat = np.zeros((3, 3))
157
+ rot_rad = rot * np.pi / 180
158
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
159
+ rot_mat[0, :2] = [cs, -sn]
160
+ rot_mat[1, :2] = [sn, cs]
161
+ rot_mat[2, 2] = 1
162
+ # Need to rotate around center
163
+ t_mat = np.eye(3)
164
+ t_mat[0, 2] = -res[1] / 2
165
+ t_mat[1, 2] = -res[0] / 2
166
+ t_inv = t_mat.copy()
167
+ t_inv[:2, 2] *= -1
168
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
169
+ return t
170
+
171
+
172
+ def transform(pt, center, scale, res, invert=0, rot=0, as_int=True):
173
+ """Transform pixel location to different reference."""
174
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
175
+ t = get_transform(center, scale, res, rot=rot)
176
+ if invert:
177
+ t = np.linalg.inv(t)
178
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
179
+ new_pt = np.dot(t, new_pt)
180
+ if as_int:
181
+ new_pt = new_pt.astype(int)
182
+ return new_pt[:2] + 1
183
+
184
+ def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0):
185
+ c_x = (ul[0] + br[0])/2
186
+ c_y = (ul[1] + br[1])/2
187
+ bb_width = patch_width = br[0] - ul[0]
188
+ bb_height = patch_height = br[1] - ul[1]
189
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0)
190
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
191
+ flags=cv2.INTER_LINEAR,
192
+ borderMode=border_mode,
193
+ borderValue=border_value
194
+ )
195
+
196
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
197
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
198
+ img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
199
+ flags=cv2.INTER_LINEAR,
200
+ borderMode=cv2.BORDER_CONSTANT,
201
+ )
202
+
203
+ return img_patch
204
+
205
+ def generate_image_patch_skimage(img: np.array, c_x: float, c_y: float,
206
+ bb_width: float, bb_height: float,
207
+ patch_width: float, patch_height: float,
208
+ do_flip: bool, scale: float, rot: float,
209
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
210
+ """
211
+ Crop image according to the supplied bounding box.
212
+ Args:
213
+ img (np.array): Input image of shape (H, W, 3)
214
+ c_x (float): Bounding box center x coordinate in the original image.
215
+ c_y (float): Bounding box center y coordinate in the original image.
216
+ bb_width (float): Bounding box width.
217
+ bb_height (float): Bounding box height.
218
+ patch_width (float): Output box width.
219
+ patch_height (float): Output box height.
220
+ do_flip (bool): Whether to flip image or not.
221
+ scale (float): Rescaling factor for the bounding box (augmentation).
222
+ rot (float): Random rotation applied to the box.
223
+ Returns:
224
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
225
+ trans (np.array): Transformation matrix.
226
+ """
227
+
228
+ img_height, img_width, img_channels = img.shape
229
+ if do_flip:
230
+ img = img[:, ::-1, :]
231
+ c_x = img_width - c_x - 1
232
+
233
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
234
+
235
+ #img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR)
236
+
237
+ # skimage
238
+ center = np.zeros(2)
239
+ center[0] = c_x
240
+ center[1] = c_y
241
+ res = np.zeros(2)
242
+ res[0] = patch_width
243
+ res[1] = patch_height
244
+ # assumes bb_width = bb_height
245
+ # assumes patch_width = patch_height
246
+ assert bb_width == bb_height, f'{bb_width=} != {bb_height=}'
247
+ assert patch_width == patch_height, f'{patch_width=} != {patch_height=}'
248
+ scale1 = scale*bb_width/200.
249
+
250
+ # Upper left point
251
+ ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1
252
+ # Bottom right point
253
+ br = np.array(transform([res[0] + 1,
254
+ res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1
255
+
256
+ # Padding so that when rotated proper amount of context is included
257
+ try:
258
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1
259
+ except:
260
+ breakpoint()
261
+ if not rot == 0:
262
+ ul -= pad
263
+ br += pad
264
+
265
+
266
+ if False:
267
+ # Old way of cropping image
268
+ ul_int = ul.astype(int)
269
+ br_int = br.astype(int)
270
+ new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]]
271
+ if len(img.shape) > 2:
272
+ new_shape += [img.shape[2]]
273
+ new_img = np.zeros(new_shape)
274
+
275
+ # Range to fill new array
276
+ new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0]
277
+ new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1]
278
+ # Range to sample from original image
279
+ old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0])
280
+ old_y = max(0, ul_int[1]), min(len(img), br_int[1])
281
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
282
+ old_x[0]:old_x[1]]
283
+
284
+ # New way of cropping image
285
+ new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32)
286
+
287
+ # print(f'{new_img.shape=}')
288
+ # print(f'{new_img1.shape=}')
289
+ # print(f'{np.allclose(new_img, new_img1)=}')
290
+ # print(f'{img.dtype=}')
291
+
292
+
293
+ if not rot == 0:
294
+ # Remove padding
295
+
296
+ new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
297
+ new_img = new_img[pad:-pad, pad:-pad]
298
+
299
+ if new_img.shape[0] < 1 or new_img.shape[1] < 1:
300
+ print(f'{img.shape=}')
301
+ print(f'{new_img.shape=}')
302
+ print(f'{ul=}')
303
+ print(f'{br=}')
304
+ print(f'{pad=}')
305
+ print(f'{rot=}')
306
+
307
+ breakpoint()
308
+
309
+ # resize image
310
+ new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
311
+
312
+ new_img = np.clip(new_img, 0, 255).astype(np.uint8)
313
+
314
+ return new_img, trans
315
+
316
+
317
+ def generate_image_patch_cv2(img: np.array, c_x: float, c_y: float,
318
+ bb_width: float, bb_height: float,
319
+ patch_width: float, patch_height: float,
320
+ do_flip: bool, scale: float, rot: float,
321
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
322
+ """
323
+ Crop the input image and return the crop and the corresponding transformation matrix.
324
+ Args:
325
+ img (np.array): Input image of shape (H, W, 3)
326
+ c_x (float): Bounding box center x coordinate in the original image.
327
+ c_y (float): Bounding box center y coordinate in the original image.
328
+ bb_width (float): Bounding box width.
329
+ bb_height (float): Bounding box height.
330
+ patch_width (float): Output box width.
331
+ patch_height (float): Output box height.
332
+ do_flip (bool): Whether to flip image or not.
333
+ scale (float): Rescaling factor for the bounding box (augmentation).
334
+ rot (float): Random rotation applied to the box.
335
+ Returns:
336
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
337
+ trans (np.array): Transformation matrix.
338
+ """
339
+
340
+ img_height, img_width, img_channels = img.shape
341
+ if do_flip:
342
+ img = img[:, ::-1, :]
343
+ c_x = img_width - c_x - 1
344
+
345
+
346
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
347
+
348
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
349
+ flags=cv2.INTER_LINEAR,
350
+ borderMode=border_mode,
351
+ borderValue=border_value,
352
+ )
353
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
354
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
355
+ img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
356
+ flags=cv2.INTER_LINEAR,
357
+ borderMode=cv2.BORDER_CONSTANT,
358
+ )
359
+
360
+ return img_patch, trans
361
+
362
+
363
+ def convert_cvimg_to_tensor(cvimg: np.array):
364
+ """
365
+ Convert image from HWC to CHW format.
366
+ Args:
367
+ cvimg (np.array): Image of shape (H, W, 3) as loaded by OpenCV.
368
+ Returns:
369
+ np.array: Output image of shape (3, H, W).
370
+ """
371
+ # from h,w,c(OpenCV) to c,h,w
372
+ img = cvimg.copy()
373
+ img = np.transpose(img, (2, 0, 1))
374
+ # from int to float
375
+ img = img.astype(np.float32)
376
+ return img
377
+
378
+ def fliplr_params(smpl_params: Dict, has_smpl_params: Dict) -> Tuple[Dict, Dict]:
379
+ """
380
+ Flip SMPL parameters when flipping the image.
381
+ Args:
382
+ smpl_params (Dict): SMPL parameter annotations.
383
+ has_smpl_params (Dict): Whether SMPL annotations are valid.
384
+ Returns:
385
+ Dict, Dict: Flipped SMPL parameters and valid flags.
386
+ """
387
+ global_orient = smpl_params['global_orient'].copy()
388
+ body_pose = smpl_params['body_pose'].copy()
389
+ betas = smpl_params['betas'].copy()
390
+ has_global_orient = has_smpl_params['global_orient'].copy()
391
+ has_body_pose = has_smpl_params['body_pose'].copy()
392
+ has_betas = has_smpl_params['betas'].copy()
393
+
394
+ body_pose_permutation = [6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13,
395
+ 14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33,
396
+ 34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41,
397
+ 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55,
398
+ 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68]
399
+ body_pose_permutation = body_pose_permutation[:len(body_pose)]
400
+ body_pose_permutation = [i-3 for i in body_pose_permutation]
401
+
402
+ body_pose = body_pose[body_pose_permutation]
403
+
404
+ global_orient[1::3] *= -1
405
+ global_orient[2::3] *= -1
406
+ body_pose[1::3] *= -1
407
+ body_pose[2::3] *= -1
408
+
409
+ smpl_params = {'global_orient': global_orient.astype(np.float32),
410
+ 'body_pose': body_pose.astype(np.float32),
411
+ 'betas': betas.astype(np.float32)
412
+ }
413
+
414
+ has_smpl_params = {'global_orient': has_global_orient,
415
+ 'body_pose': has_body_pose,
416
+ 'betas': has_betas
417
+ }
418
+
419
+ return smpl_params, has_smpl_params
420
+
421
+
422
+ def fliplr_keypoints(joints: np.array, width: float, flip_permutation: List[int]) -> np.array:
423
+ """
424
+ Flip 2D or 3D keypoints.
425
+ Args:
426
+ joints (np.array): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence.
427
+ flip_permutation (List): Permutation to apply after flipping.
428
+ Returns:
429
+ np.array: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively.
430
+ """
431
+ joints = joints.copy()
432
+ # Flip horizontal
433
+ joints[:, 0] = width - joints[:, 0] - 1
434
+ joints = joints[flip_permutation, :]
435
+
436
+ return joints
437
+
438
+ def keypoint_3d_processing(keypoints_3d: np.array, flip_permutation: List[int], rot: float, do_flip: float) -> np.array:
439
+ """
440
+ Process 3D keypoints (rotation/flipping).
441
+ Args:
442
+ keypoints_3d (np.array): Input array of shape (N, 4) containing the 3D keypoints and confidence.
443
+ flip_permutation (List): Permutation to apply after flipping.
444
+ rot (float): Random rotation applied to the keypoints.
445
+ do_flip (bool): Whether to flip keypoints or not.
446
+ Returns:
447
+ np.array: Transformed 3D keypoints with shape (N, 4).
448
+ """
449
+ if do_flip:
450
+ keypoints_3d = fliplr_keypoints(keypoints_3d, 1, flip_permutation)
451
+ # in-plane rotation
452
+ rot_mat = np.eye(3)
453
+ if not rot == 0:
454
+ rot_rad = -rot * np.pi / 180
455
+ sn,cs = np.sin(rot_rad), np.cos(rot_rad)
456
+ rot_mat[0,:2] = [cs, -sn]
457
+ rot_mat[1,:2] = [sn, cs]
458
+ keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1])
459
+ # flip the x coordinates
460
+ keypoints_3d = keypoints_3d.astype('float32')
461
+ return keypoints_3d
462
+
463
+ def rot_aa(aa: np.array, rot: float) -> np.array:
464
+ """
465
+ Rotate axis angle parameters.
466
+ Args:
467
+ aa (np.array): Axis-angle vector of shape (3,).
468
+ rot (np.array): Rotation angle in degrees.
469
+ Returns:
470
+ np.array: Rotated axis-angle vector.
471
+ """
472
+ # pose parameters
473
+ R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
474
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
475
+ [0, 0, 1]])
476
+ # find the rotation of the body in camera frame
477
+ per_rdg, _ = cv2.Rodrigues(aa)
478
+ # apply the global rotation to the global orientation
479
+ resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
480
+ aa = (resrot.T)[0]
481
+ return aa.astype(np.float32)
482
+
483
+ def smpl_param_processing(smpl_params: Dict, has_smpl_params: Dict, rot: float, do_flip: bool) -> Tuple[Dict, Dict]:
484
+ """
485
+ Apply random augmentations to the SMPL parameters.
486
+ Args:
487
+ smpl_params (Dict): SMPL parameter annotations.
488
+ has_smpl_params (Dict): Whether SMPL annotations are valid.
489
+ rot (float): Random rotation applied to the keypoints.
490
+ do_flip (bool): Whether to flip keypoints or not.
491
+ Returns:
492
+ Dict, Dict: Transformed SMPL parameters and valid flags.
493
+ """
494
+ if do_flip:
495
+ smpl_params, has_smpl_params = fliplr_params(smpl_params, has_smpl_params)
496
+ smpl_params['global_orient'] = rot_aa(smpl_params['global_orient'], rot)
497
+ return smpl_params, has_smpl_params
498
+
499
+
500
+
501
+ def get_example(img_path: str|np.ndarray, center_x: float, center_y: float,
502
+ width: float, height: float,
503
+ keypoints_2d: np.array, keypoints_3d: np.array,
504
+ smpl_params: Dict, has_smpl_params: Dict,
505
+ flip_kp_permutation: List[int],
506
+ patch_width: int, patch_height: int,
507
+ mean: np.array, std: np.array,
508
+ do_augment: bool, augm_config: CfgNode,
509
+ is_bgr: bool = True,
510
+ use_skimage_antialias: bool = False,
511
+ border_mode: int = cv2.BORDER_CONSTANT,
512
+ return_trans: bool = False) -> Tuple:
513
+ """
514
+ Get an example from the dataset and (possibly) apply random augmentations.
515
+ Args:
516
+ img_path (str): Image filename
517
+ center_x (float): Bounding box center x coordinate in the original image.
518
+ center_y (float): Bounding box center y coordinate in the original image.
519
+ width (float): Bounding box width.
520
+ height (float): Bounding box height.
521
+ keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
522
+ keypoints_3d (np.array): Array with shape (N,4) containing the 3D keypoints.
523
+ smpl_params (Dict): SMPL parameter annotations.
524
+ has_smpl_params (Dict): Whether SMPL annotations are valid.
525
+ flip_kp_permutation (List): Permutation to apply to the keypoints after flipping.
526
+ patch_width (float): Output box width.
527
+ patch_height (float): Output box height.
528
+ mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
529
+ std (np.array): Array of shape (3,) containing the std for normalizing the input image.
530
+ do_augment (bool): Whether to apply data augmentation or not.
531
+ aug_config (CfgNode): Config containing augmentation parameters.
532
+ Returns:
533
+ return img_patch, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size
534
+ img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
535
+ keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
536
+ keypoints_3d (np.array): Array with shape (N,4) containing the transformed 3D keypoints.
537
+ smpl_params (Dict): Transformed SMPL parameters.
538
+ has_smpl_params (Dict): Valid flag for transformed SMPL parameters.
539
+ img_size (np.array): Image size of the original image.
540
+ """
541
+ if isinstance(img_path, str):
542
+ # 1. load image
543
+ cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
544
+ if not isinstance(cvimg, np.ndarray):
545
+ raise IOError("Fail to read %s" % img_path)
546
+ elif isinstance(img_path, np.ndarray):
547
+ cvimg = img_path
548
+ else:
549
+ raise TypeError('img_path must be either a string or a numpy array')
550
+ img_height, img_width, img_channels = cvimg.shape
551
+
552
+ img_size = np.array([img_height, img_width])
553
+
554
+ # 2. get augmentation params
555
+ if do_augment:
556
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
557
+ else:
558
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0, 1.0, 1.0], 0., 0.
559
+
560
+ if width < 1 or height < 1:
561
+ breakpoint()
562
+
563
+ if do_extreme_crop:
564
+ if extreme_crop_lvl == 0:
565
+ center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d)
566
+ elif extreme_crop_lvl == 1:
567
+ center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height, keypoints_2d)
568
+
569
+ THRESH = 4
570
+ if width1 < THRESH or height1 < THRESH:
571
+ # print(f'{do_extreme_crop=}')
572
+ # print(f'width: {width}, height: {height}')
573
+ # print(f'width1: {width1}, height1: {height1}')
574
+ # print(f'center_x: {center_x}, center_y: {center_y}')
575
+ # print(f'center_x1: {center_x1}, center_y1: {center_y1}')
576
+ # print(f'keypoints_2d: {keypoints_2d}')
577
+ # print(f'\n\n', flush=True)
578
+ # breakpoint()
579
+ pass
580
+ # print(f'skip ==> width1: {width1}, height1: {height1}, width: {width}, height: {height}')
581
+ else:
582
+ center_x, center_y, width, height = center_x1, center_y1, width1, height1
583
+
584
+ center_x += width * tx
585
+ center_y += height * ty
586
+
587
+ # Process 3D keypoints
588
+ keypoints_3d = keypoint_3d_processing(keypoints_3d, flip_kp_permutation, rot, do_flip)
589
+
590
+ # 3. generate image patch
591
+ if use_skimage_antialias:
592
+ # Blur image to avoid aliasing artifacts
593
+ downsampling_factor = (patch_width / (width*scale))
594
+ if downsampling_factor > 1.1:
595
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True, truncate=3.0)
596
+
597
+ img_patch_cv, trans = generate_image_patch_cv2(cvimg,
598
+ center_x, center_y,
599
+ width, height,
600
+ patch_width, patch_height,
601
+ do_flip, scale, rot,
602
+ border_mode=border_mode)
603
+ # img_patch_cv, trans = generate_image_patch_skimage(cvimg,
604
+ # center_x, center_y,
605
+ # width, height,
606
+ # patch_width, patch_height,
607
+ # do_flip, scale, rot,
608
+ # border_mode=border_mode)
609
+
610
+ image = img_patch_cv.copy()
611
+ if is_bgr:
612
+ image = image[:, :, ::-1]
613
+ img_patch_cv = image.copy()
614
+ img_patch = convert_cvimg_to_tensor(image)
615
+
616
+
617
+ smpl_params, has_smpl_params = smpl_param_processing(smpl_params, has_smpl_params, rot, do_flip)
618
+
619
+ # apply normalization
620
+ for n_c in range(min(img_channels, 3)):
621
+ img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
622
+ if mean is not None and std is not None:
623
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
624
+ if do_flip:
625
+ keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, flip_kp_permutation)
626
+
627
+
628
+ for n_jt in range(len(keypoints_2d)):
629
+ keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
630
+ keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
631
+
632
+ if not return_trans:
633
+ return img_patch, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size
634
+ else:
635
+ return img_patch, keypoints_2d, keypoints_3d, smpl_params, has_smpl_params, img_size, trans
636
+
637
+ def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
638
+ """
639
+ Extreme cropping: Crop the box up to the hip locations.
640
+ Args:
641
+ center_x (float): x coordinate of the bounding box center.
642
+ center_y (float): y coordinate of the bounding box center.
643
+ width (float): Bounding box width.
644
+ height (float): Bounding box height.
645
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
646
+ Returns:
647
+ center_x (float): x coordinate of the new bounding box center.
648
+ center_y (float): y coordinate of the new bounding box center.
649
+ width (float): New bounding box width.
650
+ height (float): New bounding box height.
651
+ """
652
+ keypoints_2d = keypoints_2d.copy()
653
+ lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25+0, 25+1, 25+4, 25+5]
654
+ keypoints_2d[lower_body_keypoints, :] = 0
655
+ if keypoints_2d[:, -1].sum() > 1:
656
+ center, scale = get_bbox(keypoints_2d)
657
+ center_x = center[0]
658
+ center_y = center[1]
659
+ width = 1.1 * scale[0]
660
+ height = 1.1 * scale[1]
661
+ return center_x, center_y, width, height
662
+
663
+
664
+ def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
665
+ """
666
+ Extreme cropping: Crop the box up to the shoulder locations.
667
+ Args:
668
+ center_x (float): x coordinate of the bounding box center.
669
+ center_y (float): y coordinate of the bounding box center.
670
+ width (float): Bounding box width.
671
+ height (float): Bounding box height.
672
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
673
+ Returns:
674
+ center_x (float): x coordinate of the new bounding box center.
675
+ center_y (float): y coordinate of the new bounding box center.
676
+ width (float): New bounding box width.
677
+ height (float): New bounding box height.
678
+ """
679
+ keypoints_2d = keypoints_2d.copy()
680
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16]]
681
+ keypoints_2d[lower_body_keypoints, :] = 0
682
+ center, scale = get_bbox(keypoints_2d)
683
+ if keypoints_2d[:, -1].sum() > 1:
684
+ center, scale = get_bbox(keypoints_2d)
685
+ center_x = center[0]
686
+ center_y = center[1]
687
+ width = 1.2 * scale[0]
688
+ height = 1.2 * scale[1]
689
+ return center_x, center_y, width, height
690
+
691
+ def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
692
+ """
693
+ Extreme cropping: Crop the box and keep on only the head.
694
+ Args:
695
+ center_x (float): x coordinate of the bounding box center.
696
+ center_y (float): y coordinate of the bounding box center.
697
+ width (float): Bounding box width.
698
+ height (float): Bounding box height.
699
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
700
+ Returns:
701
+ center_x (float): x coordinate of the new bounding box center.
702
+ center_y (float): y coordinate of the new bounding box center.
703
+ width (float): New bounding box width.
704
+ height (float): New bounding box height.
705
+ """
706
+ keypoints_2d = keypoints_2d.copy()
707
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16]]
708
+ keypoints_2d[lower_body_keypoints, :] = 0
709
+ if keypoints_2d[:, -1].sum() > 1:
710
+ center, scale = get_bbox(keypoints_2d)
711
+ center_x = center[0]
712
+ center_y = center[1]
713
+ width = 1.3 * scale[0]
714
+ height = 1.3 * scale[1]
715
+ return center_x, center_y, width, height
716
+
717
+ def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
718
+ """
719
+ Extreme cropping: Crop the box and keep on only the torso.
720
+ Args:
721
+ center_x (float): x coordinate of the bounding box center.
722
+ center_y (float): y coordinate of the bounding box center.
723
+ width (float): Bounding box width.
724
+ height (float): Bounding box height.
725
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
726
+ Returns:
727
+ center_x (float): x coordinate of the new bounding box center.
728
+ center_y (float): y coordinate of the new bounding box center.
729
+ width (float): New bounding box width.
730
+ height (float): New bounding box height.
731
+ """
732
+ keypoints_2d = keypoints_2d.copy()
733
+ nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 4, 5, 6, 7, 10, 11, 13, 17, 18]]
734
+ keypoints_2d[nontorso_body_keypoints, :] = 0
735
+ if keypoints_2d[:, -1].sum() > 1:
736
+ center, scale = get_bbox(keypoints_2d)
737
+ center_x = center[0]
738
+ center_y = center[1]
739
+ width = 1.1 * scale[0]
740
+ height = 1.1 * scale[1]
741
+ return center_x, center_y, width, height
742
+
743
+ def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
744
+ """
745
+ Extreme cropping: Crop the box and keep on only the right arm.
746
+ Args:
747
+ center_x (float): x coordinate of the bounding box center.
748
+ center_y (float): y coordinate of the bounding box center.
749
+ width (float): Bounding box width.
750
+ height (float): Bounding box height.
751
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
752
+ Returns:
753
+ center_x (float): x coordinate of the new bounding box center.
754
+ center_y (float): y coordinate of the new bounding box center.
755
+ width (float): New bounding box width.
756
+ height (float): New bounding box height.
757
+ """
758
+ keypoints_2d = keypoints_2d.copy()
759
+ nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
760
+ keypoints_2d[nonrightarm_body_keypoints, :] = 0
761
+ if keypoints_2d[:, -1].sum() > 1:
762
+ center, scale = get_bbox(keypoints_2d)
763
+ center_x = center[0]
764
+ center_y = center[1]
765
+ width = 1.1 * scale[0]
766
+ height = 1.1 * scale[1]
767
+ return center_x, center_y, width, height
768
+
769
+ def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
770
+ """
771
+ Extreme cropping: Crop the box and keep on only the left arm.
772
+ Args:
773
+ center_x (float): x coordinate of the bounding box center.
774
+ center_y (float): y coordinate of the bounding box center.
775
+ width (float): Bounding box width.
776
+ height (float): Bounding box height.
777
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
778
+ Returns:
779
+ center_x (float): x coordinate of the new bounding box center.
780
+ center_y (float): y coordinate of the new bounding box center.
781
+ width (float): New bounding box width.
782
+ height (float): New bounding box height.
783
+ """
784
+ keypoints_2d = keypoints_2d.copy()
785
+ nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]]
786
+ keypoints_2d[nonleftarm_body_keypoints, :] = 0
787
+ if keypoints_2d[:, -1].sum() > 1:
788
+ center, scale = get_bbox(keypoints_2d)
789
+ center_x = center[0]
790
+ center_y = center[1]
791
+ width = 1.1 * scale[0]
792
+ height = 1.1 * scale[1]
793
+ return center_x, center_y, width, height
794
+
795
+ def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
796
+ """
797
+ Extreme cropping: Crop the box and keep on only the legs.
798
+ Args:
799
+ center_x (float): x coordinate of the bounding box center.
800
+ center_y (float): y coordinate of the bounding box center.
801
+ width (float): Bounding box width.
802
+ height (float): Bounding box height.
803
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
804
+ Returns:
805
+ center_x (float): x coordinate of the new bounding box center.
806
+ center_y (float): y coordinate of the new bounding box center.
807
+ width (float): New bounding box width.
808
+ height (float): New bounding box height.
809
+ """
810
+ keypoints_2d = keypoints_2d.copy()
811
+ nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]]
812
+ keypoints_2d[nonlegs_body_keypoints, :] = 0
813
+ if keypoints_2d[:, -1].sum() > 1:
814
+ center, scale = get_bbox(keypoints_2d)
815
+ center_x = center[0]
816
+ center_y = center[1]
817
+ width = 1.1 * scale[0]
818
+ height = 1.1 * scale[1]
819
+ return center_x, center_y, width, height
820
+
821
+ def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
822
+ """
823
+ Extreme cropping: Crop the box and keep on only the right leg.
824
+ Args:
825
+ center_x (float): x coordinate of the bounding box center.
826
+ center_y (float): y coordinate of the bounding box center.
827
+ width (float): Bounding box width.
828
+ height (float): Bounding box height.
829
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
830
+ Returns:
831
+ center_x (float): x coordinate of the new bounding box center.
832
+ center_y (float): y coordinate of the new bounding box center.
833
+ width (float): New bounding box width.
834
+ height (float): New bounding box height.
835
+ """
836
+ keypoints_2d = keypoints_2d.copy()
837
+ nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
838
+ keypoints_2d[nonrightleg_body_keypoints, :] = 0
839
+ if keypoints_2d[:, -1].sum() > 1:
840
+ center, scale = get_bbox(keypoints_2d)
841
+ center_x = center[0]
842
+ center_y = center[1]
843
+ width = 1.1 * scale[0]
844
+ height = 1.1 * scale[1]
845
+ return center_x, center_y, width, height
846
+
847
+ def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
848
+ """
849
+ Extreme cropping: Crop the box and keep on only the left leg.
850
+ Args:
851
+ center_x (float): x coordinate of the bounding box center.
852
+ center_y (float): y coordinate of the bounding box center.
853
+ width (float): Bounding box width.
854
+ height (float): Bounding box height.
855
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
856
+ Returns:
857
+ center_x (float): x coordinate of the new bounding box center.
858
+ center_y (float): y coordinate of the new bounding box center.
859
+ width (float): New bounding box width.
860
+ height (float): New bounding box height.
861
+ """
862
+ keypoints_2d = keypoints_2d.copy()
863
+ nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
864
+ keypoints_2d[nonleftleg_body_keypoints, :] = 0
865
+ if keypoints_2d[:, -1].sum() > 1:
866
+ center, scale = get_bbox(keypoints_2d)
867
+ center_x = center[0]
868
+ center_y = center[1]
869
+ width = 1.1 * scale[0]
870
+ height = 1.1 * scale[1]
871
+ return center_x, center_y, width, height
872
+
873
+ def full_body(keypoints_2d: np.array) -> bool:
874
+ """
875
+ Check if all main body joints are visible.
876
+ Args:
877
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
878
+ Returns:
879
+ bool: True if all main body joints are visible.
880
+ """
881
+
882
+ body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14]
883
+ body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]]
884
+ return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(body_keypoints)
885
+
886
+ def upper_body(keypoints_2d: np.array):
887
+ """
888
+ Check if all upper body joints are visible.
889
+ Args:
890
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
891
+ Returns:
892
+ bool: True if all main body joints are visible.
893
+ """
894
+ lower_body_keypoints_openpose = [10, 11, 13, 14]
895
+ lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]]
896
+ upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18]
897
+ upper_body_keypoints = [25+8, 25+9, 25+12, 25+13, 25+17, 25+18]
898
+ return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0)\
899
+ and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2)
900
+
901
+ def get_bbox(keypoints_2d: np.array, rescale: float = 1.2) -> Tuple:
902
+ """
903
+ Get center and scale for bounding box from openpose detections.
904
+ Args:
905
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
906
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
907
+ Returns:
908
+ center (np.array): Array of shape (2,) containing the new bounding box center.
909
+ scale (float): New bounding box scale.
910
+ """
911
+ valid = keypoints_2d[:,-1] > 0
912
+ valid_keypoints = keypoints_2d[valid][:,:-1]
913
+ center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0))
914
+ bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0))
915
+ # adjust bounding box tightness
916
+ scale = bbox_size
917
+ scale *= rescale
918
+ return center, scale
919
+
920
+ def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
921
+ """
922
+ Perform extreme cropping
923
+ Args:
924
+ center_x (float): x coordinate of bounding box center.
925
+ center_y (float): y coordinate of bounding box center.
926
+ width (float): bounding box width.
927
+ height (float): bounding box height.
928
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
929
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
930
+ Returns:
931
+ center_x (float): x coordinate of bounding box center.
932
+ center_y (float): y coordinate of bounding box center.
933
+ width (float): bounding box width.
934
+ height (float): bounding box height.
935
+ """
936
+ p = torch.rand(1).item()
937
+ if full_body(keypoints_2d):
938
+ if p < 0.7:
939
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
940
+ elif p < 0.9:
941
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
942
+ else:
943
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
944
+ elif upper_body(keypoints_2d):
945
+ if p < 0.9:
946
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
947
+ else:
948
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
949
+
950
+ return center_x, center_y, max(width, height), max(width, height)
951
+
952
+ def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
953
+ """
954
+ Perform aggressive extreme cropping
955
+ Args:
956
+ center_x (float): x coordinate of bounding box center.
957
+ center_y (float): y coordinate of bounding box center.
958
+ width (float): bounding box width.
959
+ height (float): bounding box height.
960
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
961
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
962
+ Returns:
963
+ center_x (float): x coordinate of bounding box center.
964
+ center_y (float): y coordinate of bounding box center.
965
+ width (float): bounding box width.
966
+ height (float): bounding box height.
967
+ """
968
+ p = torch.rand(1).item()
969
+ if full_body(keypoints_2d):
970
+ if p < 0.2:
971
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
972
+ elif p < 0.3:
973
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
974
+ elif p < 0.4:
975
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
976
+ elif p < 0.5:
977
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
978
+ elif p < 0.6:
979
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
980
+ elif p < 0.7:
981
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
982
+ elif p < 0.8:
983
+ center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d)
984
+ elif p < 0.9:
985
+ center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d)
986
+ else:
987
+ center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d)
988
+ elif upper_body(keypoints_2d):
989
+ if p < 0.2:
990
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
991
+ elif p < 0.4:
992
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
993
+ elif p < 0.6:
994
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
995
+ elif p < 0.8:
996
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
997
+ else:
998
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
999
+ return center_x, center_y, max(width, height), max(width, height)
hmr2/datasets/vitdet_dataset.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from skimage.filters import gaussian
6
+ from yacs.config import CfgNode
7
+ import torch
8
+
9
+ from .utils import (convert_cvimg_to_tensor,
10
+ expand_to_aspect_ratio,
11
+ generate_image_patch_cv2)
12
+
13
+ DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
14
+ DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
15
+
16
+ class ViTDetDataset(torch.utils.data.Dataset):
17
+
18
+ def __init__(self,
19
+ cfg: CfgNode,
20
+ img_cv2: np.array,
21
+ boxes: np.array,
22
+ train: bool = False,
23
+ **kwargs):
24
+ super().__init__()
25
+ self.cfg = cfg
26
+ self.img_cv2 = img_cv2
27
+ # self.boxes = boxes
28
+
29
+ assert train == False, "ViTDetDataset is only for inference"
30
+ self.train = train
31
+ self.img_size = cfg.MODEL.IMAGE_SIZE
32
+ self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN)
33
+ self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD)
34
+
35
+ # Preprocess annotations
36
+ boxes = boxes.astype(np.float32)
37
+ self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0
38
+ self.scale = (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0
39
+ self.personid = np.arange(len(boxes), dtype=np.int32)
40
+
41
+ def __len__(self) -> int:
42
+ return len(self.personid)
43
+
44
+ def __getitem__(self, idx: int) -> Dict[str, np.array]:
45
+
46
+ center = self.center[idx].copy()
47
+ center_x = center[0]
48
+ center_y = center[1]
49
+
50
+ scale = self.scale[idx]
51
+ BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None)
52
+ bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max()
53
+
54
+ patch_width = patch_height = self.img_size
55
+
56
+ # 3. generate image patch
57
+ # if use_skimage_antialias:
58
+ cvimg = self.img_cv2.copy()
59
+ if True:
60
+ # Blur image to avoid aliasing artifacts
61
+ downsampling_factor = ((bbox_size*1.0) / patch_width)
62
+ print(f'{downsampling_factor=}')
63
+ downsampling_factor = downsampling_factor / 2.0
64
+ if downsampling_factor > 1.1:
65
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True)
66
+
67
+
68
+ img_patch_cv, trans = generate_image_patch_cv2(cvimg,
69
+ center_x, center_y,
70
+ bbox_size, bbox_size,
71
+ patch_width, patch_height,
72
+ False, 1.0, 0,
73
+ border_mode=cv2.BORDER_CONSTANT)
74
+ img_patch_cv = img_patch_cv[:, :, ::-1]
75
+ img_patch = convert_cvimg_to_tensor(img_patch_cv)
76
+
77
+ # apply normalization
78
+ for n_c in range(min(self.img_cv2.shape[2], 3)):
79
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c]
80
+
81
+
82
+ item = {
83
+ 'img': img_patch,
84
+ 'personid': int(self.personid[idx]),
85
+ }
86
+ item['box_center'] = self.center[idx].copy()
87
+ item['box_size'] = bbox_size
88
+ item['img_size'] = 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]])
89
+ return item
hmr2/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .smpl_wrapper import SMPL
2
+ from .hmr2 import HMR2
3
+ from .discriminator import Discriminator
hmr2/models/backbones/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .vit import vit
2
+
3
+ def create_backbone(cfg):
4
+ if cfg.MODEL.BACKBONE.TYPE == 'vit':
5
+ return vit(cfg)
6
+ else:
7
+ raise NotImplementedError('Backbone type is not implemented')
hmr2/models/backbones/vit.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ from functools import partial
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint as checkpoint
9
+
10
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
11
+
12
+ def vit(cfg):
13
+ return ViT(
14
+ img_size=(256, 192),
15
+ patch_size=16,
16
+ embed_dim=1280,
17
+ depth=32,
18
+ num_heads=16,
19
+ ratio=1,
20
+ use_checkpoint=False,
21
+ mlp_ratio=4,
22
+ qkv_bias=True,
23
+ drop_path_rate=0.55,
24
+ )
25
+
26
+ def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
27
+ """
28
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
29
+ dimension for the original embeddings.
30
+ Args:
31
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
32
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
33
+ hw (Tuple): size of input image tokens.
34
+
35
+ Returns:
36
+ Absolute positional embeddings after processing with shape (1, H, W, C)
37
+ """
38
+ cls_token = None
39
+ B, L, C = abs_pos.shape
40
+ if has_cls_token:
41
+ cls_token = abs_pos[:, 0:1]
42
+ abs_pos = abs_pos[:, 1:]
43
+
44
+ if ori_h != h or ori_w != w:
45
+ new_abs_pos = F.interpolate(
46
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
47
+ size=(h, w),
48
+ mode="bicubic",
49
+ align_corners=False,
50
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
51
+
52
+ else:
53
+ new_abs_pos = abs_pos
54
+
55
+ if cls_token is not None:
56
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
57
+ return new_abs_pos
58
+
59
+ class DropPath(nn.Module):
60
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
61
+ """
62
+ def __init__(self, drop_prob=None):
63
+ super(DropPath, self).__init__()
64
+ self.drop_prob = drop_prob
65
+
66
+ def forward(self, x):
67
+ return drop_path(x, self.drop_prob, self.training)
68
+
69
+ def extra_repr(self):
70
+ return 'p={}'.format(self.drop_prob)
71
+
72
+ class Mlp(nn.Module):
73
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
74
+ super().__init__()
75
+ out_features = out_features or in_features
76
+ hidden_features = hidden_features or in_features
77
+ self.fc1 = nn.Linear(in_features, hidden_features)
78
+ self.act = act_layer()
79
+ self.fc2 = nn.Linear(hidden_features, out_features)
80
+ self.drop = nn.Dropout(drop)
81
+
82
+ def forward(self, x):
83
+ x = self.fc1(x)
84
+ x = self.act(x)
85
+ x = self.fc2(x)
86
+ x = self.drop(x)
87
+ return x
88
+
89
+ class Attention(nn.Module):
90
+ def __init__(
91
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
92
+ proj_drop=0., attn_head_dim=None,):
93
+ super().__init__()
94
+ self.num_heads = num_heads
95
+ head_dim = dim // num_heads
96
+ self.dim = dim
97
+
98
+ if attn_head_dim is not None:
99
+ head_dim = attn_head_dim
100
+ all_head_dim = head_dim * self.num_heads
101
+
102
+ self.scale = qk_scale or head_dim ** -0.5
103
+
104
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
105
+
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(all_head_dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop)
109
+
110
+ def forward(self, x):
111
+ B, N, C = x.shape
112
+ qkv = self.qkv(x)
113
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
114
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
115
+
116
+ q = q * self.scale
117
+ attn = (q @ k.transpose(-2, -1))
118
+
119
+ attn = attn.softmax(dim=-1)
120
+ attn = self.attn_drop(attn)
121
+
122
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
123
+ x = self.proj(x)
124
+ x = self.proj_drop(x)
125
+
126
+ return x
127
+
128
+ class Block(nn.Module):
129
+
130
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
131
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
132
+ norm_layer=nn.LayerNorm, attn_head_dim=None
133
+ ):
134
+ super().__init__()
135
+
136
+ self.norm1 = norm_layer(dim)
137
+ self.attn = Attention(
138
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
139
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
140
+ )
141
+
142
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
143
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
144
+ self.norm2 = norm_layer(dim)
145
+ mlp_hidden_dim = int(dim * mlp_ratio)
146
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
147
+
148
+ def forward(self, x):
149
+ x = x + self.drop_path(self.attn(self.norm1(x)))
150
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
151
+ return x
152
+
153
+
154
+ class PatchEmbed(nn.Module):
155
+ """ Image to Patch Embedding
156
+ """
157
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
158
+ super().__init__()
159
+ img_size = to_2tuple(img_size)
160
+ patch_size = to_2tuple(patch_size)
161
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
162
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
163
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
164
+ self.img_size = img_size
165
+ self.patch_size = patch_size
166
+ self.num_patches = num_patches
167
+
168
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
169
+
170
+ def forward(self, x, **kwargs):
171
+ B, C, H, W = x.shape
172
+ x = self.proj(x)
173
+ Hp, Wp = x.shape[2], x.shape[3]
174
+
175
+ x = x.flatten(2).transpose(1, 2)
176
+ return x, (Hp, Wp)
177
+
178
+
179
+ class HybridEmbed(nn.Module):
180
+ """ CNN Feature Map Embedding
181
+ Extract feature map from CNN, flatten, project to embedding dim.
182
+ """
183
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
184
+ super().__init__()
185
+ assert isinstance(backbone, nn.Module)
186
+ img_size = to_2tuple(img_size)
187
+ self.img_size = img_size
188
+ self.backbone = backbone
189
+ if feature_size is None:
190
+ with torch.no_grad():
191
+ training = backbone.training
192
+ if training:
193
+ backbone.eval()
194
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
195
+ feature_size = o.shape[-2:]
196
+ feature_dim = o.shape[1]
197
+ backbone.train(training)
198
+ else:
199
+ feature_size = to_2tuple(feature_size)
200
+ feature_dim = self.backbone.feature_info.channels()[-1]
201
+ self.num_patches = feature_size[0] * feature_size[1]
202
+ self.proj = nn.Linear(feature_dim, embed_dim)
203
+
204
+ def forward(self, x):
205
+ x = self.backbone(x)[-1]
206
+ x = x.flatten(2).transpose(1, 2)
207
+ x = self.proj(x)
208
+ return x
209
+
210
+
211
+ class ViT(nn.Module):
212
+
213
+ def __init__(self,
214
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
215
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
216
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
217
+ frozen_stages=-1, ratio=1, last_norm=True,
218
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
219
+ ):
220
+ # Protect mutable default arguments
221
+ super(ViT, self).__init__()
222
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
223
+ self.num_classes = num_classes
224
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
225
+ self.frozen_stages = frozen_stages
226
+ self.use_checkpoint = use_checkpoint
227
+ self.patch_padding = patch_padding
228
+ self.freeze_attn = freeze_attn
229
+ self.freeze_ffn = freeze_ffn
230
+ self.depth = depth
231
+
232
+ if hybrid_backbone is not None:
233
+ self.patch_embed = HybridEmbed(
234
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
235
+ else:
236
+ self.patch_embed = PatchEmbed(
237
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
238
+ num_patches = self.patch_embed.num_patches
239
+
240
+ # since the pretraining model has class token
241
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
242
+
243
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
244
+
245
+ self.blocks = nn.ModuleList([
246
+ Block(
247
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
248
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
249
+ )
250
+ for i in range(depth)])
251
+
252
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
253
+
254
+ if self.pos_embed is not None:
255
+ trunc_normal_(self.pos_embed, std=.02)
256
+
257
+ self._freeze_stages()
258
+
259
+ def _freeze_stages(self):
260
+ """Freeze parameters."""
261
+ if self.frozen_stages >= 0:
262
+ self.patch_embed.eval()
263
+ for param in self.patch_embed.parameters():
264
+ param.requires_grad = False
265
+
266
+ for i in range(1, self.frozen_stages + 1):
267
+ m = self.blocks[i]
268
+ m.eval()
269
+ for param in m.parameters():
270
+ param.requires_grad = False
271
+
272
+ if self.freeze_attn:
273
+ for i in range(0, self.depth):
274
+ m = self.blocks[i]
275
+ m.attn.eval()
276
+ m.norm1.eval()
277
+ for param in m.attn.parameters():
278
+ param.requires_grad = False
279
+ for param in m.norm1.parameters():
280
+ param.requires_grad = False
281
+
282
+ if self.freeze_ffn:
283
+ self.pos_embed.requires_grad = False
284
+ self.patch_embed.eval()
285
+ for param in self.patch_embed.parameters():
286
+ param.requires_grad = False
287
+ for i in range(0, self.depth):
288
+ m = self.blocks[i]
289
+ m.mlp.eval()
290
+ m.norm2.eval()
291
+ for param in m.mlp.parameters():
292
+ param.requires_grad = False
293
+ for param in m.norm2.parameters():
294
+ param.requires_grad = False
295
+
296
+ def init_weights(self):
297
+ """Initialize the weights in backbone.
298
+ Args:
299
+ pretrained (str, optional): Path to pre-trained weights.
300
+ Defaults to None.
301
+ """
302
+ def _init_weights(m):
303
+ if isinstance(m, nn.Linear):
304
+ trunc_normal_(m.weight, std=.02)
305
+ if isinstance(m, nn.Linear) and m.bias is not None:
306
+ nn.init.constant_(m.bias, 0)
307
+ elif isinstance(m, nn.LayerNorm):
308
+ nn.init.constant_(m.bias, 0)
309
+ nn.init.constant_(m.weight, 1.0)
310
+
311
+ self.apply(_init_weights)
312
+
313
+ def get_num_layers(self):
314
+ return len(self.blocks)
315
+
316
+ @torch.jit.ignore
317
+ def no_weight_decay(self):
318
+ return {'pos_embed', 'cls_token'}
319
+
320
+ def forward_features(self, x):
321
+ B, C, H, W = x.shape
322
+ x, (Hp, Wp) = self.patch_embed(x)
323
+
324
+ if self.pos_embed is not None:
325
+ # fit for multiple GPU training
326
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
327
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
328
+
329
+ for blk in self.blocks:
330
+ if self.use_checkpoint:
331
+ x = checkpoint.checkpoint(blk, x)
332
+ else:
333
+ x = blk(x)
334
+
335
+ x = self.last_norm(x)
336
+
337
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
338
+
339
+ return xp
340
+
341
+ def forward(self, x):
342
+ x = self.forward_features(x)
343
+ return x
344
+
345
+ def train(self, mode=True):
346
+ """Convert the model into training mode."""
347
+ super().train(mode)
348
+ self._freeze_stages()
hmr2/models/backbones/vit_vitpose.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import mmcv
2
+ # import mmpose
3
+ # from mmpose.models import build_posenet
4
+ # from mmcv.runner import load_checkpoint
5
+ # from pathlib import Path
6
+
7
+ # def vit(cfg):
8
+ # vitpose_dir = Path(mmpose.__file__).parent.parent
9
+ # config = f'{vitpose_dir}/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/ViTPose_huge_coco_256x192.py'
10
+ # # checkpoint = f'{vitpose_dir}/models/vitpose-h-multi-coco.pth'
11
+
12
+ # config = mmcv.Config.fromfile(config)
13
+ # config.model.pretrained = None
14
+ # model = build_posenet(config.model)
15
+ # # load_checkpoint(model, checkpoint, map_location='cpu')
16
+
17
+ # return model.backbone
hmr2/models/components/__init__.py ADDED
File without changes
hmr2/models/components/pose_transformer.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ from typing import Callable, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+ from torch import nn
8
+
9
+ from .t_cond_mlp import (
10
+ AdaptiveLayerNorm1D,
11
+ FrequencyEmbedder,
12
+ normalization_layer,
13
+ )
14
+ # from .vit import Attention, FeedForward
15
+
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+
21
+ def default(val, d):
22
+ if exists(val):
23
+ return val
24
+ return d() if isfunction(d) else d
25
+
26
+
27
+ class PreNorm(nn.Module):
28
+ def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
29
+ super().__init__()
30
+ self.norm = normalization_layer(norm, dim, norm_cond_dim)
31
+ self.fn = fn
32
+
33
+ def forward(self, x: torch.Tensor, *args, **kwargs):
34
+ if isinstance(self.norm, AdaptiveLayerNorm1D):
35
+ return self.fn(self.norm(x, *args), **kwargs)
36
+ else:
37
+ return self.fn(self.norm(x), **kwargs)
38
+
39
+
40
+ class FeedForward(nn.Module):
41
+ def __init__(self, dim, hidden_dim, dropout=0.0):
42
+ super().__init__()
43
+ self.net = nn.Sequential(
44
+ nn.Linear(dim, hidden_dim),
45
+ nn.GELU(),
46
+ nn.Dropout(dropout),
47
+ nn.Linear(hidden_dim, dim),
48
+ nn.Dropout(dropout),
49
+ )
50
+
51
+ def forward(self, x):
52
+ return self.net(x)
53
+
54
+
55
+ class Attention(nn.Module):
56
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
57
+ super().__init__()
58
+ inner_dim = dim_head * heads
59
+ project_out = not (heads == 1 and dim_head == dim)
60
+
61
+ self.heads = heads
62
+ self.scale = dim_head**-0.5
63
+
64
+ self.attend = nn.Softmax(dim=-1)
65
+ self.dropout = nn.Dropout(dropout)
66
+
67
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
68
+
69
+ self.to_out = (
70
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
71
+ if project_out
72
+ else nn.Identity()
73
+ )
74
+
75
+ def forward(self, x):
76
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
77
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
78
+
79
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
80
+
81
+ attn = self.attend(dots)
82
+ attn = self.dropout(attn)
83
+
84
+ out = torch.matmul(attn, v)
85
+ out = rearrange(out, "b h n d -> b n (h d)")
86
+ return self.to_out(out)
87
+
88
+
89
+ class CrossAttention(nn.Module):
90
+ def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
91
+ super().__init__()
92
+ inner_dim = dim_head * heads
93
+ project_out = not (heads == 1 and dim_head == dim)
94
+
95
+ self.heads = heads
96
+ self.scale = dim_head**-0.5
97
+
98
+ self.attend = nn.Softmax(dim=-1)
99
+ self.dropout = nn.Dropout(dropout)
100
+
101
+ context_dim = default(context_dim, dim)
102
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
103
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
104
+
105
+ self.to_out = (
106
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
107
+ if project_out
108
+ else nn.Identity()
109
+ )
110
+
111
+ def forward(self, x, context=None):
112
+ context = default(context, x)
113
+ k, v = self.to_kv(context).chunk(2, dim=-1)
114
+ q = self.to_q(x)
115
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
116
+
117
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
118
+
119
+ attn = self.attend(dots)
120
+ attn = self.dropout(attn)
121
+
122
+ out = torch.matmul(attn, v)
123
+ out = rearrange(out, "b h n d -> b n (h d)")
124
+ return self.to_out(out)
125
+
126
+
127
+ class Transformer(nn.Module):
128
+ def __init__(
129
+ self,
130
+ dim: int,
131
+ depth: int,
132
+ heads: int,
133
+ dim_head: int,
134
+ mlp_dim: int,
135
+ dropout: float = 0.0,
136
+ norm: str = "layer",
137
+ norm_cond_dim: int = -1,
138
+ ):
139
+ super().__init__()
140
+ self.layers = nn.ModuleList([])
141
+ for _ in range(depth):
142
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
143
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
144
+ self.layers.append(
145
+ nn.ModuleList(
146
+ [
147
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
148
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
149
+ ]
150
+ )
151
+ )
152
+
153
+ def forward(self, x: torch.Tensor, *args):
154
+ for attn, ff in self.layers:
155
+ x = attn(x, *args) + x
156
+ x = ff(x, *args) + x
157
+ return x
158
+
159
+
160
+ class TransformerCrossAttn(nn.Module):
161
+ def __init__(
162
+ self,
163
+ dim: int,
164
+ depth: int,
165
+ heads: int,
166
+ dim_head: int,
167
+ mlp_dim: int,
168
+ dropout: float = 0.0,
169
+ norm: str = "layer",
170
+ norm_cond_dim: int = -1,
171
+ context_dim: Optional[int] = None,
172
+ ):
173
+ super().__init__()
174
+ self.layers = nn.ModuleList([])
175
+ for _ in range(depth):
176
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
177
+ ca = CrossAttention(
178
+ dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
179
+ )
180
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
181
+ self.layers.append(
182
+ nn.ModuleList(
183
+ [
184
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
185
+ PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
186
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
187
+ ]
188
+ )
189
+ )
190
+
191
+ def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
192
+ if context_list is None:
193
+ context_list = [context] * len(self.layers)
194
+ if len(context_list) != len(self.layers):
195
+ raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
196
+
197
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
198
+ x = self_attn(x, *args) + x
199
+ x = cross_attn(x, *args, context=context_list[i]) + x
200
+ x = ff(x, *args) + x
201
+ return x
202
+
203
+
204
+ class DropTokenDropout(nn.Module):
205
+ def __init__(self, p: float = 0.1):
206
+ super().__init__()
207
+ if p < 0 or p > 1:
208
+ raise ValueError(
209
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
210
+ )
211
+ self.p = p
212
+
213
+ def forward(self, x: torch.Tensor):
214
+ # x: (batch_size, seq_len, dim)
215
+ if self.training and self.p > 0:
216
+ zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
217
+ # TODO: permutation idx for each batch using torch.argsort
218
+ if zero_mask.any():
219
+ x = x[:, ~zero_mask, :]
220
+ return x
221
+
222
+
223
+ class ZeroTokenDropout(nn.Module):
224
+ def __init__(self, p: float = 0.1):
225
+ super().__init__()
226
+ if p < 0 or p > 1:
227
+ raise ValueError(
228
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
229
+ )
230
+ self.p = p
231
+
232
+ def forward(self, x: torch.Tensor):
233
+ # x: (batch_size, seq_len, dim)
234
+ if self.training and self.p > 0:
235
+ zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
236
+ # Zero-out the masked tokens
237
+ x[zero_mask, :] = 0
238
+ return x
239
+
240
+
241
+ class TransformerEncoder(nn.Module):
242
+ def __init__(
243
+ self,
244
+ num_tokens: int,
245
+ token_dim: int,
246
+ dim: int,
247
+ depth: int,
248
+ heads: int,
249
+ mlp_dim: int,
250
+ dim_head: int = 64,
251
+ dropout: float = 0.0,
252
+ emb_dropout: float = 0.0,
253
+ emb_dropout_type: str = "drop",
254
+ emb_dropout_loc: str = "token",
255
+ norm: str = "layer",
256
+ norm_cond_dim: int = -1,
257
+ token_pe_numfreq: int = -1,
258
+ ):
259
+ super().__init__()
260
+ if token_pe_numfreq > 0:
261
+ token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
262
+ self.to_token_embedding = nn.Sequential(
263
+ Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
264
+ FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
265
+ Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
266
+ nn.Linear(token_dim_new, dim),
267
+ )
268
+ else:
269
+ self.to_token_embedding = nn.Linear(token_dim, dim)
270
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
271
+ if emb_dropout_type == "drop":
272
+ self.dropout = DropTokenDropout(emb_dropout)
273
+ elif emb_dropout_type == "zero":
274
+ self.dropout = ZeroTokenDropout(emb_dropout)
275
+ else:
276
+ raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
277
+ self.emb_dropout_loc = emb_dropout_loc
278
+
279
+ self.transformer = Transformer(
280
+ dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
281
+ )
282
+
283
+ def forward(self, inp: torch.Tensor, *args, **kwargs):
284
+ x = inp
285
+
286
+ if self.emb_dropout_loc == "input":
287
+ x = self.dropout(x)
288
+ x = self.to_token_embedding(x)
289
+
290
+ if self.emb_dropout_loc == "token":
291
+ x = self.dropout(x)
292
+ b, n, _ = x.shape
293
+ x += self.pos_embedding[:, :n]
294
+
295
+ if self.emb_dropout_loc == "token_afterpos":
296
+ x = self.dropout(x)
297
+ x = self.transformer(x, *args)
298
+ return x
299
+
300
+
301
+ class TransformerDecoder(nn.Module):
302
+ def __init__(
303
+ self,
304
+ num_tokens: int,
305
+ token_dim: int,
306
+ dim: int,
307
+ depth: int,
308
+ heads: int,
309
+ mlp_dim: int,
310
+ dim_head: int = 64,
311
+ dropout: float = 0.0,
312
+ emb_dropout: float = 0.0,
313
+ emb_dropout_type: str = 'drop',
314
+ norm: str = "layer",
315
+ norm_cond_dim: int = -1,
316
+ context_dim: Optional[int] = None,
317
+ skip_token_embedding: bool = False,
318
+ ):
319
+ super().__init__()
320
+ if not skip_token_embedding:
321
+ self.to_token_embedding = nn.Linear(token_dim, dim)
322
+ else:
323
+ self.to_token_embedding = nn.Identity()
324
+ if token_dim != dim:
325
+ raise ValueError(
326
+ f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
327
+ )
328
+
329
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
330
+ if emb_dropout_type == "drop":
331
+ self.dropout = DropTokenDropout(emb_dropout)
332
+ elif emb_dropout_type == "zero":
333
+ self.dropout = ZeroTokenDropout(emb_dropout)
334
+ elif emb_dropout_type == "normal":
335
+ self.dropout = nn.Dropout(emb_dropout)
336
+
337
+ self.transformer = TransformerCrossAttn(
338
+ dim,
339
+ depth,
340
+ heads,
341
+ dim_head,
342
+ mlp_dim,
343
+ dropout,
344
+ norm=norm,
345
+ norm_cond_dim=norm_cond_dim,
346
+ context_dim=context_dim,
347
+ )
348
+
349
+ def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
350
+ x = self.to_token_embedding(inp)
351
+ b, n, _ = x.shape
352
+
353
+ x = self.dropout(x)
354
+ x += self.pos_embedding[:, :n]
355
+
356
+ x = self.transformer(x, *args, context=context, context_list=context_list)
357
+ return x
358
+
hmr2/models/components/t_cond_mlp.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+
6
+
7
+ class AdaptiveLayerNorm1D(torch.nn.Module):
8
+ def __init__(self, data_dim: int, norm_cond_dim: int):
9
+ super().__init__()
10
+ if data_dim <= 0:
11
+ raise ValueError(f"data_dim must be positive, but got {data_dim}")
12
+ if norm_cond_dim <= 0:
13
+ raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
14
+ self.norm = torch.nn.LayerNorm(
15
+ data_dim
16
+ ) # TODO: Check if elementwise_affine=True is correct
17
+ self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
18
+ torch.nn.init.zeros_(self.linear.weight)
19
+ torch.nn.init.zeros_(self.linear.bias)
20
+
21
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
22
+ # x: (batch, ..., data_dim)
23
+ # t: (batch, norm_cond_dim)
24
+ # return: (batch, data_dim)
25
+ x = self.norm(x)
26
+ alpha, beta = self.linear(t).chunk(2, dim=-1)
27
+
28
+ # Add singleton dimensions to alpha and beta
29
+ if x.dim() > 2:
30
+ alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
31
+ beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
32
+
33
+ return x * (1 + alpha) + beta
34
+
35
+
36
+ class SequentialCond(torch.nn.Sequential):
37
+ def forward(self, input, *args, **kwargs):
38
+ for module in self:
39
+ if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
40
+ # print(f'Passing on args to {module}', [a.shape for a in args])
41
+ input = module(input, *args, **kwargs)
42
+ else:
43
+ # print(f'Skipping passing args to {module}', [a.shape for a in args])
44
+ input = module(input)
45
+ return input
46
+
47
+
48
+ def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
49
+ if norm == "batch":
50
+ return torch.nn.BatchNorm1d(dim)
51
+ elif norm == "layer":
52
+ return torch.nn.LayerNorm(dim)
53
+ elif norm == "ada":
54
+ assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
55
+ return AdaptiveLayerNorm1D(dim, norm_cond_dim)
56
+ elif norm is None:
57
+ return torch.nn.Identity()
58
+ else:
59
+ raise ValueError(f"Unknown norm: {norm}")
60
+
61
+
62
+ def linear_norm_activ_dropout(
63
+ input_dim: int,
64
+ output_dim: int,
65
+ activation: torch.nn.Module = torch.nn.ReLU(),
66
+ bias: bool = True,
67
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
68
+ dropout: float = 0.0,
69
+ norm_cond_dim: int = -1,
70
+ ) -> SequentialCond:
71
+ layers = []
72
+ layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
73
+ if norm is not None:
74
+ layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
75
+ layers.append(copy.deepcopy(activation))
76
+ if dropout > 0.0:
77
+ layers.append(torch.nn.Dropout(dropout))
78
+ return SequentialCond(*layers)
79
+
80
+
81
+ def create_simple_mlp(
82
+ input_dim: int,
83
+ hidden_dims: List[int],
84
+ output_dim: int,
85
+ activation: torch.nn.Module = torch.nn.ReLU(),
86
+ bias: bool = True,
87
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
88
+ dropout: float = 0.0,
89
+ norm_cond_dim: int = -1,
90
+ ) -> SequentialCond:
91
+ layers = []
92
+ prev_dim = input_dim
93
+ for hidden_dim in hidden_dims:
94
+ layers.extend(
95
+ linear_norm_activ_dropout(
96
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
97
+ )
98
+ )
99
+ prev_dim = hidden_dim
100
+ layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
101
+ return SequentialCond(*layers)
102
+
103
+
104
+ class ResidualMLPBlock(torch.nn.Module):
105
+ def __init__(
106
+ self,
107
+ input_dim: int,
108
+ hidden_dim: int,
109
+ num_hidden_layers: int,
110
+ output_dim: int,
111
+ activation: torch.nn.Module = torch.nn.ReLU(),
112
+ bias: bool = True,
113
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
114
+ dropout: float = 0.0,
115
+ norm_cond_dim: int = -1,
116
+ ):
117
+ super().__init__()
118
+ if not (input_dim == output_dim == hidden_dim):
119
+ raise NotImplementedError(
120
+ f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
121
+ )
122
+
123
+ layers = []
124
+ prev_dim = input_dim
125
+ for i in range(num_hidden_layers):
126
+ layers.append(
127
+ linear_norm_activ_dropout(
128
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
129
+ )
130
+ )
131
+ prev_dim = hidden_dim
132
+ self.model = SequentialCond(*layers)
133
+ self.skip = torch.nn.Identity()
134
+
135
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
136
+ return x + self.model(x, *args, **kwargs)
137
+
138
+
139
+ class ResidualMLP(torch.nn.Module):
140
+ def __init__(
141
+ self,
142
+ input_dim: int,
143
+ hidden_dim: int,
144
+ num_hidden_layers: int,
145
+ output_dim: int,
146
+ activation: torch.nn.Module = torch.nn.ReLU(),
147
+ bias: bool = True,
148
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
149
+ dropout: float = 0.0,
150
+ num_blocks: int = 1,
151
+ norm_cond_dim: int = -1,
152
+ ):
153
+ super().__init__()
154
+ self.input_dim = input_dim
155
+ self.model = SequentialCond(
156
+ linear_norm_activ_dropout(
157
+ input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
158
+ ),
159
+ *[
160
+ ResidualMLPBlock(
161
+ hidden_dim,
162
+ hidden_dim,
163
+ num_hidden_layers,
164
+ hidden_dim,
165
+ activation,
166
+ bias,
167
+ norm,
168
+ dropout,
169
+ norm_cond_dim,
170
+ )
171
+ for _ in range(num_blocks)
172
+ ],
173
+ torch.nn.Linear(hidden_dim, output_dim, bias=bias),
174
+ )
175
+
176
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
177
+ return self.model(x, *args, **kwargs)
178
+
179
+
180
+ class FrequencyEmbedder(torch.nn.Module):
181
+ def __init__(self, num_frequencies, max_freq_log2):
182
+ super().__init__()
183
+ frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
184
+ self.register_buffer("frequencies", frequencies)
185
+
186
+ def forward(self, x):
187
+ # x should be of size (N,) or (N, D)
188
+ N = x.size(0)
189
+ if x.dim() == 1: # (N,)
190
+ x = x.unsqueeze(1) # (N, D) where D=1
191
+ x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
192
+ scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
193
+ s = torch.sin(scaled)
194
+ c = torch.cos(scaled)
195
+ embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
196
+ N, -1
197
+ ) # (N, D * 2 * num_frequencies + D)
198
+ return embedded
199
+
hmr2/models/discriminator.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Discriminator(nn.Module):
5
+
6
+ def __init__(self):
7
+ """
8
+ Pose + Shape discriminator proposed in HMR
9
+ """
10
+ super(Discriminator, self).__init__()
11
+
12
+ self.num_joints = 23
13
+ # poses_alone
14
+ self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1)
15
+ nn.init.xavier_uniform_(self.D_conv1.weight)
16
+ nn.init.zeros_(self.D_conv1.bias)
17
+ self.relu = nn.ReLU(inplace=True)
18
+ self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1)
19
+ nn.init.xavier_uniform_(self.D_conv2.weight)
20
+ nn.init.zeros_(self.D_conv2.bias)
21
+ pose_out = []
22
+ for i in range(self.num_joints):
23
+ pose_out_temp = nn.Linear(32, 1)
24
+ nn.init.xavier_uniform_(pose_out_temp.weight)
25
+ nn.init.zeros_(pose_out_temp.bias)
26
+ pose_out.append(pose_out_temp)
27
+ self.pose_out = nn.ModuleList(pose_out)
28
+
29
+ # betas
30
+ self.betas_fc1 = nn.Linear(10, 10)
31
+ nn.init.xavier_uniform_(self.betas_fc1.weight)
32
+ nn.init.zeros_(self.betas_fc1.bias)
33
+ self.betas_fc2 = nn.Linear(10, 5)
34
+ nn.init.xavier_uniform_(self.betas_fc2.weight)
35
+ nn.init.zeros_(self.betas_fc2.bias)
36
+ self.betas_out = nn.Linear(5, 1)
37
+ nn.init.xavier_uniform_(self.betas_out.weight)
38
+ nn.init.zeros_(self.betas_out.bias)
39
+
40
+ # poses_joint
41
+ self.D_alljoints_fc1 = nn.Linear(32*self.num_joints, 1024)
42
+ nn.init.xavier_uniform_(self.D_alljoints_fc1.weight)
43
+ nn.init.zeros_(self.D_alljoints_fc1.bias)
44
+ self.D_alljoints_fc2 = nn.Linear(1024, 1024)
45
+ nn.init.xavier_uniform_(self.D_alljoints_fc2.weight)
46
+ nn.init.zeros_(self.D_alljoints_fc2.bias)
47
+ self.D_alljoints_out = nn.Linear(1024, 1)
48
+ nn.init.xavier_uniform_(self.D_alljoints_out.weight)
49
+ nn.init.zeros_(self.D_alljoints_out.bias)
50
+
51
+
52
+ def forward(self, poses: torch.Tensor, betas: torch.Tensor) -> torch.Tensor:
53
+ """
54
+ Forward pass of the discriminator.
55
+ Args:
56
+ poses (torch.Tensor): Tensor of shape (B, 23, 3, 3) containing a batch of SMPL body poses (excluding the global orientation).
57
+ betas (torch.Tensor): Tensor of shape (B, 10) containign a batch of SMPL beta coefficients.
58
+ Returns:
59
+ torch.Tensor: Discriminator output with shape (B, 25)
60
+ """
61
+ #import ipdb; ipdb.set_trace()
62
+ #bn = poses.shape[0]
63
+ # poses B x 207
64
+ #poses = poses.reshape(bn, -1)
65
+ # poses B x num_joints x 1 x 9
66
+ poses = poses.reshape(-1, self.num_joints, 1, 9)
67
+ bn = poses.shape[0]
68
+ # poses B x 9 x num_joints x 1
69
+ poses = poses.permute(0, 3, 1, 2).contiguous()
70
+
71
+ # poses_alone
72
+ poses = self.D_conv1(poses)
73
+ poses = self.relu(poses)
74
+ poses = self.D_conv2(poses)
75
+ poses = self.relu(poses)
76
+
77
+ poses_out = []
78
+ for i in range(self.num_joints):
79
+ poses_out_ = self.pose_out[i](poses[:, :, i, 0])
80
+ poses_out.append(poses_out_)
81
+ poses_out = torch.cat(poses_out, dim=1)
82
+
83
+ # betas
84
+ betas = self.betas_fc1(betas)
85
+ betas = self.relu(betas)
86
+ betas = self.betas_fc2(betas)
87
+ betas = self.relu(betas)
88
+ betas_out = self.betas_out(betas)
89
+
90
+ # poses_joint
91
+ poses = poses.reshape(bn,-1)
92
+ poses_all = self.D_alljoints_fc1(poses)
93
+ poses_all = self.relu(poses_all)
94
+ poses_all = self.D_alljoints_fc2(poses_all)
95
+ poses_all = self.relu(poses_all)
96
+ poses_all_out = self.D_alljoints_out(poses_all)
97
+
98
+ disc_out = torch.cat((poses_out, betas_out, poses_all_out), 1)
99
+ return disc_out
hmr2/models/heads/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .smpl_head import build_smpl_head
hmr2/models/heads/smpl_head.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import einops
6
+
7
+ from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat
8
+ from ..components.pose_transformer import TransformerDecoder
9
+
10
+ def build_smpl_head(cfg):
11
+ smpl_head_type = cfg.MODEL.SMPL_HEAD.get('TYPE', 'hmr')
12
+ if smpl_head_type == 'transformer_decoder':
13
+ return SMPLTransformerDecoderHead(cfg)
14
+ else:
15
+ raise ValueError('Unknown SMPL head type: {}'.format(smpl_head_type))
16
+
17
+ class SMPLTransformerDecoderHead(nn.Module):
18
+ """ Cross-attention based SMPL Transformer decoder
19
+ """
20
+
21
+ def __init__(self, cfg):
22
+ super().__init__()
23
+ self.cfg = cfg
24
+ self.joint_rep_type = cfg.MODEL.SMPL_HEAD.get('JOINT_REP', '6d')
25
+ self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
26
+ npose = self.joint_rep_dim * (cfg.SMPL.NUM_BODY_JOINTS + 1)
27
+ self.npose = npose
28
+ self.input_is_mean_shape = cfg.MODEL.SMPL_HEAD.get('TRANSFORMER_INPUT', 'zero') == 'mean_shape'
29
+ transformer_args = dict(
30
+ num_tokens=1,
31
+ token_dim=(npose + 10 + 3) if self.input_is_mean_shape else 1,
32
+ dim=1024,
33
+ )
34
+ transformer_args = (transformer_args | dict(cfg.MODEL.SMPL_HEAD.TRANSFORMER_DECODER))
35
+ self.transformer = TransformerDecoder(
36
+ **transformer_args
37
+ )
38
+ dim=transformer_args['dim']
39
+ self.decpose = nn.Linear(dim, npose)
40
+ self.decshape = nn.Linear(dim, 10)
41
+ self.deccam = nn.Linear(dim, 3)
42
+
43
+ if cfg.MODEL.SMPL_HEAD.get('INIT_DECODER_XAVIER', False):
44
+ # True by default in MLP. False by default in Transformer
45
+ nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
46
+ nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
47
+ nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
48
+
49
+ mean_params = np.load(cfg.SMPL.MEAN_PARAMS)
50
+ init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
51
+ init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
52
+ init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
53
+ self.register_buffer('init_body_pose', init_body_pose)
54
+ self.register_buffer('init_betas', init_betas)
55
+ self.register_buffer('init_cam', init_cam)
56
+
57
+ def forward(self, x, **kwargs):
58
+
59
+ batch_size = x.shape[0]
60
+ # vit pretrained backbone is channel-first. Change to token-first
61
+ x = einops.rearrange(x, 'b c h w -> b (h w) c')
62
+
63
+ init_body_pose = self.init_body_pose.expand(batch_size, -1)
64
+ init_betas = self.init_betas.expand(batch_size, -1)
65
+ init_cam = self.init_cam.expand(batch_size, -1)
66
+
67
+ # TODO: Convert init_body_pose to aa rep if needed
68
+ if self.joint_rep_type == 'aa':
69
+ raise NotImplementedError
70
+
71
+ pred_body_pose = init_body_pose
72
+ pred_betas = init_betas
73
+ pred_cam = init_cam
74
+ pred_body_pose_list = []
75
+ pred_betas_list = []
76
+ pred_cam_list = []
77
+ for i in range(self.cfg.MODEL.SMPL_HEAD.get('IEF_ITERS', 1)):
78
+ # Input token to transformer is zero token
79
+ if self.input_is_mean_shape:
80
+ token = torch.cat([pred_body_pose, pred_betas, pred_cam], dim=1)[:,None,:]
81
+ else:
82
+ token = torch.zeros(batch_size, 1, 1).to(x.device)
83
+
84
+ # Pass through transformer
85
+ token_out = self.transformer(token, context=x)
86
+ token_out = token_out.squeeze(1) # (B, C)
87
+
88
+ # Readout from token_out
89
+ pred_body_pose = self.decpose(token_out) + pred_body_pose
90
+ pred_betas = self.decshape(token_out) + pred_betas
91
+ pred_cam = self.deccam(token_out) + pred_cam
92
+ pred_body_pose_list.append(pred_body_pose)
93
+ pred_betas_list.append(pred_betas)
94
+ pred_cam_list.append(pred_cam)
95
+
96
+ # Convert self.joint_rep_type -> rotmat
97
+ joint_conversion_fn = {
98
+ '6d': rot6d_to_rotmat,
99
+ 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())
100
+ }[self.joint_rep_type]
101
+
102
+ pred_smpl_params_list = {}
103
+ pred_smpl_params_list['body_pose'] = torch.cat([joint_conversion_fn(pbp).view(batch_size, -1, 3, 3)[:, 1:, :, :] for pbp in pred_body_pose_list], dim=0)
104
+ pred_smpl_params_list['betas'] = torch.cat(pred_betas_list, dim=0)
105
+ pred_smpl_params_list['cam'] = torch.cat(pred_cam_list, dim=0)
106
+ pred_body_pose = joint_conversion_fn(pred_body_pose).view(batch_size, self.cfg.SMPL.NUM_BODY_JOINTS+1, 3, 3)
107
+
108
+ pred_smpl_params = {'global_orient': pred_body_pose[:, [0]],
109
+ 'body_pose': pred_body_pose[:, 1:],
110
+ 'betas': pred_betas}
111
+ return pred_smpl_params, pred_cam, pred_smpl_params_list
hmr2/models/hmr2.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ from typing import Any, Dict, Mapping, Tuple
4
+
5
+ from yacs.config import CfgNode
6
+
7
+ from ..utils import SkeletonRenderer, MeshRenderer
8
+ from ..utils.geometry import aa_to_rotmat, perspective_projection
9
+ from .backbones import create_backbone
10
+ from .heads import build_smpl_head
11
+ from .discriminator import Discriminator
12
+ from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss
13
+ from . import SMPL
14
+
15
+
16
+ class HMR2(pl.LightningModule):
17
+
18
+ def __init__(self, cfg: CfgNode, init_renderer: bool = True):
19
+ """
20
+ Setup HMR2 model
21
+ Args:
22
+ cfg (CfgNode): Config file as a yacs CfgNode
23
+ """
24
+ super().__init__()
25
+
26
+ # Save hyperparameters
27
+ self.save_hyperparameters(logger=False, ignore=['init_renderer'])
28
+
29
+ self.cfg = cfg
30
+ # Create backbone feature extractor
31
+ self.backbone = create_backbone(cfg)
32
+
33
+ # Create SMPL head
34
+ self.smpl_head = build_smpl_head(cfg)
35
+
36
+ # Create discriminator
37
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
38
+ self.discriminator = Discriminator()
39
+
40
+ # Define loss functions
41
+ self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1')
42
+ self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1')
43
+ self.smpl_parameter_loss = ParameterLoss()
44
+
45
+ # Instantiate SMPL model
46
+ smpl_cfg = {k.lower(): v for k,v in dict(cfg.SMPL).items()}
47
+ self.smpl = SMPL(**smpl_cfg)
48
+
49
+ # Buffer that shows whetheer we need to initialize ActNorm layers
50
+ self.register_buffer('initialized', torch.tensor(False))
51
+ # Setup renderer for visualization
52
+ if init_renderer:
53
+ self.renderer = SkeletonRenderer(self.cfg)
54
+ self.mesh_renderer = MeshRenderer(self.cfg, faces=self.smpl.faces)
55
+ else:
56
+ self.renderer = None
57
+ self.mesh_renderer = None
58
+
59
+ # Disable automatic optimization since we use adversarial training
60
+ self.automatic_optimization = False
61
+
62
+ def get_parameters(self):
63
+ all_params = list(self.smpl_head.parameters())
64
+ all_params += list(self.backbone.parameters())
65
+ return all_params
66
+
67
+ def configure_optimizers(self) -> Tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
68
+ """
69
+ Setup model and distriminator Optimizers
70
+ Returns:
71
+ Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers
72
+ """
73
+ param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}]
74
+
75
+ optimizer = torch.optim.AdamW(params=param_groups,
76
+ # lr=self.cfg.TRAIN.LR,
77
+ weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
78
+ optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(),
79
+ lr=self.cfg.TRAIN.LR,
80
+ weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
81
+
82
+ return optimizer, optimizer_disc
83
+
84
+ def forward_step(self, batch: Dict, train: bool = False) -> Dict:
85
+ """
86
+ Run a forward step of the network
87
+ Args:
88
+ batch (Dict): Dictionary containing batch data
89
+ train (bool): Flag indicating whether it is training or validation mode
90
+ Returns:
91
+ Dict: Dictionary containing the regression output
92
+ """
93
+
94
+ # Use RGB image as input
95
+ x = batch['img']
96
+ batch_size = x.shape[0]
97
+
98
+ # Compute conditioning features using the backbone
99
+ # if using ViT backbone, we need to use a different aspect ratio
100
+ conditioning_feats = self.backbone(x[:,:,:,32:-32])
101
+
102
+ pred_smpl_params, pred_cam, _ = self.smpl_head(conditioning_feats)
103
+
104
+ # Store useful regression outputs to the output dict
105
+ output = {}
106
+ output['pred_cam'] = pred_cam
107
+ output['pred_smpl_params'] = {k: v.clone() for k,v in pred_smpl_params.items()}
108
+
109
+ # Compute camera translation
110
+ device = pred_smpl_params['body_pose'].device
111
+ dtype = pred_smpl_params['body_pose'].dtype
112
+ focal_length = self.cfg.EXTRA.FOCAL_LENGTH * torch.ones(batch_size, 2, device=device, dtype=dtype)
113
+ pred_cam_t = torch.stack([pred_cam[:, 1],
114
+ pred_cam[:, 2],
115
+ 2*focal_length[:, 0]/(self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] +1e-9)],dim=-1)
116
+ output['pred_cam_t'] = pred_cam_t
117
+ output['focal_length'] = focal_length
118
+
119
+ # Compute model vertices, joints and the projected joints
120
+ pred_smpl_params['global_orient'] = pred_smpl_params['global_orient'].reshape(batch_size, -1, 3, 3)
121
+ pred_smpl_params['body_pose'] = pred_smpl_params['body_pose'].reshape(batch_size, -1, 3, 3)
122
+ pred_smpl_params['betas'] = pred_smpl_params['betas'].reshape(batch_size, -1)
123
+ smpl_output = self.smpl(**{k: v.float() for k,v in pred_smpl_params.items()}, pose2rot=False)
124
+ pred_keypoints_3d = smpl_output.joints
125
+ pred_vertices = smpl_output.vertices
126
+ output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
127
+ output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
128
+ pred_cam_t = pred_cam_t.reshape(-1, 3)
129
+ focal_length = focal_length.reshape(-1, 2)
130
+ pred_keypoints_2d = perspective_projection(pred_keypoints_3d,
131
+ translation=pred_cam_t,
132
+ focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE)
133
+
134
+ output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2)
135
+ return output
136
+
137
+ def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor:
138
+ """
139
+ Compute losses given the input batch and the regression output
140
+ Args:
141
+ batch (Dict): Dictionary containing batch data
142
+ output (Dict): Dictionary containing the regression output
143
+ train (bool): Flag indicating whether it is training or validation mode
144
+ Returns:
145
+ torch.Tensor : Total loss for current batch
146
+ """
147
+
148
+ pred_smpl_params = output['pred_smpl_params']
149
+ pred_keypoints_2d = output['pred_keypoints_2d']
150
+ pred_keypoints_3d = output['pred_keypoints_3d']
151
+
152
+
153
+ batch_size = pred_smpl_params['body_pose'].shape[0]
154
+ device = pred_smpl_params['body_pose'].device
155
+ dtype = pred_smpl_params['body_pose'].dtype
156
+
157
+ # Get annotations
158
+ gt_keypoints_2d = batch['keypoints_2d']
159
+ gt_keypoints_3d = batch['keypoints_3d']
160
+ gt_smpl_params = batch['smpl_params']
161
+ has_smpl_params = batch['has_smpl_params']
162
+ is_axis_angle = batch['smpl_params_is_axis_angle']
163
+
164
+ # Compute 3D keypoint loss
165
+ loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
166
+ loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=25+14)
167
+
168
+ # Compute loss on SMPL parameters
169
+ loss_smpl_params = {}
170
+ for k, pred in pred_smpl_params.items():
171
+ gt = gt_smpl_params[k].view(batch_size, -1)
172
+ if is_axis_angle[k].all():
173
+ gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3)
174
+ has_gt = has_smpl_params[k]
175
+ loss_smpl_params[k] = self.smpl_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1), has_gt)
176
+
177
+ # # Filter out images with corresponding SMPL parameter annotations
178
+ # smpl_params = {k: v.clone() for k,v in gt_smpl_params.items()}
179
+ # smpl_params['body_pose'] = aa_to_rotmat(smpl_params['body_pose'].reshape(-1, 3)).reshape(batch_size, -1, 3, 3)[:, :, :, :2].permute(0, 1, 3, 2).reshape(batch_size, -1)
180
+ # smpl_params['global_orient'] = aa_to_rotmat(smpl_params['global_orient'].reshape(-1, 3)).reshape(batch_size, -1, 3, 3)[:, :, :, :2].permute(0, 1, 3, 2).reshape(batch_size, -1)
181
+ # smpl_params['betas'] = smpl_params['betas']
182
+ # has_smpl_params = (batch['has_smpl_params']['body_pose'] > 0)
183
+ # smpl_params = {k: v[has_smpl_params] for k, v in smpl_params.items()}
184
+
185
+ loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\
186
+ self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\
187
+ sum([loss_smpl_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_smpl_params])
188
+
189
+ losses = dict(loss=loss.detach(),
190
+ loss_keypoints_2d=loss_keypoints_2d.detach(),
191
+ loss_keypoints_3d=loss_keypoints_3d.detach())
192
+
193
+ for k, v in loss_smpl_params.items():
194
+ losses['loss_' + k] = v.detach()
195
+
196
+ output['losses'] = losses
197
+
198
+ return loss
199
+
200
+ # Tensoroboard logging should run from first rank only
201
+ @pl.utilities.rank_zero.rank_zero_only
202
+ def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True) -> None:
203
+ """
204
+ Log results to Tensorboard
205
+ Args:
206
+ batch (Dict): Dictionary containing batch data
207
+ output (Dict): Dictionary containing the regression output
208
+ step_count (int): Global training step count
209
+ train (bool): Flag indicating whether it is training or validation mode
210
+ """
211
+
212
+ mode = 'train' if train else 'val'
213
+ batch_size = batch['keypoints_2d'].shape[0]
214
+ images = batch['img']
215
+ images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1)
216
+ images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1)
217
+ #images = 255*images.permute(0, 2, 3, 1).cpu().numpy()
218
+
219
+ pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3)
220
+ pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3)
221
+ focal_length = output['focal_length'].detach().reshape(batch_size, 2)
222
+ gt_keypoints_3d = batch['keypoints_3d']
223
+ gt_keypoints_2d = batch['keypoints_2d']
224
+ losses = output['losses']
225
+ pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3)
226
+ pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2)
227
+
228
+ if write_to_summary_writer:
229
+ summary_writer = self.logger.experiment
230
+ for loss_name, val in losses.items():
231
+ summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count)
232
+ num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES)
233
+
234
+ gt_keypoints_3d = batch['keypoints_3d']
235
+ pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3)
236
+
237
+ # We render the skeletons instead of the full mesh because rendering a lot of meshes will make the training slow.
238
+ #predictions = self.renderer(pred_keypoints_3d[:num_images],
239
+ # gt_keypoints_3d[:num_images],
240
+ # 2 * gt_keypoints_2d[:num_images],
241
+ # images=images[:num_images],
242
+ # camera_translation=pred_cam_t[:num_images])
243
+ predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(),
244
+ pred_cam_t[:num_images].cpu().numpy(),
245
+ images[:num_images].cpu().numpy(),
246
+ pred_keypoints_2d[:num_images].cpu().numpy(),
247
+ gt_keypoints_2d[:num_images].cpu().numpy(),
248
+ focal_length=focal_length[:num_images].cpu().numpy())
249
+ if write_to_summary_writer:
250
+ summary_writer.add_image('%s/predictions' % mode, predictions, step_count)
251
+
252
+ return predictions
253
+
254
+ def forward(self, batch: Dict) -> Dict:
255
+ """
256
+ Run a forward step of the network in val mode
257
+ Args:
258
+ batch (Dict): Dictionary containing batch data
259
+ Returns:
260
+ Dict: Dictionary containing the regression output
261
+ """
262
+ return self.forward_step(batch, train=False)
263
+
264
+ def training_step_discriminator(self, batch: Dict,
265
+ body_pose: torch.Tensor,
266
+ betas: torch.Tensor,
267
+ optimizer: torch.optim.Optimizer) -> torch.Tensor:
268
+ """
269
+ Run a discriminator training step
270
+ Args:
271
+ batch (Dict): Dictionary containing mocap batch data
272
+ body_pose (torch.Tensor): Regressed body pose from current step
273
+ betas (torch.Tensor): Regressed betas from current step
274
+ optimizer (torch.optim.Optimizer): Discriminator optimizer
275
+ Returns:
276
+ torch.Tensor: Discriminator loss
277
+ """
278
+ batch_size = body_pose.shape[0]
279
+ gt_body_pose = batch['body_pose']
280
+ gt_betas = batch['betas']
281
+ gt_rotmat = aa_to_rotmat(gt_body_pose.view(-1,3)).view(batch_size, -1, 3, 3)
282
+ disc_fake_out = self.discriminator(body_pose.detach(), betas.detach())
283
+ loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size
284
+ disc_real_out = self.discriminator(gt_rotmat, gt_betas)
285
+ loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size
286
+ loss_disc = loss_fake + loss_real
287
+ loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc
288
+ optimizer.zero_grad()
289
+ self.manual_backward(loss)
290
+ optimizer.step()
291
+ return loss_disc.detach()
292
+
293
+ def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict:
294
+ """
295
+ Run a full training step
296
+ Args:
297
+ joint_batch (Dict): Dictionary containing image and mocap batch data
298
+ batch_idx (int): Unused.
299
+ batch_idx (torch.Tensor): Unused.
300
+ Returns:
301
+ Dict: Dictionary containing regression output.
302
+ """
303
+ batch = joint_batch['img']
304
+ mocap_batch = joint_batch['mocap']
305
+ optimizer = self.optimizers(use_pl_optimizer=True)
306
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
307
+ optimizer, optimizer_disc = optimizer
308
+
309
+ # Update learning rates
310
+ self.update_learning_rates(batch_idx)
311
+
312
+ batch_size = batch['img'].shape[0]
313
+ output = self.forward_step(batch, train=True)
314
+ pred_smpl_params = output['pred_smpl_params']
315
+ if self.cfg.get('UPDATE_GT_SPIN', False):
316
+ self.update_batch_gt_spin(batch, output)
317
+ loss = self.compute_loss(batch, output, train=True)
318
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
319
+ disc_out = self.discriminator(pred_smpl_params['body_pose'].reshape(batch_size, -1), pred_smpl_params['betas'].reshape(batch_size, -1))
320
+ loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size
321
+ loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv
322
+
323
+ # Error if Nan
324
+ if torch.isnan(loss):
325
+ raise ValueError('Loss is NaN')
326
+
327
+ optimizer.zero_grad()
328
+ self.manual_backward(loss)
329
+ # Clip gradient
330
+ if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0:
331
+ gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True)
332
+ self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True)
333
+ optimizer.step()
334
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
335
+ loss_disc = self.training_step_discriminator(mocap_batch, pred_smpl_params['body_pose'].reshape(batch_size, -1), pred_smpl_params['betas'].reshape(batch_size, -1), optimizer_disc)
336
+ output['losses']['loss_gen'] = loss_adv
337
+ output['losses']['loss_disc'] = loss_disc
338
+
339
+ if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0:
340
+ self.tensorboard_logging(batch, output, self.global_step, train=True)
341
+
342
+ self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False)
343
+
344
+ return output
345
+
346
+ def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict:
347
+ """
348
+ Run a validation step and log to Tensorboard
349
+ Args:
350
+ batch (Dict): Dictionary containing batch data
351
+ batch_idx (int): Unused.
352
+ Returns:
353
+ Dict: Dictionary containing regression output.
354
+ """
355
+ # batch_size = batch['img'].shape[0]
356
+ output = self.forward_step(batch, train=False)
357
+
358
+ pred_smpl_params = output['pred_smpl_params']
359
+ loss = self.compute_loss(batch, output, train=False)
360
+ output['loss'] = loss
361
+ self.tensorboard_logging(batch, output, self.global_step, train=False)
362
+
363
+ return output
hmr2/models/losses.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Keypoint2DLoss(nn.Module):
5
+
6
+ def __init__(self, loss_type: str = 'l1'):
7
+ """
8
+ 2D keypoint loss module.
9
+ Args:
10
+ loss_type (str): Choose between l1 and l2 losses.
11
+ """
12
+ super(Keypoint2DLoss, self).__init__()
13
+ if loss_type == 'l1':
14
+ self.loss_fn = nn.L1Loss(reduction='none')
15
+ elif loss_type == 'l2':
16
+ self.loss_fn = nn.MSELoss(reduction='none')
17
+ else:
18
+ raise NotImplementedError('Unsupported loss function')
19
+
20
+ def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor:
21
+ """
22
+ Compute 2D reprojection loss on the keypoints.
23
+ Args:
24
+ pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
25
+ gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence.
26
+ Returns:
27
+ torch.Tensor: 2D keypoint loss.
28
+ """
29
+ conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
30
+ batch_size = conf.shape[0]
31
+ loss = (conf * self.loss_fn(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).sum(dim=(1,2))
32
+ return loss.sum()
33
+
34
+
35
+ class Keypoint3DLoss(nn.Module):
36
+
37
+ def __init__(self, loss_type: str = 'l1'):
38
+ """
39
+ 3D keypoint loss module.
40
+ Args:
41
+ loss_type (str): Choose between l1 and l2 losses.
42
+ """
43
+ super(Keypoint3DLoss, self).__init__()
44
+ if loss_type == 'l1':
45
+ self.loss_fn = nn.L1Loss(reduction='none')
46
+ elif loss_type == 'l2':
47
+ self.loss_fn = nn.MSELoss(reduction='none')
48
+ else:
49
+ raise NotImplementedError('Unsupported loss function')
50
+
51
+ def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor, pelvis_id: int = 39):
52
+ """
53
+ Compute 3D keypoint loss.
54
+ Args:
55
+ pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
56
+ gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence.
57
+ Returns:
58
+ torch.Tensor: 3D keypoint loss.
59
+ """
60
+ batch_size = pred_keypoints_3d.shape[0]
61
+ gt_keypoints_3d = gt_keypoints_3d.clone()
62
+ pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1)
63
+ gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, pelvis_id, :-1].unsqueeze(dim=1)
64
+ conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
65
+ gt_keypoints_3d = gt_keypoints_3d[:, :, :-1]
66
+ loss = (conf * self.loss_fn(pred_keypoints_3d, gt_keypoints_3d)).sum(dim=(1,2))
67
+ return loss.sum()
68
+
69
+ class ParameterLoss(nn.Module):
70
+
71
+ def __init__(self):
72
+ """
73
+ SMPL parameter loss module.
74
+ """
75
+ super(ParameterLoss, self).__init__()
76
+ self.loss_fn = nn.MSELoss(reduction='none')
77
+
78
+ def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor):
79
+ """
80
+ Compute SMPL parameter loss.
81
+ Args:
82
+ pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas)
83
+ gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth SMPL parameters.
84
+ Returns:
85
+ torch.Tensor: L2 parameter loss loss.
86
+ """
87
+ batch_size = pred_param.shape[0]
88
+ num_dims = len(pred_param.shape)
89
+ mask_dimension = [batch_size] + [1] * (num_dims-1)
90
+ has_param = has_param.type(pred_param.type()).view(*mask_dimension)
91
+ loss_param = (has_param * self.loss_fn(pred_param, gt_param))
92
+ return loss_param.sum()
hmr2/models/smpl_wrapper.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pickle
4
+ from typing import Optional
5
+ import smplx
6
+ from smplx.lbs import vertices2joints
7
+ from smplx.utils import SMPLOutput
8
+
9
+
10
+ class SMPL(smplx.SMPLLayer):
11
+ def __init__(self, *args, joint_regressor_extra: Optional[str] = None, update_hips: bool = False, **kwargs):
12
+ """
13
+ Extension of the official SMPL implementation to support more joints.
14
+ Args:
15
+ Same as SMPLLayer.
16
+ joint_regressor_extra (str): Path to extra joint regressor.
17
+ """
18
+ super(SMPL, self).__init__(*args, **kwargs)
19
+ smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4,
20
+ 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]
21
+
22
+ if joint_regressor_extra is not None:
23
+ self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
24
+ self.register_buffer('joint_map', torch.tensor(smpl_to_openpose, dtype=torch.long))
25
+ self.update_hips = update_hips
26
+
27
+ def forward(self, *args, **kwargs) -> SMPLOutput:
28
+ """
29
+ Run forward pass. Same as SMPL and also append an extra set of joints if joint_regressor_extra is specified.
30
+ """
31
+ smpl_output = super(SMPL, self).forward(*args, **kwargs)
32
+ joints = smpl_output.joints[:, self.joint_map, :]
33
+ if self.update_hips:
34
+ joints[:,[9,12]] = joints[:,[9,12]] + \
35
+ 0.25*(joints[:,[9,12]]-joints[:,[12,9]]) + \
36
+ 0.5*(joints[:,[8]] - 0.5*(joints[:,[9,12]] + joints[:,[12,9]]))
37
+ if hasattr(self, 'joint_regressor_extra'):
38
+ extra_joints = vertices2joints(self.joint_regressor_extra, smpl_output.vertices)
39
+ joints = torch.cat([joints, extra_joints], dim=1)
40
+ smpl_output.joints = joints
41
+ return smpl_output
hmr2/utils/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any
3
+
4
+ from .renderer import Renderer
5
+ from .mesh_renderer import MeshRenderer
6
+ from .skeleton_renderer import SkeletonRenderer
7
+ from .pose_utils import eval_pose, Evaluator
8
+
9
+ def recursive_to(x: Any, target: torch.device):
10
+ """
11
+ Recursively transfer a batch of data to the target device
12
+ Args:
13
+ x (Any): Batch of data.
14
+ target (torch.device): Target device.
15
+ Returns:
16
+ Batch of data where all tensors are transfered to the target device.
17
+ """
18
+ if isinstance(x, dict):
19
+ return {k: recursive_to(v, target) for k, v in x.items()}
20
+ elif isinstance(x, torch.Tensor):
21
+ return x.to(target)
22
+ elif isinstance(x, list):
23
+ return [recursive_to(i, target) for i in x]
24
+ else:
25
+ return x
hmr2/utils/geometry.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+ def aa_to_rotmat(theta: torch.Tensor):
6
+ """
7
+ Convert axis-angle representation to rotation matrix.
8
+ Works by first converting it to a quaternion.
9
+ Args:
10
+ theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
11
+ Returns:
12
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
13
+ """
14
+ norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
15
+ angle = torch.unsqueeze(norm, -1)
16
+ normalized = torch.div(theta, angle)
17
+ angle = angle * 0.5
18
+ v_cos = torch.cos(angle)
19
+ v_sin = torch.sin(angle)
20
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
21
+ return quat_to_rotmat(quat)
22
+
23
+ def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
24
+ """
25
+ Convert quaternion representation to rotation matrix.
26
+ Args:
27
+ quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
28
+ Returns:
29
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
30
+ """
31
+ norm_quat = quat
32
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
33
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
34
+
35
+ B = quat.size(0)
36
+
37
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
38
+ wx, wy, wz = w*x, w*y, w*z
39
+ xy, xz, yz = x*y, x*z, y*z
40
+
41
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
42
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
43
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
44
+ return rotMat
45
+
46
+
47
+ def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ Convert 6D rotation representation to 3x3 rotation matrix.
50
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
51
+ Args:
52
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
53
+ Returns:
54
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
55
+ """
56
+ x = x.reshape(-1,2,3).permute(0, 2, 1).contiguous()
57
+ a1 = x[:, :, 0]
58
+ a2 = x[:, :, 1]
59
+ b1 = F.normalize(a1)
60
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
61
+ b3 = torch.cross(b1, b2)
62
+ return torch.stack((b1, b2, b3), dim=-1)
63
+
64
+ def perspective_projection(points: torch.Tensor,
65
+ translation: torch.Tensor,
66
+ focal_length: torch.Tensor,
67
+ camera_center: Optional[torch.Tensor] = None,
68
+ rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
69
+ """
70
+ Computes the perspective projection of a set of 3D points.
71
+ Args:
72
+ points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
73
+ translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
74
+ focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
75
+ camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
76
+ rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
77
+ Returns:
78
+ torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
79
+ """
80
+ batch_size = points.shape[0]
81
+ if rotation is None:
82
+ rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
83
+ if camera_center is None:
84
+ camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
85
+ # Populate intrinsic camera matrix K.
86
+ K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
87
+ K[:,0,0] = focal_length[:,0]
88
+ K[:,1,1] = focal_length[:,1]
89
+ K[:,2,2] = 1.
90
+ K[:,:-1, -1] = camera_center
91
+
92
+ # Transform points
93
+ points = torch.einsum('bij,bkj->bki', rotation, points)
94
+ points = points + translation.unsqueeze(1)
95
+
96
+ # Apply perspective distortion
97
+ projected_points = points / points[:,:,-1].unsqueeze(-1)
98
+
99
+ # Apply camera intrinsics
100
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
101
+
102
+ return projected_points[:, :, :-1]
hmr2/utils/mesh_renderer.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ if 'PYOPENGL_PLATFORM' not in os.environ:
3
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
4
+ import torch
5
+ from torchvision.utils import make_grid
6
+ import numpy as np
7
+ import pyrender
8
+ import trimesh
9
+ import cv2
10
+ import torch.nn.functional as F
11
+
12
+ from .render_openpose import render_openpose
13
+
14
+ def create_raymond_lights():
15
+ import pyrender
16
+ thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
17
+ phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
18
+
19
+ nodes = []
20
+
21
+ for phi, theta in zip(phis, thetas):
22
+ xp = np.sin(theta) * np.cos(phi)
23
+ yp = np.sin(theta) * np.sin(phi)
24
+ zp = np.cos(theta)
25
+
26
+ z = np.array([xp, yp, zp])
27
+ z = z / np.linalg.norm(z)
28
+ x = np.array([-z[1], z[0], 0.0])
29
+ if np.linalg.norm(x) == 0:
30
+ x = np.array([1.0, 0.0, 0.0])
31
+ x = x / np.linalg.norm(x)
32
+ y = np.cross(z, x)
33
+
34
+ matrix = np.eye(4)
35
+ matrix[:3,:3] = np.c_[x,y,z]
36
+ nodes.append(pyrender.Node(
37
+ light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0),
38
+ matrix=matrix
39
+ ))
40
+
41
+ return nodes
42
+
43
+ class MeshRenderer:
44
+
45
+ def __init__(self, cfg, faces=None):
46
+ self.cfg = cfg
47
+ self.focal_length = cfg.EXTRA.FOCAL_LENGTH
48
+ self.img_res = cfg.MODEL.IMAGE_SIZE
49
+ self.renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res,
50
+ viewport_height=self.img_res,
51
+ point_size=1.0)
52
+
53
+ self.camera_center = [self.img_res // 2, self.img_res // 2]
54
+ self.faces = faces
55
+
56
+ def visualize(self, vertices, camera_translation, images, focal_length=None, nrow=3, padding=2):
57
+ images_np = np.transpose(images, (0,2,3,1))
58
+ rend_imgs = []
59
+ for i in range(vertices.shape[0]):
60
+ fl = self.focal_length
61
+ rend_img = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=False), (2,0,1))).float()
62
+ rend_img_side = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=True), (2,0,1))).float()
63
+ rend_imgs.append(torch.from_numpy(images[i]))
64
+ rend_imgs.append(rend_img)
65
+ rend_imgs.append(rend_img_side)
66
+ rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding)
67
+ return rend_imgs
68
+
69
+ def visualize_tensorboard(self, vertices, camera_translation, images, pred_keypoints, gt_keypoints, focal_length=None, nrow=5, padding=2):
70
+ images_np = np.transpose(images, (0,2,3,1))
71
+ rend_imgs = []
72
+ pred_keypoints = np.concatenate((pred_keypoints, np.ones_like(pred_keypoints)[:, :, [0]]), axis=-1)
73
+ pred_keypoints = self.img_res * (pred_keypoints + 0.5)
74
+ gt_keypoints[:, :, :-1] = self.img_res * (gt_keypoints[:, :, :-1] + 0.5)
75
+ keypoint_matches = [(1, 12), (2, 8), (3, 7), (4, 6), (5, 9), (6, 10), (7, 11), (8, 14), (9, 2), (10, 1), (11, 0), (12, 3), (13, 4), (14, 5)]
76
+ for i in range(vertices.shape[0]):
77
+ fl = self.focal_length
78
+ rend_img = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=False), (2,0,1))).float()
79
+ rend_img_side = torch.from_numpy(np.transpose(self.__call__(vertices[i], camera_translation[i], images_np[i], focal_length=fl, side_view=True), (2,0,1))).float()
80
+ body_keypoints = pred_keypoints[i, :25]
81
+ extra_keypoints = pred_keypoints[i, -19:]
82
+ for pair in keypoint_matches:
83
+ body_keypoints[pair[0], :] = extra_keypoints[pair[1], :]
84
+ pred_keypoints_img = render_openpose(255 * images_np[i].copy(), body_keypoints) / 255
85
+ body_keypoints = gt_keypoints[i, :25]
86
+ extra_keypoints = gt_keypoints[i, -19:]
87
+ for pair in keypoint_matches:
88
+ if extra_keypoints[pair[1], -1] > 0 and body_keypoints[pair[0], -1] == 0:
89
+ body_keypoints[pair[0], :] = extra_keypoints[pair[1], :]
90
+ gt_keypoints_img = render_openpose(255*images_np[i].copy(), body_keypoints) / 255
91
+ rend_imgs.append(torch.from_numpy(images[i]))
92
+ rend_imgs.append(rend_img)
93
+ rend_imgs.append(rend_img_side)
94
+ rend_imgs.append(torch.from_numpy(pred_keypoints_img).permute(2,0,1))
95
+ rend_imgs.append(torch.from_numpy(gt_keypoints_img).permute(2,0,1))
96
+ rend_imgs = make_grid(rend_imgs, nrow=nrow, padding=padding)
97
+ return rend_imgs
98
+
99
+ def __call__(self, vertices, camera_translation, image, focal_length=5000, text=None, resize=None, side_view=False, baseColorFactor=(1.0, 1.0, 0.9, 1.0), rot_angle=90):
100
+ renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1],
101
+ viewport_height=image.shape[0],
102
+ point_size=1.0)
103
+ material = pyrender.MetallicRoughnessMaterial(
104
+ metallicFactor=0.0,
105
+ alphaMode='OPAQUE',
106
+ baseColorFactor=baseColorFactor)
107
+
108
+ camera_translation[0] *= -1.
109
+
110
+ mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
111
+ if side_view:
112
+ rot = trimesh.transformations.rotation_matrix(
113
+ np.radians(rot_angle), [0, 1, 0])
114
+ mesh.apply_transform(rot)
115
+ rot = trimesh.transformations.rotation_matrix(
116
+ np.radians(180), [1, 0, 0])
117
+ mesh.apply_transform(rot)
118
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
119
+
120
+ scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
121
+ ambient_light=(0.3, 0.3, 0.3))
122
+ scene.add(mesh, 'mesh')
123
+
124
+ camera_pose = np.eye(4)
125
+ camera_pose[:3, 3] = camera_translation
126
+ camera_center = [image.shape[1] / 2., image.shape[0] / 2.]
127
+ camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
128
+ cx=camera_center[0], cy=camera_center[1])
129
+ scene.add(camera, pose=camera_pose)
130
+
131
+
132
+ light_nodes = create_raymond_lights()
133
+ for node in light_nodes:
134
+ scene.add_node(node)
135
+
136
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
137
+ color = color.astype(np.float32) / 255.0
138
+ valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
139
+ if not side_view:
140
+ output_img = (color[:, :, :3] * valid_mask +
141
+ (1 - valid_mask) * image)
142
+ else:
143
+ output_img = color[:, :, :3]
144
+ if resize is not None:
145
+ output_img = cv2.resize(output_img, resize)
146
+
147
+ output_img = output_img.astype(np.float32)
148
+ renderer.delete()
149
+ return output_img
hmr2/utils/pose_utils.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from: https://github.com/akanazawa/hmr/blob/master/src/benchmark/eval_util.py
3
+ """
4
+
5
+ import torch
6
+ import numpy as np
7
+ from typing import Optional, Dict, List, Tuple
8
+
9
+ def compute_similarity_transform(S1: torch.Tensor, S2: torch.Tensor) -> torch.Tensor:
10
+ """
11
+ Computes a similarity transform (sR, t) in a batched way that takes
12
+ a set of 3D points S1 (B, N, 3) closest to a set of 3D points S2 (B, N, 3),
13
+ where R is a 3x3 rotation matrix, t 3x1 translation, s scale.
14
+ i.e. solves the orthogonal Procrutes problem.
15
+ Args:
16
+ S1 (torch.Tensor): First set of points of shape (B, N, 3).
17
+ S2 (torch.Tensor): Second set of points of shape (B, N, 3).
18
+ Returns:
19
+ (torch.Tensor): The first set of points after applying the similarity transformation.
20
+ """
21
+
22
+ batch_size = S1.shape[0]
23
+ S1 = S1.permute(0, 2, 1)
24
+ S2 = S2.permute(0, 2, 1)
25
+ # 1. Remove mean.
26
+ mu1 = S1.mean(dim=2, keepdim=True)
27
+ mu2 = S2.mean(dim=2, keepdim=True)
28
+ X1 = S1 - mu1
29
+ X2 = S2 - mu2
30
+
31
+ # 2. Compute variance of X1 used for scale.
32
+ var1 = (X1**2).sum(dim=(1,2))
33
+
34
+ # 3. The outer product of X1 and X2.
35
+ K = torch.matmul(X1, X2.permute(0, 2, 1))
36
+
37
+ # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are singular vectors of K.
38
+ U, s, V = torch.svd(K)
39
+ Vh = V.permute(0, 2, 1)
40
+
41
+ # Construct Z that fixes the orientation of R to get det(R)=1.
42
+ Z = torch.eye(U.shape[1], device=U.device).unsqueeze(0).repeat(batch_size, 1, 1)
43
+ Z[:, -1, -1] *= torch.sign(torch.linalg.det(torch.matmul(U, Vh)))
44
+
45
+ # Construct R.
46
+ R = torch.matmul(torch.matmul(V, Z), U.permute(0, 2, 1))
47
+
48
+ # 5. Recover scale.
49
+ trace = torch.matmul(R, K).diagonal(offset=0, dim1=-1, dim2=-2).sum(dim=-1)
50
+ scale = (trace / var1).unsqueeze(dim=-1).unsqueeze(dim=-1)
51
+
52
+ # 6. Recover translation.
53
+ t = mu2 - scale*torch.matmul(R, mu1)
54
+
55
+ # 7. Error:
56
+ S1_hat = scale*torch.matmul(R, S1) + t
57
+
58
+ return S1_hat.permute(0, 2, 1)
59
+
60
+ def reconstruction_error(S1, S2) -> np.array:
61
+ """
62
+ Computes the mean Euclidean distance of 2 set of points S1, S2 after performing Procrustes alignment.
63
+ Args:
64
+ S1 (torch.Tensor): First set of points of shape (B, N, 3).
65
+ S2 (torch.Tensor): Second set of points of shape (B, N, 3).
66
+ Returns:
67
+ (np.array): Reconstruction error.
68
+ """
69
+ S1_hat = compute_similarity_transform(S1, S2)
70
+ re = torch.sqrt( ((S1_hat - S2)** 2).sum(dim=-1)).mean(dim=-1)
71
+ return re
72
+
73
+ def eval_pose(pred_joints, gt_joints) -> Tuple[np.array, np.array]:
74
+ """
75
+ Compute joint errors in mm before and after Procrustes alignment.
76
+ Args:
77
+ pred_joints (torch.Tensor): Predicted 3D joints of shape (B, N, 3).
78
+ gt_joints (torch.Tensor): Ground truth 3D joints of shape (B, N, 3).
79
+ Returns:
80
+ Tuple[np.array, np.array]: Joint errors in mm before and after alignment.
81
+ """
82
+ # Absolute error (MPJPE)
83
+ mpjpe = torch.sqrt(((pred_joints - gt_joints) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy()
84
+
85
+ # Reconstuction_error
86
+ r_error = reconstruction_error(pred_joints, gt_joints).cpu().numpy()
87
+ return 1000 * mpjpe, 1000 * r_error
88
+
89
+ class Evaluator:
90
+
91
+ def __init__(self,
92
+ dataset_length: int,
93
+ keypoint_list: List,
94
+ pelvis_ind: int,
95
+ metrics: List = ['mode_mpjpe', 'mode_re', 'min_mpjpe', 'min_re'],
96
+ pck_thresholds: Optional[List] = None):
97
+ """
98
+ Class used for evaluating trained models on different 3D pose datasets.
99
+ Args:
100
+ dataset_length (int): Total dataset length.
101
+ keypoint_list [List]: List of keypoints used for evaluation.
102
+ pelvis_ind (int): Index of pelvis keypoint; used for aligning the predictions and ground truth.
103
+ metrics [List]: List of evaluation metrics to record.
104
+ """
105
+ self.dataset_length = dataset_length
106
+ self.keypoint_list = keypoint_list
107
+ self.pelvis_ind = pelvis_ind
108
+ self.metrics = metrics
109
+ for metric in self.metrics:
110
+ setattr(self, metric, np.zeros((dataset_length,)))
111
+ self.counter = 0
112
+ if pck_thresholds is None:
113
+ self.pck_evaluator = None
114
+ else:
115
+ self.pck_evaluator = EvaluatorPCK(pck_thresholds)
116
+
117
+ def log(self):
118
+ """
119
+ Print current evaluation metrics
120
+ """
121
+ if self.counter == 0:
122
+ print('Evaluation has not started')
123
+ return
124
+ print(f'{self.counter} / {self.dataset_length} samples')
125
+ if self.pck_evaluator is not None:
126
+ self.pck_evaluator.log()
127
+ for metric in self.metrics:
128
+ if metric in ['mode_mpjpe', 'mode_re', 'min_mpjpe', 'min_re']:
129
+ unit = 'mm'
130
+ else:
131
+ unit = ''
132
+ print(f'{metric}: {getattr(self, metric)[:self.counter].mean()} {unit}')
133
+ print('***')
134
+
135
+ def get_metrics_dict(self) -> Dict:
136
+ """
137
+ Returns:
138
+ Dict: Dictionary of evaluation metrics.
139
+ """
140
+ d1 = {metric: getattr(self, metric)[:self.counter].mean() for metric in self.metrics}
141
+ if self.pck_evaluator is not None:
142
+ d2 = self.pck_evaluator.get_metrics_dict()
143
+ d1.update(d2)
144
+ return d1
145
+
146
+ def __call__(self, output: Dict, batch: Dict, opt_output: Optional[Dict] = None):
147
+ """
148
+ Evaluate current batch.
149
+ Args:
150
+ output (Dict): Regression output.
151
+ batch (Dict): Dictionary containing images and their corresponding annotations.
152
+ opt_output (Dict): Optimization output.
153
+ """
154
+ if self.pck_evaluator is not None:
155
+ self.pck_evaluator(output, batch, opt_output)
156
+
157
+ pred_keypoints_3d = output['pred_keypoints_3d'].detach()
158
+ pred_keypoints_3d = pred_keypoints_3d[:,None,:,:]
159
+ batch_size = pred_keypoints_3d.shape[0]
160
+ num_samples = pred_keypoints_3d.shape[1]
161
+ gt_keypoints_3d = batch['keypoints_3d'][:, :, :-1].unsqueeze(1).repeat(1, num_samples, 1, 1)
162
+
163
+ # Align predictions and ground truth such that the pelvis location is at the origin
164
+ pred_keypoints_3d -= pred_keypoints_3d[:, :, [self.pelvis_ind]]
165
+ gt_keypoints_3d -= gt_keypoints_3d[:, :, [self.pelvis_ind]]
166
+
167
+ # Compute joint errors
168
+ mpjpe, re = eval_pose(pred_keypoints_3d.reshape(batch_size * num_samples, -1, 3)[:, self.keypoint_list], gt_keypoints_3d.reshape(batch_size * num_samples, -1 ,3)[:, self.keypoint_list])
169
+ mpjpe = mpjpe.reshape(batch_size, num_samples)
170
+ re = re.reshape(batch_size, num_samples)
171
+
172
+ # Compute 2d keypoint errors
173
+ pred_keypoints_2d = output['pred_keypoints_2d'].detach()
174
+ pred_keypoints_2d = pred_keypoints_2d[:,None,:,:]
175
+ gt_keypoints_2d = batch['keypoints_2d'][:,None,:,:].repeat(1, num_samples, 1, 1)
176
+ conf = gt_keypoints_2d[:, :, :, -1].clone()
177
+ kp_err = torch.nn.functional.mse_loss(
178
+ pred_keypoints_2d,
179
+ gt_keypoints_2d[:, :, :, :-1],
180
+ reduction='none'
181
+ ).sum(dim=3)
182
+ kp_l2_loss = (conf * kp_err).mean(dim=2)
183
+ kp_l2_loss = kp_l2_loss.detach().cpu().numpy()
184
+
185
+ # Compute joint errors after optimization, if available.
186
+ if opt_output is not None:
187
+ opt_keypoints_3d = opt_output['model_joints']
188
+ opt_keypoints_3d -= opt_keypoints_3d[:, [self.pelvis_ind]]
189
+ opt_mpjpe, opt_re = eval_pose(opt_keypoints_3d[:, self.keypoint_list], gt_keypoints_3d[:, 0, self.keypoint_list])
190
+
191
+ # The 0-th sample always corresponds to the mode
192
+ if hasattr(self, 'mode_mpjpe'):
193
+ mode_mpjpe = mpjpe[:, 0]
194
+ self.mode_mpjpe[self.counter:self.counter+batch_size] = mode_mpjpe
195
+ if hasattr(self, 'mode_re'):
196
+ mode_re = re[:, 0]
197
+ self.mode_re[self.counter:self.counter+batch_size] = mode_re
198
+ if hasattr(self, 'mode_kpl2'):
199
+ mode_kpl2 = kp_l2_loss[:, 0]
200
+ self.mode_kpl2[self.counter:self.counter+batch_size] = mode_kpl2
201
+ if hasattr(self, 'min_mpjpe'):
202
+ min_mpjpe = mpjpe.min(axis=-1)
203
+ self.min_mpjpe[self.counter:self.counter+batch_size] = min_mpjpe
204
+ if hasattr(self, 'min_re'):
205
+ min_re = re.min(axis=-1)
206
+ self.min_re[self.counter:self.counter+batch_size] = min_re
207
+ if hasattr(self, 'min_kpl2'):
208
+ min_kpl2 = kp_l2_loss.min(axis=-1)
209
+ self.min_kpl2[self.counter:self.counter+batch_size] = min_kpl2
210
+ if hasattr(self, 'opt_mpjpe'):
211
+ self.opt_mpjpe[self.counter:self.counter+batch_size] = opt_mpjpe
212
+ if hasattr(self, 'opt_re'):
213
+ self.opt_re[self.counter:self.counter+batch_size] = opt_re
214
+
215
+ self.counter += batch_size
216
+
217
+ if hasattr(self, 'mode_mpjpe') and hasattr(self, 'mode_re'):
218
+ return {
219
+ 'mode_mpjpe': mode_mpjpe,
220
+ 'mode_re': mode_re,
221
+ }
222
+ else:
223
+ return {}
224
+
225
+
226
+ class EvaluatorPCK:
227
+
228
+ def __init__(self, thresholds: List = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5],):
229
+ """
230
+ Class used for evaluating trained models on different 3D pose datasets.
231
+ Args:
232
+ thresholds [List]: List of PCK thresholds to evaluate.
233
+ metrics [List]: List of evaluation metrics to record.
234
+ """
235
+ self.thresholds = thresholds
236
+ self.pred_kp_2d = []
237
+ self.gt_kp_2d = []
238
+ self.gt_conf_2d = []
239
+ self.counter = 0
240
+
241
+ def log(self):
242
+ """
243
+ Print current evaluation metrics
244
+ """
245
+ if self.counter == 0:
246
+ print('Evaluation has not started')
247
+ return
248
+ print(f'{self.counter} samples')
249
+ metrics_dict = self.get_metrics_dict()
250
+ for metric in metrics_dict:
251
+ print(f'{metric}: {metrics_dict[metric]}')
252
+ print('***')
253
+
254
+ def get_metrics_dict(self) -> Dict:
255
+ """
256
+ Returns:
257
+ Dict: Dictionary of evaluation metrics.
258
+ """
259
+ pcks = self.compute_pcks()
260
+ metrics = {}
261
+ for thr, (acc,avg_acc,cnt) in zip(self.thresholds, pcks):
262
+ metrics.update({f'kp{i}_pck_{thr}': float(a) for i, a in enumerate(acc) if a>=0})
263
+ metrics.update({f'kpAvg_pck_{thr}': float(avg_acc)})
264
+ return metrics
265
+
266
+ def compute_pcks(self):
267
+ pred_kp_2d = np.concatenate(self.pred_kp_2d, axis=0)
268
+ gt_kp_2d = np.concatenate(self.gt_kp_2d, axis=0)
269
+ gt_conf_2d = np.concatenate(self.gt_conf_2d, axis=0)
270
+ assert pred_kp_2d.shape == gt_kp_2d.shape
271
+ assert pred_kp_2d[..., 0].shape == gt_conf_2d.shape
272
+ assert pred_kp_2d.shape[1] == 1 # num_samples
273
+
274
+ from mmpose.core.evaluation import keypoint_pck_accuracy
275
+ pcks = [
276
+ keypoint_pck_accuracy(
277
+ pred_kp_2d[:, 0, :, :],
278
+ gt_kp_2d[:, 0, :, :],
279
+ gt_conf_2d[:, 0, :]>0.5,
280
+ thr=thr,
281
+ normalize = np.ones((len(pred_kp_2d),2)) # Already in [-0.5,0.5] range. No need to normalize
282
+ )
283
+ for thr in self.thresholds
284
+ ]
285
+ return pcks
286
+
287
+ def __call__(self, output: Dict, batch: Dict, opt_output: Optional[Dict] = None):
288
+ """
289
+ Evaluate current batch.
290
+ Args:
291
+ output (Dict): Regression output.
292
+ batch (Dict): Dictionary containing images and their corresponding annotations.
293
+ opt_output (Dict): Optimization output.
294
+ """
295
+ pred_keypoints_2d = output['pred_keypoints_2d'].detach()
296
+ num_samples = 1
297
+ batch_size = pred_keypoints_2d.shape[0]
298
+
299
+ pred_keypoints_2d = pred_keypoints_2d[:,None,:,:]
300
+ gt_keypoints_2d = batch['keypoints_2d'][:,None,:,:].repeat(1, num_samples, 1, 1)
301
+
302
+ self.pred_kp_2d.append(pred_keypoints_2d[:, :, :, :2].detach().cpu().numpy())
303
+ self.gt_conf_2d.append(gt_keypoints_2d[:, :, :, -1].detach().cpu().numpy())
304
+ self.gt_kp_2d.append(gt_keypoints_2d[:, :, :, :2].detach().cpu().numpy())
305
+
306
+ self.counter += batch_size
hmr2/utils/render_openpose.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Render OpenPose keypoints.
3
+ Code was ported to Python from the official C++ implementation https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/utilities/keypoint.cpp
4
+ """
5
+ import cv2
6
+ import math
7
+ import numpy as np
8
+ from typing import List, Tuple
9
+
10
+ def get_keypoints_rectangle(keypoints: np.array, threshold: float) -> Tuple[float, float, float]:
11
+ """
12
+ Compute rectangle enclosing keypoints above the threshold.
13
+ Args:
14
+ keypoints (np.array): Keypoint array of shape (N, 3).
15
+ threshold (float): Confidence visualization threshold.
16
+ Returns:
17
+ Tuple[float, float, float]: Rectangle width, height and area.
18
+ """
19
+ valid_ind = keypoints[:, -1] > threshold
20
+ if valid_ind.sum() > 0:
21
+ valid_keypoints = keypoints[valid_ind][:, :-1]
22
+ max_x = valid_keypoints[:,0].max()
23
+ max_y = valid_keypoints[:,1].max()
24
+ min_x = valid_keypoints[:,0].min()
25
+ min_y = valid_keypoints[:,1].min()
26
+ width = max_x - min_x
27
+ height = max_y - min_y
28
+ area = width * height
29
+ return width, height, area
30
+ else:
31
+ return 0,0,0
32
+
33
+ def render_keypoints(img: np.array,
34
+ keypoints: np.array,
35
+ pairs: List,
36
+ colors: List,
37
+ thickness_circle_ratio: float,
38
+ thickness_line_ratio_wrt_circle: float,
39
+ pose_scales: List,
40
+ threshold: float = 0.1) -> np.array:
41
+ """
42
+ Render keypoints on input image.
43
+ Args:
44
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
45
+ keypoints (np.array): Keypoint array of shape (N, 3).
46
+ pairs (List): List of keypoint pairs per limb.
47
+ colors: (List): List of colors per keypoint.
48
+ thickness_circle_ratio (float): Circle thickness ratio.
49
+ thickness_line_ratio_wrt_circle (float): Line thickness ratio wrt the circle.
50
+ pose_scales (List): List of pose scales.
51
+ threshold (float): Only visualize keypoints with confidence above the threshold.
52
+ Returns:
53
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
54
+ """
55
+ img_orig = img.copy()
56
+ width, height = img.shape[1], img.shape[2]
57
+ area = width * height
58
+
59
+ lineType = 8
60
+ shift = 0
61
+ numberColors = len(colors)
62
+ thresholdRectangle = 0.1
63
+
64
+ person_width, person_height, person_area = get_keypoints_rectangle(keypoints, thresholdRectangle)
65
+ if person_area > 0:
66
+ ratioAreas = min(1, max(person_width / width, person_height / height))
67
+ thicknessRatio = np.maximum(np.round(math.sqrt(area) * thickness_circle_ratio * ratioAreas), 2)
68
+ thicknessCircle = np.maximum(1, thicknessRatio if ratioAreas > 0.05 else -np.ones_like(thicknessRatio))
69
+ thicknessLine = np.maximum(1, np.round(thicknessRatio * thickness_line_ratio_wrt_circle))
70
+ radius = thicknessRatio / 2
71
+
72
+ img = np.ascontiguousarray(img.copy())
73
+ for i, pair in enumerate(pairs):
74
+ index1, index2 = pair
75
+ if keypoints[index1, -1] > threshold and keypoints[index2, -1] > threshold:
76
+ thicknessLineScaled = int(round(min(thicknessLine[index1], thicknessLine[index2]) * pose_scales[0]))
77
+ colorIndex = index2
78
+ color = colors[colorIndex % numberColors]
79
+ keypoint1 = keypoints[index1, :-1].astype(np.int)
80
+ keypoint2 = keypoints[index2, :-1].astype(np.int)
81
+ cv2.line(img, tuple(keypoint1.tolist()), tuple(keypoint2.tolist()), tuple(color.tolist()), thicknessLineScaled, lineType, shift)
82
+ for part in range(len(keypoints)):
83
+ faceIndex = part
84
+ if keypoints[faceIndex, -1] > threshold:
85
+ radiusScaled = int(round(radius[faceIndex] * pose_scales[0]))
86
+ thicknessCircleScaled = int(round(thicknessCircle[faceIndex] * pose_scales[0]))
87
+ colorIndex = part
88
+ color = colors[colorIndex % numberColors]
89
+ center = keypoints[faceIndex, :-1].astype(np.int)
90
+ cv2.circle(img, tuple(center.tolist()), radiusScaled, tuple(color.tolist()), thicknessCircleScaled, lineType, shift)
91
+ return img
92
+
93
+ def render_body_keypoints(img: np.array,
94
+ body_keypoints: np.array) -> np.array:
95
+ """
96
+ Render OpenPose body keypoints on input image.
97
+ Args:
98
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
99
+ body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
100
+ Returns:
101
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
102
+ """
103
+
104
+ thickness_circle_ratio = 1./75. * np.ones(body_keypoints.shape[0])
105
+ thickness_line_ratio_wrt_circle = 0.75
106
+ pairs = []
107
+ pairs = [1,8,1,2,1,5,2,3,3,4,5,6,6,7,8,9,9,10,10,11,8,12,12,13,13,14,1,0,0,15,15,17,0,16,16,18,14,19,19,20,14,21,11,22,22,23,11,24]
108
+ pairs = np.array(pairs).reshape(-1,2)
109
+ colors = [255., 0., 85.,
110
+ 255., 0., 0.,
111
+ 255., 85., 0.,
112
+ 255., 170., 0.,
113
+ 255., 255., 0.,
114
+ 170., 255., 0.,
115
+ 85., 255., 0.,
116
+ 0., 255., 0.,
117
+ 255., 0., 0.,
118
+ 0., 255., 85.,
119
+ 0., 255., 170.,
120
+ 0., 255., 255.,
121
+ 0., 170., 255.,
122
+ 0., 85., 255.,
123
+ 0., 0., 255.,
124
+ 255., 0., 170.,
125
+ 170., 0., 255.,
126
+ 255., 0., 255.,
127
+ 85., 0., 255.,
128
+ 0., 0., 255.,
129
+ 0., 0., 255.,
130
+ 0., 0., 255.,
131
+ 0., 255., 255.,
132
+ 0., 255., 255.,
133
+ 0., 255., 255.]
134
+ colors = np.array(colors).reshape(-1,3)
135
+ pose_scales = [1]
136
+ return render_keypoints(img, body_keypoints, pairs, colors, thickness_circle_ratio, thickness_line_ratio_wrt_circle, pose_scales, 0.1)
137
+
138
+ def render_openpose(img: np.array,
139
+ body_keypoints: np.array) -> np.array:
140
+ """
141
+ Render keypoints in the OpenPose format on input image.
142
+ Args:
143
+ img (np.array): Input image of shape (H, W, 3) with pixel values in the [0,255] range.
144
+ body_keypoints (np.array): Keypoint array of shape (N, 3); 3 <====> (x, y, confidence).
145
+ Returns:
146
+ (np.array): Image of shape (H, W, 3) with keypoints drawn on top of the original image.
147
+ """
148
+ img = render_body_keypoints(img, body_keypoints)
149
+ return img
hmr2/utils/renderer.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ if 'PYOPENGL_PLATFORM' not in os.environ:
3
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
4
+ import torch
5
+ import numpy as np
6
+ import pyrender
7
+ import trimesh
8
+ import cv2
9
+ from yacs.config import CfgNode
10
+ from typing import List, Optional
11
+
12
+ def cam_crop_to_full(cam_bbox, box_center, box_size, img_size, focal_length=5000.):
13
+ # Convert cam_bbox to full image
14
+ img_w, img_h = img_size[:, 0], img_size[:, 1]
15
+ cx, cy, b = box_center[:, 0], box_center[:, 1], box_size
16
+ w_2, h_2 = img_w / 2., img_h / 2.
17
+ bs = b * cam_bbox[:, 0] + 1e-9
18
+ tz = 2 * focal_length / bs
19
+ tx = (2 * (cx - w_2) / bs) + cam_bbox[:, 1]
20
+ ty = (2 * (cy - h_2) / bs) + cam_bbox[:, 2]
21
+ full_cam = torch.stack([tx, ty, tz], dim=-1)
22
+ return full_cam
23
+
24
+ def get_light_poses(n_lights=5, elevation=np.pi / 3, dist=12):
25
+ # get lights in a circle around origin at elevation
26
+ thetas = elevation * np.ones(n_lights)
27
+ phis = 2 * np.pi * np.arange(n_lights) / n_lights
28
+ poses = []
29
+ trans = make_translation(torch.tensor([0, 0, dist]))
30
+ for phi, theta in zip(phis, thetas):
31
+ rot = make_rotation(rx=-theta, ry=phi, order="xyz")
32
+ poses.append((rot @ trans).numpy())
33
+ return poses
34
+
35
+ def make_translation(t):
36
+ return make_4x4_pose(torch.eye(3), t)
37
+
38
+ def make_rotation(rx=0, ry=0, rz=0, order="xyz"):
39
+ Rx = rotx(rx)
40
+ Ry = roty(ry)
41
+ Rz = rotz(rz)
42
+ if order == "xyz":
43
+ R = Rz @ Ry @ Rx
44
+ elif order == "xzy":
45
+ R = Ry @ Rz @ Rx
46
+ elif order == "yxz":
47
+ R = Rz @ Rx @ Ry
48
+ elif order == "yzx":
49
+ R = Rx @ Rz @ Ry
50
+ elif order == "zyx":
51
+ R = Rx @ Ry @ Rz
52
+ elif order == "zxy":
53
+ R = Ry @ Rx @ Rz
54
+ return make_4x4_pose(R, torch.zeros(3))
55
+
56
+ def make_4x4_pose(R, t):
57
+ """
58
+ :param R (*, 3, 3)
59
+ :param t (*, 3)
60
+ return (*, 4, 4)
61
+ """
62
+ dims = R.shape[:-2]
63
+ pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1)
64
+ bottom = (
65
+ torch.tensor([0, 0, 0, 1], device=R.device)
66
+ .reshape(*(1,) * len(dims), 1, 4)
67
+ .expand(*dims, 1, 4)
68
+ )
69
+ return torch.cat([pose_3x4, bottom], dim=-2)
70
+
71
+
72
+ def rotx(theta):
73
+ return torch.tensor(
74
+ [
75
+ [1, 0, 0],
76
+ [0, np.cos(theta), -np.sin(theta)],
77
+ [0, np.sin(theta), np.cos(theta)],
78
+ ],
79
+ dtype=torch.float32,
80
+ )
81
+
82
+
83
+ def roty(theta):
84
+ return torch.tensor(
85
+ [
86
+ [np.cos(theta), 0, np.sin(theta)],
87
+ [0, 1, 0],
88
+ [-np.sin(theta), 0, np.cos(theta)],
89
+ ],
90
+ dtype=torch.float32,
91
+ )
92
+
93
+
94
+ def rotz(theta):
95
+ return torch.tensor(
96
+ [
97
+ [np.cos(theta), -np.sin(theta), 0],
98
+ [np.sin(theta), np.cos(theta), 0],
99
+ [0, 0, 1],
100
+ ],
101
+ dtype=torch.float32,
102
+ )
103
+
104
+
105
+ def create_raymond_lights() -> List[pyrender.Node]:
106
+ """
107
+ Return raymond light nodes for the scene.
108
+ """
109
+ thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
110
+ phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
111
+
112
+ nodes = []
113
+
114
+ for phi, theta in zip(phis, thetas):
115
+ xp = np.sin(theta) * np.cos(phi)
116
+ yp = np.sin(theta) * np.sin(phi)
117
+ zp = np.cos(theta)
118
+
119
+ z = np.array([xp, yp, zp])
120
+ z = z / np.linalg.norm(z)
121
+ x = np.array([-z[1], z[0], 0.0])
122
+ if np.linalg.norm(x) == 0:
123
+ x = np.array([1.0, 0.0, 0.0])
124
+ x = x / np.linalg.norm(x)
125
+ y = np.cross(z, x)
126
+
127
+ matrix = np.eye(4)
128
+ matrix[:3,:3] = np.c_[x,y,z]
129
+ nodes.append(pyrender.Node(
130
+ light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0),
131
+ matrix=matrix
132
+ ))
133
+
134
+ return nodes
135
+
136
+ class Renderer:
137
+
138
+ def __init__(self, cfg: CfgNode, faces: np.array):
139
+ """
140
+ Wrapper around the pyrender renderer to render SMPL meshes.
141
+ Args:
142
+ cfg (CfgNode): Model config file.
143
+ faces (np.array): Array of shape (F, 3) containing the mesh faces.
144
+ """
145
+ self.cfg = cfg
146
+ self.focal_length = cfg.EXTRA.FOCAL_LENGTH
147
+ self.img_res = cfg.MODEL.IMAGE_SIZE
148
+ # self.renderer = pyrender.OffscreenRenderer(viewport_width=self.img_res,
149
+ # viewport_height=self.img_res,
150
+ # point_size=1.0)
151
+
152
+ self.camera_center = [self.img_res // 2, self.img_res // 2]
153
+ self.faces = faces
154
+
155
+ def __call__(self,
156
+ vertices: np.array,
157
+ camera_translation: np.array,
158
+ image: torch.Tensor,
159
+ full_frame: bool = False,
160
+ imgname: Optional[str] = None,
161
+ side_view=False, rot_angle=90,
162
+ mesh_base_color=(1.0, 1.0, 0.9),
163
+ scene_bg_color=(0,0,0),
164
+ return_rgba=False,
165
+ ) -> np.array:
166
+ """
167
+ Render meshes on input image
168
+ Args:
169
+ vertices (np.array): Array of shape (V, 3) containing the mesh vertices.
170
+ camera_translation (np.array): Array of shape (3,) with the camera translation.
171
+ image (torch.Tensor): Tensor of shape (3, H, W) containing the image crop with normalized pixel values.
172
+ full_frame (bool): If True, then render on the full image.
173
+ imgname (Optional[str]): Contains the original image filenamee. Used only if full_frame == True.
174
+ """
175
+
176
+ if full_frame:
177
+ image = cv2.imread(imgname).astype(np.float32)[:, :, ::-1] / 255.
178
+ else:
179
+ image = image.clone() * torch.tensor(self.cfg.MODEL.IMAGE_STD, device=image.device).reshape(3,1,1)
180
+ image = image + torch.tensor(self.cfg.MODEL.IMAGE_MEAN, device=image.device).reshape(3,1,1)
181
+ image = image.permute(1, 2, 0).cpu().numpy()
182
+
183
+ renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1],
184
+ viewport_height=image.shape[0],
185
+ point_size=1.0)
186
+ material = pyrender.MetallicRoughnessMaterial(
187
+ metallicFactor=0.0,
188
+ alphaMode='OPAQUE',
189
+ baseColorFactor=(*mesh_base_color, 1.0))
190
+
191
+ camera_translation[0] *= -1.
192
+
193
+ mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
194
+ if side_view:
195
+ rot = trimesh.transformations.rotation_matrix(
196
+ np.radians(rot_angle), [0, 1, 0])
197
+ mesh.apply_transform(rot)
198
+ rot = trimesh.transformations.rotation_matrix(
199
+ np.radians(180), [1, 0, 0])
200
+ mesh.apply_transform(rot)
201
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
202
+
203
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
204
+ ambient_light=(0.3, 0.3, 0.3))
205
+ scene.add(mesh, 'mesh')
206
+
207
+ camera_pose = np.eye(4)
208
+ camera_pose[:3, 3] = camera_translation
209
+ camera_center = [image.shape[1] / 2., image.shape[0] / 2.]
210
+ camera = pyrender.IntrinsicsCamera(fx=self.focal_length, fy=self.focal_length,
211
+ cx=camera_center[0], cy=camera_center[1])
212
+ scene.add(camera, pose=camera_pose)
213
+
214
+
215
+ light_nodes = create_raymond_lights()
216
+ for node in light_nodes:
217
+ scene.add_node(node)
218
+
219
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
220
+ color = color.astype(np.float32) / 255.0
221
+ renderer.delete()
222
+
223
+ if return_rgba:
224
+ return color
225
+
226
+ valid_mask = (color[:, :, -1])[:, :, np.newaxis]
227
+ if not side_view:
228
+ output_img = (color[:, :, :3] * valid_mask + (1 - valid_mask) * image)
229
+ else:
230
+ output_img = color[:, :, :3]
231
+
232
+ output_img = output_img.astype(np.float32)
233
+ return output_img
234
+
235
+ def vertices_to_trimesh(self, vertices, camera_translation, mesh_base_color=(1.0, 1.0, 0.9),
236
+ rot_axis=[1,0,0], rot_angle=0,):
237
+ # material = pyrender.MetallicRoughnessMaterial(
238
+ # metallicFactor=0.0,
239
+ # alphaMode='OPAQUE',
240
+ # baseColorFactor=(*mesh_base_color, 1.0))
241
+ vertex_colors = np.array([(*mesh_base_color, 1.0)] * vertices.shape[0])
242
+ print(vertices.shape, camera_translation.shape)
243
+ mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces.copy(), vertex_colors=vertex_colors)
244
+ # mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
245
+
246
+ rot = trimesh.transformations.rotation_matrix(
247
+ np.radians(rot_angle), rot_axis)
248
+ mesh.apply_transform(rot)
249
+
250
+ rot = trimesh.transformations.rotation_matrix(
251
+ np.radians(180), [1, 0, 0])
252
+ mesh.apply_transform(rot)
253
+ return mesh
254
+
255
+ def render_rgba(
256
+ self,
257
+ vertices: np.array,
258
+ cam_t = None,
259
+ rot=None,
260
+ rot_axis=[1,0,0],
261
+ rot_angle=0,
262
+ camera_z=3,
263
+ # camera_translation: np.array,
264
+ mesh_base_color=(1.0, 1.0, 0.9),
265
+ scene_bg_color=(0,0,0),
266
+ render_res=[256, 256],
267
+ ):
268
+
269
+ renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0],
270
+ viewport_height=render_res[1],
271
+ point_size=1.0)
272
+ # material = pyrender.MetallicRoughnessMaterial(
273
+ # metallicFactor=0.0,
274
+ # alphaMode='OPAQUE',
275
+ # baseColorFactor=(*mesh_base_color, 1.0))
276
+
277
+ if cam_t is not None:
278
+ camera_translation = cam_t.copy()
279
+ # camera_translation[0] *= -1.
280
+ else:
281
+ camera_translation = np.array([0, 0, camera_z * self.focal_length/render_res[1]])
282
+
283
+ mesh = self.vertices_to_trimesh(vertices, camera_translation, mesh_base_color, rot_axis, rot_angle)
284
+ mesh = pyrender.Mesh.from_trimesh(mesh)
285
+ # mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
286
+
287
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
288
+ ambient_light=(0.3, 0.3, 0.3))
289
+ scene.add(mesh, 'mesh')
290
+
291
+ camera_pose = np.eye(4)
292
+ # camera_pose[:3, 3] = camera_translation
293
+ camera_center = [render_res[0] / 2., render_res[1] / 2.]
294
+ camera = pyrender.IntrinsicsCamera(fx=self.focal_length, fy=self.focal_length,
295
+ cx=camera_center[0], cy=camera_center[1])
296
+
297
+ # Create camera node and add it to pyRender scene
298
+ camera_node = pyrender.Node(camera=camera, matrix=camera_pose)
299
+ scene.add_node(camera_node)
300
+ self.add_point_lighting(scene, camera_node)
301
+ self.add_lighting(scene, camera_node)
302
+
303
+ light_nodes = create_raymond_lights()
304
+ for node in light_nodes:
305
+ scene.add_node(node)
306
+
307
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
308
+ color = color.astype(np.float32) / 255.0
309
+ renderer.delete()
310
+
311
+ return color
312
+
313
+ def render_rgba_multiple(
314
+ self,
315
+ vertices: List[np.array],
316
+ cam_t: List[np.array],
317
+ rot_axis=[1,0,0],
318
+ rot_angle=0,
319
+ mesh_base_color=(1.0, 1.0, 0.9),
320
+ scene_bg_color=(0,0,0),
321
+ render_res=[256, 256],
322
+ ):
323
+
324
+ renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0],
325
+ viewport_height=render_res[1],
326
+ point_size=1.0)
327
+ # material = pyrender.MetallicRoughnessMaterial(
328
+ # metallicFactor=0.0,
329
+ # alphaMode='OPAQUE',
330
+ # baseColorFactor=(*mesh_base_color, 1.0))
331
+
332
+ mesh_list = [pyrender.Mesh.from_trimesh(self.vertices_to_trimesh(vvv, ttt.copy(), mesh_base_color, rot_axis, rot_angle)) for vvv,ttt in zip(vertices, cam_t)]
333
+
334
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
335
+ ambient_light=(0.3, 0.3, 0.3))
336
+ for i,mesh in enumerate(mesh_list):
337
+ scene.add(mesh, f'mesh_{i}')
338
+
339
+ camera_pose = np.eye(4)
340
+ # camera_pose[:3, 3] = camera_translation
341
+ camera_center = [render_res[0] / 2., render_res[1] / 2.]
342
+ camera = pyrender.IntrinsicsCamera(fx=self.focal_length, fy=self.focal_length,
343
+ cx=camera_center[0], cy=camera_center[1])
344
+
345
+ # Create camera node and add it to pyRender scene
346
+ camera_node = pyrender.Node(camera=camera, matrix=camera_pose)
347
+ scene.add_node(camera_node)
348
+ self.add_point_lighting(scene, camera_node)
349
+ self.add_lighting(scene, camera_node)
350
+
351
+ light_nodes = create_raymond_lights()
352
+ for node in light_nodes:
353
+ scene.add_node(node)
354
+
355
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
356
+ color = color.astype(np.float32) / 255.0
357
+ renderer.delete()
358
+
359
+ return color
360
+
361
+ def add_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0):
362
+ # from phalp.visualize.py_renderer import get_light_poses
363
+ light_poses = get_light_poses()
364
+ light_poses.append(np.eye(4))
365
+ cam_pose = scene.get_pose(cam_node)
366
+ for i, pose in enumerate(light_poses):
367
+ matrix = cam_pose @ pose
368
+ node = pyrender.Node(
369
+ name=f"light-{i:02d}",
370
+ light=pyrender.DirectionalLight(color=color, intensity=intensity),
371
+ matrix=matrix,
372
+ )
373
+ if scene.has_node(node):
374
+ continue
375
+ scene.add_node(node)
376
+
377
+ def add_point_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0):
378
+ # from phalp.visualize.py_renderer import get_light_poses
379
+ light_poses = get_light_poses(dist=0.5)
380
+ light_poses.append(np.eye(4))
381
+ cam_pose = scene.get_pose(cam_node)
382
+ for i, pose in enumerate(light_poses):
383
+ matrix = cam_pose @ pose
384
+ # node = pyrender.Node(
385
+ # name=f"light-{i:02d}",
386
+ # light=pyrender.DirectionalLight(color=color, intensity=intensity),
387
+ # matrix=matrix,
388
+ # )
389
+ node = pyrender.Node(
390
+ name=f"plight-{i:02d}",
391
+ light=pyrender.PointLight(color=color, intensity=intensity),
392
+ matrix=matrix,
393
+ )
394
+ if scene.has_node(node):
395
+ continue
396
+ scene.add_node(node)
hmr2/utils/skeleton_renderer.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import trimesh
4
+ from typing import Optional
5
+ from yacs.config import CfgNode
6
+
7
+ from .geometry import perspective_projection
8
+ from .render_openpose import render_openpose
9
+
10
+ class SkeletonRenderer:
11
+
12
+ def __init__(self, cfg: CfgNode):
13
+ """
14
+ Object used to render 3D keypoints. Faster for use during training.
15
+ Args:
16
+ cfg (CfgNode): Model config file.
17
+ """
18
+ self.cfg = cfg
19
+
20
+ def __call__(self,
21
+ pred_keypoints_3d: torch.Tensor,
22
+ gt_keypoints_3d: torch.Tensor,
23
+ gt_keypoints_2d: torch.Tensor,
24
+ images: Optional[np.array] = None,
25
+ camera_translation: Optional[torch.Tensor] = None) -> np.array:
26
+ """
27
+ Render batch of 3D keypoints.
28
+ Args:
29
+ pred_keypoints_3d (torch.Tensor): Tensor of shape (B, S, N, 3) containing a batch of predicted 3D keypoints, with S samples per image.
30
+ gt_keypoints_3d (torch.Tensor): Tensor of shape (B, N, 4) containing corresponding ground truth 3D keypoints; last value is the confidence.
31
+ gt_keypoints_2d (torch.Tensor): Tensor of shape (B, N, 3) containing corresponding ground truth 2D keypoints.
32
+ images (torch.Tensor): Tensor of shape (B, H, W, 3) containing images with values in the [0,255] range.
33
+ camera_translation (torch.Tensor): Tensor of shape (B, 3) containing the camera translation.
34
+ Returns:
35
+ np.array : Image with the following layout. Each row contains the a) input image,
36
+ b) image with gt 2D keypoints,
37
+ c) image with projected gt 3D keypoints,
38
+ d_1, ... , d_S) image with projected predicted 3D keypoints,
39
+ e) gt 3D keypoints rendered from a side view,
40
+ f_1, ... , f_S) predicted 3D keypoints frorm a side view
41
+ """
42
+ batch_size = pred_keypoints_3d.shape[0]
43
+ # num_samples = pred_keypoints_3d.shape[1]
44
+ pred_keypoints_3d = pred_keypoints_3d.clone().cpu().float()
45
+ gt_keypoints_3d = gt_keypoints_3d.clone().cpu().float()
46
+ gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, [25+14], :-1] + pred_keypoints_3d[:, [25+14]]
47
+ gt_keypoints_2d = gt_keypoints_2d.clone().cpu().float().numpy()
48
+ gt_keypoints_2d[:, :, :-1] = self.cfg.MODEL.IMAGE_SIZE * (gt_keypoints_2d[:, :, :-1] + 1.0) / 2.0
49
+
50
+ openpose_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
51
+ gt_indices = [12, 8, 7, 6, 9, 10, 11, 14, 2, 1, 0, 3, 4, 5]
52
+ gt_indices = [25 + i for i in gt_indices]
53
+ keypoints_to_render = torch.ones(batch_size, gt_keypoints_3d.shape[1], 1)
54
+ rotation = torch.eye(3).unsqueeze(0)
55
+ if camera_translation is None:
56
+ camera_translation = torch.tensor([0.0, 0.0, 2 * self.cfg.EXTRA.FOCAL_LENGTH / (0.8 * self.cfg.MODEL.IMAGE_SIZE)]).unsqueeze(0).repeat(batch_size, 1)
57
+ else:
58
+ camera_translation = camera_translation.cpu()
59
+
60
+ if images is None:
61
+ images = np.zeros((batch_size, self.cfg.MODEL.IMAGE_SIZE, self.cfg.MODEL.IMAGE_SIZE, 3))
62
+ focal_length = torch.tensor([self.cfg.EXTRA.FOCAL_LENGTH, self.cfg.EXTRA.FOCAL_LENGTH]).reshape(1, 2)
63
+ camera_center = torch.tensor([self.cfg.MODEL.IMAGE_SIZE, self.cfg.MODEL.IMAGE_SIZE], dtype=torch.float).reshape(1, 2) / 2.
64
+ gt_keypoints_3d_proj = perspective_projection(gt_keypoints_3d[:, :, :-1], rotation=rotation.repeat(batch_size, 1, 1), translation=camera_translation[:, :], focal_length=focal_length.repeat(batch_size, 1), camera_center=camera_center.repeat(batch_size, 1))
65
+ pred_keypoints_3d_proj = perspective_projection(pred_keypoints_3d.reshape(batch_size, -1, 3), rotation=rotation.repeat(batch_size, 1, 1), translation=camera_translation.reshape(batch_size, -1), focal_length=focal_length.repeat(batch_size, 1), camera_center=camera_center.repeat(batch_size, 1)).reshape(batch_size, -1, 2)
66
+ gt_keypoints_3d_proj = torch.cat([gt_keypoints_3d_proj, gt_keypoints_3d[:, :, [-1]]], dim=-1).cpu().numpy()
67
+ pred_keypoints_3d_proj = torch.cat([pred_keypoints_3d_proj, keypoints_to_render.reshape(batch_size, -1, 1)], dim=-1).cpu().numpy()
68
+ rows = []
69
+ # Rotate keypoints to visualize side view
70
+ R = torch.tensor(trimesh.transformations.rotation_matrix(np.radians(90), [0, 1, 0])[:3, :3]).float()
71
+ gt_keypoints_3d_side = gt_keypoints_3d.clone()
72
+ gt_keypoints_3d_side[:, :, :-1] = torch.einsum('bni,ij->bnj', gt_keypoints_3d_side[:, :, :-1], R)
73
+ pred_keypoints_3d_side = pred_keypoints_3d.clone()
74
+ pred_keypoints_3d_side = torch.einsum('bni,ij->bnj', pred_keypoints_3d_side, R)
75
+ gt_keypoints_3d_proj_side = perspective_projection(gt_keypoints_3d_side[:, :, :-1], rotation=rotation.repeat(batch_size, 1, 1), translation=camera_translation[:, :], focal_length=focal_length.repeat(batch_size, 1), camera_center=camera_center.repeat(batch_size, 1))
76
+ pred_keypoints_3d_proj_side = perspective_projection(pred_keypoints_3d_side.reshape(batch_size, -1, 3), rotation=rotation.repeat(batch_size, 1, 1), translation=camera_translation.reshape(batch_size, -1), focal_length=focal_length.repeat(batch_size, 1), camera_center=camera_center.repeat(batch_size, 1)).reshape(batch_size, -1, 2)
77
+ gt_keypoints_3d_proj_side = torch.cat([gt_keypoints_3d_proj_side, gt_keypoints_3d_side[:, :, [-1]]], dim=-1).cpu().numpy()
78
+ pred_keypoints_3d_proj_side = torch.cat([pred_keypoints_3d_proj_side, keypoints_to_render.reshape(batch_size, -1, 1)], dim=-1).cpu().numpy()
79
+ for i in range(batch_size):
80
+ img = images[i]
81
+ side_img = np.zeros((self.cfg.MODEL.IMAGE_SIZE, self.cfg.MODEL.IMAGE_SIZE, 3))
82
+ # gt 2D keypoints
83
+ body_keypoints_2d = gt_keypoints_2d[i, :25].copy()
84
+ for op, gt in zip(openpose_indices, gt_indices):
85
+ if gt_keypoints_2d[i, gt, -1] > body_keypoints_2d[op, -1]:
86
+ body_keypoints_2d[op] = gt_keypoints_2d[i, gt]
87
+ gt_keypoints_img = render_openpose(img, body_keypoints_2d) / 255.
88
+ # gt 3D keypoints
89
+ body_keypoints_3d_proj = gt_keypoints_3d_proj[i, :25].copy()
90
+ for op, gt in zip(openpose_indices, gt_indices):
91
+ if gt_keypoints_3d_proj[i, gt, -1] > body_keypoints_3d_proj[op, -1]:
92
+ body_keypoints_3d_proj[op] = gt_keypoints_3d_proj[i, gt]
93
+ gt_keypoints_3d_proj_img = render_openpose(img, body_keypoints_3d_proj) / 255.
94
+ # gt 3D keypoints from the side
95
+ body_keypoints_3d_proj = gt_keypoints_3d_proj_side[i, :25].copy()
96
+ for op, gt in zip(openpose_indices, gt_indices):
97
+ if gt_keypoints_3d_proj_side[i, gt, -1] > body_keypoints_3d_proj[op, -1]:
98
+ body_keypoints_3d_proj[op] = gt_keypoints_3d_proj_side[i, gt]
99
+ gt_keypoints_3d_proj_img_side = render_openpose(side_img, body_keypoints_3d_proj) / 255.
100
+ # pred 3D keypoints
101
+ pred_keypoints_3d_proj_imgs = []
102
+ body_keypoints_3d_proj = pred_keypoints_3d_proj[i, :25].copy()
103
+ for op, gt in zip(openpose_indices, gt_indices):
104
+ if pred_keypoints_3d_proj[i, gt, -1] >= body_keypoints_3d_proj[op, -1]:
105
+ body_keypoints_3d_proj[op] = pred_keypoints_3d_proj[i, gt]
106
+ pred_keypoints_3d_proj_imgs.append(render_openpose(img, body_keypoints_3d_proj) / 255.)
107
+ pred_keypoints_3d_proj_img = np.concatenate(pred_keypoints_3d_proj_imgs, axis=1)
108
+ # gt 3D keypoints from the side
109
+ pred_keypoints_3d_proj_imgs_side = []
110
+ body_keypoints_3d_proj = pred_keypoints_3d_proj_side[i, :25].copy()
111
+ for op, gt in zip(openpose_indices, gt_indices):
112
+ if pred_keypoints_3d_proj_side[i, gt, -1] >= body_keypoints_3d_proj[op, -1]:
113
+ body_keypoints_3d_proj[op] = pred_keypoints_3d_proj_side[i, gt]
114
+ pred_keypoints_3d_proj_imgs_side.append(render_openpose(side_img, body_keypoints_3d_proj) / 255.)
115
+ pred_keypoints_3d_proj_img_side = np.concatenate(pred_keypoints_3d_proj_imgs_side, axis=1)
116
+ rows.append(np.concatenate((gt_keypoints_img, gt_keypoints_3d_proj_img, pred_keypoints_3d_proj_img, gt_keypoints_3d_proj_img_side, pred_keypoints_3d_proj_img_side), axis=1))
117
+ # Concatenate images
118
+ img = np.concatenate(rows, axis=0)
119
+ img[:, ::self.cfg.MODEL.IMAGE_SIZE, :] = 1.0
120
+ img[::self.cfg.MODEL.IMAGE_SIZE, :, :] = 1.0
121
+ img[:, (1+1+1)*self.cfg.MODEL.IMAGE_SIZE, :] = 0.5
122
+ return img
hmr2/utils/texture_utils.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.nn import functional as F
4
+ # from psbody.mesh.visibility import visibility_compute
5
+
6
+ def uv_to_xyz_and_normals(verts, f, fmap, bmap, ftov):
7
+ vn = estimate_vertex_normals(verts, f, ftov)
8
+ pixels_to_set = torch.nonzero(fmap+1)
9
+ x_to_set = pixels_to_set[:,0]
10
+ y_to_set = pixels_to_set[:,1]
11
+ b_coords = bmap[x_to_set, y_to_set, :]
12
+ f_coords = fmap[x_to_set, y_to_set]
13
+ v_ids = f[f_coords]
14
+ points = (b_coords[:,0,None]*verts[:,v_ids[:,0]]
15
+ + b_coords[:,1,None]*verts[:,v_ids[:,1]]
16
+ + b_coords[:,2,None]*verts[:,v_ids[:,2]])
17
+ normals = (b_coords[:,0,None]*vn[:,v_ids[:,0]]
18
+ + b_coords[:,1,None]*vn[:,v_ids[:,1]]
19
+ + b_coords[:,2,None]*vn[:,v_ids[:,2]])
20
+ return points, normals, vn, f_coords
21
+
22
+ def estimate_vertex_normals(v, f, ftov):
23
+ face_normals = TriNormalsScaled(v, f)
24
+ non_scaled_normals = torch.einsum('ij,bjk->bik', ftov, face_normals)
25
+ norms = torch.sum(non_scaled_normals ** 2.0, 2) ** 0.5
26
+ norms[norms == 0] = 1.0
27
+ return torch.div(non_scaled_normals, norms[:,:,None])
28
+
29
+ def TriNormalsScaled(v, f):
30
+ return torch.cross(_edges_for(v, f, 1, 0), _edges_for(v, f, 2, 0))
31
+
32
+ def _edges_for(v, f, cplus, cminus):
33
+ return v[:,f[:,cplus]] - v[:,f[:,cminus]]
34
+
35
+ def psbody_get_face_visibility(v, n, f, cams, normal_threshold=0.5):
36
+ bn, nverts, _ = v.shape
37
+ nfaces, _ = f.shape
38
+ vis_f = np.zeros([bn, nfaces], dtype='float32')
39
+ for i in range(bn):
40
+ vis, n_dot_cam = visibility_compute(v=v[i], n=n[i], f=f, cams=cams)
41
+ vis_v = (vis == 1) & (n_dot_cam > normal_threshold)
42
+ vis_f[i] = np.all(vis_v[0,f],1)
43
+ return vis_f
44
+
45
+ def compute_uvsampler(vt, ft, tex_size=6):
46
+ """
47
+ For this mesh, pre-computes the UV coordinates for
48
+ F x T x T points.
49
+ Returns F x T x T x 2
50
+ """
51
+ uv = obj2nmr_uvmap(ft, vt, tex_size=tex_size)
52
+ uv = uv.reshape(-1, tex_size, tex_size, 2)
53
+ return uv
54
+
55
+ def obj2nmr_uvmap(ft, vt, tex_size=6):
56
+ """
57
+ Converts obj uv_map to NMR uv_map (F x T x T x 2),
58
+ where tex_size (T) is the sample rate on each face.
59
+ """
60
+ # This is F x 3 x 2
61
+ uv_map_for_verts = vt[ft]
62
+
63
+ # obj's y coordinate is [1-0], but image is [0-1]
64
+ uv_map_for_verts[:, :, 1] = 1 - uv_map_for_verts[:, :, 1]
65
+
66
+ # range [0, 1] -> [-1, 1]
67
+ uv_map_for_verts = (2 * uv_map_for_verts) - 1
68
+
69
+ alpha = np.arange(tex_size, dtype=np.float) / (tex_size - 1)
70
+ beta = np.arange(tex_size, dtype=np.float) / (tex_size - 1)
71
+ import itertools
72
+ # Barycentric coordinate values
73
+ coords = np.stack([p for p in itertools.product(*[alpha, beta])])
74
+
75
+ # Compute alpha, beta (this is the same order as NMR)
76
+ v2 = uv_map_for_verts[:, 2]
77
+ v0v2 = uv_map_for_verts[:, 0] - uv_map_for_verts[:, 2]
78
+ v1v2 = uv_map_for_verts[:, 1] - uv_map_for_verts[:, 2]
79
+ # Interpolate the vertex uv values: F x 2 x T*2
80
+ uv_map = np.dstack([v0v2, v1v2]).dot(coords.T) + v2.reshape(-1, 2, 1)
81
+
82
+ # F x T*2 x 2 -> F x T x T x 2
83
+ uv_map = np.transpose(uv_map, (0, 2, 1)).reshape(-1, tex_size, tex_size, 2)
84
+
85
+ return uv_map
hmr2/utils/utils_detectron2.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import detectron2.data.transforms as T
2
+ import torch
3
+ from detectron2.checkpoint import DetectionCheckpointer
4
+ from detectron2.config import CfgNode, instantiate
5
+ from detectron2.data import MetadataCatalog
6
+ from omegaconf import OmegaConf
7
+
8
+
9
+ class DefaultPredictor_Lazy:
10
+ """Create a simple end-to-end predictor with the given config that runs on single device for a
11
+ single input image.
12
+
13
+ Compared to using the model directly, this class does the following additions:
14
+
15
+ 1. Load checkpoint from the weights specified in config (cfg.MODEL.WEIGHTS).
16
+ 2. Always take BGR image as the input and apply format conversion internally.
17
+ 3. Apply resizing defined by the config (`cfg.INPUT.{MIN,MAX}_SIZE_TEST`).
18
+ 4. Take one input image and produce a single output, instead of a batch.
19
+
20
+ This is meant for simple demo purposes, so it does the above steps automatically.
21
+ This is not meant for benchmarks or running complicated inference logic.
22
+ If you'd like to do anything more complicated, please refer to its source code as
23
+ examples to build and use the model manually.
24
+
25
+ Attributes:
26
+ metadata (Metadata): the metadata of the underlying dataset, obtained from
27
+ test dataset name in the config.
28
+
29
+
30
+ Examples:
31
+ ::
32
+ pred = DefaultPredictor(cfg)
33
+ inputs = cv2.imread("input.jpg")
34
+ outputs = pred(inputs)
35
+ """
36
+
37
+ def __init__(self, cfg):
38
+ """
39
+ Args:
40
+ cfg: a yacs CfgNode or a omegaconf dict object.
41
+ """
42
+ if isinstance(cfg, CfgNode):
43
+ self.cfg = cfg.clone() # cfg can be modified by model
44
+ self.model = build_model(self.cfg) # noqa: F821
45
+ if len(cfg.DATASETS.TEST):
46
+ test_dataset = cfg.DATASETS.TEST[0]
47
+
48
+ checkpointer = DetectionCheckpointer(self.model)
49
+ checkpointer.load(cfg.MODEL.WEIGHTS)
50
+
51
+ self.aug = T.ResizeShortestEdge(
52
+ [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
53
+ )
54
+
55
+ self.input_format = cfg.INPUT.FORMAT
56
+ else: # new LazyConfig
57
+ self.cfg = cfg
58
+ self.model = instantiate(cfg.model)
59
+ test_dataset = OmegaConf.select(cfg, "dataloader.test.dataset.names", default=None)
60
+ if isinstance(test_dataset, (list, tuple)):
61
+ test_dataset = test_dataset[0]
62
+
63
+ checkpointer = DetectionCheckpointer(self.model)
64
+ checkpointer.load(OmegaConf.select(cfg, "train.init_checkpoint", default=""))
65
+
66
+ mapper = instantiate(cfg.dataloader.test.mapper)
67
+ self.aug = mapper.augmentations
68
+ self.input_format = mapper.image_format
69
+
70
+ self.model.eval().cuda()
71
+ if test_dataset:
72
+ self.metadata = MetadataCatalog.get(test_dataset)
73
+ assert self.input_format in ["RGB", "BGR"], self.input_format
74
+
75
+ def __call__(self, original_image):
76
+ """
77
+ Args:
78
+ original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
79
+
80
+ Returns:
81
+ predictions (dict):
82
+ the output of the model for one image only.
83
+ See :doc:`/tutorials/models` for details about the format.
84
+ """
85
+ with torch.no_grad():
86
+ if self.input_format == "RGB":
87
+ original_image = original_image[:, :, ::-1]
88
+ height, width = original_image.shape[:2]
89
+ image = self.aug(T.AugInput(original_image)).apply_image(original_image)
90
+ image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
91
+ inputs = {"image": image, "height": height, "width": width}
92
+ predictions = self.model([inputs])[0]
93
+ return predictions
requirements.txt ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # torch
2
+ # pytorch-lightning
3
+ # smplx==0.1.28
4
+ # pyrender
5
+ # opencv-python
6
+ # yacs
7
+ # scikit-image
8
+ # einops
9
+ # timm
10
+ # OmegaConf
11
+
12
+ --extra-index-url https://download.pytorch.org/whl/cu116
13
+ torch==1.13.1+cu116
14
+ torchvision==0.14.1+cu116
15
+ pytorch-lightning
16
+ smplx==0.1.28
17
+ opencv-python
18
+ yacs
19
+ scikit-image
20
+ einops
21
+ timm
22
+ OmegaConf
23
+ trimesh
24
+ pyopengl==3.1.0
25
+ pyglet
26
+ PyOpenGL
27
+ PyOpenGL_accelerate
28
+ numpy==1.23.3
29
+ shapely
setup.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ print('Found packages:', find_packages())
4
+ setup(
5
+ description='HMR2 as a package',
6
+ name='hmr2',
7
+ packages=find_packages()
8
+ )
vendor/detectron2/.circleci/config.yml ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 2.1
2
+
3
+ # -------------------------------------------------------------------------------------
4
+ # Environments to run the jobs in
5
+ # -------------------------------------------------------------------------------------
6
+ cpu: &cpu
7
+ machine:
8
+ image: ubuntu-2004:202107-02
9
+ resource_class: medium
10
+
11
+ gpu: &gpu
12
+ machine:
13
+ # NOTE: use a cuda version that's supported by all our pytorch versions
14
+ image: ubuntu-1604-cuda-11.1:202012-01
15
+ resource_class: gpu.nvidia.small
16
+
17
+ windows-cpu: &windows_cpu
18
+ machine:
19
+ resource_class: windows.medium
20
+ image: windows-server-2019-vs2019:stable
21
+ shell: powershell.exe
22
+
23
+ # windows-gpu: &windows_gpu
24
+ # machine:
25
+ # resource_class: windows.gpu.nvidia.medium
26
+ # image: windows-server-2019-nvidia:stable
27
+
28
+ version_parameters: &version_parameters
29
+ parameters:
30
+ pytorch_version:
31
+ type: string
32
+ torchvision_version:
33
+ type: string
34
+ pytorch_index:
35
+ type: string
36
+ # use test wheels index to have access to RC wheels
37
+ # https://download.pytorch.org/whl/test/torch_test.html
38
+ default: "https://download.pytorch.org/whl/torch_stable.html"
39
+ python_version: # NOTE: only affect linux
40
+ type: string
41
+ default: '3.8.6'
42
+
43
+ environment:
44
+ PYTORCH_VERSION: << parameters.pytorch_version >>
45
+ TORCHVISION_VERSION: << parameters.torchvision_version >>
46
+ PYTORCH_INDEX: << parameters.pytorch_index >>
47
+ PYTHON_VERSION: << parameters.python_version>>
48
+ # point datasets to ~/.torch so it's cached in CI
49
+ DETECTRON2_DATASETS: ~/.torch/datasets
50
+
51
+ # -------------------------------------------------------------------------------------
52
+ # Re-usable commands
53
+ # -------------------------------------------------------------------------------------
54
+ # install_nvidia_driver: &install_nvidia_driver
55
+ # - run:
56
+ # name: Install nvidia driver
57
+ # working_directory: ~/
58
+ # command: |
59
+ # wget -q 'https://s3.amazonaws.com/ossci-linux/nvidia_driver/NVIDIA-Linux-x86_64-430.40.run'
60
+ # sudo /bin/bash ./NVIDIA-Linux-x86_64-430.40.run -s --no-drm
61
+ # nvidia-smi
62
+
63
+ add_ssh_keys: &add_ssh_keys
64
+ # https://circleci.com/docs/2.0/add-ssh-key/
65
+ - add_ssh_keys:
66
+ fingerprints:
67
+ - "e4:13:f2:22:d4:49:e8:e4:57:5a:ac:20:2f:3f:1f:ca"
68
+
69
+ install_python: &install_python
70
+ - run:
71
+ name: Install Python
72
+ working_directory: ~/
73
+ command: |
74
+ # upgrade pyenv
75
+ cd /opt/circleci/.pyenv/plugins/python-build/../.. && git pull && cd -
76
+ pyenv install -s $PYTHON_VERSION
77
+ pyenv global $PYTHON_VERSION
78
+ python --version
79
+ which python
80
+ pip install --upgrade pip
81
+
82
+ setup_venv: &setup_venv
83
+ - run:
84
+ name: Setup Virtual Env
85
+ working_directory: ~/
86
+ command: |
87
+ python -m venv ~/venv
88
+ echo ". ~/venv/bin/activate" >> $BASH_ENV
89
+ . ~/venv/bin/activate
90
+ python --version
91
+ which python
92
+ which pip
93
+ pip install --upgrade pip
94
+
95
+ setup_venv_win: &setup_venv_win
96
+ - run:
97
+ name: Setup Virtual Env for Windows
98
+ command: |
99
+ pip install virtualenv
100
+ python -m virtualenv env
101
+ .\env\Scripts\activate
102
+ python --version
103
+ which python
104
+ which pip
105
+
106
+ install_linux_dep: &install_linux_dep
107
+ - run:
108
+ name: Install Dependencies
109
+ command: |
110
+ # disable crash coredump, so unittests fail fast
111
+ sudo systemctl stop apport.service
112
+ # install from github to get latest; install iopath first since fvcore depends on it
113
+ pip install --progress-bar off -U 'git+https://github.com/facebookresearch/iopath'
114
+ pip install --progress-bar off -U 'git+https://github.com/facebookresearch/fvcore'
115
+ # Don't use pytest-xdist: cuda tests are unstable under multi-process workers.
116
+ # Don't use opencv 4.7.0.68: https://github.com/opencv/opencv-python/issues/765
117
+ pip install --progress-bar off ninja opencv-python-headless!=4.7.0.68 pytest tensorboard pycocotools onnx
118
+ pip install --progress-bar off torch==$PYTORCH_VERSION -f $PYTORCH_INDEX
119
+ if [[ "$TORCHVISION_VERSION" == "master" ]]; then
120
+ pip install git+https://github.com/pytorch/vision.git
121
+ else
122
+ pip install --progress-bar off torchvision==$TORCHVISION_VERSION -f $PYTORCH_INDEX
123
+ fi
124
+
125
+ python -c 'import torch; print("CUDA:", torch.cuda.is_available())'
126
+ gcc --version
127
+
128
+ install_detectron2: &install_detectron2
129
+ - run:
130
+ name: Install Detectron2
131
+ command: |
132
+ # Remove first, in case it's in the CI cache
133
+ pip uninstall -y detectron2
134
+
135
+ pip install --progress-bar off -e .[all]
136
+ python -m detectron2.utils.collect_env
137
+ ./datasets/prepare_for_tests.sh
138
+
139
+ run_unittests: &run_unittests
140
+ - run:
141
+ name: Run Unit Tests
142
+ command: |
143
+ pytest -sv --durations=15 tests # parallel causes some random failures
144
+
145
+ uninstall_tests: &uninstall_tests
146
+ - run:
147
+ name: Run Tests After Uninstalling
148
+ command: |
149
+ pip uninstall -y detectron2
150
+ # Remove built binaries
151
+ rm -rf build/ detectron2/*.so
152
+ # Tests that code is importable without installation
153
+ PYTHONPATH=. ./.circleci/import-tests.sh
154
+
155
+
156
+ # -------------------------------------------------------------------------------------
157
+ # Jobs to run
158
+ # -------------------------------------------------------------------------------------
159
+ jobs:
160
+ linux_cpu_tests:
161
+ <<: *cpu
162
+ <<: *version_parameters
163
+
164
+ working_directory: ~/detectron2
165
+
166
+ steps:
167
+ - checkout
168
+
169
+ # Cache the venv directory that contains python, dependencies, and checkpoints
170
+ # Refresh the key when dependencies should be updated (e.g. when pytorch releases)
171
+ - restore_cache:
172
+ keys:
173
+ - cache-{{ arch }}-<< parameters.pytorch_version >>-{{ .Branch }}-20210827
174
+
175
+ - <<: *install_python
176
+ - <<: *install_linux_dep
177
+ - <<: *install_detectron2
178
+ - <<: *run_unittests
179
+ - <<: *uninstall_tests
180
+
181
+ - save_cache:
182
+ paths:
183
+ - /opt/circleci/.pyenv
184
+ - ~/.torch
185
+ key: cache-{{ arch }}-<< parameters.pytorch_version >>-{{ .Branch }}-20210827
186
+
187
+
188
+ linux_gpu_tests:
189
+ <<: *gpu
190
+ <<: *version_parameters
191
+
192
+ working_directory: ~/detectron2
193
+
194
+ steps:
195
+ - checkout
196
+
197
+ - restore_cache:
198
+ keys:
199
+ - cache-{{ arch }}-<< parameters.pytorch_version >>-{{ .Branch }}-20210827
200
+
201
+ - <<: *install_python
202
+ - <<: *install_linux_dep
203
+ - <<: *install_detectron2
204
+ - <<: *run_unittests
205
+ - <<: *uninstall_tests
206
+
207
+ - save_cache:
208
+ paths:
209
+ - /opt/circleci/.pyenv
210
+ - ~/.torch
211
+ key: cache-{{ arch }}-<< parameters.pytorch_version >>-{{ .Branch }}-20210827
212
+
213
+ windows_cpu_build:
214
+ <<: *windows_cpu
215
+ <<: *version_parameters
216
+ steps:
217
+ - <<: *add_ssh_keys
218
+ - checkout
219
+ - <<: *setup_venv_win
220
+
221
+ # Cache the env directory that contains dependencies
222
+ - restore_cache:
223
+ keys:
224
+ - cache-{{ arch }}-<< parameters.pytorch_version >>-{{ .Branch }}-20210404
225
+
226
+ - run:
227
+ name: Install Dependencies
228
+ command: |
229
+ pip install certifi --ignore-installed # required on windows to workaround some cert issue
230
+ pip install numpy cython # required on windows before pycocotools
231
+ pip install opencv-python-headless pytest-xdist pycocotools tensorboard onnx
232
+ pip install -U git+https://github.com/facebookresearch/iopath
233
+ pip install -U git+https://github.com/facebookresearch/fvcore
234
+ pip install torch==$env:PYTORCH_VERSION torchvision==$env:TORCHVISION_VERSION -f $env:PYTORCH_INDEX
235
+
236
+ - save_cache:
237
+ paths:
238
+ - env
239
+ key: cache-{{ arch }}-<< parameters.pytorch_version >>-{{ .Branch }}-20210404
240
+
241
+ - <<: *install_detectron2
242
+ # TODO: unittest fails for now
243
+
244
+ workflows:
245
+ version: 2
246
+ regular_test:
247
+ jobs:
248
+ - linux_cpu_tests:
249
+ name: linux_cpu_tests_pytorch1.10
250
+ pytorch_version: '1.10.0+cpu'
251
+ torchvision_version: '0.11.1+cpu'
252
+ - linux_gpu_tests:
253
+ name: linux_gpu_tests_pytorch1.8
254
+ pytorch_version: '1.8.1+cu111'
255
+ torchvision_version: '0.9.1+cu111'
256
+ - linux_gpu_tests:
257
+ name: linux_gpu_tests_pytorch1.9
258
+ pytorch_version: '1.9+cu111'
259
+ torchvision_version: '0.10+cu111'
260
+ - linux_gpu_tests:
261
+ name: linux_gpu_tests_pytorch1.10
262
+ pytorch_version: '1.10+cu111'
263
+ torchvision_version: '0.11.1+cu111'
264
+ - linux_gpu_tests:
265
+ name: linux_gpu_tests_pytorch1.10_python39
266
+ pytorch_version: '1.10+cu111'
267
+ torchvision_version: '0.11.1+cu111'
268
+ python_version: '3.9.6'
269
+ - windows_cpu_build:
270
+ pytorch_version: '1.10+cpu'
271
+ torchvision_version: '0.11.1+cpu'
vendor/detectron2/.circleci/import-tests.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -e
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+
4
+ # Test that import works without building detectron2.
5
+
6
+ # Check that _C is not importable
7
+ python -c "from detectron2 import _C" > /dev/null 2>&1 && {
8
+ echo "This test should be run without building detectron2."
9
+ exit 1
10
+ }
11
+
12
+ # Check that other modules are still importable, even when _C is not importable
13
+ python -c "from detectron2 import modeling"
14
+ python -c "from detectron2 import modeling, data"
15
+ python -c "from detectron2 import evaluation, export, checkpoint"
16
+ python -c "from detectron2 import utils, engine"
vendor/detectron2/.clang-format ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AccessModifierOffset: -1
2
+ AlignAfterOpenBracket: AlwaysBreak
3
+ AlignConsecutiveAssignments: false
4
+ AlignConsecutiveDeclarations: false
5
+ AlignEscapedNewlinesLeft: true
6
+ AlignOperands: false
7
+ AlignTrailingComments: false
8
+ AllowAllParametersOfDeclarationOnNextLine: false
9
+ AllowShortBlocksOnASingleLine: false
10
+ AllowShortCaseLabelsOnASingleLine: false
11
+ AllowShortFunctionsOnASingleLine: Empty
12
+ AllowShortIfStatementsOnASingleLine: false
13
+ AllowShortLoopsOnASingleLine: false
14
+ AlwaysBreakAfterReturnType: None
15
+ AlwaysBreakBeforeMultilineStrings: true
16
+ AlwaysBreakTemplateDeclarations: true
17
+ BinPackArguments: false
18
+ BinPackParameters: false
19
+ BraceWrapping:
20
+ AfterClass: false
21
+ AfterControlStatement: false
22
+ AfterEnum: false
23
+ AfterFunction: false
24
+ AfterNamespace: false
25
+ AfterObjCDeclaration: false
26
+ AfterStruct: false
27
+ AfterUnion: false
28
+ BeforeCatch: false
29
+ BeforeElse: false
30
+ IndentBraces: false
31
+ BreakBeforeBinaryOperators: None
32
+ BreakBeforeBraces: Attach
33
+ BreakBeforeTernaryOperators: true
34
+ BreakConstructorInitializersBeforeComma: false
35
+ BreakAfterJavaFieldAnnotations: false
36
+ BreakStringLiterals: false
37
+ ColumnLimit: 80
38
+ CommentPragmas: '^ IWYU pragma:'
39
+ ConstructorInitializerAllOnOneLineOrOnePerLine: true
40
+ ConstructorInitializerIndentWidth: 4
41
+ ContinuationIndentWidth: 4
42
+ Cpp11BracedListStyle: true
43
+ DerivePointerAlignment: false
44
+ DisableFormat: false
45
+ ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ]
46
+ IncludeCategories:
47
+ - Regex: '^<.*\.h(pp)?>'
48
+ Priority: 1
49
+ - Regex: '^<.*'
50
+ Priority: 2
51
+ - Regex: '.*'
52
+ Priority: 3
53
+ IndentCaseLabels: true
54
+ IndentWidth: 2
55
+ IndentWrappedFunctionNames: false
56
+ KeepEmptyLinesAtTheStartOfBlocks: false
57
+ MacroBlockBegin: ''
58
+ MacroBlockEnd: ''
59
+ MaxEmptyLinesToKeep: 1
60
+ NamespaceIndentation: None
61
+ ObjCBlockIndentWidth: 2
62
+ ObjCSpaceAfterProperty: false
63
+ ObjCSpaceBeforeProtocolList: false
64
+ PenaltyBreakBeforeFirstCallParameter: 1
65
+ PenaltyBreakComment: 300
66
+ PenaltyBreakFirstLessLess: 120
67
+ PenaltyBreakString: 1000
68
+ PenaltyExcessCharacter: 1000000
69
+ PenaltyReturnTypeOnItsOwnLine: 200
70
+ PointerAlignment: Left
71
+ ReflowComments: true
72
+ SortIncludes: true
73
+ SpaceAfterCStyleCast: false
74
+ SpaceBeforeAssignmentOperators: true
75
+ SpaceBeforeParens: ControlStatements
76
+ SpaceInEmptyParentheses: false
77
+ SpacesBeforeTrailingComments: 1
78
+ SpacesInAngles: false
79
+ SpacesInContainerLiterals: true
80
+ SpacesInCStyleCastParentheses: false
81
+ SpacesInParentheses: false
82
+ SpacesInSquareBrackets: false
83
+ Standard: Cpp11
84
+ TabWidth: 8
85
+ UseTab: Never
vendor/detectron2/.flake8 ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example .flake8 config, used when developing *Black* itself.
2
+ # Keep in sync with setup.cfg which is used for source packages.
3
+
4
+ [flake8]
5
+ ignore = W503, E203, E221, C901, C408, E741, C407, B017, F811, C101, EXE001, EXE002
6
+ max-line-length = 100
7
+ max-complexity = 18
8
+ select = B,C,E,F,W,T4,B9
9
+ exclude = build
10
+ per-file-ignores =
11
+ **/__init__.py:F401,F403,E402
12
+ **/configs/**.py:F401,E402
13
+ configs/**.py:F401,E402
14
+ **/tests/config/**.py:F401,E402
15
+ tests/config/**.py:F401,E402
vendor/detectron2/.github/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
4
+ Please read the [full text](https://code.fb.com/codeofconduct/)
5
+ so that you can understand what actions will and will not be tolerated.
vendor/detectron2/.github/CONTRIBUTING.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to detectron2
2
+
3
+ ## Issues
4
+ We use GitHub issues to track public bugs and questions.
5
+ Please make sure to follow one of the
6
+ [issue templates](https://github.com/facebookresearch/detectron2/issues/new/choose)
7
+ when reporting any issues.
8
+
9
+ Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
10
+ disclosure of security bugs. In those cases, please go through the process
11
+ outlined on that page and do not file a public issue.
12
+
13
+ ## Pull Requests
14
+ We actively welcome pull requests.
15
+
16
+ However, if you're adding any significant features (e.g. > 50 lines), please
17
+ make sure to discuss with maintainers about your motivation and proposals in an issue
18
+ before sending a PR. This is to save your time so you don't spend time on a PR that we'll not accept.
19
+
20
+ We do not always accept new features, and we take the following
21
+ factors into consideration:
22
+
23
+ 1. Whether the same feature can be achieved without modifying detectron2.
24
+ Detectron2 is designed so that you can implement many extensions from the outside, e.g.
25
+ those in [projects](https://github.com/facebookresearch/detectron2/tree/master/projects).
26
+ * If some part of detectron2 is not extensible enough, you can also bring up a more general issue to
27
+ improve it. Such feature request may be useful to more users.
28
+ 2. Whether the feature is potentially useful to a large audience (e.g. an impactful detection paper, a popular dataset,
29
+ a significant speedup, a widely useful utility),
30
+ or only to a small portion of users (e.g., a less-known paper, an improvement not in the object
31
+ detection field, a trick that's not very popular in the community, code to handle a non-standard type of data)
32
+ * Adoption of additional models, datasets, new task are by default not added to detectron2 before they
33
+ receive significant popularity in the community.
34
+ We sometimes accept such features in `projects/`, or as a link in `projects/README.md`.
35
+ 3. Whether the proposed solution has a good design / interface. This can be discussed in the issue prior to PRs, or
36
+ in the form of a draft PR.
37
+ 4. Whether the proposed solution adds extra mental/practical overhead to users who don't
38
+ need such feature.
39
+ 5. Whether the proposed solution breaks existing APIs.
40
+
41
+ To add a feature to an existing function/class `Func`, there are always two approaches:
42
+ (1) add new arguments to `Func`; (2) write a new `Func_with_new_feature`.
43
+ To meet the above criteria, we often prefer approach (2), because:
44
+
45
+ 1. It does not involve modifying or potentially breaking existing code.
46
+ 2. It does not add overhead to users who do not need the new feature.
47
+ 3. Adding new arguments to a function/class is not scalable w.r.t. all the possible new research ideas in the future.
48
+
49
+ When sending a PR, please do:
50
+
51
+ 1. If a PR contains multiple orthogonal changes, split it to several PRs.
52
+ 2. If you've added code that should be tested, add tests.
53
+ 3. For PRs that need experiments (e.g. adding a new model or new methods),
54
+ you don't need to update model zoo, but do provide experiment results in the description of the PR.
55
+ 4. If APIs are changed, update the documentation.
56
+ 5. We use the [Google style docstrings](https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html) in python.
57
+ 6. Make sure your code lints with `./dev/linter.sh`.
58
+
59
+
60
+ ## Contributor License Agreement ("CLA")
61
+ In order to accept your pull request, we need you to submit a CLA. You only need
62
+ to do this once to work on any of Facebook's open source projects.
63
+
64
+ Complete your CLA here: <https://code.facebook.com/cla>
65
+
66
+ ## License
67
+ By contributing to detectron2, you agree that your contributions will be licensed
68
+ under the LICENSE file in the root directory of this source tree.
vendor/detectron2/.github/Detectron2-Logo-Horz.svg ADDED
vendor/detectron2/.github/ISSUE_TEMPLATE.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ Please select an issue template from
3
+ https://github.com/facebookresearch/detectron2/issues/new/choose .
4
+
5
+ Otherwise your issue will be closed.
vendor/detectron2/.github/ISSUE_TEMPLATE/bugs.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "πŸ› Bugs"
3
+ about: Report bugs in detectron2
4
+ title: Please read & provide the following
5
+
6
+ ---
7
+
8
+ ## Instructions To Reproduce the πŸ› Bug:
9
+ 1. Full runnable code or full changes you made:
10
+ ```
11
+ If making changes to the project itself, please use output of the following command:
12
+ git rev-parse HEAD; git diff
13
+
14
+ <put code or diff here>
15
+ ```
16
+ 2. What exact command you run:
17
+ 3. __Full logs__ or other relevant observations:
18
+ ```
19
+ <put logs here>
20
+ ```
21
+ 4. please simplify the steps as much as possible so they do not require additional resources to
22
+ run, such as a private dataset.
23
+
24
+ ## Expected behavior:
25
+
26
+ If there are no obvious error in "full logs" provided above,
27
+ please tell us the expected behavior.
28
+
29
+ ## Environment:
30
+
31
+ Provide your environment information using the following command:
32
+ ```
33
+ wget -nc -q https://github.com/facebookresearch/detectron2/raw/main/detectron2/utils/collect_env.py && python collect_env.py
34
+ ```
35
+
36
+ If your issue looks like an installation issue / environment issue,
37
+ please first try to solve it yourself with the instructions in
38
+ https://detectron2.readthedocs.io/tutorials/install.html#common-installation-issues
vendor/detectron2/.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # require an issue template to be chosen
2
+ blank_issues_enabled: false
3
+
4
+ contact_links:
5
+ - name: How-To / All Other Questions
6
+ url: https://github.com/facebookresearch/detectron2/discussions
7
+ about: Use "github discussions" for community support on general questions that don't belong to the above issue categories
8
+ - name: Detectron2 Documentation
9
+ url: https://detectron2.readthedocs.io/index.html
10
+ about: Check if your question is answered in tutorials or API docs
11
+
12
+ # Unexpected behaviors & bugs are split to two templates.
13
+ # When they are one template, users think "it's not a bug" and don't choose the template.
14
+ #
15
+ # But the file name is still "unexpected-problems-bugs.md" so that old references
16
+ # to this issue template still works.
17
+ # It's ok since this template should be a superset of "bugs.md" (unexpected behaviors is a superset of bugs)
vendor/detectron2/.github/ISSUE_TEMPLATE/documentation.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "\U0001F4DA Documentation Issue"
3
+ about: Report a problem about existing documentation, comments, website or tutorials.
4
+ labels: documentation
5
+
6
+ ---
7
+
8
+ ## πŸ“š Documentation Issue
9
+
10
+ This issue category is for problems about existing documentation, not for asking how-to questions.
11
+
12
+ * Provide a link to an existing documentation/comment/tutorial:
13
+
14
+ * How should the above documentation/comment/tutorial improve:
vendor/detectron2/.github/ISSUE_TEMPLATE/feature-request.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: "\U0001F680Feature Request"
3
+ about: Suggest an improvement or new feature
4
+ labels: enhancement
5
+
6
+ ---
7
+
8
+ ## πŸš€ Feature
9
+ A clear and concise description of the feature proposal.
10
+
11
+ ## Motivation & Examples
12
+
13
+ Tell us why the feature is useful.
14
+
15
+ Describe what the feature would look like, if it is implemented.
16
+ Best demonstrated using **code examples** in addition to words.
17
+
18
+ ## Note
19
+
20
+ We only consider adding new features if they are relevant to many users.
21
+
22
+ If you request implementation of research papers -- we only consider papers that have enough significance and prevalance in the object detection field.
23
+
24
+ We do not take requests for most projects in the `projects/` directory, because they are research code release that is mainly for other researchers to reproduce results.
25
+
26
+ "Make X faster/accurate" is not a valid feature request. "Implement a concrete feature that can make X faster/accurate" can be a valid feature request.
27
+
28
+ Instead of adding features inside detectron2,
29
+ you can implement many features by [extending detectron2](https://detectron2.readthedocs.io/tutorials/extend.html).
30
+ The [projects/](https://github.com/facebookresearch/detectron2/tree/main/projects/) directory contains many of such examples.
31
+