SondosM commited on
Commit
c038c49
·
verified ·
1 Parent(s): 4f59c55

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. WiLoR/README.md +93 -0
  3. WiLoR/assets/teaser.png +3 -0
  4. WiLoR/demo.py +142 -0
  5. WiLoR/demo_img/test1.jpg +0 -0
  6. WiLoR/demo_img/test2.png +3 -0
  7. WiLoR/demo_img/test3.jpg +0 -0
  8. WiLoR/demo_img/test4.jpg +3 -0
  9. WiLoR/demo_img/test5.jpeg +3 -0
  10. WiLoR/demo_img/test6.jpg +3 -0
  11. WiLoR/demo_img/test7.jpg +0 -0
  12. WiLoR/demo_img/test8.jpg +3 -0
  13. WiLoR/download_videos.py +58 -0
  14. WiLoR/gradio_demo.py +192 -0
  15. WiLoR/license.txt +402 -0
  16. WiLoR/mano_data/mano_mean_params.npz +3 -0
  17. WiLoR/pretrained_models/dataset_config.yaml +58 -0
  18. WiLoR/pretrained_models/model_config.yaml +119 -0
  19. WiLoR/requirements.txt +20 -0
  20. WiLoR/whim/Dataset_instructions.md +31 -0
  21. WiLoR/whim/test_video_ids.json +1 -0
  22. WiLoR/whim/train_video_ids.json +0 -0
  23. WiLoR/wilor/configs/__init__.py +114 -0
  24. WiLoR/wilor/configs/__pycache__/__init__.cpython-311.pyc +0 -0
  25. WiLoR/wilor/datasets/utils.py +994 -0
  26. WiLoR/wilor/datasets/vitdet_dataset.py +95 -0
  27. WiLoR/wilor/models/__init__.py +36 -0
  28. WiLoR/wilor/models/__pycache__/__init__.cpython-311.pyc +0 -0
  29. WiLoR/wilor/models/__pycache__/discriminator.cpython-311.pyc +0 -0
  30. WiLoR/wilor/models/__pycache__/losses.cpython-311.pyc +0 -0
  31. WiLoR/wilor/models/__pycache__/mano_wrapper.cpython-311.pyc +0 -0
  32. WiLoR/wilor/models/__pycache__/wilor.cpython-311.pyc +0 -0
  33. WiLoR/wilor/models/backbones/__init__.py +17 -0
  34. WiLoR/wilor/models/backbones/__pycache__/__init__.cpython-310.pyc +0 -0
  35. WiLoR/wilor/models/backbones/__pycache__/__init__.cpython-311.pyc +0 -0
  36. WiLoR/wilor/models/backbones/__pycache__/vit.cpython-310.pyc +0 -0
  37. WiLoR/wilor/models/backbones/__pycache__/vit.cpython-311.pyc +0 -0
  38. WiLoR/wilor/models/backbones/vit.py +410 -0
  39. WiLoR/wilor/models/discriminator.py +98 -0
  40. WiLoR/wilor/models/heads/__init__.py +1 -0
  41. WiLoR/wilor/models/heads/__pycache__/__init__.cpython-310.pyc +0 -0
  42. WiLoR/wilor/models/heads/__pycache__/__init__.cpython-311.pyc +0 -0
  43. WiLoR/wilor/models/heads/__pycache__/refinement_net.cpython-310.pyc +0 -0
  44. WiLoR/wilor/models/heads/__pycache__/refinement_net.cpython-311.pyc +0 -0
  45. WiLoR/wilor/models/heads/refinement_net.py +204 -0
  46. WiLoR/wilor/models/losses.py +92 -0
  47. WiLoR/wilor/models/mano_wrapper.py +40 -0
  48. WiLoR/wilor/models/wilor.py +376 -0
  49. WiLoR/wilor/utils/__init__.py +25 -0
  50. WiLoR/wilor/utils/__pycache__/__init__.cpython-311.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ WiLoR/assets/teaser.png filter=lfs diff=lfs merge=lfs -text
37
+ WiLoR/demo_img/test2.png filter=lfs diff=lfs merge=lfs -text
38
+ WiLoR/demo_img/test4.jpg filter=lfs diff=lfs merge=lfs -text
39
+ WiLoR/demo_img/test5.jpeg filter=lfs diff=lfs merge=lfs -text
40
+ WiLoR/demo_img/test6.jpg filter=lfs diff=lfs merge=lfs -text
41
+ WiLoR/demo_img/test8.jpg filter=lfs diff=lfs merge=lfs -text
WiLoR/README.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # WiLoR: End-to-end 3D hand localization and reconstruction in-the-wild
4
+
5
+ [Rolandos Alexandros Potamias](https://rolpotamias.github.io)<sup>1</sup> &emsp; [Jinglei Zhang]()<sup>2</sup> &emsp; [Jiankang Deng](https://jiankangdeng.github.io/)<sup>1</sup> &emsp; [Stefanos Zafeiriou](https://www.imperial.ac.uk/people/s.zafeiriou)<sup>1</sup>
6
+
7
+ <sup>1</sup>Imperial College London, UK <br>
8
+ <sup>2</sup>Shanghai Jiao Tong University, China
9
+
10
+ <font color="blue"><strong>CVPR 2025</strong></font>
11
+
12
+ <a href='https://rolpotamias.github.io/WiLoR/'><img src='https://img.shields.io/badge/Project-Page-blue'></a>
13
+ <a href='https://arxiv.org/abs/2409.12259'><img src='https://img.shields.io/badge/Paper-arXiv-red'></a>
14
+ <a href='https://huggingface.co/spaces/rolpotamias/WiLoR'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Demo-green'></a>
15
+ <a href='https://colab.research.google.com/drive/1bNnYFECmJbbvCNZAKtQcxJGxf0DZppsB?usp=sharing'><img src='https://colab.research.google.com/assets/colab-badge.svg'></a>
16
+ </div>
17
+
18
+ <div align="center">
19
+
20
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/wilor-end-to-end-3d-hand-localization-and/3d-hand-pose-estimation-on-freihand)](https://paperswithcode.com/sota/3d-hand-pose-estimation-on-freihand?p=wilor-end-to-end-3d-hand-localization-and)
21
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/wilor-end-to-end-3d-hand-localization-and/3d-hand-pose-estimation-on-ho-3d)](https://paperswithcode.com/sota/3d-hand-pose-estimation-on-ho-3d?p=wilor-end-to-end-3d-hand-localization-and)
22
+
23
+ </div>
24
+
25
+ This is the official implementation of **[WiLoR](https://rolpotamias.github.io/WiLoR/)**, an state-of-the-art hand localization and reconstruction model:
26
+
27
+ ![teaser](assets/teaser.png)
28
+
29
+ ## Installation
30
+ ### [Update] Quick Installation
31
+ Thanks to [@warmshao](https://github.com/warmshao) WiLoR can now be installed using a single pip command:
32
+ ```
33
+ pip install git+https://github.com/warmshao/WiLoR-mini
34
+ ```
35
+ Please head to [WiLoR-mini](https://github.com/warmshao/WiLoR-mini) for additional details.
36
+
37
+ **Note:** the above code is a simplified version of WiLoR and can be used for demo only.
38
+ If you wish to use WiLoR for other tasks it is suggested to follow the original installation instructued bellow:
39
+ ### Original Installation
40
+ ```
41
+ git clone --recursive https://github.com/rolpotamias/WiLoR.git
42
+ cd WiLoR
43
+ ```
44
+
45
+ The code has been tested with PyTorch 2.0.0 and CUDA 11.7. It is suggested to use an anaconda environment to install the the required dependencies:
46
+ ```bash
47
+ conda create --name wilor python=3.10
48
+ conda activate wilor
49
+
50
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu117
51
+ # Install requirements
52
+ pip install -r requirements.txt
53
+ ```
54
+ Download the pretrained models using:
55
+ ```bash
56
+ wget https://huggingface.co/spaces/rolpotamias/WiLoR/resolve/main/pretrained_models/detector.pt -P ./pretrained_models/
57
+ wget https://huggingface.co/spaces/rolpotamias/WiLoR/resolve/main/pretrained_models/wilor_final.ckpt -P ./pretrained_models/
58
+ ```
59
+ It is also required to download MANO model from [MANO website](https://mano.is.tue.mpg.de).
60
+ Create an account by clicking Sign Up and download the models (mano_v*_*.zip). Unzip and place the right hand model `MANO_RIGHT.pkl` under the `mano_data/` folder.
61
+ Note that MANO model falls under the [MANO license](https://mano.is.tue.mpg.de/license.html).
62
+ ## Demo
63
+ ```bash
64
+ python demo.py --img_folder demo_img --out_folder demo_out --save_mesh
65
+ ```
66
+ ## Start a local gradio demo
67
+ You can start a local demo for inference by running:
68
+ ```bash
69
+ python gradio_demo.py
70
+ ```
71
+ ## WHIM Dataset
72
+ To download WHIM dataset please follow the instructions [here](./whim/Dataset_instructions.md)
73
+
74
+ ## Acknowledgements
75
+ Parts of the code are taken or adapted from the following repos:
76
+ - [HaMeR](https://github.com/geopavlakos/hamer/)
77
+ - [Ultralytics](https://github.com/ultralytics/ultralytics)
78
+
79
+ ## License
80
+ WiLoR models fall under the [CC-BY-NC--ND License](./license.txt). This repository depends also on [Ultralytics library](https://github.com/ultralytics/ultralytics) and [MANO Model](https://mano.is.tue.mpg.de/license.html), which are fall under their own licenses. By using this repository, you must also comply with the terms of these external licenses.
81
+ ## Citing
82
+ If you find WiLoR useful for your research, please consider citing our paper:
83
+
84
+ ```bibtex
85
+ @misc{potamias2024wilor,
86
+ title={WiLoR: End-to-end 3D Hand Localization and Reconstruction in-the-wild},
87
+ author={Rolandos Alexandros Potamias and Jinglei Zhang and Jiankang Deng and Stefanos Zafeiriou},
88
+ year={2024},
89
+ eprint={2409.12259},
90
+ archivePrefix={arXiv},
91
+ primaryClass={cs.CV}
92
+ }
93
+ ```
WiLoR/assets/teaser.png ADDED

Git LFS Details

  • SHA256: d5f07ada2f470af0619716c0ce4f60d9dfd3da1673d06c28c97d85abb84eadc0
  • Pointer size: 132 Bytes
  • Size of remote file: 9.21 MB
WiLoR/demo.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ import argparse
4
+ import os
5
+ import cv2
6
+ import numpy as np
7
+ import json
8
+ from typing import Dict, Optional
9
+
10
+ from wilor.models import WiLoR, load_wilor
11
+ from wilor.utils import recursive_to
12
+ from wilor.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
13
+ from wilor.utils.renderer import Renderer, cam_crop_to_full
14
+ from ultralytics import YOLO
15
+ LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353)
16
+
17
+ def main():
18
+ parser = argparse.ArgumentParser(description='WiLoR demo code')
19
+ parser.add_argument('--img_folder', type=str, default='images', help='Folder with input images')
20
+ parser.add_argument('--out_folder', type=str, default='out_demo', help='Output folder to save rendered results')
21
+ parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, help='If set, save meshes to disk also')
22
+ parser.add_argument('--rescale_factor', type=float, default=2.0, help='Factor for padding the bbox')
23
+ parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg'], help='List of file extensions to consider')
24
+
25
+ args = parser.parse_args()
26
+
27
+ # Download and load checkpoints
28
+ model, model_cfg = load_wilor(checkpoint_path = './pretrained_models/wilor_final.ckpt' , cfg_path= './pretrained_models/model_config.yaml')
29
+ detector = YOLO('./pretrained_models/detector.pt')
30
+ # Setup the renderer
31
+ renderer = Renderer(model_cfg, faces=model.mano.faces)
32
+ renderer_side = Renderer(model_cfg, faces=model.mano.faces)
33
+
34
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
35
+ model = model.to(device)
36
+ detector = detector.to(device)
37
+ model.eval()
38
+
39
+ # Make output directory if it does not exist
40
+ os.makedirs(args.out_folder, exist_ok=True)
41
+
42
+ # Get all demo images ends with .jpg or .png
43
+ img_paths = [img for end in args.file_type for img in Path(args.img_folder).glob(end)]
44
+ # Iterate over all images in folder
45
+ for img_path in img_paths:
46
+ img_cv2 = cv2.imread(str(img_path))
47
+ detections = detector(img_cv2, conf = 0.3, verbose=False)[0]
48
+ bboxes = []
49
+ is_right = []
50
+ for det in detections:
51
+ Bbox = det.boxes.data.cpu().detach().squeeze().numpy()
52
+ is_right.append(det.boxes.cls.cpu().detach().squeeze().item())
53
+ bboxes.append(Bbox[:4].tolist())
54
+
55
+ if len(bboxes) == 0:
56
+ continue
57
+ boxes = np.stack(bboxes)
58
+ right = np.stack(is_right)
59
+ dataset = ViTDetDataset(model_cfg, img_cv2, boxes, right, rescale_factor=args.rescale_factor)
60
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)
61
+
62
+ all_verts = []
63
+ all_cam_t = []
64
+ all_right = []
65
+ all_joints= []
66
+ all_kpts = []
67
+
68
+ for batch in dataloader:
69
+ batch = recursive_to(batch, device)
70
+
71
+ with torch.no_grad():
72
+ out = model(batch)
73
+
74
+ multiplier = (2*batch['right']-1)
75
+ pred_cam = out['pred_cam']
76
+ pred_cam[:,1] = multiplier*pred_cam[:,1]
77
+ box_center = batch["box_center"].float()
78
+ box_size = batch["box_size"].float()
79
+ img_size = batch["img_size"].float()
80
+ scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max()
81
+ pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size, scaled_focal_length).detach().cpu().numpy()
82
+
83
+
84
+ # Render the result
85
+ batch_size = batch['img'].shape[0]
86
+ for n in range(batch_size):
87
+ # Get filename from path img_path
88
+ img_fn, _ = os.path.splitext(os.path.basename(img_path))
89
+
90
+ verts = out['pred_vertices'][n].detach().cpu().numpy()
91
+ joints = out['pred_keypoints_3d'][n].detach().cpu().numpy()
92
+
93
+ is_right = batch['right'][n].cpu().numpy()
94
+ verts[:,0] = (2*is_right-1)*verts[:,0]
95
+ joints[:,0] = (2*is_right-1)*joints[:,0]
96
+ cam_t = pred_cam_t_full[n]
97
+ kpts_2d = project_full_img(verts, cam_t, scaled_focal_length, img_size[n])
98
+
99
+ all_verts.append(verts)
100
+ all_cam_t.append(cam_t)
101
+ all_right.append(is_right)
102
+ all_joints.append(joints)
103
+ all_kpts.append(kpts_2d)
104
+
105
+
106
+ # Save all meshes to disk
107
+ if args.save_mesh:
108
+ camera_translation = cam_t.copy()
109
+ tmesh = renderer.vertices_to_trimesh(verts, camera_translation, LIGHT_PURPLE, is_right=is_right)
110
+ tmesh.export(os.path.join(args.out_folder, f'{img_fn}_{n}.obj'))
111
+
112
+ # Render front view
113
+ if len(all_verts) > 0:
114
+ misc_args = dict(
115
+ mesh_base_color=LIGHT_PURPLE,
116
+ scene_bg_color=(1, 1, 1),
117
+ focal_length=scaled_focal_length,
118
+ )
119
+ cam_view = renderer.render_rgba_multiple(all_verts, cam_t=all_cam_t, render_res=img_size[n], is_right=all_right, **misc_args)
120
+
121
+ # Overlay image
122
+ input_img = img_cv2.astype(np.float32)[:,:,::-1]/255.0
123
+ input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
124
+ input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
125
+
126
+ cv2.imwrite(os.path.join(args.out_folder, f'{img_fn}.jpg'), 255*input_img_overlay[:, :, ::-1])
127
+
128
+ def project_full_img(points, cam_trans, focal_length, img_res):
129
+ camera_center = [img_res[0] / 2., img_res[1] / 2.]
130
+ K = torch.eye(3)
131
+ K[0,0] = focal_length
132
+ K[1,1] = focal_length
133
+ K[0,2] = camera_center[0]
134
+ K[1,2] = camera_center[1]
135
+ points = points + cam_trans
136
+ points = points / points[..., -1:]
137
+
138
+ V_2d = (K @ points.T).T
139
+ return V_2d[..., :-1]
140
+
141
+ if __name__ == '__main__':
142
+ main()
WiLoR/demo_img/test1.jpg ADDED
WiLoR/demo_img/test2.png ADDED

Git LFS Details

  • SHA256: 589f5d12593acbcbcb9ec07b288b04f6d7e70542e1312ceee3ea992ba0f41ff9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
WiLoR/demo_img/test3.jpg ADDED
WiLoR/demo_img/test4.jpg ADDED

Git LFS Details

  • SHA256: efb16543caa936aa671ad1cb28ca2c6129ba8cba58d08476ed9538fd12de9265
  • Pointer size: 131 Bytes
  • Size of remote file: 315 kB
WiLoR/demo_img/test5.jpeg ADDED

Git LFS Details

  • SHA256: 84d161aa4f1a335ec3971c5d050338e7c13b9e3c90231c0de7e677094a172eae
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
WiLoR/demo_img/test6.jpg ADDED

Git LFS Details

  • SHA256: 617a3a3d04a1e17e4285dab5bca2003080923df66953df93c85ddfdaa383e8f5
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
WiLoR/demo_img/test7.jpg ADDED
WiLoR/demo_img/test8.jpg ADDED

Git LFS Details

  • SHA256: 886ef1a8981bef175707353b2adea60168657a926c1dd5a95789c4907d881907
  • Pointer size: 131 Bytes
  • Size of remote file: 398 kB
WiLoR/download_videos.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import argparse
5
+ from pytubefix import YouTube
6
+
7
+ parser = argparse.ArgumentParser()
8
+
9
+ parser.add_argument("--root", type=str, help="Directory of WiLoR")
10
+ parser.add_argument("--mode", type=str, choices=['train', 'test'], default= 'train', help="Train/Test set")
11
+
12
+ args = parser.parse_args()
13
+
14
+ with open(os.path.join(args.root, f'./whim/{args.mode}_video_ids.json')) as f:
15
+ video_dict = json.load(f)
16
+
17
+ Video_IDs = video_dict.keys()
18
+ failed_IDs = []
19
+ os.makedirs(os.path.join(args.root, 'Videos'), exist_ok=True)
20
+
21
+ for Video_ID in Video_IDs:
22
+ res = video_dict[Video_ID]['res'][0]
23
+ try:
24
+ YouTube('https://youtu.be/'+Video_ID).streams.filter(only_video=True,
25
+ file_extension='mp4',
26
+ res =f'{res}p'
27
+ ).order_by('resolution').desc().first().download(
28
+ output_path=os.path.join(args.root, 'Videos') ,
29
+ filename = Video_ID +'.mp4')
30
+ except:
31
+ print(f'Failed {Video_ID}')
32
+ failed_IDs.append(Video_ID)
33
+ continue
34
+
35
+
36
+ cap = cv2.VideoCapture(os.path.join(args.root, 'Videos', Video_ID + '.mp4'))
37
+ if (cap.isOpened()== False):
38
+ print(f"Error opening video stream {os.path.join(args.root, 'Videos', Video_ID + '.mp4')}")
39
+
40
+ VIDEO_LEN = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
41
+ length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
42
+ fps = cap.get(cv2.CAP_PROP_FPS)
43
+
44
+ fps_org = video_dict[Video_ID]['fps']
45
+ fps_rate = round(fps / fps_org)
46
+
47
+ all_frames = os.listdir(os.path.join(args.root, 'WHIM', args.mode, 'anno', Video_ID))
48
+
49
+ for frame in all_frames:
50
+ frame_gt = int(frame[:-4])
51
+ frame_idx = (frame_gt * fps_rate)
52
+
53
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
54
+ ret, img_cv2 = cap.read()
55
+
56
+ cv2.imwrite(os.path.join(args.root, 'WHIM', args.mode, 'anno', Video_ID, frame +'.jpg' ), img_cv2.astype(np.float32))
57
+
58
+ np.save(os.path.join(args.root, 'failed_videos.npy'), failed_IDs)
WiLoR/gradio_demo.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ os.environ["PYOPENGL_PLATFORM"] = "egl"
4
+ os.environ["MESA_GL_VERSION_OVERRIDE"] = "4.1"
5
+ # os.system('pip install /home/user/app/pyrender')
6
+ # sys.path.append('/home/user/app/pyrender')
7
+
8
+ import gradio as gr
9
+ #import spaces
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ from ultralytics import YOLO
14
+ from pathlib import Path
15
+ import argparse
16
+ import json
17
+ from typing import Dict, Optional
18
+
19
+ from wilor.models import WiLoR, load_wilor
20
+ from wilor.utils import recursive_to
21
+ from wilor.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
22
+ from wilor.utils.renderer import Renderer, cam_crop_to_full
23
+ device = torch.device('cpu') if torch.cuda.is_available() else torch.device('cuda')
24
+
25
+ LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353)
26
+
27
+ model, model_cfg = load_wilor(checkpoint_path = './pretrained_models/wilor_final.ckpt' , cfg_path= './pretrained_models/model_config.yaml')
28
+ # Setup the renderer
29
+ renderer = Renderer(model_cfg, faces=model.mano.faces)
30
+ model = model.to(device)
31
+ model.eval()
32
+
33
+ detector = YOLO(f'./pretrained_models/detector.pt').to(device)
34
+
35
+ def render_reconstruction(image, conf, IoU_threshold=0.3):
36
+ input_img, num_dets, reconstructions = run_wilow_model(image, conf, IoU_threshold=0.5)
37
+ if num_dets> 0:
38
+ # Render front view
39
+
40
+ misc_args = dict(
41
+ mesh_base_color=LIGHT_PURPLE,
42
+ scene_bg_color=(1, 1, 1),
43
+ focal_length=reconstructions['focal'],
44
+ )
45
+
46
+ cam_view = renderer.render_rgba_multiple(reconstructions['verts'],
47
+ cam_t=reconstructions['cam_t'],
48
+ render_res=reconstructions['img_size'],
49
+ is_right=reconstructions['right'], **misc_args)
50
+
51
+ # Overlay image
52
+
53
+ input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
54
+ input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
55
+
56
+ return input_img_overlay, f'{num_dets} hands detected'
57
+ else:
58
+ return input_img, f'{num_dets} hands detected'
59
+
60
+ #@spaces.GPU()
61
+ def run_wilow_model(image, conf, IoU_threshold=0.5):
62
+ img_cv2 = image[...,::-1]
63
+ img_vis = image.copy()
64
+
65
+ detections = detector(img_cv2, conf=conf, verbose=False, iou=IoU_threshold)[0]
66
+
67
+ bboxes = []
68
+ is_right = []
69
+ for det in detections:
70
+ Bbox = det.boxes.data.cpu().detach().squeeze().numpy()
71
+ Conf = det.boxes.conf.data.cpu().detach()[0].numpy().reshape(-1).astype(np.float16)
72
+ Side = det.boxes.cls.data.cpu().detach()
73
+ #Bbox[:2] -= np.int32(0.1 * Bbox[:2])
74
+ #Bbox[2:] += np.int32(0.1 * Bbox[ 2:])
75
+ is_right.append(det.boxes.cls.cpu().detach().squeeze().item())
76
+ bboxes.append(Bbox[:4].tolist())
77
+
78
+ color = (255*0.208, 255*0.647 ,255*0.603 ) if Side==0. else (255*1, 255*0.78039, 255*0.2353)
79
+ label = f'L - {Conf[0]:.3f}' if Side==0 else f'R - {Conf[0]:.3f}'
80
+
81
+ cv2.rectangle(img_vis, (int(Bbox[0]), int(Bbox[1])), (int(Bbox[2]), int(Bbox[3])), color , 3)
82
+ (w, h), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
83
+ cv2.rectangle(img_vis, (int(Bbox[0]), int(Bbox[1]) - 20), (int(Bbox[0]) + w, int(Bbox[1])), color, -1)
84
+ cv2.putText(img_vis, label, (int(Bbox[0]), int(Bbox[1]) - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2)
85
+
86
+ if len(bboxes) != 0:
87
+ boxes = np.stack(bboxes)
88
+ right = np.stack(is_right)
89
+ dataset = ViTDetDataset(model_cfg, img_cv2, boxes, right, rescale_factor=2.0 )
90
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0)
91
+
92
+ all_verts = []
93
+ all_cam_t = []
94
+ all_right = []
95
+ all_joints= []
96
+
97
+ for batch in dataloader:
98
+ batch = recursive_to(batch, device)
99
+
100
+ with torch.no_grad():
101
+ out = model(batch)
102
+
103
+ multiplier = (2*batch['right']-1)
104
+ pred_cam = out['pred_cam']
105
+ pred_cam[:,1] = multiplier*pred_cam[:,1]
106
+ box_center = batch["box_center"].float()
107
+ box_size = batch["box_size"].float()
108
+ img_size = batch["img_size"].float()
109
+ scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max()
110
+ pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size, scaled_focal_length).detach().cpu().numpy()
111
+
112
+
113
+ batch_size = batch['img'].shape[0]
114
+ for n in range(batch_size):
115
+
116
+ verts = out['pred_vertices'][n].detach().cpu().numpy()
117
+ joints = out['pred_keypoints_3d'][n].detach().cpu().numpy()
118
+
119
+ is_right = batch['right'][n].cpu().numpy()
120
+ verts[:,0] = (2*is_right-1)*verts[:,0]
121
+ joints[:,0] = (2*is_right-1)*joints[:,0]
122
+
123
+ cam_t = pred_cam_t_full[n]
124
+
125
+ all_verts.append(verts)
126
+ all_cam_t.append(cam_t)
127
+ all_right.append(is_right)
128
+ all_joints.append(joints)
129
+
130
+ reconstructions = {'verts': all_verts, 'cam_t': all_cam_t, 'right': all_right, 'img_size': img_size[n], 'focal': scaled_focal_length}
131
+ return img_vis.astype(np.float32)/255.0, len(detections), reconstructions
132
+ else:
133
+ return img_vis.astype(np.float32)/255.0, len(detections), None
134
+
135
+
136
+
137
+ header = ('''
138
+ <div class="embed_hidden" style="text-align: center;">
139
+ <h1> <b>WiLoR</b>: End-to-end 3D hand localization and reconstruction in-the-wild</h1>
140
+ <h3>
141
+ <a href="https://rolpotamias.github.io" target="_blank" rel="noopener noreferrer">Rolandos Alexandros Potamias</a><sup>1</sup>,
142
+ <a href="" target="_blank" rel="noopener noreferrer">Jinglei Zhang</a><sup>2</sup>,
143
+ <br>
144
+ <a href="https://jiankangdeng.github.io/" target="_blank" rel="noopener noreferrer">Jiankang Deng</a><sup>1</sup>,
145
+ <a href="https://wp.doc.ic.ac.uk/szafeiri/" target="_blank" rel="noopener noreferrer">Stefanos Zafeiriou</a><sup>1</sup>
146
+ </h3>
147
+ <h3>
148
+ <sup>1</sup>Imperial College London;
149
+ <sup>2</sup>Shanghai Jiao Tong University
150
+ </h3>
151
+ </div>
152
+ <div style="display:flex; gap: 0.3rem; justify-content: center; align-items: center;" align="center">
153
+ <a href=''><img src='https://img.shields.io/badge/Arxiv-......-A42C25?style=flat&logo=arXiv&logoColor=A42C25'></a>
154
+ <a href='https://rolpotamias.github.io/pdfs/WiLoR.pdf'><img src='https://img.shields.io/badge/Paper-PDF-yellow?style=flat&logo=arXiv&logoColor=yellow'></a>
155
+ <a href='https://rolpotamias.github.io/WiLoR/'><img src='https://img.shields.io/badge/Project-Page-%23df5b46?style=flat&logo=Google%20chrome&logoColor=%23df5b46'></a>
156
+ <a href='https://github.com/rolpotamias/WiLoR'><img src='https://img.shields.io/badge/GitHub-Code-black?style=flat&logo=github&logoColor=white'></a>
157
+ ''')
158
+
159
+
160
+ with gr.Blocks(title="WiLoR: End-to-end 3D hand localization and reconstruction in-the-wild", css=".gradio-container") as demo:
161
+
162
+ gr.Markdown(header)
163
+
164
+ with gr.Row():
165
+ with gr.Column():
166
+ input_image = gr.Image(label="Input image", type="numpy")
167
+ threshold = gr.Slider(value=0.3, minimum=0.05, maximum=0.95, step=0.05, label='Detection Confidence Threshold')
168
+ #nms = gr.Slider(value=0.5, minimum=0.05, maximum=0.95, step=0.05, label='IoU NMS Threshold')
169
+ submit = gr.Button("Submit", variant="primary")
170
+
171
+
172
+ with gr.Column():
173
+ reconstruction = gr.Image(label="Reconstructions", type="numpy")
174
+ hands_detected = gr.Textbox(label="Hands Detected")
175
+
176
+ submit.click(fn=render_reconstruction, inputs=[input_image, threshold], outputs=[reconstruction, hands_detected])
177
+
178
+ with gr.Row():
179
+ example_images = gr.Examples([
180
+
181
+ ['./demo_img/test1.jpg'],
182
+ ['./demo_img/test2.png'],
183
+ ['./demo_img/test3.jpg'],
184
+ ['./demo_img/test4.jpg'],
185
+ ['./demo_img/test5.jpeg'],
186
+ ['./demo_img/test6.jpg'],
187
+ ['./demo_img/test7.jpg'],
188
+ ['./demo_img/test8.jpg'],
189
+ ],
190
+ inputs=input_image)
191
+
192
+ demo.launch()
WiLoR/license.txt ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial-NoDerivatives 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial-NoDerivatives 4.0
58
+ International Public License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial-NoDerivatives 4.0 International Public
63
+ License ("Public License"). To the extent this Public License may be
64
+ interpreted as a contract, You are granted the Licensed Rights in
65
+ consideration of Your acceptance of these terms and conditions, and the
66
+ Licensor grants You such rights in consideration of benefits the
67
+ Licensor receives from making the Licensed Material available under
68
+ these terms and conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Copyright and Similar Rights means copyright and/or similar rights
84
+ closely related to copyright including, without limitation,
85
+ performance, broadcast, sound recording, and Sui Generis Database
86
+ Rights, without regard to how the rights are labeled or
87
+ categorized. For purposes of this Public License, the rights
88
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
89
+ Rights.
90
+
91
+ c. Effective Technological Measures means those measures that, in the
92
+ absence of proper authority, may not be circumvented under laws
93
+ fulfilling obligations under Article 11 of the WIPO Copyright
94
+ Treaty adopted on December 20, 1996, and/or similar international
95
+ agreements.
96
+
97
+ d. Exceptions and Limitations means fair use, fair dealing, and/or
98
+ any other exception or limitation to Copyright and Similar Rights
99
+ that applies to Your use of the Licensed Material.
100
+
101
+ e. Licensed Material means the artistic or literary work, database,
102
+ or other material to which the Licensor applied this Public
103
+ License.
104
+
105
+ f. Licensed Rights means the rights granted to You subject to the
106
+ terms and conditions of this Public License, which are limited to
107
+ all Copyright and Similar Rights that apply to Your use of the
108
+ Licensed Material and that the Licensor has authority to license.
109
+
110
+ g. Licensor means the individual(s) or entity(ies) granting rights
111
+ under this Public License.
112
+
113
+ h. NonCommercial means not primarily intended for or directed towards
114
+ commercial advantage or monetary compensation. For purposes of
115
+ this Public License, the exchange of the Licensed Material for
116
+ other material subject to Copyright and Similar Rights by digital
117
+ file-sharing or similar means is NonCommercial provided there is
118
+ no payment of monetary compensation in connection with the
119
+ exchange.
120
+
121
+ i. Share means to provide material to the public by any means or
122
+ process that requires permission under the Licensed Rights, such
123
+ as reproduction, public display, public performance, distribution,
124
+ dissemination, communication, or importation, and to make material
125
+ available to the public including in ways that members of the
126
+ public may access the material from a place and at a time
127
+ individually chosen by them.
128
+
129
+ j. Sui Generis Database Rights means rights other than copyright
130
+ resulting from Directive 96/9/EC of the European Parliament and of
131
+ the Council of 11 March 1996 on the legal protection of databases,
132
+ as amended and/or succeeded, as well as other essentially
133
+ equivalent rights anywhere in the world.
134
+
135
+ k. You means the individual or entity exercising the Licensed Rights
136
+ under this Public License. Your has a corresponding meaning.
137
+
138
+
139
+ Section 2 -- Scope.
140
+
141
+ a. License grant.
142
+
143
+ 1. Subject to the terms and conditions of this Public License,
144
+ the Licensor hereby grants You a worldwide, royalty-free,
145
+ non-sublicensable, non-exclusive, irrevocable license to
146
+ exercise the Licensed Rights in the Licensed Material to:
147
+
148
+ a. reproduce and Share the Licensed Material, in whole or
149
+ in part, for NonCommercial purposes only; and
150
+
151
+ b. produce and reproduce, but not Share, Adapted Material
152
+ for NonCommercial purposes only.
153
+
154
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
155
+ Exceptions and Limitations apply to Your use, this Public
156
+ License does not apply, and You do not need to comply with
157
+ its terms and conditions.
158
+
159
+ 3. Term. The term of this Public License is specified in Section
160
+ 6(a).
161
+
162
+ 4. Media and formats; technical modifications allowed. The
163
+ Licensor authorizes You to exercise the Licensed Rights in
164
+ all media and formats whether now known or hereafter created,
165
+ and to make technical modifications necessary to do so. The
166
+ Licensor waives and/or agrees not to assert any right or
167
+ authority to forbid You from making technical modifications
168
+ necessary to exercise the Licensed Rights, including
169
+ technical modifications necessary to circumvent Effective
170
+ Technological Measures. For purposes of this Public License,
171
+ simply making modifications authorized by this Section 2(a)
172
+ (4) never produces Adapted Material.
173
+
174
+ 5. Downstream recipients.
175
+
176
+ a. Offer from the Licensor -- Licensed Material. Every
177
+ recipient of the Licensed Material automatically
178
+ receives an offer from the Licensor to exercise the
179
+ Licensed Rights under the terms and conditions of this
180
+ Public License.
181
+
182
+ b. No downstream restrictions. You may not offer or impose
183
+ any additional or different terms or conditions on, or
184
+ apply any Effective Technological Measures to, the
185
+ Licensed Material if doing so restricts exercise of the
186
+ Licensed Rights by any recipient of the Licensed
187
+ Material.
188
+
189
+ 6. No endorsement. Nothing in this Public License constitutes or
190
+ may be construed as permission to assert or imply that You
191
+ are, or that Your use of the Licensed Material is, connected
192
+ with, or sponsored, endorsed, or granted official status by,
193
+ the Licensor or others designated to receive attribution as
194
+ provided in Section 3(a)(1)(A)(i).
195
+
196
+ b. Other rights.
197
+
198
+ 1. Moral rights, such as the right of integrity, are not
199
+ licensed under this Public License, nor are publicity,
200
+ privacy, and/or other similar personality rights; however, to
201
+ the extent possible, the Licensor waives and/or agrees not to
202
+ assert any such rights held by the Licensor to the limited
203
+ extent necessary to allow You to exercise the Licensed
204
+ Rights, but not otherwise.
205
+
206
+ 2. Patent and trademark rights are not licensed under this
207
+ Public License.
208
+
209
+ 3. To the extent possible, the Licensor waives any right to
210
+ collect royalties from You for the exercise of the Licensed
211
+ Rights, whether directly or through a collecting society
212
+ under any voluntary or waivable statutory or compulsory
213
+ licensing scheme. In all other cases the Licensor expressly
214
+ reserves any right to collect such royalties, including when
215
+ the Licensed Material is used other than for NonCommercial
216
+ purposes.
217
+
218
+
219
+ Section 3 -- License Conditions.
220
+
221
+ Your exercise of the Licensed Rights is expressly made subject to the
222
+ following conditions.
223
+
224
+ a. Attribution.
225
+
226
+ 1. If You Share the Licensed Material, You must:
227
+
228
+ a. retain the following if it is supplied by the Licensor
229
+ with the Licensed Material:
230
+
231
+ i. identification of the creator(s) of the Licensed
232
+ Material and any others designated to receive
233
+ attribution, in any reasonable manner requested by
234
+ the Licensor (including by pseudonym if
235
+ designated);
236
+
237
+ ii. a copyright notice;
238
+
239
+ iii. a notice that refers to this Public License;
240
+
241
+ iv. a notice that refers to the disclaimer of
242
+ warranties;
243
+
244
+ v. a URI or hyperlink to the Licensed Material to the
245
+ extent reasonably practicable;
246
+
247
+ b. indicate if You modified the Licensed Material and
248
+ retain an indication of any previous modifications; and
249
+
250
+ c. indicate the Licensed Material is licensed under this
251
+ Public License, and include the text of, or the URI or
252
+ hyperlink to, this Public License.
253
+
254
+ For the avoidance of doubt, You do not have permission under
255
+ this Public License to Share Adapted Material.
256
+
257
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
258
+ reasonable manner based on the medium, means, and context in
259
+ which You Share the Licensed Material. For example, it may be
260
+ reasonable to satisfy the conditions by providing a URI or
261
+ hyperlink to a resource that includes the required
262
+ information.
263
+
264
+ 3. If requested by the Licensor, You must remove any of the
265
+ information required by Section 3(a)(1)(A) to the extent
266
+ reasonably practicable.
267
+
268
+
269
+ Section 4 -- Sui Generis Database Rights.
270
+
271
+ Where the Licensed Rights include Sui Generis Database Rights that
272
+ apply to Your use of the Licensed Material:
273
+
274
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
275
+ to extract, reuse, reproduce, and Share all or a substantial
276
+ portion of the contents of the database for NonCommercial purposes
277
+ only and provided You do not Share Adapted Material;
278
+
279
+ b. if You include all or a substantial portion of the database
280
+ contents in a database in which You have Sui Generis Database
281
+ Rights, then the database in which You have Sui Generis Database
282
+ Rights (but not its individual contents) is Adapted Material; and
283
+
284
+ c. You must comply with the conditions in Section 3(a) if You Share
285
+ all or a substantial portion of the contents of the database.
286
+
287
+ For the avoidance of doubt, this Section 4 supplements and does not
288
+ replace Your obligations under this Public License where the Licensed
289
+ Rights include other Copyright and Similar Rights.
290
+
291
+
292
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
+
294
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
+
305
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
+
315
+ c. The disclaimer of warranties and limitation of liability provided
316
+ above shall be interpreted in a manner that, to the extent
317
+ possible, most closely approximates an absolute disclaimer and
318
+ waiver of all liability.
319
+
320
+
321
+ Section 6 -- Term and Termination.
322
+
323
+ a. This Public License applies for the term of the Copyright and
324
+ Similar Rights licensed here. However, if You fail to comply with
325
+ this Public License, then Your rights under this Public License
326
+ terminate automatically.
327
+
328
+ b. Where Your right to use the Licensed Material has terminated under
329
+ Section 6(a), it reinstates:
330
+
331
+ 1. automatically as of the date the violation is cured, provided
332
+ it is cured within 30 days of Your discovery of the
333
+ violation; or
334
+
335
+ 2. upon express reinstatement by the Licensor.
336
+
337
+ For the avoidance of doubt, this Section 6(b) does not affect any
338
+ right the Licensor may have to seek remedies for Your violations
339
+ of this Public License.
340
+
341
+ c. For the avoidance of doubt, the Licensor may also offer the
342
+ Licensed Material under separate terms or conditions or stop
343
+ distributing the Licensed Material at any time; however, doing so
344
+ will not terminate this Public License.
345
+
346
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
347
+ License.
348
+
349
+
350
+ Section 7 -- Other Terms and Conditions.
351
+
352
+ a. The Licensor shall not be bound by any additional or different
353
+ terms or conditions communicated by You unless expressly agreed.
354
+
355
+ b. Any arrangements, understandings, or agreements regarding the
356
+ Licensed Material not stated herein are separate from and
357
+ independent of the terms and conditions of this Public License.
358
+
359
+
360
+ Section 8 -- Interpretation.
361
+
362
+ a. For the avoidance of doubt, this Public License does not, and
363
+ shall not be interpreted to, reduce, limit, restrict, or impose
364
+ conditions on any use of the Licensed Material that could lawfully
365
+ be made without permission under this Public License.
366
+
367
+ b. To the extent possible, if any provision of this Public License is
368
+ deemed unenforceable, it shall be automatically reformed to the
369
+ minimum extent necessary to make it enforceable. If the provision
370
+ cannot be reformed, it shall be severed from this Public License
371
+ without affecting the enforceability of the remaining terms and
372
+ conditions.
373
+
374
+ c. No term or condition of this Public License will be waived and no
375
+ failure to comply consented to unless expressly agreed to by the
376
+ Licensor.
377
+
378
+ d. Nothing in this Public License constitutes or may be interpreted
379
+ as a limitation upon, or waiver of, any privileges and immunities
380
+ that apply to the Licensor or You, including from the legal
381
+ processes of any jurisdiction or authority.
382
+
383
+ =======================================================================
384
+
385
+ Creative Commons is not a party to its public
386
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
387
+ its public licenses to material it publishes and in those instances
388
+ will be considered the “Licensor.” The text of the Creative Commons
389
+ public licenses is dedicated to the public domain under the CC0 Public
390
+ Domain Dedication. Except for the limited purpose of indicating that
391
+ material is shared under a Creative Commons public license or as
392
+ otherwise permitted by the Creative Commons policies published at
393
+ creativecommons.org/policies, Creative Commons does not authorize the
394
+ use of the trademark "Creative Commons" or any other trademark or logo
395
+ of Creative Commons without its prior written consent including,
396
+ without limitation, in connection with any unauthorized modifications
397
+ to any of its public licenses or any other arrangements,
398
+ understandings, or agreements concerning use of licensed material. For
399
+ the avoidance of doubt, this paragraph does not form part of the
400
+ public licenses.
401
+
402
+ Creative Commons may be contacted at creativecommons.org.
WiLoR/mano_data/mano_mean_params.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:efc0ec58e4a5cef78f3abfb4e8f91623b8950be9eff8b8e0dbb0d036ebc63988
3
+ size 1178
WiLoR/pretrained_models/dataset_config.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARCTIC-TRAIN:
2
+ TYPE: ImageDataset
3
+ URLS: wilor_training_data/dataset_tars/arctic-train/{000000..000176}.tar
4
+ epoch_size: 177000
5
+ BEDLAM-TRAIN:
6
+ TYPE: ImageDataset
7
+ URLS: wilor_training_data/dataset_tars/bedlam-train/{000000..000300}.tar
8
+ epoch_size: 301000
9
+ COCOW-TRAIN:
10
+ TYPE: ImageDataset
11
+ URLS: wilor_training_data/dataset_tars/cocow-train/{000000..000036}.tar
12
+ epoch_size: 78666
13
+ DEX-TRAIN:
14
+ TYPE: ImageDataset
15
+ URLS: wilor_training_data/dataset_tars/dex-train/{000000..000406}.tar
16
+ epoch_size: 406888
17
+ FREIHAND-MOCAP:
18
+ DATASET_FILE: wilor_training_data/freihand_mocap.npz
19
+ FREIHAND-TRAIN:
20
+ TYPE: ImageDataset
21
+ URLS: wilor_training_data/dataset_tars/freihand-train/{000000..000130}.tar
22
+ epoch_size: 130240
23
+ H2O3D-TRAIN:
24
+ TYPE: ImageDataset
25
+ URLS: wilor_training_data/dataset_tars/h2o3d-train/{000000..000060}.tar
26
+ epoch_size: 121996
27
+ HALPE-TRAIN:
28
+ TYPE: ImageDataset
29
+ URLS: wilor_training_data/dataset_tars/halpe-train/{000000..000022}.tar
30
+ epoch_size: 34289
31
+ HO3D-TRAIN:
32
+ TYPE: ImageDataset
33
+ URLS: wilor_training_data/dataset_tars/ho3d-train/{000000..000083}.tar
34
+ epoch_size: 83325
35
+ HOT3D-TRAIN:
36
+ TYPE: ImageDataset
37
+ URLS: wilor_training_data/dataset_tars/hot3d-train/{000000..000571}.tar
38
+ epoch_size: 572000
39
+ INTERHAND26M-TRAIN:
40
+ TYPE: ImageDataset
41
+ URLS: wilor_training_data/dataset_tars/interhand26m-train/{000000..001056}.tar
42
+ epoch_size: 1424632
43
+ MPIINZSL-TRAIN:
44
+ TYPE: ImageDataset
45
+ URLS: wilor_training_data/dataset_tars/mpiinzsl-train/{000000..000015}.tar
46
+ epoch_size: 15184
47
+ MTC-TRAIN:
48
+ TYPE: ImageDataset
49
+ URLS: wilor_training_data/dataset_tars/mtc-train/{000000..000306}.tar
50
+ epoch_size: 363947
51
+ REINTER-TRAIN:
52
+ TYPE: ImageDataset
53
+ URLS: wilor_training_data/dataset_tars/reinter-train/{000000..000418}.tar
54
+ epoch_size: 419000
55
+ RHD-TRAIN:
56
+ TYPE: ImageDataset
57
+ URLS: wilor_training_data/dataset_tars/rhd-train/{000000..000041}.tar
58
+ epoch_size: 61705
WiLoR/pretrained_models/model_config.yaml ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task_name: train
2
+ tags:
3
+ - dev
4
+ train: true
5
+ test: false
6
+ ckpt_path: null
7
+ seed: null
8
+ DATASETS:
9
+ TRAIN:
10
+ FREIHAND-TRAIN:
11
+ WEIGHT: 0.2
12
+ INTERHAND26M-TRAIN:
13
+ WEIGHT: 0.1
14
+ MTC-TRAIN:
15
+ WEIGHT: 0.05
16
+ RHD-TRAIN:
17
+ WEIGHT: 0.05
18
+ COCOW-TRAIN:
19
+ WEIGHT: 0.05
20
+ HALPE-TRAIN:
21
+ WEIGHT: 0.05
22
+ MPIINZSL-TRAIN:
23
+ WEIGHT: 0.05
24
+ HO3D-TRAIN:
25
+ WEIGHT: 0.05
26
+ H2O3D-TRAIN:
27
+ WEIGHT: 0.05
28
+ DEX-TRAIN:
29
+ WEIGHT: 0.05
30
+ BEDLAM-TRAIN:
31
+ WEIGHT: 0.05
32
+ REINTER-TRAIN:
33
+ WEIGHT: 0.1
34
+ HOT3D-TRAIN:
35
+ WEIGHT: 0.05
36
+ ARCTIC-TRAIN:
37
+ WEIGHT: 0.1
38
+ VAL:
39
+ FREIHAND-TRAIN:
40
+ WEIGHT: 1.0
41
+ MOCAP: FREIHAND-MOCAP
42
+ BETAS_REG: true
43
+ CONFIG:
44
+ SCALE_FACTOR: 0.3
45
+ ROT_FACTOR: 30
46
+ TRANS_FACTOR: 0.02
47
+ COLOR_SCALE: 0.2
48
+ ROT_AUG_RATE: 0.6
49
+ TRANS_AUG_RATE: 0.5
50
+ DO_FLIP: false
51
+ FLIP_AUG_RATE: 0.0
52
+ EXTREME_CROP_AUG_RATE: 0.0
53
+ EXTREME_CROP_AUG_LEVEL: 1
54
+ extras:
55
+ ignore_warnings: false
56
+ enforce_tags: true
57
+ print_config: true
58
+ exp_name: WiLoR
59
+ MANO:
60
+ DATA_DIR: mano_data
61
+ MODEL_PATH: ${MANO.DATA_DIR}
62
+ GENDER: neutral
63
+ NUM_HAND_JOINTS: 15
64
+ MEAN_PARAMS: ${MANO.DATA_DIR}/mano_mean_params.npz
65
+ CREATE_BODY_POSE: false
66
+ EXTRA:
67
+ FOCAL_LENGTH: 5000
68
+ NUM_LOG_IMAGES: 4
69
+ NUM_LOG_SAMPLES_PER_IMAGE: 8
70
+ PELVIS_IND: 0
71
+ GENERAL:
72
+ TOTAL_STEPS: 1000000
73
+ LOG_STEPS: 1000
74
+ VAL_STEPS: 1000
75
+ CHECKPOINT_STEPS: 1000
76
+ CHECKPOINT_SAVE_TOP_K: 1
77
+ NUM_WORKERS: 8
78
+ PREFETCH_FACTOR: 2
79
+ TRAIN:
80
+ LR: 1.0e-05
81
+ WEIGHT_DECAY: 0.0001
82
+ BATCH_SIZE: 32
83
+ LOSS_REDUCTION: mean
84
+ NUM_TRAIN_SAMPLES: 2
85
+ NUM_TEST_SAMPLES: 64
86
+ POSE_2D_NOISE_RATIO: 0.01
87
+ SMPL_PARAM_NOISE_RATIO: 0.005
88
+ MODEL:
89
+ IMAGE_SIZE: 256
90
+ IMAGE_MEAN:
91
+ - 0.485
92
+ - 0.456
93
+ - 0.406
94
+ IMAGE_STD:
95
+ - 0.229
96
+ - 0.224
97
+ - 0.225
98
+ BACKBONE:
99
+ TYPE: vit
100
+ PRETRAINED_WEIGHTS: training_data/vitpose_backbone.pth
101
+ MANO_HEAD:
102
+ TYPE: transformer_decoder
103
+ IN_CHANNELS: 2048
104
+ TRANSFORMER_DECODER:
105
+ depth: 6
106
+ heads: 8
107
+ mlp_dim: 1024
108
+ dim_head: 64
109
+ dropout: 0.0
110
+ emb_dropout: 0.0
111
+ norm: layer
112
+ context_dim: 1280
113
+ LOSS_WEIGHTS:
114
+ KEYPOINTS_3D: 0.05
115
+ KEYPOINTS_2D: 0.01
116
+ GLOBAL_ORIENT: 0.001
117
+ HAND_POSE: 0.001
118
+ BETAS: 0.0005
119
+ ADVERSARIAL: 0.0005
WiLoR/requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ opencv-python
3
+ pyrender
4
+ pytorch-lightning
5
+ scikit-image
6
+ smplx==0.1.28
7
+ yacs
8
+ chumpy @ git+https://github.com/mattloper/chumpy
9
+ timm
10
+ einops
11
+ xtcocotools
12
+ pandas
13
+ hydra-core
14
+ hydra-submitit-launcher
15
+ hydra-colorlog
16
+ pyrootutils
17
+ rich
18
+ webdataset
19
+ gradio
20
+ ultralytics==8.1.34
WiLoR/whim/Dataset_instructions.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## WHIM Dataset
2
+
3
+ **Annotations**
4
+
5
+ The image annotations can be downloaded from the following Drive:
6
+
7
+ ```
8
+ https://drive.google.com/drive/folders/1d9Fw7LfnF5oJuA6yE8T3xA-u9p6H5ObZ
9
+ ```
10
+
11
+ **[Alternative]**: The image annotations can be also downloaded from Hugging Face:
12
+ ```
13
+ https://huggingface.co/datasets/rolpotamias/WHIM
14
+ ```
15
+ If you are using Hugging Face you might need to merge the training zip files into a single file before uncompressing:
16
+ ```
17
+ cat train_split.zip* > ~/train_split.zip
18
+ ```
19
+
20
+ **Images**
21
+
22
+ To download the corresponding images you need to first download the YouTube videos and extract the specific frames.
23
+ You will need to install ''pytubefix'' or any similar package to download YouTube videos:
24
+ ```
25
+ pip install -Iv pytubefix==8.12.2
26
+ ```
27
+ You can then run the following command to download the corresponding train/test images:
28
+ ```
29
+ python download_videos.py --mode {train/test}
30
+ ```
31
+ Please make sure that the data are downloaded in the same directory.
WiLoR/whim/test_video_ids.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"YynYZyoETto": {"res": [360, 480], "length": 4678, "fps": 29.97002997002997}, "_iirwC_DvJ0": {"res": [480, 854], "length": 7994, "fps": 29.97002997002997}, "ZMnb9TTsx98": {"res": [1080, 1920], "length": 8109, "fps": 29.97002997002997}, "IrSIHJ0-AaU": {"res": [360, 640], "length": 880, "fps": 30.0}, "w2ULyzWkZ3k": {"res": [1080, 1920], "length": 17032, "fps": 29.97002997002997}, "ivyqQreoVQA": {"res": [1080, 1440], "length": 9610, "fps": 29.97002997002997}, "R07f8kg1h8o": {"res": [1080, 1920], "length": 10726, "fps": 23.976023976023978}, "7S9q1kAVmc0": {"res": [720, 1280], "length": 53757, "fps": 25.0}, "_Ce7G35GIqA": {"res": [720, 1280], "length": 1620, "fps": 30.0}, "lhHkJ3InQOE": {"res": [240, 320], "length": 1600, "fps": 11.988011988011989}, "NXRHcCScubA": {"res": [1080, 1920], "length": 9785, "fps": 29.97002997002997}, "DjFX4idkS3o": {"res": [720, 1280], "length": 5046, "fps": 29.97002997002997}, "06kKvQp4SfM": {"res": [720, 1280], "length": 2661, "fps": 30.0}, "8NqJiAu9W3Y": {"res": [720, 1280], "length": 4738, "fps": 29.97002997002997}, "nN5Y--biYv4": {"res": [720, 1280], "length": 38380, "fps": 29.97}, "OiAlJIaWOBg": {"res": [720, 1280], "length": 10944, "fps": 30.0}, "nJa_omJBzoU": {"res": [720, 1280], "length": 4311, "fps": 29.97002997002997}, "ff_xcsFJ8Pw": {"res": [720, 1280], "length": 5631, "fps": 29.97}, "Y1mNu5iFwMg": {"res": [720, 1280], "length": 7060, "fps": 30.0}, "Ipe9xJCfuTM": {"res": [1080, 1920], "length": 52419, "fps": 29.97002997002997}, "vRkcw9SRems": {"res": [1080, 1920], "length": 10282, "fps": 23.976023976023978}, "ChIJjJyBjQ0": {"res": [1080, 1920], "length": 20228, "fps": 29.97002997002997}, "bxZtXdVvfpc": {"res": [1080, 1920], "length": 2369, "fps": 23.976023976023978}, "MPeXy2U4yJM": {"res": [1080, 1920], "length": 6760, "fps": 24.0}, "wnKnoui3THA": {"res": [1080, 1920], "length": 7934, "fps": 25.0}, "gnArvcWaH6I": {"res": [480, 720], "length": 6864, "fps": 29.97002997002997}}
WiLoR/whim/train_video_ids.json ADDED
The diff for this file is too large to render. See raw diff
 
WiLoR/wilor/configs/__init__.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+ from yacs.config import CfgNode as CN
4
+
5
+ CACHE_DIR_PRETRAINED = "./pretrained_models/"
6
+
7
+ def to_lower(x: Dict) -> Dict:
8
+ """
9
+ Convert all dictionary keys to lowercase
10
+ Args:
11
+ x (dict): Input dictionary
12
+ Returns:
13
+ dict: Output dictionary with all keys converted to lowercase
14
+ """
15
+ return {k.lower(): v for k, v in x.items()}
16
+
17
+ _C = CN(new_allowed=True)
18
+
19
+ _C.GENERAL = CN(new_allowed=True)
20
+ _C.GENERAL.RESUME = True
21
+ _C.GENERAL.TIME_TO_RUN = 3300
22
+ _C.GENERAL.VAL_STEPS = 100
23
+ _C.GENERAL.LOG_STEPS = 100
24
+ _C.GENERAL.CHECKPOINT_STEPS = 20000
25
+ _C.GENERAL.CHECKPOINT_DIR = "checkpoints"
26
+ _C.GENERAL.SUMMARY_DIR = "tensorboard"
27
+ _C.GENERAL.NUM_GPUS = 1
28
+ _C.GENERAL.NUM_WORKERS = 4
29
+ _C.GENERAL.MIXED_PRECISION = True
30
+ _C.GENERAL.ALLOW_CUDA = True
31
+ _C.GENERAL.PIN_MEMORY = False
32
+ _C.GENERAL.DISTRIBUTED = False
33
+ _C.GENERAL.LOCAL_RANK = 0
34
+ _C.GENERAL.USE_SYNCBN = False
35
+ _C.GENERAL.WORLD_SIZE = 1
36
+
37
+ _C.TRAIN = CN(new_allowed=True)
38
+ _C.TRAIN.NUM_EPOCHS = 100
39
+ _C.TRAIN.BATCH_SIZE = 32
40
+ _C.TRAIN.SHUFFLE = True
41
+ _C.TRAIN.WARMUP = False
42
+ _C.TRAIN.NORMALIZE_PER_IMAGE = False
43
+ _C.TRAIN.CLIP_GRAD = False
44
+ _C.TRAIN.CLIP_GRAD_VALUE = 1.0
45
+ _C.LOSS_WEIGHTS = CN(new_allowed=True)
46
+
47
+ _C.DATASETS = CN(new_allowed=True)
48
+
49
+ _C.MODEL = CN(new_allowed=True)
50
+ _C.MODEL.IMAGE_SIZE = 224
51
+
52
+ _C.EXTRA = CN(new_allowed=True)
53
+ _C.EXTRA.FOCAL_LENGTH = 5000
54
+
55
+ _C.DATASETS.CONFIG = CN(new_allowed=True)
56
+ _C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
57
+ _C.DATASETS.CONFIG.ROT_FACTOR = 30
58
+ _C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
59
+ _C.DATASETS.CONFIG.COLOR_SCALE = 0.2
60
+ _C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
61
+ _C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
62
+ _C.DATASETS.CONFIG.DO_FLIP = False
63
+ _C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
64
+ _C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
65
+
66
+ def default_config() -> CN:
67
+ """
68
+ Get a yacs CfgNode object with the default config values.
69
+ """
70
+ # Return a clone so that the defaults will not be altered
71
+ # This is for the "local variable" use pattern
72
+ return _C.clone()
73
+
74
+ def dataset_config(name='datasets_tar.yaml') -> CN:
75
+ """
76
+ Get dataset config file
77
+ Returns:
78
+ CfgNode: Dataset config as a yacs CfgNode object.
79
+ """
80
+ cfg = CN(new_allowed=True)
81
+ config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), name)
82
+ cfg.merge_from_file(config_file)
83
+ cfg.freeze()
84
+ return cfg
85
+
86
+ def dataset_eval_config() -> CN:
87
+ return dataset_config('datasets_eval.yaml')
88
+
89
+ def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN:
90
+ """
91
+ Read a config file and optionally merge it with the default config file.
92
+ Args:
93
+ config_file (str): Path to config file.
94
+ merge (bool): Whether to merge with the default config or not.
95
+ Returns:
96
+ CfgNode: Config as a yacs CfgNode object.
97
+ """
98
+ if merge:
99
+ cfg = default_config()
100
+ else:
101
+ cfg = CN(new_allowed=True)
102
+ cfg.merge_from_file(config_file)
103
+
104
+ if update_cachedir:
105
+ def update_path(path: str) -> str:
106
+ if os.path.isabs(path):
107
+ return path
108
+ return os.path.join(CACHE_DIR_PRETRAINED, path)
109
+
110
+ cfg.MANO.MODEL_PATH = update_path(cfg.MANO.MODEL_PATH)
111
+ cfg.MANO.MEAN_PARAMS = update_path(cfg.MANO.MEAN_PARAMS)
112
+
113
+ cfg.freeze()
114
+ return cfg
WiLoR/wilor/configs/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (6.02 kB). View file
 
WiLoR/wilor/datasets/utils.py ADDED
@@ -0,0 +1,994 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parts of the code are taken or adapted from
3
+ https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py
4
+ """
5
+ import torch
6
+ import numpy as np
7
+ from skimage.transform import rotate, resize
8
+ from skimage.filters import gaussian
9
+ import random
10
+ import cv2
11
+ from typing import List, Dict, Tuple
12
+ from yacs.config import CfgNode
13
+
14
+ def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None):
15
+ """Increase the size of the bounding box to match the target shape."""
16
+ if target_aspect_ratio is None:
17
+ return input_shape
18
+
19
+ try:
20
+ w , h = input_shape
21
+ except (ValueError, TypeError):
22
+ return input_shape
23
+
24
+ w_t, h_t = target_aspect_ratio
25
+ if h / w < h_t / w_t:
26
+ h_new = max(w * h_t / w_t, h)
27
+ w_new = w
28
+ else:
29
+ h_new = h
30
+ w_new = max(h * w_t / h_t, w)
31
+ if h_new < h or w_new < w:
32
+ breakpoint()
33
+ return np.array([w_new, h_new])
34
+
35
+ def do_augmentation(aug_config: CfgNode) -> Tuple:
36
+ """
37
+ Compute random augmentation parameters.
38
+ Args:
39
+ aug_config (CfgNode): Config containing augmentation parameters.
40
+ Returns:
41
+ scale (float): Box rescaling factor.
42
+ rot (float): Random image rotation.
43
+ do_flip (bool): Whether to flip image or not.
44
+ do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT).
45
+ color_scale (List): Color rescaling factor
46
+ tx (float): Random translation along the x axis.
47
+ ty (float): Random translation along the y axis.
48
+ """
49
+
50
+ tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
51
+ ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
52
+ scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0
53
+ rot = np.clip(np.random.randn(), -2.0,
54
+ 2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0
55
+ do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE
56
+ do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE
57
+ extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0)
58
+ # extreme_crop_lvl = 0
59
+ c_up = 1.0 + aug_config.COLOR_SCALE
60
+ c_low = 1.0 - aug_config.COLOR_SCALE
61
+ color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]
62
+ return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty
63
+
64
+ def rotate_2d(pt_2d: np.array, rot_rad: float) -> np.array:
65
+ """
66
+ Rotate a 2D point on the x-y plane.
67
+ Args:
68
+ pt_2d (np.array): Input 2D point with shape (2,).
69
+ rot_rad (float): Rotation angle
70
+ Returns:
71
+ np.array: Rotated 2D point.
72
+ """
73
+ x = pt_2d[0]
74
+ y = pt_2d[1]
75
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
76
+ xx = x * cs - y * sn
77
+ yy = x * sn + y * cs
78
+ return np.array([xx, yy], dtype=np.float32)
79
+
80
+
81
+ def gen_trans_from_patch_cv(c_x: float, c_y: float,
82
+ src_width: float, src_height: float,
83
+ dst_width: float, dst_height: float,
84
+ scale: float, rot: float) -> np.array:
85
+ """
86
+ Create transformation matrix for the bounding box crop.
87
+ Args:
88
+ c_x (float): Bounding box center x coordinate in the original image.
89
+ c_y (float): Bounding box center y coordinate in the original image.
90
+ src_width (float): Bounding box width.
91
+ src_height (float): Bounding box height.
92
+ dst_width (float): Output box width.
93
+ dst_height (float): Output box height.
94
+ scale (float): Rescaling factor for the bounding box (augmentation).
95
+ rot (float): Random rotation applied to the box.
96
+ Returns:
97
+ trans (np.array): Target geometric transformation.
98
+ """
99
+ # augment size with scale
100
+ src_w = src_width * scale
101
+ src_h = src_height * scale
102
+ src_center = np.zeros(2)
103
+ src_center[0] = c_x
104
+ src_center[1] = c_y
105
+ # augment rotation
106
+ rot_rad = np.pi * rot / 180
107
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
108
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
109
+
110
+ dst_w = dst_width
111
+ dst_h = dst_height
112
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
113
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
114
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
115
+
116
+ src = np.zeros((3, 2), dtype=np.float32)
117
+ src[0, :] = src_center
118
+ src[1, :] = src_center + src_downdir
119
+ src[2, :] = src_center + src_rightdir
120
+
121
+ dst = np.zeros((3, 2), dtype=np.float32)
122
+ dst[0, :] = dst_center
123
+ dst[1, :] = dst_center + dst_downdir
124
+ dst[2, :] = dst_center + dst_rightdir
125
+
126
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
127
+
128
+ return trans
129
+
130
+
131
+ def trans_point2d(pt_2d: np.array, trans: np.array):
132
+ """
133
+ Transform a 2D point using translation matrix trans.
134
+ Args:
135
+ pt_2d (np.array): Input 2D point with shape (2,).
136
+ trans (np.array): Transformation matrix.
137
+ Returns:
138
+ np.array: Transformed 2D point.
139
+ """
140
+ src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T
141
+ dst_pt = np.dot(trans, src_pt)
142
+ return dst_pt[0:2]
143
+
144
+ def get_transform(center, scale, res, rot=0):
145
+ """Generate transformation matrix."""
146
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
147
+ h = 200 * scale
148
+ t = np.zeros((3, 3))
149
+ t[0, 0] = float(res[1]) / h
150
+ t[1, 1] = float(res[0]) / h
151
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
152
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
153
+ t[2, 2] = 1
154
+ if not rot == 0:
155
+ rot = -rot # To match direction of rotation from cropping
156
+ rot_mat = np.zeros((3, 3))
157
+ rot_rad = rot * np.pi / 180
158
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
159
+ rot_mat[0, :2] = [cs, -sn]
160
+ rot_mat[1, :2] = [sn, cs]
161
+ rot_mat[2, 2] = 1
162
+ # Need to rotate around center
163
+ t_mat = np.eye(3)
164
+ t_mat[0, 2] = -res[1] / 2
165
+ t_mat[1, 2] = -res[0] / 2
166
+ t_inv = t_mat.copy()
167
+ t_inv[:2, 2] *= -1
168
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
169
+ return t
170
+
171
+
172
+ def transform(pt, center, scale, res, invert=0, rot=0, as_int=True):
173
+ """Transform pixel location to different reference."""
174
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
175
+ t = get_transform(center, scale, res, rot=rot)
176
+ if invert:
177
+ t = np.linalg.inv(t)
178
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
179
+ new_pt = np.dot(t, new_pt)
180
+ if as_int:
181
+ new_pt = new_pt.astype(int)
182
+ return new_pt[:2] + 1
183
+
184
+ def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0):
185
+ c_x = (ul[0] + br[0])/2
186
+ c_y = (ul[1] + br[1])/2
187
+ bb_width = patch_width = br[0] - ul[0]
188
+ bb_height = patch_height = br[1] - ul[1]
189
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0)
190
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
191
+ flags=cv2.INTER_LINEAR,
192
+ borderMode=border_mode,
193
+ borderValue=border_value
194
+ )
195
+
196
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
197
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
198
+ img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
199
+ flags=cv2.INTER_LINEAR,
200
+ borderMode=cv2.BORDER_CONSTANT,
201
+ )
202
+
203
+ return img_patch
204
+
205
+ def generate_image_patch_skimage(img: np.array, c_x: float, c_y: float,
206
+ bb_width: float, bb_height: float,
207
+ patch_width: float, patch_height: float,
208
+ do_flip: bool, scale: float, rot: float,
209
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
210
+ """
211
+ Crop image according to the supplied bounding box.
212
+ Args:
213
+ img (np.array): Input image of shape (H, W, 3)
214
+ c_x (float): Bounding box center x coordinate in the original image.
215
+ c_y (float): Bounding box center y coordinate in the original image.
216
+ bb_width (float): Bounding box width.
217
+ bb_height (float): Bounding box height.
218
+ patch_width (float): Output box width.
219
+ patch_height (float): Output box height.
220
+ do_flip (bool): Whether to flip image or not.
221
+ scale (float): Rescaling factor for the bounding box (augmentation).
222
+ rot (float): Random rotation applied to the box.
223
+ Returns:
224
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
225
+ trans (np.array): Transformation matrix.
226
+ """
227
+
228
+ img_height, img_width, img_channels = img.shape
229
+ if do_flip:
230
+ img = img[:, ::-1, :]
231
+ c_x = img_width - c_x - 1
232
+
233
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
234
+
235
+ #img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR)
236
+
237
+ # skimage
238
+ center = np.zeros(2)
239
+ center[0] = c_x
240
+ center[1] = c_y
241
+ res = np.zeros(2)
242
+ res[0] = patch_width
243
+ res[1] = patch_height
244
+ # assumes bb_width = bb_height
245
+ # assumes patch_width = patch_height
246
+ assert bb_width == bb_height, f'{bb_width=} != {bb_height=}'
247
+ assert patch_width == patch_height, f'{patch_width=} != {patch_height=}'
248
+ scale1 = scale*bb_width/200.
249
+
250
+ # Upper left point
251
+ ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1
252
+ # Bottom right point
253
+ br = np.array(transform([res[0] + 1,
254
+ res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1
255
+
256
+ # Padding so that when rotated proper amount of context is included
257
+ try:
258
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1
259
+ except:
260
+ breakpoint()
261
+ if not rot == 0:
262
+ ul -= pad
263
+ br += pad
264
+
265
+
266
+ if False:
267
+ # Old way of cropping image
268
+ ul_int = ul.astype(int)
269
+ br_int = br.astype(int)
270
+ new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]]
271
+ if len(img.shape) > 2:
272
+ new_shape += [img.shape[2]]
273
+ new_img = np.zeros(new_shape)
274
+
275
+ # Range to fill new array
276
+ new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0]
277
+ new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1]
278
+ # Range to sample from original image
279
+ old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0])
280
+ old_y = max(0, ul_int[1]), min(len(img), br_int[1])
281
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
282
+ old_x[0]:old_x[1]]
283
+
284
+ # New way of cropping image
285
+ new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32)
286
+
287
+ # print(f'{new_img.shape=}')
288
+ # print(f'{new_img1.shape=}')
289
+ # print(f'{np.allclose(new_img, new_img1)=}')
290
+ # print(f'{img.dtype=}')
291
+
292
+
293
+ if not rot == 0:
294
+ # Remove padding
295
+
296
+ new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
297
+ new_img = new_img[pad:-pad, pad:-pad]
298
+
299
+ if new_img.shape[0] < 1 or new_img.shape[1] < 1:
300
+ print(f'{img.shape=}')
301
+ print(f'{new_img.shape=}')
302
+ print(f'{ul=}')
303
+ print(f'{br=}')
304
+ print(f'{pad=}')
305
+ print(f'{rot=}')
306
+
307
+ breakpoint()
308
+
309
+ # resize image
310
+ new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
311
+
312
+ new_img = np.clip(new_img, 0, 255).astype(np.uint8)
313
+
314
+ return new_img, trans
315
+
316
+
317
+ def generate_image_patch_cv2(img: np.array, c_x: float, c_y: float,
318
+ bb_width: float, bb_height: float,
319
+ patch_width: float, patch_height: float,
320
+ do_flip: bool, scale: float, rot: float,
321
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
322
+ """
323
+ Crop the input image and return the crop and the corresponding transformation matrix.
324
+ Args:
325
+ img (np.array): Input image of shape (H, W, 3)
326
+ c_x (float): Bounding box center x coordinate in the original image.
327
+ c_y (float): Bounding box center y coordinate in the original image.
328
+ bb_width (float): Bounding box width.
329
+ bb_height (float): Bounding box height.
330
+ patch_width (float): Output box width.
331
+ patch_height (float): Output box height.
332
+ do_flip (bool): Whether to flip image or not.
333
+ scale (float): Rescaling factor for the bounding box (augmentation).
334
+ rot (float): Random rotation applied to the box.
335
+ Returns:
336
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
337
+ trans (np.array): Transformation matrix.
338
+ """
339
+
340
+ img_height, img_width, img_channels = img.shape
341
+ if do_flip:
342
+ img = img[:, ::-1, :]
343
+ c_x = img_width - c_x - 1
344
+
345
+
346
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
347
+
348
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
349
+ flags=cv2.INTER_LINEAR,
350
+ borderMode=border_mode,
351
+ borderValue=border_value,
352
+ )
353
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
354
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
355
+ img_patch[:,:,3] = cv2.warpAffine(img[:,:,3], trans, (int(patch_width), int(patch_height)),
356
+ flags=cv2.INTER_LINEAR,
357
+ borderMode=cv2.BORDER_CONSTANT,
358
+ )
359
+
360
+ return img_patch, trans
361
+
362
+
363
+ def convert_cvimg_to_tensor(cvimg: np.array):
364
+ """
365
+ Convert image from HWC to CHW format.
366
+ Args:
367
+ cvimg (np.array): Image of shape (H, W, 3) as loaded by OpenCV.
368
+ Returns:
369
+ np.array: Output image of shape (3, H, W).
370
+ """
371
+ # from h,w,c(OpenCV) to c,h,w
372
+ img = cvimg.copy()
373
+ img = np.transpose(img, (2, 0, 1))
374
+ # from int to float
375
+ img = img.astype(np.float32)
376
+ return img
377
+
378
+ def fliplr_params(mano_params: Dict, has_mano_params: Dict) -> Tuple[Dict, Dict]:
379
+ """
380
+ Flip MANO parameters when flipping the image.
381
+ Args:
382
+ mano_params (Dict): MANO parameter annotations.
383
+ has_mano_params (Dict): Whether MANO annotations are valid.
384
+ Returns:
385
+ Dict, Dict: Flipped MANO parameters and valid flags.
386
+ """
387
+ global_orient = mano_params['global_orient'].copy()
388
+ hand_pose = mano_params['hand_pose'].copy()
389
+ betas = mano_params['betas'].copy()
390
+ has_global_orient = has_mano_params['global_orient'].copy()
391
+ has_hand_pose = has_mano_params['hand_pose'].copy()
392
+ has_betas = has_mano_params['betas'].copy()
393
+
394
+ global_orient[1::3] *= -1
395
+ global_orient[2::3] *= -1
396
+ hand_pose[1::3] *= -1
397
+ hand_pose[2::3] *= -1
398
+
399
+ mano_params = {'global_orient': global_orient.astype(np.float32),
400
+ 'hand_pose': hand_pose.astype(np.float32),
401
+ 'betas': betas.astype(np.float32)
402
+ }
403
+
404
+ has_mano_params = {'global_orient': has_global_orient,
405
+ 'hand_pose': has_hand_pose,
406
+ 'betas': has_betas
407
+ }
408
+
409
+ return mano_params, has_mano_params
410
+
411
+
412
+ def fliplr_keypoints(joints: np.array, width: float, flip_permutation: List[int]) -> np.array:
413
+ """
414
+ Flip 2D or 3D keypoints.
415
+ Args:
416
+ joints (np.array): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence.
417
+ flip_permutation (List): Permutation to apply after flipping.
418
+ Returns:
419
+ np.array: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively.
420
+ """
421
+ joints = joints.copy()
422
+ # Flip horizontal
423
+ joints[:, 0] = width - joints[:, 0] - 1
424
+ joints = joints[flip_permutation, :]
425
+
426
+ return joints
427
+
428
+ def keypoint_3d_processing(keypoints_3d: np.array, flip_permutation: List[int], rot: float, do_flip: float) -> np.array:
429
+ """
430
+ Process 3D keypoints (rotation/flipping).
431
+ Args:
432
+ keypoints_3d (np.array): Input array of shape (N, 4) containing the 3D keypoints and confidence.
433
+ flip_permutation (List): Permutation to apply after flipping.
434
+ rot (float): Random rotation applied to the keypoints.
435
+ do_flip (bool): Whether to flip keypoints or not.
436
+ Returns:
437
+ np.array: Transformed 3D keypoints with shape (N, 4).
438
+ """
439
+ if do_flip:
440
+ keypoints_3d = fliplr_keypoints(keypoints_3d, 1, flip_permutation)
441
+ # in-plane rotation
442
+ rot_mat = np.eye(3)
443
+ if not rot == 0:
444
+ rot_rad = -rot * np.pi / 180
445
+ sn,cs = np.sin(rot_rad), np.cos(rot_rad)
446
+ rot_mat[0,:2] = [cs, -sn]
447
+ rot_mat[1,:2] = [sn, cs]
448
+ keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1])
449
+ # flip the x coordinates
450
+ keypoints_3d = keypoints_3d.astype('float32')
451
+ return keypoints_3d
452
+
453
+ def rot_aa(aa: np.array, rot: float) -> np.array:
454
+ """
455
+ Rotate axis angle parameters.
456
+ Args:
457
+ aa (np.array): Axis-angle vector of shape (3,).
458
+ rot (np.array): Rotation angle in degrees.
459
+ Returns:
460
+ np.array: Rotated axis-angle vector.
461
+ """
462
+ # pose parameters
463
+ R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
464
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
465
+ [0, 0, 1]])
466
+ # find the rotation of the hand in camera frame
467
+ per_rdg, _ = cv2.Rodrigues(aa)
468
+ # apply the global rotation to the global orientation
469
+ resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
470
+ aa = (resrot.T)[0]
471
+ return aa.astype(np.float32)
472
+
473
+ def mano_param_processing(mano_params: Dict, has_mano_params: Dict, rot: float, do_flip: bool) -> Tuple[Dict, Dict]:
474
+ """
475
+ Apply random augmentations to the MANO parameters.
476
+ Args:
477
+ mano_params (Dict): MANO parameter annotations.
478
+ has_mano_params (Dict): Whether mano annotations are valid.
479
+ rot (float): Random rotation applied to the keypoints.
480
+ do_flip (bool): Whether to flip keypoints or not.
481
+ Returns:
482
+ Dict, Dict: Transformed MANO parameters and valid flags.
483
+ """
484
+ if do_flip:
485
+ mano_params, has_mano_params = fliplr_params(mano_params, has_mano_params)
486
+ mano_params['global_orient'] = rot_aa(mano_params['global_orient'], rot)
487
+ return mano_params, has_mano_params
488
+
489
+
490
+
491
+ def get_example(img_path: str|np.ndarray, center_x: float, center_y: float,
492
+ width: float, height: float,
493
+ keypoints_2d: np.array, keypoints_3d: np.array,
494
+ mano_params: Dict, has_mano_params: Dict,
495
+ flip_kp_permutation: List[int],
496
+ patch_width: int, patch_height: int,
497
+ mean: np.array, std: np.array,
498
+ do_augment: bool, is_right: bool, augm_config: CfgNode,
499
+ is_bgr: bool = True,
500
+ use_skimage_antialias: bool = False,
501
+ border_mode: int = cv2.BORDER_CONSTANT,
502
+ return_trans: bool = False) -> Tuple:
503
+ """
504
+ Get an example from the dataset and (possibly) apply random augmentations.
505
+ Args:
506
+ img_path (str): Image filename
507
+ center_x (float): Bounding box center x coordinate in the original image.
508
+ center_y (float): Bounding box center y coordinate in the original image.
509
+ width (float): Bounding box width.
510
+ height (float): Bounding box height.
511
+ keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
512
+ keypoints_3d (np.array): Array with shape (N,4) containing the 3D keypoints.
513
+ mano_params (Dict): MANO parameter annotations.
514
+ has_mano_params (Dict): Whether MANO annotations are valid.
515
+ flip_kp_permutation (List): Permutation to apply to the keypoints after flipping.
516
+ patch_width (float): Output box width.
517
+ patch_height (float): Output box height.
518
+ mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
519
+ std (np.array): Array of shape (3,) containing the std for normalizing the input image.
520
+ do_augment (bool): Whether to apply data augmentation or not.
521
+ aug_config (CfgNode): Config containing augmentation parameters.
522
+ Returns:
523
+ return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size
524
+ img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
525
+ keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
526
+ keypoints_3d (np.array): Array with shape (N,4) containing the transformed 3D keypoints.
527
+ mano_params (Dict): Transformed MANO parameters.
528
+ has_mano_params (Dict): Valid flag for transformed MANO parameters.
529
+ img_size (np.array): Image size of the original image.
530
+ """
531
+ if isinstance(img_path, str):
532
+ # 1. load image
533
+ cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
534
+ if not isinstance(cvimg, np.ndarray):
535
+ raise IOError("Fail to read %s" % img_path)
536
+ elif isinstance(img_path, np.ndarray):
537
+ cvimg = img_path
538
+ else:
539
+ raise TypeError('img_path must be either a string or a numpy array')
540
+ img_height, img_width, img_channels = cvimg.shape
541
+
542
+ img_size = np.array([img_height, img_width])
543
+
544
+ # 2. get augmentation params
545
+ if do_augment:
546
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
547
+ else:
548
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0, 1.0, 1.0], 0., 0.
549
+
550
+ # if it's a left hand, we flip
551
+ if not is_right:
552
+ do_flip = True
553
+
554
+ if width < 1 or height < 1:
555
+ breakpoint()
556
+
557
+ if do_extreme_crop:
558
+ if extreme_crop_lvl == 0:
559
+ center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d)
560
+ elif extreme_crop_lvl == 1:
561
+ center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height, keypoints_2d)
562
+
563
+ THRESH = 4
564
+ if width1 < THRESH or height1 < THRESH:
565
+ # print(f'{do_extreme_crop=}')
566
+ # print(f'width: {width}, height: {height}')
567
+ # print(f'width1: {width1}, height1: {height1}')
568
+ # print(f'center_x: {center_x}, center_y: {center_y}')
569
+ # print(f'center_x1: {center_x1}, center_y1: {center_y1}')
570
+ # print(f'keypoints_2d: {keypoints_2d}')
571
+ # print(f'\n\n', flush=True)
572
+ # breakpoint()
573
+ pass
574
+ # print(f'skip ==> width1: {width1}, height1: {height1}, width: {width}, height: {height}')
575
+ else:
576
+ center_x, center_y, width, height = center_x1, center_y1, width1, height1
577
+
578
+ center_x += width * tx
579
+ center_y += height * ty
580
+
581
+ # Process 3D keypoints
582
+ keypoints_3d = keypoint_3d_processing(keypoints_3d, flip_kp_permutation, rot, do_flip)
583
+
584
+ # 3. generate image patch
585
+ if use_skimage_antialias:
586
+ # Blur image to avoid aliasing artifacts
587
+ downsampling_factor = (patch_width / (width*scale))
588
+ if downsampling_factor > 1.1:
589
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True, truncate=3.0)
590
+
591
+ img_patch_cv, trans = generate_image_patch_cv2(cvimg,
592
+ center_x, center_y,
593
+ width, height,
594
+ patch_width, patch_height,
595
+ do_flip, scale, rot,
596
+ border_mode=border_mode)
597
+
598
+ # img_patch_cv, trans = generate_image_patch_skimage(cvimg,
599
+ # center_x, center_y,
600
+ # width, height,
601
+ # patch_width, patch_height,
602
+ # do_flip, scale, rot,
603
+ # border_mode=border_mode)
604
+
605
+ image = img_patch_cv.copy()
606
+ if is_bgr:
607
+ image = image[:, :, ::-1]
608
+ img_patch_cv = image.copy()
609
+ img_patch = convert_cvimg_to_tensor(image)
610
+
611
+
612
+ mano_params, has_mano_params = mano_param_processing(mano_params, has_mano_params, rot, do_flip)
613
+
614
+ # apply normalization
615
+ for n_c in range(min(img_channels, 3)):
616
+ img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
617
+ if mean is not None and std is not None:
618
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
619
+ if do_flip:
620
+ keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, flip_kp_permutation)
621
+
622
+
623
+ for n_jt in range(len(keypoints_2d)):
624
+ keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
625
+ keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
626
+
627
+ if not return_trans:
628
+ return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size
629
+ else:
630
+ return img_patch, keypoints_2d, keypoints_3d, mano_params, has_mano_params, img_size, trans
631
+
632
+ def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
633
+ """
634
+ Extreme cropping: Crop the box up to the hip locations.
635
+ Args:
636
+ center_x (float): x coordinate of the bounding box center.
637
+ center_y (float): y coordinate of the bounding box center.
638
+ width (float): Bounding box width.
639
+ height (float): Bounding box height.
640
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
641
+ Returns:
642
+ center_x (float): x coordinate of the new bounding box center.
643
+ center_y (float): y coordinate of the new bounding box center.
644
+ width (float): New bounding box width.
645
+ height (float): New bounding box height.
646
+ """
647
+ keypoints_2d = keypoints_2d.copy()
648
+ lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25+0, 25+1, 25+4, 25+5]
649
+ keypoints_2d[lower_body_keypoints, :] = 0
650
+ if keypoints_2d[:, -1].sum() > 1:
651
+ center, scale = get_bbox(keypoints_2d)
652
+ center_x = center[0]
653
+ center_y = center[1]
654
+ width = 1.1 * scale[0]
655
+ height = 1.1 * scale[1]
656
+ return center_x, center_y, width, height
657
+
658
+
659
+ def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
660
+ """
661
+ Extreme cropping: Crop the box up to the shoulder locations.
662
+ Args:
663
+ center_x (float): x coordinate of the bounding box center.
664
+ center_y (float): y coordinate of the bounding box center.
665
+ width (float): Bounding box width.
666
+ height (float): Bounding box height.
667
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
668
+ Returns:
669
+ center_x (float): x coordinate of the new bounding box center.
670
+ center_y (float): y coordinate of the new bounding box center.
671
+ width (float): New bounding box width.
672
+ height (float): New bounding box height.
673
+ """
674
+ keypoints_2d = keypoints_2d.copy()
675
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16]]
676
+ keypoints_2d[lower_body_keypoints, :] = 0
677
+ center, scale = get_bbox(keypoints_2d)
678
+ if keypoints_2d[:, -1].sum() > 1:
679
+ center, scale = get_bbox(keypoints_2d)
680
+ center_x = center[0]
681
+ center_y = center[1]
682
+ width = 1.2 * scale[0]
683
+ height = 1.2 * scale[1]
684
+ return center_x, center_y, width, height
685
+
686
+ def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
687
+ """
688
+ Extreme cropping: Crop the box and keep on only the head.
689
+ Args:
690
+ center_x (float): x coordinate of the bounding box center.
691
+ center_y (float): y coordinate of the bounding box center.
692
+ width (float): Bounding box width.
693
+ height (float): Bounding box height.
694
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
695
+ Returns:
696
+ center_x (float): x coordinate of the new bounding box center.
697
+ center_y (float): y coordinate of the new bounding box center.
698
+ width (float): New bounding box width.
699
+ height (float): New bounding box height.
700
+ """
701
+ keypoints_2d = keypoints_2d.copy()
702
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, 16]]
703
+ keypoints_2d[lower_body_keypoints, :] = 0
704
+ if keypoints_2d[:, -1].sum() > 1:
705
+ center, scale = get_bbox(keypoints_2d)
706
+ center_x = center[0]
707
+ center_y = center[1]
708
+ width = 1.3 * scale[0]
709
+ height = 1.3 * scale[1]
710
+ return center_x, center_y, width, height
711
+
712
+ def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
713
+ """
714
+ Extreme cropping: Crop the box and keep on only the torso.
715
+ Args:
716
+ center_x (float): x coordinate of the bounding box center.
717
+ center_y (float): y coordinate of the bounding box center.
718
+ width (float): Bounding box width.
719
+ height (float): Bounding box height.
720
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
721
+ Returns:
722
+ center_x (float): x coordinate of the new bounding box center.
723
+ center_y (float): y coordinate of the new bounding box center.
724
+ width (float): New bounding box width.
725
+ height (float): New bounding box height.
726
+ """
727
+ keypoints_2d = keypoints_2d.copy()
728
+ nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 4, 5, 6, 7, 10, 11, 13, 17, 18]]
729
+ keypoints_2d[nontorso_body_keypoints, :] = 0
730
+ if keypoints_2d[:, -1].sum() > 1:
731
+ center, scale = get_bbox(keypoints_2d)
732
+ center_x = center[0]
733
+ center_y = center[1]
734
+ width = 1.1 * scale[0]
735
+ height = 1.1 * scale[1]
736
+ return center_x, center_y, width, height
737
+
738
+ def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
739
+ """
740
+ Extreme cropping: Crop the box and keep on only the right arm.
741
+ Args:
742
+ center_x (float): x coordinate of the bounding box center.
743
+ center_y (float): y coordinate of the bounding box center.
744
+ width (float): Bounding box width.
745
+ height (float): Bounding box height.
746
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
747
+ Returns:
748
+ center_x (float): x coordinate of the new bounding box center.
749
+ center_y (float): y coordinate of the new bounding box center.
750
+ width (float): New bounding box width.
751
+ height (float): New bounding box height.
752
+ """
753
+ keypoints_2d = keypoints_2d.copy()
754
+ nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
755
+ keypoints_2d[nonrightarm_body_keypoints, :] = 0
756
+ if keypoints_2d[:, -1].sum() > 1:
757
+ center, scale = get_bbox(keypoints_2d)
758
+ center_x = center[0]
759
+ center_y = center[1]
760
+ width = 1.1 * scale[0]
761
+ height = 1.1 * scale[1]
762
+ return center_x, center_y, width, height
763
+
764
+ def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
765
+ """
766
+ Extreme cropping: Crop the box and keep on only the left arm.
767
+ Args:
768
+ center_x (float): x coordinate of the bounding box center.
769
+ center_y (float): y coordinate of the bounding box center.
770
+ width (float): Bounding box width.
771
+ height (float): Bounding box height.
772
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
773
+ Returns:
774
+ center_x (float): x coordinate of the new bounding box center.
775
+ center_y (float): y coordinate of the new bounding box center.
776
+ width (float): New bounding box width.
777
+ height (float): New bounding box height.
778
+ """
779
+ keypoints_2d = keypoints_2d.copy()
780
+ nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]]
781
+ keypoints_2d[nonleftarm_body_keypoints, :] = 0
782
+ if keypoints_2d[:, -1].sum() > 1:
783
+ center, scale = get_bbox(keypoints_2d)
784
+ center_x = center[0]
785
+ center_y = center[1]
786
+ width = 1.1 * scale[0]
787
+ height = 1.1 * scale[1]
788
+ return center_x, center_y, width, height
789
+
790
+ def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
791
+ """
792
+ Extreme cropping: Crop the box and keep on only the legs.
793
+ Args:
794
+ center_x (float): x coordinate of the bounding box center.
795
+ center_y (float): y coordinate of the bounding box center.
796
+ width (float): Bounding box width.
797
+ height (float): Bounding box height.
798
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
799
+ Returns:
800
+ center_x (float): x coordinate of the new bounding box center.
801
+ center_y (float): y coordinate of the new bounding box center.
802
+ width (float): New bounding box width.
803
+ height (float): New bounding box height.
804
+ """
805
+ keypoints_2d = keypoints_2d.copy()
806
+ nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]]
807
+ keypoints_2d[nonlegs_body_keypoints, :] = 0
808
+ if keypoints_2d[:, -1].sum() > 1:
809
+ center, scale = get_bbox(keypoints_2d)
810
+ center_x = center[0]
811
+ center_y = center[1]
812
+ width = 1.1 * scale[0]
813
+ height = 1.1 * scale[1]
814
+ return center_x, center_y, width, height
815
+
816
+ def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
817
+ """
818
+ Extreme cropping: Crop the box and keep on only the right leg.
819
+ Args:
820
+ center_x (float): x coordinate of the bounding box center.
821
+ center_y (float): y coordinate of the bounding box center.
822
+ width (float): Bounding box width.
823
+ height (float): Bounding box height.
824
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
825
+ Returns:
826
+ center_x (float): x coordinate of the new bounding box center.
827
+ center_y (float): y coordinate of the new bounding box center.
828
+ width (float): New bounding box width.
829
+ height (float): New bounding box height.
830
+ """
831
+ keypoints_2d = keypoints_2d.copy()
832
+ nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
833
+ keypoints_2d[nonrightleg_body_keypoints, :] = 0
834
+ if keypoints_2d[:, -1].sum() > 1:
835
+ center, scale = get_bbox(keypoints_2d)
836
+ center_x = center[0]
837
+ center_y = center[1]
838
+ width = 1.1 * scale[0]
839
+ height = 1.1 * scale[1]
840
+ return center_x, center_y, width, height
841
+
842
+ def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
843
+ """
844
+ Extreme cropping: Crop the box and keep on only the left leg.
845
+ Args:
846
+ center_x (float): x coordinate of the bounding box center.
847
+ center_y (float): y coordinate of the bounding box center.
848
+ width (float): Bounding box width.
849
+ height (float): Bounding box height.
850
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
851
+ Returns:
852
+ center_x (float): x coordinate of the new bounding box center.
853
+ center_y (float): y coordinate of the new bounding box center.
854
+ width (float): New bounding box width.
855
+ height (float): New bounding box height.
856
+ """
857
+ keypoints_2d = keypoints_2d.copy()
858
+ nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in [0, 1, 2, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
859
+ keypoints_2d[nonleftleg_body_keypoints, :] = 0
860
+ if keypoints_2d[:, -1].sum() > 1:
861
+ center, scale = get_bbox(keypoints_2d)
862
+ center_x = center[0]
863
+ center_y = center[1]
864
+ width = 1.1 * scale[0]
865
+ height = 1.1 * scale[1]
866
+ return center_x, center_y, width, height
867
+
868
+ def full_body(keypoints_2d: np.array) -> bool:
869
+ """
870
+ Check if all main body joints are visible.
871
+ Args:
872
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
873
+ Returns:
874
+ bool: True if all main body joints are visible.
875
+ """
876
+
877
+ body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14]
878
+ body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]]
879
+ return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(body_keypoints)
880
+
881
+ def upper_body(keypoints_2d: np.array):
882
+ """
883
+ Check if all upper body joints are visible.
884
+ Args:
885
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
886
+ Returns:
887
+ bool: True if all main body joints are visible.
888
+ """
889
+ lower_body_keypoints_openpose = [10, 11, 13, 14]
890
+ lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]]
891
+ upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18]
892
+ upper_body_keypoints = [25+8, 25+9, 25+12, 25+13, 25+17, 25+18]
893
+ return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0)\
894
+ and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2)
895
+
896
+ def get_bbox(keypoints_2d: np.array, rescale: float = 1.2) -> Tuple:
897
+ """
898
+ Get center and scale for bounding box from openpose detections.
899
+ Args:
900
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
901
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
902
+ Returns:
903
+ center (np.array): Array of shape (2,) containing the new bounding box center.
904
+ scale (float): New bounding box scale.
905
+ """
906
+ valid = keypoints_2d[:,-1] > 0
907
+ valid_keypoints = keypoints_2d[valid][:,:-1]
908
+ center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0))
909
+ bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0))
910
+ # adjust bounding box tightness
911
+ scale = bbox_size
912
+ scale *= rescale
913
+ return center, scale
914
+
915
+ def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
916
+ """
917
+ Perform extreme cropping
918
+ Args:
919
+ center_x (float): x coordinate of bounding box center.
920
+ center_y (float): y coordinate of bounding box center.
921
+ width (float): bounding box width.
922
+ height (float): bounding box height.
923
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
924
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
925
+ Returns:
926
+ center_x (float): x coordinate of bounding box center.
927
+ center_y (float): y coordinate of bounding box center.
928
+ width (float): bounding box width.
929
+ height (float): bounding box height.
930
+ """
931
+ p = torch.rand(1).item()
932
+ if full_body(keypoints_2d):
933
+ if p < 0.7:
934
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
935
+ elif p < 0.9:
936
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
937
+ else:
938
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
939
+ elif upper_body(keypoints_2d):
940
+ if p < 0.9:
941
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
942
+ else:
943
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
944
+
945
+ return center_x, center_y, max(width, height), max(width, height)
946
+
947
+ def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
948
+ """
949
+ Perform aggressive extreme cropping
950
+ Args:
951
+ center_x (float): x coordinate of bounding box center.
952
+ center_y (float): y coordinate of bounding box center.
953
+ width (float): bounding box width.
954
+ height (float): bounding box height.
955
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
956
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
957
+ Returns:
958
+ center_x (float): x coordinate of bounding box center.
959
+ center_y (float): y coordinate of bounding box center.
960
+ width (float): bounding box width.
961
+ height (float): bounding box height.
962
+ """
963
+ p = torch.rand(1).item()
964
+ if full_body(keypoints_2d):
965
+ if p < 0.2:
966
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
967
+ elif p < 0.3:
968
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
969
+ elif p < 0.4:
970
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
971
+ elif p < 0.5:
972
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
973
+ elif p < 0.6:
974
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
975
+ elif p < 0.7:
976
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
977
+ elif p < 0.8:
978
+ center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d)
979
+ elif p < 0.9:
980
+ center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d)
981
+ else:
982
+ center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d)
983
+ elif upper_body(keypoints_2d):
984
+ if p < 0.2:
985
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
986
+ elif p < 0.4:
987
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
988
+ elif p < 0.6:
989
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
990
+ elif p < 0.8:
991
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
992
+ else:
993
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
994
+ return center_x, center_y, max(width, height), max(width, height)
WiLoR/wilor/datasets/vitdet_dataset.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from skimage.filters import gaussian
6
+ from yacs.config import CfgNode
7
+ import torch
8
+
9
+ from .utils import (convert_cvimg_to_tensor,
10
+ expand_to_aspect_ratio,
11
+ generate_image_patch_cv2)
12
+
13
+ DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
14
+ DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
15
+
16
+ class ViTDetDataset(torch.utils.data.Dataset):
17
+
18
+ def __init__(self,
19
+ cfg: CfgNode,
20
+ img_cv2: np.array,
21
+ boxes: np.array,
22
+ right: np.array,
23
+ rescale_factor=2.5,
24
+ train: bool = False,
25
+ **kwargs):
26
+ super().__init__()
27
+ self.cfg = cfg
28
+ self.img_cv2 = img_cv2
29
+ # self.boxes = boxes
30
+
31
+ assert train == False, "ViTDetDataset is only for inference"
32
+ self.train = train
33
+ self.img_size = cfg.MODEL.IMAGE_SIZE
34
+ self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN)
35
+ self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD)
36
+
37
+ # Preprocess annotations
38
+ boxes = boxes.astype(np.float32)
39
+ self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0
40
+ self.scale = rescale_factor * (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0
41
+ self.personid = np.arange(len(boxes), dtype=np.int32)
42
+ self.right = right.astype(np.float32)
43
+
44
+ def __len__(self) -> int:
45
+ return len(self.personid)
46
+
47
+ def __getitem__(self, idx: int) -> Dict[str, np.array]:
48
+
49
+ center = self.center[idx].copy()
50
+ center_x = center[0]
51
+ center_y = center[1]
52
+
53
+ scale = self.scale[idx]
54
+ BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None)
55
+ bbox_size = expand_to_aspect_ratio(scale*200, target_aspect_ratio=BBOX_SHAPE).max()
56
+
57
+ patch_width = patch_height = self.img_size
58
+
59
+ right = self.right[idx].copy()
60
+ flip = right == 0
61
+
62
+ # 3. generate image patch
63
+ # if use_skimage_antialias:
64
+ cvimg = self.img_cv2.copy()
65
+ if True:
66
+ # Blur image to avoid aliasing artifacts
67
+ downsampling_factor = ((bbox_size*1.0) / patch_width)
68
+ #print(f'{downsampling_factor=}')
69
+ downsampling_factor = downsampling_factor / 2.0
70
+ if downsampling_factor > 1.1:
71
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor-1)/2, channel_axis=2, preserve_range=True)
72
+
73
+
74
+ img_patch_cv, trans = generate_image_patch_cv2(cvimg,
75
+ center_x, center_y,
76
+ bbox_size, bbox_size,
77
+ patch_width, patch_height,
78
+ flip, 1.0, 0,
79
+ border_mode=cv2.BORDER_CONSTANT)
80
+ img_patch_cv = img_patch_cv[:, :, ::-1]
81
+ img_patch = convert_cvimg_to_tensor(img_patch_cv)
82
+
83
+ # apply normalization
84
+ for n_c in range(min(self.img_cv2.shape[2], 3)):
85
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c]
86
+
87
+ item = {
88
+ 'img': img_patch,
89
+ 'personid': int(self.personid[idx]),
90
+ }
91
+ item['box_center'] = self.center[idx].copy()
92
+ item['box_size'] = bbox_size
93
+ item['img_size'] = 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]])
94
+ item['right'] = self.right[idx].copy()
95
+ return item
WiLoR/wilor/models/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mano_wrapper import MANO
2
+ from .wilor import WiLoR
3
+
4
+ from .discriminator import Discriminator
5
+
6
+ def load_wilor(checkpoint_path, cfg_path):
7
+ from pathlib import Path
8
+ from wilor.configs import get_config
9
+ print('Loading ', checkpoint_path)
10
+ model_cfg = get_config(cfg_path, update_cachedir=True)
11
+
12
+ # Override some config values, to crop bbox correctly
13
+ if ('vit' in model_cfg.MODEL.BACKBONE.TYPE) and ('BBOX_SHAPE' not in model_cfg.MODEL):
14
+
15
+ model_cfg.defrost()
16
+ assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
17
+ model_cfg.MODEL.BBOX_SHAPE = [192,256]
18
+ model_cfg.freeze()
19
+
20
+ # Update config to be compatible with demo
21
+ if ('PRETRAINED_WEIGHTS' in model_cfg.MODEL.BACKBONE):
22
+ model_cfg.defrost()
23
+ model_cfg.MODEL.BACKBONE.pop('PRETRAINED_WEIGHTS')
24
+ model_cfg.freeze()
25
+
26
+ # Update config to be compatible with demo
27
+
28
+ if ('DATA_DIR' in model_cfg.MANO):
29
+ model_cfg.defrost()
30
+ model_cfg.MANO.DATA_DIR = './mano_data/'
31
+ model_cfg.MANO.MODEL_PATH = './mano_data/'
32
+ model_cfg.MANO.MEAN_PARAMS = './mano_data/mano_mean_params.npz'
33
+ model_cfg.freeze()
34
+
35
+ model = WiLoR.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg)
36
+ return model, model_cfg
WiLoR/wilor/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.02 kB). View file
 
WiLoR/wilor/models/__pycache__/discriminator.cpython-311.pyc ADDED
Binary file (6.52 kB). View file
 
WiLoR/wilor/models/__pycache__/losses.cpython-311.pyc ADDED
Binary file (6.87 kB). View file
 
WiLoR/wilor/models/__pycache__/mano_wrapper.cpython-311.pyc ADDED
Binary file (3.43 kB). View file
 
WiLoR/wilor/models/__pycache__/wilor.cpython-311.pyc ADDED
Binary file (24.1 kB). View file
 
WiLoR/wilor/models/backbones/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .vit import vit
2
+
3
+ def create_backbone(cfg):
4
+ if cfg.MODEL.BACKBONE.TYPE == 'vit':
5
+ return vit(cfg)
6
+ elif cfg.MODEL.BACKBONE.TYPE == 'fast_vit':
7
+ import torch
8
+ import sys
9
+ from timm.models import create_model
10
+ #from models.modules.mobileone import reparameterize_model
11
+ fast_vit = create_model("fastvit_ma36", drop_path_rate=0.2)
12
+ checkpoint = torch.load('./pretrained_models/fastvit_ma36.pt')
13
+ fast_vit.load_state_dict(checkpoint['state_dict'])
14
+ return fast_vit
15
+
16
+ else:
17
+ raise NotImplementedError('Backbone type is not implemented')
WiLoR/wilor/models/backbones/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (737 Bytes). View file
 
WiLoR/wilor/models/backbones/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.09 kB). View file
 
WiLoR/wilor/models/backbones/__pycache__/vit.cpython-310.pyc ADDED
Binary file (13.2 kB). View file
 
WiLoR/wilor/models/backbones/__pycache__/vit.cpython-311.pyc ADDED
Binary file (26.9 kB). View file
 
WiLoR/wilor/models/backbones/vit.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from functools import partial
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint as checkpoint
9
+ from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat
10
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
11
+
12
+ def vit(cfg):
13
+ return ViT(
14
+ img_size=(256, 192),
15
+ patch_size=16,
16
+ embed_dim=1280,
17
+ depth=32,
18
+ num_heads=16,
19
+ ratio=1,
20
+ use_checkpoint=False,
21
+ mlp_ratio=4,
22
+ qkv_bias=True,
23
+ drop_path_rate=0.55,
24
+ cfg = cfg
25
+ )
26
+
27
+ def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
28
+ """
29
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
30
+ dimension for the original embeddings.
31
+ Args:
32
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
33
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
34
+ hw (Tuple): size of input image tokens.
35
+
36
+ Returns:
37
+ Absolute positional embeddings after processing with shape (1, H, W, C)
38
+ """
39
+ cls_token = None
40
+ B, L, C = abs_pos.shape
41
+ if has_cls_token:
42
+ cls_token = abs_pos[:, 0:1]
43
+ abs_pos = abs_pos[:, 1:]
44
+
45
+ if ori_h != h or ori_w != w:
46
+ new_abs_pos = F.interpolate(
47
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
48
+ size=(h, w),
49
+ mode="bicubic",
50
+ align_corners=False,
51
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
52
+
53
+ else:
54
+ new_abs_pos = abs_pos
55
+
56
+ if cls_token is not None:
57
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
58
+ return new_abs_pos
59
+
60
+ class DropPath(nn.Module):
61
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
62
+ """
63
+ def __init__(self, drop_prob=None):
64
+ super(DropPath, self).__init__()
65
+ self.drop_prob = drop_prob
66
+
67
+ def forward(self, x):
68
+ return drop_path(x, self.drop_prob, self.training)
69
+
70
+ def extra_repr(self):
71
+ return 'p={}'.format(self.drop_prob)
72
+
73
+ class Mlp(nn.Module):
74
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
75
+ super().__init__()
76
+ out_features = out_features or in_features
77
+ hidden_features = hidden_features or in_features
78
+ self.fc1 = nn.Linear(in_features, hidden_features)
79
+ self.act = act_layer()
80
+ self.fc2 = nn.Linear(hidden_features, out_features)
81
+ self.drop = nn.Dropout(drop)
82
+
83
+ def forward(self, x):
84
+ x = self.fc1(x)
85
+ x = self.act(x)
86
+ x = self.fc2(x)
87
+ x = self.drop(x)
88
+ return x
89
+
90
+ class Attention(nn.Module):
91
+ def __init__(
92
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
93
+ proj_drop=0., attn_head_dim=None,):
94
+ super().__init__()
95
+ self.num_heads = num_heads
96
+ head_dim = dim // num_heads
97
+ self.dim = dim
98
+
99
+ if attn_head_dim is not None:
100
+ head_dim = attn_head_dim
101
+ all_head_dim = head_dim * self.num_heads
102
+
103
+ self.scale = qk_scale or head_dim ** -0.5
104
+
105
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
106
+
107
+ self.attn_drop = nn.Dropout(attn_drop)
108
+ self.proj = nn.Linear(all_head_dim, dim)
109
+ self.proj_drop = nn.Dropout(proj_drop)
110
+
111
+ def forward(self, x):
112
+ B, N, C = x.shape
113
+ qkv = self.qkv(x)
114
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
115
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
116
+
117
+ q = q * self.scale
118
+ attn = (q @ k.transpose(-2, -1))
119
+
120
+ attn = attn.softmax(dim=-1)
121
+ attn = self.attn_drop(attn)
122
+
123
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
124
+ x = self.proj(x)
125
+ x = self.proj_drop(x)
126
+
127
+ return x
128
+
129
+ class Block(nn.Module):
130
+
131
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
132
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
133
+ norm_layer=nn.LayerNorm, attn_head_dim=None
134
+ ):
135
+ super().__init__()
136
+
137
+ self.norm1 = norm_layer(dim)
138
+ self.attn = Attention(
139
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
140
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
141
+ )
142
+
143
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
144
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
145
+ self.norm2 = norm_layer(dim)
146
+ mlp_hidden_dim = int(dim * mlp_ratio)
147
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
148
+
149
+ def forward(self, x):
150
+ x = x + self.drop_path(self.attn(self.norm1(x)))
151
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
152
+ return x
153
+
154
+
155
+ class PatchEmbed(nn.Module):
156
+ """ Image to Patch Embedding
157
+ """
158
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
159
+ super().__init__()
160
+ img_size = to_2tuple(img_size)
161
+ patch_size = to_2tuple(patch_size)
162
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
163
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
164
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
165
+ self.img_size = img_size
166
+ self.patch_size = patch_size
167
+ self.num_patches = num_patches
168
+
169
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio), padding=4 + 2 * (ratio//2-1))
170
+
171
+ def forward(self, x, **kwargs):
172
+ B, C, H, W = x.shape
173
+ x = self.proj(x)
174
+ Hp, Wp = x.shape[2], x.shape[3]
175
+
176
+ x = x.flatten(2).transpose(1, 2)
177
+ return x, (Hp, Wp)
178
+
179
+
180
+ class HybridEmbed(nn.Module):
181
+ """ CNN Feature Map Embedding
182
+ Extract feature map from CNN, flatten, project to embedding dim.
183
+ """
184
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
185
+ super().__init__()
186
+ assert isinstance(backbone, nn.Module)
187
+ img_size = to_2tuple(img_size)
188
+ self.img_size = img_size
189
+ self.backbone = backbone
190
+ if feature_size is None:
191
+ with torch.no_grad():
192
+ training = backbone.training
193
+ if training:
194
+ backbone.eval()
195
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
196
+ feature_size = o.shape[-2:]
197
+ feature_dim = o.shape[1]
198
+ backbone.train(training)
199
+ else:
200
+ feature_size = to_2tuple(feature_size)
201
+ feature_dim = self.backbone.feature_info.channels()[-1]
202
+ self.num_patches = feature_size[0] * feature_size[1]
203
+ self.proj = nn.Linear(feature_dim, embed_dim)
204
+
205
+ def forward(self, x):
206
+ x = self.backbone(x)[-1]
207
+ x = x.flatten(2).transpose(1, 2)
208
+ x = self.proj(x)
209
+ return x
210
+
211
+
212
+ class ViT(nn.Module):
213
+
214
+ def __init__(self,
215
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
216
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
217
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
218
+ frozen_stages=-1, ratio=1, last_norm=True,
219
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,cfg=None,
220
+ ):
221
+ # Protect mutable default arguments
222
+ super(ViT, self).__init__()
223
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
224
+ self.num_classes = num_classes
225
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
226
+ self.frozen_stages = frozen_stages
227
+ self.use_checkpoint = use_checkpoint
228
+ self.patch_padding = patch_padding
229
+ self.freeze_attn = freeze_attn
230
+ self.freeze_ffn = freeze_ffn
231
+ self.depth = depth
232
+
233
+ if hybrid_backbone is not None:
234
+ self.patch_embed = HybridEmbed(
235
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
236
+ else:
237
+ self.patch_embed = PatchEmbed(
238
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
239
+ num_patches = self.patch_embed.num_patches
240
+
241
+ ##########################################
242
+ self.cfg = cfg
243
+ self.joint_rep_type = cfg.MODEL.MANO_HEAD.get('JOINT_REP', '6d')
244
+ self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
245
+ npose = self.joint_rep_dim * (cfg.MANO.NUM_HAND_JOINTS + 1)
246
+ self.npose = npose
247
+ mean_params = np.load(cfg.MANO.MEAN_PARAMS)
248
+ init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0)
249
+ self.register_buffer('init_cam', init_cam)
250
+ init_hand_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0)
251
+ init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0)
252
+ self.register_buffer('init_hand_pose', init_hand_pose)
253
+ self.register_buffer('init_betas', init_betas)
254
+
255
+ self.pose_emb = nn.Linear(self.joint_rep_dim , embed_dim)
256
+ self.shape_emb = nn.Linear(10 , embed_dim)
257
+ self.cam_emb = nn.Linear(3 , embed_dim)
258
+
259
+ self.decpose = nn.Linear(self.num_features, 6)
260
+ self.decshape = nn.Linear(self.num_features, 10)
261
+ self.deccam = nn.Linear(self.num_features, 3)
262
+ if cfg.MODEL.MANO_HEAD.get('INIT_DECODER_XAVIER', False):
263
+ # True by default in MLP. False by default in Transformer
264
+ nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
265
+ nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
266
+ nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
267
+
268
+
269
+ ##########################################
270
+
271
+ # since the pretraining model has class token
272
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
273
+
274
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
275
+
276
+ self.blocks = nn.ModuleList([
277
+ Block(
278
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
279
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
280
+ )
281
+ for i in range(depth)])
282
+
283
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
284
+
285
+ if self.pos_embed is not None:
286
+ trunc_normal_(self.pos_embed, std=.02)
287
+
288
+ self._freeze_stages()
289
+
290
+ def _freeze_stages(self):
291
+ """Freeze parameters."""
292
+ if self.frozen_stages >= 0:
293
+ self.patch_embed.eval()
294
+ for param in self.patch_embed.parameters():
295
+ param.requires_grad = False
296
+
297
+ for i in range(1, self.frozen_stages + 1):
298
+ m = self.blocks[i]
299
+ m.eval()
300
+ for param in m.parameters():
301
+ param.requires_grad = False
302
+
303
+ if self.freeze_attn:
304
+ for i in range(0, self.depth):
305
+ m = self.blocks[i]
306
+ m.attn.eval()
307
+ m.norm1.eval()
308
+ for param in m.attn.parameters():
309
+ param.requires_grad = False
310
+ for param in m.norm1.parameters():
311
+ param.requires_grad = False
312
+
313
+ if self.freeze_ffn:
314
+ self.pos_embed.requires_grad = False
315
+ self.patch_embed.eval()
316
+ for param in self.patch_embed.parameters():
317
+ param.requires_grad = False
318
+ for i in range(0, self.depth):
319
+ m = self.blocks[i]
320
+ m.mlp.eval()
321
+ m.norm2.eval()
322
+ for param in m.mlp.parameters():
323
+ param.requires_grad = False
324
+ for param in m.norm2.parameters():
325
+ param.requires_grad = False
326
+
327
+ def init_weights(self):
328
+ """Initialize the weights in backbone.
329
+ Args:
330
+ pretrained (str, optional): Path to pre-trained weights.
331
+ Defaults to None.
332
+ """
333
+ def _init_weights(m):
334
+ if isinstance(m, nn.Linear):
335
+ trunc_normal_(m.weight, std=.02)
336
+ if isinstance(m, nn.Linear) and m.bias is not None:
337
+ nn.init.constant_(m.bias, 0)
338
+ elif isinstance(m, nn.LayerNorm):
339
+ nn.init.constant_(m.bias, 0)
340
+ nn.init.constant_(m.weight, 1.0)
341
+
342
+ self.apply(_init_weights)
343
+
344
+ def get_num_layers(self):
345
+ return len(self.blocks)
346
+
347
+ @torch.jit.ignore
348
+ def no_weight_decay(self):
349
+ return {'pos_embed', 'cls_token'}
350
+
351
+ def forward_features(self, x):
352
+ B, C, H, W = x.shape
353
+ x, (Hp, Wp) = self.patch_embed(x)
354
+
355
+ if self.pos_embed is not None:
356
+ # fit for multiple GPU training
357
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
358
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
359
+ # X [B, 192, 1280]
360
+ # x cat [ mean_pose, mean_shape, mean_cam] tokens
361
+ pose_tokens = self.pose_emb(self.init_hand_pose.reshape(1, self.cfg.MANO.NUM_HAND_JOINTS + 1, self.joint_rep_dim)).repeat(B, 1, 1)
362
+ shape_tokens = self.shape_emb(self.init_betas).unsqueeze(1).repeat(B, 1, 1)
363
+ cam_tokens = self.cam_emb(self.init_cam).unsqueeze(1).repeat(B, 1, 1)
364
+
365
+ x = torch.cat([pose_tokens, shape_tokens, cam_tokens, x], 1)
366
+ for blk in self.blocks:
367
+ if self.use_checkpoint:
368
+ x = checkpoint.checkpoint(blk, x)
369
+ else:
370
+ x = blk(x)
371
+
372
+ x = self.last_norm(x)
373
+
374
+
375
+ pose_feat = x[:, :(self.cfg.MANO.NUM_HAND_JOINTS + 1)]
376
+ shape_feat = x[:, (self.cfg.MANO.NUM_HAND_JOINTS + 1):1+(self.cfg.MANO.NUM_HAND_JOINTS + 1)]
377
+ cam_feat = x[:, 1+(self.cfg.MANO.NUM_HAND_JOINTS + 1):2+(self.cfg.MANO.NUM_HAND_JOINTS + 1)]
378
+
379
+ #print(pose_feat.shape, shape_feat.shape, cam_feat.shape)
380
+ pred_hand_pose = self.decpose(pose_feat).reshape(B, -1) + self.init_hand_pose #B , 96
381
+ pred_betas = self.decshape(shape_feat).reshape(B, -1) + self.init_betas #B , 10
382
+ pred_cam = self.deccam(cam_feat).reshape(B, -1) + self.init_cam #B , 3
383
+
384
+ pred_mano_feats = {}
385
+ pred_mano_feats['hand_pose'] = pred_hand_pose
386
+ pred_mano_feats['betas'] = pred_betas
387
+ pred_mano_feats['cam'] = pred_cam
388
+
389
+
390
+ joint_conversion_fn = {
391
+ '6d': rot6d_to_rotmat,
392
+ 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())
393
+ }[self.joint_rep_type]
394
+
395
+ pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(B, self.cfg.MANO.NUM_HAND_JOINTS+1, 3, 3)
396
+ pred_mano_params = {'global_orient': pred_hand_pose[:, [0]],
397
+ 'hand_pose': pred_hand_pose[:, 1:],
398
+ 'betas': pred_betas}
399
+
400
+ img_feat = x[:, 2+(self.cfg.MANO.NUM_HAND_JOINTS + 1):].reshape(B, Hp, Wp, -1).permute(0, 3, 1, 2)
401
+ return pred_mano_params, pred_cam, pred_mano_feats, img_feat
402
+
403
+ def forward(self, x):
404
+ x = self.forward_features(x)
405
+ return x
406
+
407
+ def train(self, mode=True):
408
+ """Convert the model into training mode."""
409
+ super().train(mode)
410
+ self._freeze_stages()
WiLoR/wilor/models/discriminator.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Discriminator(nn.Module):
5
+
6
+ def __init__(self):
7
+ """
8
+ Pose + Shape discriminator proposed in HMR
9
+ """
10
+ super(Discriminator, self).__init__()
11
+
12
+ self.num_joints = 15
13
+ # poses_alone
14
+ self.D_conv1 = nn.Conv2d(9, 32, kernel_size=1)
15
+ nn.init.xavier_uniform_(self.D_conv1.weight)
16
+ nn.init.zeros_(self.D_conv1.bias)
17
+ self.relu = nn.ReLU(inplace=True)
18
+ self.D_conv2 = nn.Conv2d(32, 32, kernel_size=1)
19
+ nn.init.xavier_uniform_(self.D_conv2.weight)
20
+ nn.init.zeros_(self.D_conv2.bias)
21
+ pose_out = []
22
+ for i in range(self.num_joints):
23
+ pose_out_temp = nn.Linear(32, 1)
24
+ nn.init.xavier_uniform_(pose_out_temp.weight)
25
+ nn.init.zeros_(pose_out_temp.bias)
26
+ pose_out.append(pose_out_temp)
27
+ self.pose_out = nn.ModuleList(pose_out)
28
+
29
+ # betas
30
+ self.betas_fc1 = nn.Linear(10, 10)
31
+ nn.init.xavier_uniform_(self.betas_fc1.weight)
32
+ nn.init.zeros_(self.betas_fc1.bias)
33
+ self.betas_fc2 = nn.Linear(10, 5)
34
+ nn.init.xavier_uniform_(self.betas_fc2.weight)
35
+ nn.init.zeros_(self.betas_fc2.bias)
36
+ self.betas_out = nn.Linear(5, 1)
37
+ nn.init.xavier_uniform_(self.betas_out.weight)
38
+ nn.init.zeros_(self.betas_out.bias)
39
+
40
+ # poses_joint
41
+ self.D_alljoints_fc1 = nn.Linear(32*self.num_joints, 1024)
42
+ nn.init.xavier_uniform_(self.D_alljoints_fc1.weight)
43
+ nn.init.zeros_(self.D_alljoints_fc1.bias)
44
+ self.D_alljoints_fc2 = nn.Linear(1024, 1024)
45
+ nn.init.xavier_uniform_(self.D_alljoints_fc2.weight)
46
+ nn.init.zeros_(self.D_alljoints_fc2.bias)
47
+ self.D_alljoints_out = nn.Linear(1024, 1)
48
+ nn.init.xavier_uniform_(self.D_alljoints_out.weight)
49
+ nn.init.zeros_(self.D_alljoints_out.bias)
50
+
51
+
52
+ def forward(self, poses: torch.Tensor, betas: torch.Tensor) -> torch.Tensor:
53
+ """
54
+ Forward pass of the discriminator.
55
+ Args:
56
+ poses (torch.Tensor): Tensor of shape (B, 23, 3, 3) containing a batch of MANO hand poses (excluding the global orientation).
57
+ betas (torch.Tensor): Tensor of shape (B, 10) containign a batch of MANO beta coefficients.
58
+ Returns:
59
+ torch.Tensor: Discriminator output with shape (B, 25)
60
+ """
61
+ #bn = poses.shape[0]
62
+ # poses B x 207
63
+ #poses = poses.reshape(bn, -1)
64
+ # poses B x num_joints x 1 x 9
65
+ poses = poses.reshape(-1, self.num_joints, 1, 9)
66
+ bn = poses.shape[0]
67
+ # poses B x 9 x num_joints x 1
68
+ poses = poses.permute(0, 3, 1, 2).contiguous()
69
+
70
+ # poses_alone
71
+ poses = self.D_conv1(poses)
72
+ poses = self.relu(poses)
73
+ poses = self.D_conv2(poses)
74
+ poses = self.relu(poses)
75
+
76
+ poses_out = []
77
+ for i in range(self.num_joints):
78
+ poses_out_ = self.pose_out[i](poses[:, :, i, 0])
79
+ poses_out.append(poses_out_)
80
+ poses_out = torch.cat(poses_out, dim=1)
81
+
82
+ # betas
83
+ betas = self.betas_fc1(betas)
84
+ betas = self.relu(betas)
85
+ betas = self.betas_fc2(betas)
86
+ betas = self.relu(betas)
87
+ betas_out = self.betas_out(betas)
88
+
89
+ # poses_joint
90
+ poses = poses.reshape(bn,-1)
91
+ poses_all = self.D_alljoints_fc1(poses)
92
+ poses_all = self.relu(poses_all)
93
+ poses_all = self.D_alljoints_fc2(poses_all)
94
+ poses_all = self.relu(poses_all)
95
+ poses_all_out = self.D_alljoints_out(poses_all)
96
+
97
+ disc_out = torch.cat((poses_out, betas_out, poses_all_out), 1)
98
+ return disc_out
WiLoR/wilor/models/heads/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .refinement_net import RefineNet
WiLoR/wilor/models/heads/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (200 Bytes). View file
 
WiLoR/wilor/models/heads/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (219 Bytes). View file
 
WiLoR/wilor/models/heads/__pycache__/refinement_net.cpython-310.pyc ADDED
Binary file (7.63 kB). View file
 
WiLoR/wilor/models/heads/__pycache__/refinement_net.cpython-311.pyc ADDED
Binary file (15 kB). View file
 
WiLoR/wilor/models/heads/refinement_net.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat
6
+ from typing import Optional
7
+
8
+ def make_linear_layers(feat_dims, relu_final=True, use_bn=False):
9
+ layers = []
10
+ for i in range(len(feat_dims)-1):
11
+ layers.append(nn.Linear(feat_dims[i], feat_dims[i+1]))
12
+
13
+ # Do not use ReLU for final estimation
14
+ if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and relu_final):
15
+ if use_bn:
16
+ layers.append(nn.BatchNorm1d(feat_dims[i+1]))
17
+ layers.append(nn.ReLU(inplace=True))
18
+
19
+ return nn.Sequential(*layers)
20
+
21
+ def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True):
22
+ layers = []
23
+ for i in range(len(feat_dims)-1):
24
+ layers.append(
25
+ nn.Conv2d(
26
+ in_channels=feat_dims[i],
27
+ out_channels=feat_dims[i+1],
28
+ kernel_size=kernel,
29
+ stride=stride,
30
+ padding=padding
31
+ ))
32
+ # Do not use BN and ReLU for final estimation
33
+ if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
34
+ layers.append(nn.BatchNorm2d(feat_dims[i+1]))
35
+ layers.append(nn.ReLU(inplace=True))
36
+
37
+ return nn.Sequential(*layers)
38
+
39
+ def make_deconv_layers(feat_dims, bnrelu_final=True):
40
+ layers = []
41
+ for i in range(len(feat_dims)-1):
42
+ layers.append(
43
+ nn.ConvTranspose2d(
44
+ in_channels=feat_dims[i],
45
+ out_channels=feat_dims[i+1],
46
+ kernel_size=4,
47
+ stride=2,
48
+ padding=1,
49
+ output_padding=0,
50
+ bias=False))
51
+
52
+ # Do not use BN and ReLU for final estimation
53
+ if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
54
+ layers.append(nn.BatchNorm2d(feat_dims[i+1]))
55
+ layers.append(nn.ReLU(inplace=True))
56
+
57
+ return nn.Sequential(*layers)
58
+
59
+ def sample_joint_features(img_feat, joint_xy):
60
+ height, width = img_feat.shape[2:]
61
+ x = joint_xy[:, :, 0] / (width - 1) * 2 - 1
62
+ y = joint_xy[:, :, 1] / (height - 1) * 2 - 1
63
+ grid = torch.stack((x, y), 2)[:, :, None, :]
64
+ img_feat = F.grid_sample(img_feat, grid, align_corners=True)[:, :, :, 0] # batch_size, channel_dim, joint_num
65
+ img_feat = img_feat.permute(0, 2, 1).contiguous() # batch_size, joint_num, channel_dim
66
+ return img_feat
67
+
68
+ def perspective_projection(points: torch.Tensor,
69
+ translation: torch.Tensor,
70
+ focal_length: torch.Tensor,
71
+ camera_center: Optional[torch.Tensor] = None,
72
+ rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
73
+ """
74
+ Computes the perspective projection of a set of 3D points.
75
+ Args:
76
+ points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
77
+ translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
78
+ focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
79
+ camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
80
+ rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
81
+ Returns:
82
+ torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
83
+ """
84
+ batch_size = points.shape[0]
85
+ if rotation is None:
86
+ rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
87
+ if camera_center is None:
88
+ camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
89
+ # Populate intrinsic camera matrix K.
90
+ K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
91
+ K[:,0,0] = focal_length[:,0]
92
+ K[:,1,1] = focal_length[:,1]
93
+ K[:,2,2] = 1.
94
+ K[:,:-1, -1] = camera_center
95
+ # Transform points
96
+ points = torch.einsum('bij,bkj->bki', rotation, points)
97
+ points = points + translation.unsqueeze(1)
98
+
99
+ # Apply perspective distortion
100
+ projected_points = points / points[:,:,-1].unsqueeze(-1)
101
+
102
+ # Apply camera intrinsics
103
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
104
+
105
+ return projected_points[:, :, :-1]
106
+
107
+ class DeConvNet(nn.Module):
108
+ def __init__(self, feat_dim=768, upscale=4):
109
+ super(DeConvNet, self).__init__()
110
+ self.first_conv = make_conv_layers([feat_dim, feat_dim//2], kernel=1, stride=1, padding=0, bnrelu_final=False)
111
+ self.deconv = nn.ModuleList([])
112
+ for i in range(int(math.log2(upscale))+1):
113
+ if i==0:
114
+ self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4]))
115
+ elif i==1:
116
+ self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4, feat_dim//8]))
117
+ elif i==2:
118
+ self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4, feat_dim//8, feat_dim//8]))
119
+
120
+ def forward(self, img_feat):
121
+
122
+ face_img_feats = []
123
+ img_feat = self.first_conv(img_feat)
124
+ face_img_feats.append(img_feat)
125
+ for i, deconv in enumerate(self.deconv):
126
+ scale = 2**i
127
+ img_feat_i = deconv(img_feat)
128
+ face_img_feat = img_feat_i
129
+ face_img_feats.append(face_img_feat)
130
+ return face_img_feats[::-1] # high resolution -> low resolution
131
+
132
+ class DeConvNet_v2(nn.Module):
133
+ def __init__(self, feat_dim=768):
134
+ super(DeConvNet_v2, self).__init__()
135
+ self.first_conv = make_conv_layers([feat_dim, feat_dim//2], kernel=1, stride=1, padding=0, bnrelu_final=False)
136
+ self.deconv = nn.Sequential(*[nn.ConvTranspose2d(in_channels=feat_dim//2, out_channels=feat_dim//4, kernel_size=4, stride=4, padding=0, output_padding=0, bias=False),
137
+ nn.BatchNorm2d(feat_dim//4),
138
+ nn.ReLU(inplace=True)])
139
+
140
+ def forward(self, img_feat):
141
+
142
+ face_img_feats = []
143
+ img_feat = self.first_conv(img_feat)
144
+ img_feat = self.deconv(img_feat)
145
+
146
+ return [img_feat]
147
+
148
+ class RefineNet(nn.Module):
149
+ def __init__(self, cfg, feat_dim=1280, upscale=3):
150
+ super(RefineNet, self).__init__()
151
+ #self.deconv = DeConvNet_v2(feat_dim=feat_dim)
152
+ #self.out_dim = feat_dim//4
153
+
154
+ self.deconv = DeConvNet(feat_dim=feat_dim, upscale=upscale)
155
+ self.out_dim = feat_dim//8 + feat_dim//4 + feat_dim//2
156
+ self.dec_pose = nn.Linear(self.out_dim, 96)
157
+ self.dec_cam = nn.Linear(self.out_dim, 3)
158
+ self.dec_shape = nn.Linear(self.out_dim, 10)
159
+
160
+ self.cfg = cfg
161
+ self.joint_rep_type = cfg.MODEL.MANO_HEAD.get('JOINT_REP', '6d')
162
+ self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
163
+
164
+ def forward(self, img_feat, verts_3d, pred_cam, pred_mano_feats, focal_length):
165
+ B = img_feat.shape[0]
166
+
167
+ img_feats = self.deconv(img_feat)
168
+
169
+ img_feat_sizes = [img_feat.shape[2] for img_feat in img_feats]
170
+
171
+ temp_cams = [torch.stack([pred_cam[:, 1], pred_cam[:, 2],
172
+ 2*focal_length[:, 0]/(img_feat_size * pred_cam[:, 0] +1e-9)],dim=-1) for img_feat_size in img_feat_sizes]
173
+
174
+ verts_2d = [perspective_projection(verts_3d,
175
+ translation=temp_cams[i],
176
+ focal_length=focal_length / img_feat_sizes[i]) for i in range(len(img_feat_sizes))]
177
+
178
+ vert_feats = [sample_joint_features(img_feats[i], verts_2d[i]).max(1).values for i in range(len(img_feat_sizes))]
179
+
180
+ vert_feats = torch.cat(vert_feats, dim=-1)
181
+
182
+ delta_pose = self.dec_pose(vert_feats)
183
+ delta_betas = self.dec_shape(vert_feats)
184
+ delta_cam = self.dec_cam(vert_feats)
185
+
186
+
187
+ pred_hand_pose = pred_mano_feats['hand_pose'] + delta_pose
188
+ pred_betas = pred_mano_feats['betas'] + delta_betas
189
+ pred_cam = pred_mano_feats['cam'] + delta_cam
190
+
191
+ joint_conversion_fn = {
192
+ '6d': rot6d_to_rotmat,
193
+ 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())
194
+ }[self.joint_rep_type]
195
+
196
+ pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(B, self.cfg.MANO.NUM_HAND_JOINTS+1, 3, 3)
197
+
198
+ pred_mano_params = {'global_orient': pred_hand_pose[:, [0]],
199
+ 'hand_pose': pred_hand_pose[:, 1:],
200
+ 'betas': pred_betas}
201
+
202
+ return pred_mano_params, pred_cam
203
+
204
+
WiLoR/wilor/models/losses.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Keypoint2DLoss(nn.Module):
5
+
6
+ def __init__(self, loss_type: str = 'l1'):
7
+ """
8
+ 2D keypoint loss module.
9
+ Args:
10
+ loss_type (str): Choose between l1 and l2 losses.
11
+ """
12
+ super(Keypoint2DLoss, self).__init__()
13
+ if loss_type == 'l1':
14
+ self.loss_fn = nn.L1Loss(reduction='none')
15
+ elif loss_type == 'l2':
16
+ self.loss_fn = nn.MSELoss(reduction='none')
17
+ else:
18
+ raise NotImplementedError('Unsupported loss function')
19
+
20
+ def forward(self, pred_keypoints_2d: torch.Tensor, gt_keypoints_2d: torch.Tensor) -> torch.Tensor:
21
+ """
22
+ Compute 2D reprojection loss on the keypoints.
23
+ Args:
24
+ pred_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 2] containing projected 2D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
25
+ gt_keypoints_2d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the ground truth 2D keypoints and confidence.
26
+ Returns:
27
+ torch.Tensor: 2D keypoint loss.
28
+ """
29
+ conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone()
30
+ batch_size = conf.shape[0]
31
+ loss = (conf * self.loss_fn(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).sum(dim=(1,2))
32
+ return loss.sum()
33
+
34
+
35
+ class Keypoint3DLoss(nn.Module):
36
+
37
+ def __init__(self, loss_type: str = 'l1'):
38
+ """
39
+ 3D keypoint loss module.
40
+ Args:
41
+ loss_type (str): Choose between l1 and l2 losses.
42
+ """
43
+ super(Keypoint3DLoss, self).__init__()
44
+ if loss_type == 'l1':
45
+ self.loss_fn = nn.L1Loss(reduction='none')
46
+ elif loss_type == 'l2':
47
+ self.loss_fn = nn.MSELoss(reduction='none')
48
+ else:
49
+ raise NotImplementedError('Unsupported loss function')
50
+
51
+ def forward(self, pred_keypoints_3d: torch.Tensor, gt_keypoints_3d: torch.Tensor, pelvis_id: int = 0):
52
+ """
53
+ Compute 3D keypoint loss.
54
+ Args:
55
+ pred_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 3] containing the predicted 3D keypoints (B: batch_size, S: num_samples, N: num_keypoints)
56
+ gt_keypoints_3d (torch.Tensor): Tensor of shape [B, S, N, 4] containing the ground truth 3D keypoints and confidence.
57
+ Returns:
58
+ torch.Tensor: 3D keypoint loss.
59
+ """
60
+ batch_size = pred_keypoints_3d.shape[0]
61
+ gt_keypoints_3d = gt_keypoints_3d.clone()
62
+ pred_keypoints_3d = pred_keypoints_3d - pred_keypoints_3d[:, pelvis_id, :].unsqueeze(dim=1)
63
+ gt_keypoints_3d[:, :, :-1] = gt_keypoints_3d[:, :, :-1] - gt_keypoints_3d[:, pelvis_id, :-1].unsqueeze(dim=1)
64
+ conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone()
65
+ gt_keypoints_3d = gt_keypoints_3d[:, :, :-1]
66
+ loss = (conf * self.loss_fn(pred_keypoints_3d, gt_keypoints_3d)).sum(dim=(1,2))
67
+ return loss.sum()
68
+
69
+ class ParameterLoss(nn.Module):
70
+
71
+ def __init__(self):
72
+ """
73
+ MANO parameter loss module.
74
+ """
75
+ super(ParameterLoss, self).__init__()
76
+ self.loss_fn = nn.MSELoss(reduction='none')
77
+
78
+ def forward(self, pred_param: torch.Tensor, gt_param: torch.Tensor, has_param: torch.Tensor):
79
+ """
80
+ Compute MANO parameter loss.
81
+ Args:
82
+ pred_param (torch.Tensor): Tensor of shape [B, S, ...] containing the predicted parameters (body pose / global orientation / betas)
83
+ gt_param (torch.Tensor): Tensor of shape [B, S, ...] containing the ground truth MANO parameters.
84
+ Returns:
85
+ torch.Tensor: L2 parameter loss loss.
86
+ """
87
+ batch_size = pred_param.shape[0]
88
+ num_dims = len(pred_param.shape)
89
+ mask_dimension = [batch_size] + [1] * (num_dims-1)
90
+ has_param = has_param.type(pred_param.type()).view(*mask_dimension)
91
+ loss_param = (has_param * self.loss_fn(pred_param, gt_param))
92
+ return loss_param.sum()
WiLoR/wilor/models/mano_wrapper.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pickle
4
+ from typing import Optional
5
+ import smplx
6
+ from smplx.lbs import vertices2joints
7
+ from smplx.utils import MANOOutput, to_tensor
8
+ from smplx.vertex_ids import vertex_ids
9
+
10
+
11
+ class MANO(smplx.MANOLayer):
12
+ def __init__(self, *args, joint_regressor_extra: Optional[str] = None, **kwargs):
13
+ """
14
+ Extension of the official MANO implementation to support more joints.
15
+ Args:
16
+ Same as MANOLayer.
17
+ joint_regressor_extra (str): Path to extra joint regressor.
18
+ """
19
+ super(MANO, self).__init__(*args, **kwargs)
20
+ mano_to_openpose = [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]
21
+
22
+ #2, 3, 5, 4, 1
23
+ if joint_regressor_extra is not None:
24
+ self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
25
+ self.register_buffer('extra_joints_idxs', to_tensor(list(vertex_ids['mano'].values()), dtype=torch.long))
26
+ self.register_buffer('joint_map', torch.tensor(mano_to_openpose, dtype=torch.long))
27
+
28
+ def forward(self, *args, **kwargs) -> MANOOutput:
29
+ """
30
+ Run forward pass. Same as MANO and also append an extra set of joints if joint_regressor_extra is specified.
31
+ """
32
+ mano_output = super(MANO, self).forward(*args, **kwargs)
33
+ extra_joints = torch.index_select(mano_output.vertices, 1, self.extra_joints_idxs)
34
+ joints = torch.cat([mano_output.joints, extra_joints], dim=1)
35
+ joints = joints[:, self.joint_map, :]
36
+ if hasattr(self, 'joint_regressor_extra'):
37
+ extra_joints = vertices2joints(self.joint_regressor_extra, mano_output.vertices)
38
+ joints = torch.cat([joints, extra_joints], dim=1)
39
+ mano_output.joints = joints
40
+ return mano_output
WiLoR/wilor/models/wilor.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ from typing import Any, Dict, Mapping, Tuple
4
+
5
+ from yacs.config import CfgNode
6
+
7
+ from ..utils import SkeletonRenderer, MeshRenderer
8
+ from ..utils.geometry import aa_to_rotmat, perspective_projection
9
+ from ..utils.pylogger import get_pylogger
10
+ from .backbones import create_backbone
11
+ from .heads import RefineNet
12
+ from .discriminator import Discriminator
13
+ from .losses import Keypoint3DLoss, Keypoint2DLoss, ParameterLoss
14
+ from . import MANO
15
+
16
+ log = get_pylogger(__name__)
17
+
18
+ class WiLoR(pl.LightningModule):
19
+
20
+ def __init__(self, cfg: CfgNode, init_renderer: bool = True):
21
+ """
22
+ Setup WiLoR model
23
+ Args:
24
+ cfg (CfgNode): Config file as a yacs CfgNode
25
+ """
26
+ super().__init__()
27
+
28
+ # Save hyperparameters
29
+ self.save_hyperparameters(logger=False, ignore=['init_renderer'])
30
+
31
+ self.cfg = cfg
32
+ # Create backbone feature extractor
33
+ self.backbone = create_backbone(cfg)
34
+ if cfg.MODEL.BACKBONE.get('PRETRAINED_WEIGHTS', None):
35
+ log.info(f'Loading backbone weights from {cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS}')
36
+ self.backbone.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAINED_WEIGHTS, map_location='cpu')['state_dict'], strict = False)
37
+
38
+ # Create RefineNet head
39
+ self.refine_net = RefineNet(cfg, feat_dim=1280, upscale=3)
40
+
41
+ # Create discriminator
42
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
43
+ self.discriminator = Discriminator()
44
+
45
+ # Define loss functions
46
+ self.keypoint_3d_loss = Keypoint3DLoss(loss_type='l1')
47
+ self.keypoint_2d_loss = Keypoint2DLoss(loss_type='l1')
48
+ self.mano_parameter_loss = ParameterLoss()
49
+
50
+ # Instantiate MANO model
51
+ mano_cfg = {k.lower(): v for k,v in dict(cfg.MANO).items()}
52
+ self.mano = MANO(**mano_cfg)
53
+
54
+ # Buffer that shows whetheer we need to initialize ActNorm layers
55
+ self.register_buffer('initialized', torch.tensor(False))
56
+ # Setup renderer for visualization
57
+ if init_renderer:
58
+ self.renderer = SkeletonRenderer(self.cfg)
59
+ self.mesh_renderer = MeshRenderer(self.cfg, faces=self.mano.faces)
60
+ else:
61
+ self.renderer = None
62
+ self.mesh_renderer = None
63
+
64
+
65
+ # Disable automatic optimization since we use adversarial training
66
+ self.automatic_optimization = False
67
+
68
+ def on_after_backward(self):
69
+ for name, param in self.named_parameters():
70
+ if param.grad is None:
71
+ print(param.shape)
72
+ print(name)
73
+
74
+
75
+ def get_parameters(self):
76
+ #all_params = list(self.mano_head.parameters())
77
+ all_params = list(self.backbone.parameters())
78
+ return all_params
79
+
80
+ def configure_optimizers(self) -> Tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
81
+ """
82
+ Setup model and distriminator Optimizers
83
+ Returns:
84
+ Tuple[torch.optim.Optimizer, torch.optim.Optimizer]: Model and discriminator optimizers
85
+ """
86
+ param_groups = [{'params': filter(lambda p: p.requires_grad, self.get_parameters()), 'lr': self.cfg.TRAIN.LR}]
87
+
88
+ optimizer = torch.optim.AdamW(params=param_groups,
89
+ # lr=self.cfg.TRAIN.LR,
90
+ weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
91
+ optimizer_disc = torch.optim.AdamW(params=self.discriminator.parameters(),
92
+ lr=self.cfg.TRAIN.LR,
93
+ weight_decay=self.cfg.TRAIN.WEIGHT_DECAY)
94
+
95
+ return optimizer, optimizer_disc
96
+
97
+ def forward_step(self, batch: Dict, train: bool = False) -> Dict:
98
+ """
99
+ Run a forward step of the network
100
+ Args:
101
+ batch (Dict): Dictionary containing batch data
102
+ train (bool): Flag indicating whether it is training or validation mode
103
+ Returns:
104
+ Dict: Dictionary containing the regression output
105
+ """
106
+ # Use RGB image as input
107
+ x = batch['img']
108
+ batch_size = x.shape[0]
109
+ # Compute conditioning features using the backbone
110
+ # if using ViT backbone, we need to use a different aspect ratio
111
+ temp_mano_params, pred_cam, pred_mano_feats, vit_out = self.backbone(x[:,:,:,32:-32]) # B, 1280, 16, 12
112
+
113
+
114
+ # Compute camera translation
115
+ device = temp_mano_params['hand_pose'].device
116
+ dtype = temp_mano_params['hand_pose'].dtype
117
+ focal_length = self.cfg.EXTRA.FOCAL_LENGTH * torch.ones(batch_size, 2, device=device, dtype=dtype)
118
+
119
+
120
+ ## Temp MANO
121
+ temp_mano_params['global_orient'] = temp_mano_params['global_orient'].reshape(batch_size, -1, 3, 3)
122
+ temp_mano_params['hand_pose'] = temp_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3)
123
+ temp_mano_params['betas'] = temp_mano_params['betas'].reshape(batch_size, -1)
124
+ temp_mano_output = self.mano(**{k: v.float() for k,v in temp_mano_params.items()}, pose2rot=False)
125
+ #temp_keypoints_3d = temp_mano_output.joints
126
+ temp_vertices = temp_mano_output.vertices
127
+
128
+ pred_mano_params, pred_cam = self.refine_net(vit_out, temp_vertices, pred_cam, pred_mano_feats, focal_length)
129
+ # Store useful regression outputs to the output dict
130
+
131
+
132
+ output = {}
133
+ output['pred_cam'] = pred_cam
134
+ output['pred_mano_params'] = {k: v.clone() for k,v in pred_mano_params.items()}
135
+
136
+ pred_cam_t = torch.stack([pred_cam[:, 1],
137
+ pred_cam[:, 2],
138
+ 2*focal_length[:, 0]/(self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] +1e-9)],dim=-1)
139
+ output['pred_cam_t'] = pred_cam_t
140
+ output['focal_length'] = focal_length
141
+
142
+ # Compute model vertices, joints and the projected joints
143
+ pred_mano_params['global_orient'] = pred_mano_params['global_orient'].reshape(batch_size, -1, 3, 3)
144
+ pred_mano_params['hand_pose'] = pred_mano_params['hand_pose'].reshape(batch_size, -1, 3, 3)
145
+ pred_mano_params['betas'] = pred_mano_params['betas'].reshape(batch_size, -1)
146
+ mano_output = self.mano(**{k: v.float() for k,v in pred_mano_params.items()}, pose2rot=False)
147
+ pred_keypoints_3d = mano_output.joints
148
+ pred_vertices = mano_output.vertices
149
+
150
+ output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
151
+ output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
152
+ pred_cam_t = pred_cam_t.reshape(-1, 3)
153
+ focal_length = focal_length.reshape(-1, 2)
154
+
155
+ pred_keypoints_2d = perspective_projection(pred_keypoints_3d,
156
+ translation=pred_cam_t,
157
+ focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE)
158
+ output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2)
159
+
160
+ return output
161
+
162
+ def compute_loss(self, batch: Dict, output: Dict, train: bool = True) -> torch.Tensor:
163
+ """
164
+ Compute losses given the input batch and the regression output
165
+ Args:
166
+ batch (Dict): Dictionary containing batch data
167
+ output (Dict): Dictionary containing the regression output
168
+ train (bool): Flag indicating whether it is training or validation mode
169
+ Returns:
170
+ torch.Tensor : Total loss for current batch
171
+ """
172
+
173
+ pred_mano_params = output['pred_mano_params']
174
+ pred_keypoints_2d = output['pred_keypoints_2d']
175
+ pred_keypoints_3d = output['pred_keypoints_3d']
176
+
177
+
178
+ batch_size = pred_mano_params['hand_pose'].shape[0]
179
+ device = pred_mano_params['hand_pose'].device
180
+ dtype = pred_mano_params['hand_pose'].dtype
181
+
182
+ # Get annotations
183
+ gt_keypoints_2d = batch['keypoints_2d']
184
+ gt_keypoints_3d = batch['keypoints_3d']
185
+ gt_mano_params = batch['mano_params']
186
+ has_mano_params = batch['has_mano_params']
187
+ is_axis_angle = batch['mano_params_is_axis_angle']
188
+
189
+ # Compute 3D keypoint loss
190
+ loss_keypoints_2d = self.keypoint_2d_loss(pred_keypoints_2d, gt_keypoints_2d)
191
+ loss_keypoints_3d = self.keypoint_3d_loss(pred_keypoints_3d, gt_keypoints_3d, pelvis_id=0)
192
+
193
+ # Compute loss on MANO parameters
194
+ loss_mano_params = {}
195
+ for k, pred in pred_mano_params.items():
196
+ gt = gt_mano_params[k].view(batch_size, -1)
197
+ if is_axis_angle[k].all():
198
+ gt = aa_to_rotmat(gt.reshape(-1, 3)).view(batch_size, -1, 3, 3)
199
+ has_gt = has_mano_params[k]
200
+ loss_mano_params[k] = self.mano_parameter_loss(pred.reshape(batch_size, -1), gt.reshape(batch_size, -1), has_gt)
201
+
202
+ loss = self.cfg.LOSS_WEIGHTS['KEYPOINTS_3D'] * loss_keypoints_3d+\
203
+ self.cfg.LOSS_WEIGHTS['KEYPOINTS_2D'] * loss_keypoints_2d+\
204
+ sum([loss_mano_params[k] * self.cfg.LOSS_WEIGHTS[k.upper()] for k in loss_mano_params])
205
+
206
+
207
+ losses = dict(loss=loss.detach(),
208
+ loss_keypoints_2d=loss_keypoints_2d.detach(),
209
+ loss_keypoints_3d=loss_keypoints_3d.detach())
210
+
211
+ for k, v in loss_mano_params.items():
212
+ losses['loss_' + k] = v.detach()
213
+
214
+ output['losses'] = losses
215
+
216
+ return loss
217
+
218
+ # Tensoroboard logging should run from first rank only
219
+ @pl.utilities.rank_zero.rank_zero_only
220
+ def tensorboard_logging(self, batch: Dict, output: Dict, step_count: int, train: bool = True, write_to_summary_writer: bool = True) -> None:
221
+ """
222
+ Log results to Tensorboard
223
+ Args:
224
+ batch (Dict): Dictionary containing batch data
225
+ output (Dict): Dictionary containing the regression output
226
+ step_count (int): Global training step count
227
+ train (bool): Flag indicating whether it is training or validation mode
228
+ """
229
+
230
+ mode = 'train' if train else 'val'
231
+ batch_size = batch['keypoints_2d'].shape[0]
232
+ images = batch['img']
233
+ images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1,3,1,1)
234
+ images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1,3,1,1)
235
+ #images = 255*images.permute(0, 2, 3, 1).cpu().numpy()
236
+
237
+ pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3)
238
+ pred_vertices = output['pred_vertices'].detach().reshape(batch_size, -1, 3)
239
+ focal_length = output['focal_length'].detach().reshape(batch_size, 2)
240
+ gt_keypoints_3d = batch['keypoints_3d']
241
+ gt_keypoints_2d = batch['keypoints_2d']
242
+
243
+ losses = output['losses']
244
+ pred_cam_t = output['pred_cam_t'].detach().reshape(batch_size, 3)
245
+ pred_keypoints_2d = output['pred_keypoints_2d'].detach().reshape(batch_size, -1, 2)
246
+ if write_to_summary_writer:
247
+ summary_writer = self.logger.experiment
248
+ for loss_name, val in losses.items():
249
+ summary_writer.add_scalar(mode +'/' + loss_name, val.detach().item(), step_count)
250
+ num_images = min(batch_size, self.cfg.EXTRA.NUM_LOG_IMAGES)
251
+
252
+ gt_keypoints_3d = batch['keypoints_3d']
253
+ pred_keypoints_3d = output['pred_keypoints_3d'].detach().reshape(batch_size, -1, 3)
254
+
255
+ # We render the skeletons instead of the full mesh because rendering a lot of meshes will make the training slow.
256
+ #predictions = self.renderer(pred_keypoints_3d[:num_images],
257
+ # gt_keypoints_3d[:num_images],
258
+ # 2 * gt_keypoints_2d[:num_images],
259
+ # images=images[:num_images],
260
+ # camera_translation=pred_cam_t[:num_images])
261
+ predictions = self.mesh_renderer.visualize_tensorboard(pred_vertices[:num_images].cpu().numpy(),
262
+ pred_cam_t[:num_images].cpu().numpy(),
263
+ images[:num_images].cpu().numpy(),
264
+ pred_keypoints_2d[:num_images].cpu().numpy(),
265
+ gt_keypoints_2d[:num_images].cpu().numpy(),
266
+ focal_length=focal_length[:num_images].cpu().numpy())
267
+ if write_to_summary_writer:
268
+ summary_writer.add_image('%s/predictions' % mode, predictions, step_count)
269
+
270
+ return predictions
271
+
272
+ def forward(self, batch: Dict) -> Dict:
273
+ """
274
+ Run a forward step of the network in val mode
275
+ Args:
276
+ batch (Dict): Dictionary containing batch data
277
+ Returns:
278
+ Dict: Dictionary containing the regression output
279
+ """
280
+ return self.forward_step(batch, train=False)
281
+
282
+ def training_step_discriminator(self, batch: Dict,
283
+ hand_pose: torch.Tensor,
284
+ betas: torch.Tensor,
285
+ optimizer: torch.optim.Optimizer) -> torch.Tensor:
286
+ """
287
+ Run a discriminator training step
288
+ Args:
289
+ batch (Dict): Dictionary containing mocap batch data
290
+ hand_pose (torch.Tensor): Regressed hand pose from current step
291
+ betas (torch.Tensor): Regressed betas from current step
292
+ optimizer (torch.optim.Optimizer): Discriminator optimizer
293
+ Returns:
294
+ torch.Tensor: Discriminator loss
295
+ """
296
+ batch_size = hand_pose.shape[0]
297
+ gt_hand_pose = batch['hand_pose']
298
+ gt_betas = batch['betas']
299
+ gt_rotmat = aa_to_rotmat(gt_hand_pose.view(-1,3)).view(batch_size, -1, 3, 3)
300
+ disc_fake_out = self.discriminator(hand_pose.detach(), betas.detach())
301
+ loss_fake = ((disc_fake_out - 0.0) ** 2).sum() / batch_size
302
+ disc_real_out = self.discriminator(gt_rotmat, gt_betas)
303
+ loss_real = ((disc_real_out - 1.0) ** 2).sum() / batch_size
304
+ loss_disc = loss_fake + loss_real
305
+ loss = self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_disc
306
+ optimizer.zero_grad()
307
+ self.manual_backward(loss)
308
+ optimizer.step()
309
+ return loss_disc.detach()
310
+
311
+ def training_step(self, joint_batch: Dict, batch_idx: int) -> Dict:
312
+ """
313
+ Run a full training step
314
+ Args:
315
+ joint_batch (Dict): Dictionary containing image and mocap batch data
316
+ batch_idx (int): Unused.
317
+ batch_idx (torch.Tensor): Unused.
318
+ Returns:
319
+ Dict: Dictionary containing regression output.
320
+ """
321
+ batch = joint_batch['img']
322
+ mocap_batch = joint_batch['mocap']
323
+ optimizer = self.optimizers(use_pl_optimizer=True)
324
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
325
+ optimizer, optimizer_disc = optimizer
326
+
327
+ batch_size = batch['img'].shape[0]
328
+ output = self.forward_step(batch, train=True)
329
+ pred_mano_params = output['pred_mano_params']
330
+ if self.cfg.get('UPDATE_GT_SPIN', False):
331
+ self.update_batch_gt_spin(batch, output)
332
+ loss = self.compute_loss(batch, output, train=True)
333
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
334
+ disc_out = self.discriminator(pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1))
335
+ loss_adv = ((disc_out - 1.0) ** 2).sum() / batch_size
336
+ loss = loss + self.cfg.LOSS_WEIGHTS.ADVERSARIAL * loss_adv
337
+
338
+ # Error if Nan
339
+ if torch.isnan(loss):
340
+ raise ValueError('Loss is NaN')
341
+
342
+ optimizer.zero_grad()
343
+ self.manual_backward(loss)
344
+ # Clip gradient
345
+ if self.cfg.TRAIN.get('GRAD_CLIP_VAL', 0) > 0:
346
+ gn = torch.nn.utils.clip_grad_norm_(self.get_parameters(), self.cfg.TRAIN.GRAD_CLIP_VAL, error_if_nonfinite=True)
347
+ self.log('train/grad_norm', gn, on_step=True, on_epoch=True, prog_bar=True, logger=True)
348
+ optimizer.step()
349
+ if self.cfg.LOSS_WEIGHTS.ADVERSARIAL > 0:
350
+ loss_disc = self.training_step_discriminator(mocap_batch, pred_mano_params['hand_pose'].reshape(batch_size, -1), pred_mano_params['betas'].reshape(batch_size, -1), optimizer_disc)
351
+ output['losses']['loss_gen'] = loss_adv
352
+ output['losses']['loss_disc'] = loss_disc
353
+
354
+ if self.global_step > 0 and self.global_step % self.cfg.GENERAL.LOG_STEPS == 0:
355
+ self.tensorboard_logging(batch, output, self.global_step, train=True)
356
+
357
+ self.log('train/loss', output['losses']['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=False)
358
+
359
+ return output
360
+
361
+ def validation_step(self, batch: Dict, batch_idx: int, dataloader_idx=0) -> Dict:
362
+ """
363
+ Run a validation step and log to Tensorboard
364
+ Args:
365
+ batch (Dict): Dictionary containing batch data
366
+ batch_idx (int): Unused.
367
+ Returns:
368
+ Dict: Dictionary containing regression output.
369
+ """
370
+ # batch_size = batch['img'].shape[0]
371
+ output = self.forward_step(batch, train=False)
372
+ loss = self.compute_loss(batch, output, train=False)
373
+ output['loss'] = loss
374
+ self.tensorboard_logging(batch, output, self.global_step, train=False)
375
+
376
+ return output
WiLoR/wilor/utils/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any
3
+
4
+ from .renderer import Renderer
5
+ from .mesh_renderer import MeshRenderer
6
+ from .skeleton_renderer import SkeletonRenderer
7
+ from .pose_utils import eval_pose, Evaluator
8
+
9
+ def recursive_to(x: Any, target: torch.device):
10
+ """
11
+ Recursively transfer a batch of data to the target device
12
+ Args:
13
+ x (Any): Batch of data.
14
+ target (torch.device): Target device.
15
+ Returns:
16
+ Batch of data where all tensors are transfered to the target device.
17
+ """
18
+ if isinstance(x, dict):
19
+ return {k: recursive_to(v, target) for k, v in x.items()}
20
+ elif isinstance(x, torch.Tensor):
21
+ return x.to(target)
22
+ elif isinstance(x, list):
23
+ return [recursive_to(i, target) for i in x]
24
+ else:
25
+ return x
WiLoR/wilor/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.84 kB). View file