chris10 commited on
Commit
15bc41b
1 Parent(s): e6c4101
Files changed (16) hide show
  1. .gitignore +162 -0
  2. Dockerfile +1 -6
  3. README.md +0 -1
  4. app.py +25 -14
  5. arg_parser.py +72 -0
  6. demo.py +263 -0
  7. model/TEHNet.py +208 -0
  8. model/__init__.py +1 -0
  9. model/model.py +64 -0
  10. model/pointnet2_utils.py +315 -0
  11. model/utils.py +42 -0
  12. record.py +20 -0
  13. requirements.txt +5 -0
  14. settings.py +45 -0
  15. test.py +93 -0
  16. vis.py +13 -0
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+ src/Ev2Hands/outputs
162
+ src/HandSimulator/logs
Dockerfile CHANGED
@@ -15,9 +15,4 @@ RUN pip install --no-cache-dir --upgrade -r requirements.txt
15
 
16
  RUN cd esim_py && pip install .
17
 
18
-
19
- EXPOSE 8501
20
-
21
- # CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
22
- # streamlit run app.py
23
- CMD ["streamlit", "run", "app.py", "0.0.0.0", "--port", "7860"]
 
15
 
16
  RUN cd esim_py && pip install .
17
 
18
+ CMD ["python3", "app.py"]
 
 
 
 
 
README.md CHANGED
@@ -6,7 +6,6 @@ colorTo: indigo
6
  sdk: docker
7
  pinned: false
8
  license: cc-by-4.0
9
- app_port: 8501
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
6
  sdk: docker
7
  pinned: false
8
  license: cc-by-4.0
 
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,14 +1,25 @@
1
- import cv2
2
- import streamlit as st
3
-
4
- st.title("Webcam Live Feed")
5
- run = st.checkbox('Run')
6
- FRAME_WINDOW = st.image([])
7
- camera = cv2.VideoCapture(0)
8
-
9
- while run:
10
- _, frame = camera.read()
11
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
12
- FRAME_WINDOW.image(frame)
13
- else:
14
- st.write('Stopped')
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+
4
+
5
+
6
+ import gradio as gr
7
+ import os
8
+
9
+
10
+ def video_identity(video):
11
+ print(video)
12
+ return video
13
+
14
+
15
+ demo = gr.Interface(video_identity,
16
+ gr.Video(),
17
+ "playable_video",
18
+ examples=[
19
+ os.path.join(os.path.dirname(__file__),
20
+ "example/video.mp4")],
21
+ cache_examples=True)
22
+
23
+ if __name__ == "__main__":
24
+ demo.launch(server_name="0.0.0.0", server_port=7860)
25
+
arg_parser.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ def demo():
5
+ parser = argparse.ArgumentParser(description='Demo for Ev2Hands')
6
+
7
+ parser.add_argument('--batch_size', dest='batch_size', required=False,
8
+ help='Set the batch_size (default: 128)', default='32')
9
+
10
+ parser.add_argument('--checkpoint_path', dest='checkpoint', required=False,
11
+ help='path of checkpoint_path', default='./savedmodels/best_model_state_dict.pth')
12
+
13
+ args = parser.parse_args()
14
+
15
+ os.environ['CHECKPOINT_PATH'] = args.checkpoint
16
+ os.environ['BATCH_SIZE'] = args.batch_size
17
+
18
+ return args
19
+
20
+
21
+ def evaluate():
22
+ parser = argparse.ArgumentParser(description='Evaluation of Ev2Hands')
23
+
24
+
25
+ parser.add_argument('--batch_size', dest='batch_size', required=False,
26
+ help='Set the batch_size (default: 128)', default='128')
27
+
28
+ parser.add_argument('--checkpoint_path', dest='checkpoint', required=False,
29
+ help='path of checkpoint',
30
+ default='./savedmodels/best_model_state_dict.pth')
31
+
32
+ args = parser.parse_args()
33
+
34
+ os.environ['CHECKPOINT_PATH'] = args.checkpoint
35
+ os.environ['BATCH_SIZE'] = args.batch_size
36
+
37
+ return args
38
+
39
+
40
+ def train():
41
+ parser = argparse.ArgumentParser(description='Trainer of Ev2Hands')
42
+
43
+ parser.add_argument('--batch_size', dest='batch_size', required=False,
44
+ help='Set the batch_size (default: 8)', default='8')
45
+
46
+ parser.add_argument('--checkpoint_path', dest='checkpoint', required=False,
47
+ help='path of checkpoint', default='')
48
+
49
+ args = parser.parse_args()
50
+
51
+ os.environ['CHECKPOINT_PATH'] = args.checkpoint
52
+ os.environ['BATCH_SIZE'] = args.batch_size
53
+
54
+ return args
55
+
56
+
57
+ def finetune():
58
+ parser = argparse.ArgumentParser(description='FineTuner of Ev2Hands for real data')
59
+
60
+ parser.add_argument('--batch_size', dest='batch_size', required=False,
61
+ help='Set the batch_size (default: 8)', default='8')
62
+
63
+ parser.add_argument('--checkpoint_path', dest='checkpoint', required=False,
64
+ help='path of checkpoint',
65
+ default='./savedmodels/best_model_state_dict.pth')
66
+
67
+ args = parser.parse_args()
68
+
69
+ os.environ['CHECKPOINT_PATH'] = args.checkpoint
70
+ os.environ['BATCH_SIZE'] = args.batch_size
71
+
72
+ return args
demo.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ os.environ['ERPC'] = '1'
4
+
5
+ import esim_py
6
+
7
+ import torch
8
+ import cv2
9
+ import time
10
+ import pyrender
11
+ import numpy as np
12
+ import trimesh
13
+
14
+ import arg_parser
15
+
16
+ from model import TEHNetWrapper
17
+ from settings import OUTPUT_HEIGHT, OUTPUT_WIDTH, MAIN_CAMERA, REAL_TEST_DATA_PATH
18
+
19
+
20
+ def pc_normalize(pc):
21
+ pc[:, 0] /= OUTPUT_WIDTH
22
+ pc[:, 1] /= OUTPUT_HEIGHT
23
+ pc[:, :2] = 2 * pc[:, :2] - 1
24
+
25
+ ts = pc[:, 2:]
26
+
27
+ t_max = ts.max(0).values
28
+ t_min = ts.min(0).values
29
+
30
+ ts = (2 * ((ts - t_min) / (t_max - t_min))) - 1
31
+
32
+ pc[:, 2:] = ts
33
+
34
+ return pc
35
+
36
+
37
+
38
+ def process_events(events):
39
+ n_events = 2048
40
+
41
+ events[:, 2] -= events[0, 2] # normalize ts
42
+
43
+ event_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.float32)
44
+ count_grid = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH), dtype=np.float32)
45
+
46
+ x, y, t, p = events.T
47
+ x, y = x.astype(dtype=np.int32), y.astype(dtype=np.int32)
48
+
49
+ np.add.at(event_grid, (y, x, 0), t)
50
+ np.add.at(event_grid, (y, x, 1), p == 1)
51
+ np.add.at(event_grid, (y, x, 2), p != 1)
52
+
53
+ np.add.at(count_grid, (y, x), 1)
54
+
55
+
56
+ yi, xi = np.nonzero(count_grid)
57
+ t_avg = event_grid[yi, xi, 0] / count_grid[yi, xi]
58
+ p_evn = event_grid[yi, xi, 1]
59
+ n_evn = event_grid[yi, xi, 2]
60
+
61
+ events = np.hstack([xi[:, None], yi[:, None], t_avg[:, None], p_evn[:, None], n_evn[:, None]])
62
+
63
+ sampled_indices = np.random.choice(events.shape[0], n_events)
64
+ events = events[sampled_indices]
65
+
66
+ events = torch.tensor(events, dtype=torch.float32)
67
+
68
+ coordinates = np.zeros((events.shape[0], 2))
69
+ event_frame = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8)
70
+ for idx, (x, y, t_avg, p_evn, n_evn) in enumerate(events):
71
+ y, x = y.int(), x.int()
72
+
73
+ coordinates[idx] = (y, x)
74
+ event_frame[y, x, 0] = (p_evn / (p_evn + n_evn)) * 255
75
+ event_frame[y, x, -1] = (n_evn / (p_evn + n_evn)) * 255
76
+
77
+
78
+ events[:, :3] = pc_normalize(events[:, :3])
79
+
80
+ hand_data = {
81
+ 'event_frame': torch.tensor(event_frame, dtype=torch.uint8),
82
+ 'events': events.permute(1, 0).unsqueeze(0),
83
+ 'coordinates': torch.tensor(coordinates, dtype=torch.float32)
84
+ }
85
+
86
+ return hand_data
87
+
88
+
89
+
90
+ def demo(net, device, data):
91
+ net.eval()
92
+
93
+ events = data['events']
94
+ events = events.to(device=device, dtype=torch.float32)
95
+
96
+ start_time = time.time()
97
+ with torch.no_grad():
98
+ outputs = net(events)
99
+
100
+ end_time = time.time()
101
+
102
+ N = events.shape[0]
103
+ print(end_time - start_time)
104
+
105
+ outputs['class_logits'] = outputs['class_logits'].softmax(1).argmax(1).int().cpu()
106
+
107
+ frames = list()
108
+ for idx in range(N):
109
+ hands = dict()
110
+
111
+ hands['left'] = {
112
+ 'vertices': outputs['left']['vertices'][idx].cpu(),
113
+ 'j3d': outputs['left']['j3d'][idx].cpu(),
114
+ }
115
+
116
+ hands['right'] = {
117
+ 'vertices': outputs['right']['vertices'][idx].cpu(),
118
+ 'j3d': outputs['right']['j3d'][idx].cpu(),
119
+ }
120
+
121
+ coordinates = data['coordinates']
122
+
123
+ seg_mask = np.zeros((OUTPUT_HEIGHT, OUTPUT_WIDTH, 3), dtype=np.uint8)
124
+ for edx, (y, x) in enumerate(coordinates):
125
+ y, x = y.int(), x.int()
126
+
127
+ cid = outputs['class_logits'][idx][edx]
128
+
129
+ if cid == 3:
130
+ seg_mask[y, x] = 255
131
+ else:
132
+ seg_mask[y, x, cid] = 255
133
+
134
+ hands['seg_mask'] = seg_mask
135
+
136
+ frames.append(hands)
137
+
138
+ return frames
139
+
140
+
141
+
142
+ def main():
143
+ arg_parser.demo()
144
+ os.makedirs('outputs', exist_ok=True)
145
+
146
+ device = torch.device('cpu')
147
+
148
+ net = TEHNetWrapper(device=device)
149
+
150
+ save_path = os.environ['CHECKPOINT_PATH']
151
+ batch_size = int(os.environ['BATCH_SIZE'])
152
+
153
+ checkpoint = torch.load(save_path, map_location=device)
154
+ net.load_state_dict(checkpoint['state_dict'], strict=True)
155
+
156
+ renderer = pyrender.OffscreenRenderer(viewport_width=OUTPUT_WIDTH, viewport_height=OUTPUT_HEIGHT)
157
+
158
+ scene = pyrender.Scene(ambient_light=(0.3, 0.3, 0.3))
159
+ light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8)
160
+ light_pose = np.eye(4)
161
+ light_pose[:3, 3] = np.array([0, -1, 1])
162
+ scene.add(light, pose=light_pose)
163
+ light_pose[:3, 3] = np.array([0, 1, 1])
164
+ scene.add(light, pose=light_pose)
165
+ light_pose[:3, 3] = np.array([1, 1, 2])
166
+ scene.add(light, pose=light_pose)
167
+
168
+ rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
169
+
170
+ mano_hands = net.hands
171
+
172
+ # camera = cv2.VideoCapture(0)
173
+ input_video_stream = cv2.VideoCapture('video.mp4')
174
+
175
+
176
+
177
+ video_fps = 25
178
+ video = cv2.VideoWriter('outputs/video.mp4', cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (3 * OUTPUT_WIDTH, OUTPUT_HEIGHT))
179
+
180
+
181
+ POS_THRESHOLD = 0.5
182
+ NEG_THRESHOLD = 0.5
183
+ REF_PERIOD = 0.000
184
+
185
+ esim = esim_py.EventSimulator(POS_THRESHOLD, NEG_THRESHOLD, REF_PERIOD, 1e-4, True)
186
+
187
+
188
+ fps = cv2.CAP_PROP_FPS
189
+ ts_s = 1 / fps
190
+ ts_ns = ts_s * 1e9 # convert s to ns
191
+
192
+ is_init = False
193
+ idx = 0
194
+ while True:
195
+ _, frame_bgr = input_video_stream.read()
196
+ frame_bgr = cv2.resize(frame_bgr, (OUTPUT_WIDTH, OUTPUT_HEIGHT))
197
+ frame_gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
198
+ frame_log = np.log(frame_gray.astype("float32") / 255 + 1e-4)
199
+ height, width = frame_log.shape[:2]
200
+
201
+ current_ts_ns = idx * ts_ns
202
+
203
+ if not is_init:
204
+ esim.init(frame_log, current_ts_ns)
205
+ is_init = True
206
+ idx += 1
207
+
208
+ continue
209
+ idx += 1
210
+
211
+ events = esim.generateEventFromCVImage(frame_log, current_ts_ns)
212
+ data = process_events(events)
213
+
214
+ event_frame = data['event_frame'].cpu().numpy().astype(dtype=np.uint8)
215
+
216
+ cv2.imwrite(f"outputs/event_frame_{idx}.png", event_frame)
217
+
218
+ print(idx, event_frame.shape)
219
+
220
+ frame = demo(net=net, device=device, data=data)[0]
221
+ seg_mask = frame['seg_mask']
222
+
223
+ pred_meshes = list()
224
+ for hand_type in ['left', 'right']:
225
+ faces = mano_hands[hand_type].faces
226
+
227
+ pred_mesh = trimesh.Trimesh(frame[hand_type]['vertices'].cpu().numpy() * 1000, faces)
228
+ pred_mesh.visual.vertex_colors = [255, 0, 0]
229
+ pred_meshes.append(pred_mesh)
230
+
231
+ pred_meshes = trimesh.util.concatenate(pred_meshes)
232
+ pred_meshes.apply_transform(rot)
233
+
234
+ camera = MAIN_CAMERA
235
+
236
+ nc = pyrender.Node(camera=camera, matrix=np.eye(4))
237
+ scene.add_node(nc)
238
+
239
+ mesh_node = pyrender.Node(mesh=pyrender.Mesh.from_trimesh(pred_meshes))
240
+ scene.add_node(mesh_node)
241
+ pred_rgb, depth = renderer.render(scene)
242
+ scene.remove_node(mesh_node)
243
+ scene.remove_node(nc)
244
+
245
+ pred_rgb = cv2.cvtColor(pred_rgb, cv2.COLOR_RGB2BGR)
246
+ pred_rgb[pred_rgb == 255] = 0
247
+
248
+ img_stack = np.hstack([event_frame, seg_mask, pred_rgb])
249
+ video.write(img_stack)
250
+
251
+ cv2.imshow('image', img_stack)
252
+ c = cv2.waitKey(1)
253
+
254
+ if c == ord('q'):
255
+ video.release()
256
+ exit(0)
257
+
258
+ video.release()
259
+
260
+
261
+ if __name__ == '__main__':
262
+ main()
263
+
model/TEHNet.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn as nn
3
+ import torch
4
+ import os
5
+ import torch.nn.functional as F
6
+ from .pointnet2_utils import PointNetSetAbstractionMsg, PointNetSetAbstraction, PointNetFeaturePropagation
7
+
8
+
9
+ class AttentionBlock(nn.Module):
10
+ def __init__(self):
11
+ super(AttentionBlock, self).__init__()
12
+
13
+ def forward(self, key, value, query):
14
+ query = query.permute(0, 2, 1)
15
+ N, KC = key.shape[:2]
16
+ key = key.view(N, KC, -1)
17
+
18
+ N, KC = value.shape[:2]
19
+ value = value.view(N, KC, -1)
20
+
21
+ sim_map = torch.bmm(key, query)
22
+ sim_map = (KC ** -.5 ) * sim_map
23
+ sim_map = F.softmax(sim_map, dim=1)
24
+
25
+ context = torch.bmm(sim_map, value)
26
+
27
+ return context
28
+
29
+
30
+ class MANORegressor(nn.Module):
31
+ def __init__(self, n_inp_features=4, n_pose_params=6, n_shape_params=10):
32
+ super(MANORegressor, self).__init__()
33
+
34
+ normal_channel = True
35
+
36
+ if normal_channel:
37
+ additional_channel = n_inp_features
38
+ else:
39
+ additional_channel = 0
40
+
41
+ self.normal_channel = normal_channel
42
+
43
+ self.sa1 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], additional_channel, [[128, 128, 256], [128, 196, 256]])
44
+ self.sa2 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512], group_all=True)
45
+
46
+ self.n_pose_params = n_pose_params
47
+ self.n_mano_params = n_pose_params + n_shape_params
48
+
49
+ self.mano_regressor = nn.Sequential(
50
+ nn.Linear(512, 1024),
51
+ nn.ReLU(),
52
+ nn.BatchNorm1d(1024),
53
+ nn.Dropout(0.3),
54
+ nn.Linear(1024, 3 + self.n_mano_params + 3),
55
+ )
56
+
57
+
58
+ def J3dtoJ2d(self, j3d, scale):
59
+ B, N = j3d.shape[:2]
60
+ device = j3d.device
61
+
62
+ j2d = torch.zeros(B, N, 2, device=device)
63
+ j2d[:, :, 0] = scale[:, :, 0] * j3d[:, :, 0]
64
+ j2d[:, :, 1] = scale[:, :, 1] * j3d[:, :, 1]
65
+
66
+ return j2d
67
+
68
+ def forward(self, xyz, features, mano_hand, previous_mano_params=None):
69
+ device = xyz.device
70
+ batch_size = xyz.shape[0]
71
+
72
+ l0_xyz = xyz
73
+ l0_points = features
74
+
75
+ l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
76
+ l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
77
+
78
+ l2_xyz = l2_xyz.squeeze(-1)
79
+ l2_points = l2_points.squeeze(-1)
80
+
81
+ if previous_mano_params is None:
82
+ previous_mano_params = torch.zeros(self.n_mano_params).unsqueeze(0).expand(batch_size, -1).to(device)
83
+ previous_rot_trans_params = torch.zeros(6).unsqueeze(0).expand(batch_size, -1).to(device)
84
+
85
+ mano_params = self.mano_regressor(l2_points)
86
+
87
+ global_orient = mano_params[:, :3]
88
+ hand_pose = mano_params[:, 3:3+self.n_pose_params]
89
+ betas = mano_params[:, 3+self.n_pose_params:-3]
90
+ transl = mano_params[:, -3:]
91
+
92
+ device = mano_hand.shapedirs.device
93
+
94
+ mano_args = {
95
+ 'global_orient': global_orient.to(device),
96
+ 'hand_pose' : hand_pose.to(device),
97
+ 'betas' : betas.to(device),
98
+ 'transl' : transl.to(device),
99
+ }
100
+
101
+ mano_outs = dict()
102
+
103
+ output = mano_hand(**mano_args)
104
+ mano_outs['vertices'] = output.vertices
105
+ mano_outs['j3d'] = output.joints
106
+
107
+ mano_outs.update(mano_args)
108
+
109
+ if not self.training:
110
+ mano_outs['faces'] = np.tile(mano_hand.faces, (batch_size, 1, 1))
111
+
112
+ return mano_outs
113
+
114
+
115
+ class TEHNet(nn.Module):
116
+ def __init__(self, n_pose_params, num_classes=4):
117
+ super(TEHNet, self).__init__()
118
+
119
+ normal_channel = True
120
+
121
+ if normal_channel:
122
+ additional_channel = 1 + int(os.getenv('ERPC', 0))
123
+ else:
124
+ additional_channel = 0
125
+
126
+ self.normal_channel = normal_channel
127
+ self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3+additional_channel, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
128
+ self.sa2 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]])
129
+ self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True)
130
+ self.fp3 = PointNetFeaturePropagation(in_channel=1536, mlp=[256, 256])
131
+ self.fp2 = PointNetFeaturePropagation(in_channel=576, mlp=[256, 128])
132
+
133
+ self.fp1 = PointNetFeaturePropagation(128, [128, 128, 256])
134
+
135
+ self.classifier = nn.Sequential(
136
+ nn.Conv1d(256, 256, 1),
137
+ nn.ReLU(),
138
+ nn.BatchNorm1d(256),
139
+ nn.Dropout(0.3),
140
+ nn.Conv1d(256, num_classes, 1)
141
+ )
142
+
143
+ self.attention_block = AttentionBlock()
144
+
145
+ self.left_mano_regressor = MANORegressor(n_pose_params=n_pose_params)
146
+ self.right_mano_regressor = MANORegressor(n_pose_params=n_pose_params)
147
+
148
+ self.mhlnes = int(os.getenv('MHLNES', 0))
149
+
150
+ self.left_query_conv = nn.Sequential(
151
+ nn.Conv1d(256, 256, 3, 1, 3//2),
152
+ nn.ReLU(),
153
+ nn.BatchNorm1d(256),
154
+ nn.Dropout(0.1),
155
+ nn.Conv1d(256, 256, 3, 1, 3//2),
156
+ nn.BatchNorm1d(256),
157
+ )
158
+
159
+ self.right_query_conv = nn.Sequential(
160
+ nn.Conv1d(256, 256, 3, 1, 3//2),
161
+ nn.ReLU(),
162
+ nn.BatchNorm1d(256),
163
+ nn.Dropout(0.1),
164
+ nn.Conv1d(256, 256, 3, 1, 3//2),
165
+ nn.BatchNorm1d(256),
166
+ )
167
+
168
+ def forward(self, xyz, mano_hands):
169
+ device = xyz.device
170
+
171
+ # Set Abstraction layers
172
+ l0_points = xyz
173
+
174
+ l0_xyz = xyz[:, :3, :]
175
+
176
+ if self.mhlnes:
177
+ l0_xyz[:, -1, :] = xyz[:, 3:, :].mean(1)
178
+
179
+ l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)
180
+ l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
181
+ l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
182
+
183
+ # Feature Propagation layers
184
+ l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
185
+ l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
186
+ l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)
187
+
188
+ seg_out = self.classifier(l0_points)
189
+ feat_fuse = l0_points
190
+
191
+ left_hand_features = self.attention_block(seg_out, feat_fuse, self.left_query_conv(feat_fuse))
192
+ right_hand_features = self.attention_block(seg_out, feat_fuse, self.right_query_conv(feat_fuse))
193
+
194
+ left = self.left_mano_regressor(l0_xyz, left_hand_features, mano_hands['left'])
195
+ right = self.right_mano_regressor(l0_xyz, right_hand_features, mano_hands['right'])
196
+
197
+ return {'class_logits': seg_out, 'left': left, 'right': right}
198
+
199
+
200
+ def main():
201
+
202
+ net = TEHNet(n_pose_params=6)
203
+ points = torch.rand(4, 4, 128)
204
+ net(points)
205
+
206
+
207
+ if __name__ == '__main__':
208
+ main()
model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import TEHNetWrapper
model/model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import trimesh
4
+ import numpy as np
5
+ from .TEHNet import TEHNet
6
+ from .utils import create_mano_layers
7
+ from settings import MANO_PATH, MANO_CMPS
8
+
9
+
10
+ class TEHNetWrapper():
11
+ def state_dict(self):
12
+ return self.net.state_dict()
13
+
14
+ def load_state_dict(self, params, *args, **kwargs):
15
+ modified_params = dict()
16
+
17
+ for k, v in params.items():
18
+ if k.startswith('module.'):
19
+ k = k[len('module.'):]
20
+
21
+ modified_params[k] = v
22
+
23
+ self.net.load_state_dict(modified_params, *args, **kwargs)
24
+
25
+ def parameters(self):
26
+ return self.net.parameters()
27
+
28
+ def train(self):
29
+ self.training = True
30
+ return self.net.train()
31
+
32
+ def eval(self):
33
+ self.training = False
34
+ return self.net.eval()
35
+
36
+ def P3dtoP2d(self, j3d, scale, translation):
37
+ B, N = j3d.shape[:2]
38
+
39
+ homogeneous_j3d = torch.cat([j3d, torch.ones(B, N, 1, device=j3d.device)], 2)
40
+ homogeneous_j3d = homogeneous_j3d @ self.rot.detach()
41
+
42
+ translation = translation.unsqueeze(1)
43
+ scale = scale.unsqueeze(1)
44
+
45
+ j2d = torch.zeros(B, N, 2, device=j3d.device)
46
+ j2d[:, :, 0] = translation[:, :, 0] + scale[:, :, 0] * homogeneous_j3d[:, :, 0]
47
+ j2d[:, :, 1] = translation[:, :, 1] + scale[:, :, 1] * homogeneous_j3d[:, :, 1]
48
+
49
+ return j2d
50
+
51
+ def __init__(self, device):
52
+ net = TEHNet(n_pose_params=MANO_CMPS).to(device)
53
+
54
+ self.net = net
55
+ self.training = False
56
+
57
+ self.hands = create_mano_layers(MANO_PATH, device, MANO_CMPS)
58
+
59
+ self.rot = torch.tensor(trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0]), device=device).float()
60
+
61
+ def __call__(self, inp):
62
+ outputs = self.net(inp, self.hands)
63
+
64
+ return outputs
model/pointnet2_utils.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from time import time
5
+ import numpy as np
6
+
7
+ def timeit(tag, t):
8
+ print("{}: {}s".format(tag, time() - t))
9
+ return time()
10
+
11
+ def pc_normalize(pc):
12
+ l = pc.shape[0]
13
+ centroid = np.mean(pc, axis=0)
14
+ pc = pc - centroid
15
+ m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
16
+ pc = pc / m
17
+ return pc
18
+
19
+ def square_distance(src, dst):
20
+ """
21
+ Calculate Euclid distance between each two points.
22
+
23
+ src^T * dst = xn * xm + yn * ym + zn * zm;
24
+ sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
25
+ sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
26
+ dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
27
+ = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
28
+
29
+ Input:
30
+ src: source points, [B, N, C]
31
+ dst: target points, [B, M, C]
32
+ Output:
33
+ dist: per-point square distance, [B, N, M]
34
+ """
35
+ B, N, _ = src.shape
36
+ _, M, _ = dst.shape
37
+ dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
38
+ dist += torch.sum(src ** 2, -1).view(B, N, 1)
39
+ dist += torch.sum(dst ** 2, -1).view(B, 1, M)
40
+ return dist
41
+
42
+
43
+ def index_points(points, idx):
44
+ """
45
+
46
+ Input:
47
+ points: input points data, [B, N, C]
48
+ idx: sample index data, [B, S]
49
+ Return:
50
+ new_points:, indexed points data, [B, S, C]
51
+ """
52
+ device = points.device
53
+ B = points.shape[0]
54
+ view_shape = list(idx.shape)
55
+ view_shape[1:] = [1] * (len(view_shape) - 1)
56
+ repeat_shape = list(idx.shape)
57
+ repeat_shape[0] = 1
58
+ batch_indices = torch.arange(B, dtype=torch.long).view(view_shape).repeat(repeat_shape)
59
+ new_points = points[batch_indices, idx, :]
60
+ return new_points
61
+
62
+
63
+ def farthest_point_sample(xyz, npoint):
64
+ """
65
+ Input:
66
+ xyz: pointcloud data, [B, N, 3]
67
+ npoint: number of samples
68
+ Return:
69
+ centroids: sampled pointcloud index, [B, npoint]
70
+ """
71
+ device = xyz.device
72
+ B, N, C = xyz.shape
73
+ centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
74
+ distance = torch.ones(B, N).to(device) * 1e10
75
+ farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
76
+ batch_indices = torch.arange(B, dtype=torch.long).to(device)
77
+ for i in range(npoint):
78
+ centroids[:, i] = farthest
79
+ centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
80
+ dist = torch.sum((xyz - centroid) ** 2, -1)
81
+ mask = dist < distance
82
+ distance[mask] = dist[mask]
83
+ farthest = torch.max(distance, -1)[1]
84
+ return centroids
85
+
86
+
87
+ def query_ball_point(radius, nsample, xyz, new_xyz):
88
+ """
89
+ Input:
90
+ radius: local region radius
91
+ nsample: max sample number in local region
92
+ xyz: all points, [B, N, 3]
93
+ new_xyz: query points, [B, S, 3]
94
+ Return:
95
+ group_idx: grouped points index, [B, S, nsample]
96
+ """
97
+ device = xyz.device
98
+ B, N, C = xyz.shape
99
+ _, S, _ = new_xyz.shape
100
+ group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
101
+ sqrdists = square_distance(new_xyz, xyz)
102
+ group_idx[sqrdists > radius ** 2] = N
103
+ group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
104
+ group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
105
+ mask = group_idx == N
106
+ group_idx[mask] = group_first[mask]
107
+ return group_idx
108
+
109
+
110
+ def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
111
+ """
112
+ Input:
113
+ npoint:
114
+ radius:
115
+ nsample:
116
+ xyz: input points position data, [B, N, 3]
117
+ points: input points data, [B, N, D]
118
+ Return:
119
+ new_xyz: sampled points position data, [B, npoint, nsample, 3]
120
+ new_points: sampled points data, [B, npoint, nsample, 3+D]
121
+ """
122
+ B, N, C = xyz.shape
123
+ S = npoint
124
+ fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
125
+ new_xyz = index_points(xyz, fps_idx)
126
+ idx = query_ball_point(radius, nsample, xyz, new_xyz)
127
+ grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
128
+ grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
129
+
130
+ if points is not None:
131
+ grouped_points = index_points(points, idx)
132
+ new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
133
+ else:
134
+ new_points = grouped_xyz_norm
135
+ if returnfps:
136
+ return new_xyz, new_points, grouped_xyz, fps_idx
137
+ else:
138
+ return new_xyz, new_points
139
+
140
+
141
+ def sample_and_group_all(xyz, points):
142
+ """
143
+ Input:
144
+ xyz: input points position data, [B, N, 3]
145
+ points: input points data, [B, N, D]
146
+ Return:
147
+ new_xyz: sampled points position data, [B, 1, 3]
148
+ new_points: sampled points data, [B, 1, N, 3+D]
149
+ """
150
+ device = xyz.device
151
+ B, N, C = xyz.shape
152
+ new_xyz = torch.zeros(B, 1, C).to(device)
153
+ grouped_xyz = xyz.view(B, 1, N, C)
154
+ if points is not None:
155
+ new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
156
+ else:
157
+ new_points = grouped_xyz
158
+ return new_xyz, new_points
159
+
160
+
161
+ class PointNetSetAbstraction(nn.Module):
162
+ def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
163
+ super(PointNetSetAbstraction, self).__init__()
164
+ self.npoint = npoint
165
+ self.radius = radius
166
+ self.nsample = nsample
167
+ self.mlp_convs = nn.ModuleList()
168
+ self.mlp_bns = nn.ModuleList()
169
+ last_channel = in_channel
170
+ for out_channel in mlp:
171
+ self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
172
+ self.mlp_bns.append(nn.BatchNorm2d(out_channel))
173
+ last_channel = out_channel
174
+ self.group_all = group_all
175
+
176
+ def forward(self, xyz, points):
177
+ """
178
+ Input:
179
+ xyz: input points position data, [B, C, N]
180
+ points: input points data, [B, D, N]
181
+ Return:
182
+ new_xyz: sampled points position data, [B, C, S]
183
+ new_points_concat: sample points feature data, [B, D', S]
184
+ """
185
+ xyz = xyz.permute(0, 2, 1).contiguous()
186
+ if points is not None:
187
+ points = points.permute(0, 2, 1).contiguous()
188
+
189
+ if self.group_all:
190
+ new_xyz, new_points = sample_and_group_all(xyz, points)
191
+ else:
192
+ new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
193
+ # new_xyz: sampled points position data, [B, npoint, C]
194
+ # new_points: sampled points data, [B, npoint, nsample, C+D]
195
+ new_points = new_points.permute(0, 3, 2, 1).contiguous() # [B, C+D, nsample,npoint]
196
+ for i, conv in enumerate(self.mlp_convs):
197
+ bn = self.mlp_bns[i]
198
+ new_points = F.relu(bn(conv(new_points)))
199
+
200
+ new_points = torch.max(new_points, 2)[0]
201
+ new_xyz = new_xyz.permute(0, 2, 1).contiguous()
202
+ return new_xyz, new_points
203
+
204
+
205
+ class PointNetSetAbstractionMsg(nn.Module):
206
+ def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
207
+ super(PointNetSetAbstractionMsg, self).__init__()
208
+ self.npoint = npoint
209
+ self.radius_list = radius_list
210
+ self.nsample_list = nsample_list
211
+ self.conv_blocks = nn.ModuleList()
212
+ self.bn_blocks = nn.ModuleList()
213
+ for i in range(len(mlp_list)):
214
+ convs = nn.ModuleList()
215
+ bns = nn.ModuleList()
216
+ last_channel = in_channel + 3
217
+ for out_channel in mlp_list[i]:
218
+ convs.append(nn.Conv2d(last_channel, out_channel, 1))
219
+ bns.append(nn.BatchNorm2d(out_channel))
220
+ last_channel = out_channel
221
+ self.conv_blocks.append(convs)
222
+ self.bn_blocks.append(bns)
223
+
224
+ def forward(self, xyz, points):
225
+ """
226
+ Input:
227
+ xyz: input points position data, [B, C, N]
228
+ points: input points data, [B, D, N]
229
+ Return:
230
+ new_xyz: sampled points position data, [B, C, S]
231
+ new_points_concat: sample points feature data, [B, D', S]
232
+ """
233
+ xyz = xyz.permute(0, 2, 1).contiguous()
234
+ if points is not None:
235
+ points = points.permute(0, 2, 1).contiguous()
236
+
237
+ B, N, C = xyz.shape
238
+ S = self.npoint
239
+ new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
240
+ new_points_list = []
241
+ for i, radius in enumerate(self.radius_list):
242
+ K = self.nsample_list[i]
243
+ group_idx = query_ball_point(radius, K, xyz, new_xyz)
244
+ grouped_xyz = index_points(xyz, group_idx)
245
+ grouped_xyz -= new_xyz.view(B, S, 1, C)
246
+ if points is not None:
247
+ grouped_points = index_points(points, group_idx)
248
+ grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
249
+ else:
250
+ grouped_points = grouped_xyz
251
+
252
+ grouped_points = grouped_points.permute(0, 3, 2, 1).contiguous() # [B, D, K, S]
253
+ for j in range(len(self.conv_blocks[i])):
254
+ conv = self.conv_blocks[i][j]
255
+ bn = self.bn_blocks[i][j]
256
+ grouped_points = F.relu(bn(conv(grouped_points)))
257
+ new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
258
+ new_points_list.append(new_points)
259
+
260
+ new_xyz = new_xyz.permute(0, 2, 1).contiguous()
261
+ new_points_concat = torch.cat(new_points_list, dim=1)
262
+ return new_xyz, new_points_concat
263
+
264
+
265
+ class PointNetFeaturePropagation(nn.Module):
266
+ def __init__(self, in_channel, mlp):
267
+ super(PointNetFeaturePropagation, self).__init__()
268
+ self.mlp_convs = nn.ModuleList()
269
+ self.mlp_bns = nn.ModuleList()
270
+ last_channel = in_channel
271
+ for out_channel in mlp:
272
+ self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
273
+ self.mlp_bns.append(nn.BatchNorm1d(out_channel))
274
+ last_channel = out_channel
275
+
276
+ def forward(self, xyz1, xyz2, points1, points2):
277
+ """
278
+ Input:
279
+ xyz1: input points position data, [B, C, N]
280
+ xyz2: sampled input points position data, [B, C, S]
281
+ points1: input points data, [B, D, N]
282
+ points2: input points data, [B, D, S]
283
+ Return:
284
+ new_points: upsampled points data, [B, D', N]
285
+ """
286
+ xyz1 = xyz1.permute(0, 2, 1).contiguous()
287
+ xyz2 = xyz2.permute(0, 2, 1).contiguous()
288
+
289
+ points2 = points2.permute(0, 2, 1).contiguous()
290
+ B, N, C = xyz1.shape
291
+ _, S, _ = xyz2.shape
292
+
293
+ if S == 1:
294
+ interpolated_points = points2.repeat(1, N, 1)
295
+ else:
296
+ dists = square_distance(xyz1, xyz2)
297
+ dists, idx = dists.sort(dim=-1)
298
+ dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
299
+
300
+ dist_recip = 1.0 / (dists + 1e-8)
301
+ norm = torch.sum(dist_recip, dim=2, keepdim=True)
302
+ weight = dist_recip / norm
303
+ interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
304
+
305
+ if points1 is not None:
306
+ points1 = points1.permute(0, 2, 1).contiguous()
307
+ new_points = torch.cat([points1, interpolated_points], dim=-1)
308
+ else:
309
+ new_points = interpolated_points
310
+
311
+ new_points = new_points.permute(0, 2, 1).contiguous()
312
+ for i, conv in enumerate(self.mlp_convs):
313
+ bn = self.mlp_bns[i]
314
+ new_points = F.relu(bn(conv(new_points)))
315
+ return new_points
model/utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.parallel
4
+ import torch.utils.data
5
+ from torch.autograd import Variable
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+
9
+ from manopth.manolayer import ManoLayer
10
+
11
+
12
+
13
+ def create_mano_layers(mano_path, device, n_cmps):
14
+ class Output:
15
+ def __init__(self, vertices, joints):
16
+ self.vertices = vertices
17
+ self.joints = joints
18
+
19
+ class SmplxAdapter:
20
+ def __init__(self, side):
21
+ self.m = ManoLayer(mano_root=f'{mano_path}/mano', use_pca=True, ncomps=n_cmps, side=side, flat_hand_mean=False, robust_rot=True).to(device)
22
+ self.faces = self.m.th_faces.cpu().numpy()
23
+ self.shapedirs = self.m.th_shapedirs
24
+
25
+ def __call__(self, global_orient, hand_pose, betas, transl):
26
+ vertices, joints = self.m(torch.cat([global_orient, hand_pose], 1), betas, transl)
27
+
28
+ vertices /= 1000
29
+ joints /= 1000
30
+
31
+ return Output(vertices, joints)
32
+
33
+ mano_layer = {
34
+ 'left': SmplxAdapter(side='left'),
35
+ 'right': SmplxAdapter(side='right')
36
+ }
37
+
38
+ if torch.sum(torch.abs(mano_layer['left'].m.th_shapedirs[:,0,:] - mano_layer['right'].m.th_shapedirs[:,0,:])) < 1:
39
+ print('Fix th_shapedirs bug of MANO')
40
+ mano_layer['left'].m.th_shapedirs[:,0,:] *= -1
41
+
42
+ return mano_layer
record.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+ camera = cv2.VideoCapture(0)
4
+ fps = cv2.CAP_PROP_FPS
5
+
6
+ video = cv2.VideoWriter('video.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (640, 480))
7
+
8
+ while True:
9
+ _, frame = camera.read()
10
+
11
+ video.write(frame)
12
+
13
+ cv2.imshow("Frame", frame)
14
+ c = cv2.waitKey(1)
15
+
16
+ if c == ord('q'):
17
+ break
18
+
19
+ video.release()
20
+ camera.release()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ opencv-python
2
+ git+https://github.com/hassony2/manopth
3
+ pyrender
4
+ git+https://github.com/mattloper/chumpy.git
5
+ gradio
settings.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ if os.name != 'nt': os.environ["PYOPENGL_PLATFORM"] = "egl"
3
+
4
+ import pyrender
5
+ import numpy as np
6
+ import os
7
+ import platform
8
+
9
+
10
+ ESIM_REFRACTORY_PERIOD_NS = 0
11
+ ESIM_POSITIVE_THRESHOLD = 0.4
12
+ ESIM_NEGATIVE_THRESHOLD = 0.4
13
+
14
+ RENDER_SMPLX = False
15
+
16
+ AUGMENTATED_SEQUENCE = True
17
+ NUMBER_OF_AUGMENTATED_SEQUENCES = 10
18
+
19
+ SIMULATOR_FPS = 1000 # fps for event generation using ESIM
20
+ INTERPOLATION_FPS = 30
21
+ OUTPUT_WIDTH = 346
22
+ OUTPUT_HEIGHT = 260
23
+ LNES_WINDOW_MS = 5
24
+
25
+
26
+ INTERHAND_ROOT_PATH = '/CT/datasets01/static00/InterHand2.6m/InterHand2.6M_5fps_batch1'
27
+ ROOT_TRAIN_DATA_PATH = '/CT/datasets07/nobackup/Ev2Hands/Ev2Hands-S'
28
+
29
+ REAL_TRAIN_DATA_PATH = '/CT/datasets07/nobackup/Ev2Hands/Ev2Hands-R/train_data'
30
+ REAL_TEST_DATA_PATH = '/CT/datasets07/nobackup/Ev2Hands/Ev2Hands-R/test_data'
31
+
32
+ DATA_PATH = '../data'
33
+ MANO_PATH = 'data/models'
34
+
35
+
36
+ GENERATION_MODE = str(os.getenv('GENERATION_MODE', 'train'))
37
+
38
+ MANO_CMPS = 6
39
+
40
+ SEGMENTAION_COLOR = {'left': [0, 1, 0], 'right': [0, 0, 1]}
41
+
42
+ MAIN_CAMERA = pyrender.PerspectiveCamera(yfov=np.deg2rad(30), aspectRatio=OUTPUT_WIDTH / OUTPUT_HEIGHT)
43
+ PROJECTION_MATRIX = MAIN_CAMERA.get_projection_matrix(OUTPUT_WIDTH, OUTPUT_HEIGHT)
44
+
45
+ HAND_COLOR = [198/255, 134/255, 66/255]
test.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import esim_py
4
+
5
+
6
+ # camera = cv2.VideoCapture(0)
7
+ camera = cv2.VideoCapture('video.mp4')
8
+
9
+ POS_THRESHOLD = 0.5
10
+ NEG_THRESHOLD = 0.5
11
+ REF_PERIOD = 0.000
12
+
13
+ esim = esim_py.EventSimulator(POS_THRESHOLD, NEG_THRESHOLD, REF_PERIOD, 1e-4, True)
14
+
15
+ # # generate events from list of images and timestamps
16
+ # events_list_of_images = esim.generateFromStampedImageSequence(
17
+ # list_of_image_files, # list of absolute paths to images
18
+ # list_of_timestamps # list of timestamps in ascending order
19
+ # )
20
+
21
+ fps = cv2.CAP_PROP_FPS
22
+ ts_s = 1 / fps
23
+ ts_ns = ts_s * 1e9 # convert s to ns
24
+
25
+ is_init = False
26
+ idx = 0
27
+ while True:
28
+ _, frame_bgr = camera.read()
29
+ frame_gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
30
+ frame_log = np.log(frame_gray.astype("float32") / 255 + 1e-4)
31
+ height, width = frame_log.shape[:2]
32
+
33
+ current_ts_ns = idx * ts_ns
34
+
35
+ if not is_init:
36
+ esim.init(frame_log, current_ts_ns)
37
+ is_init = True
38
+ idx += 1
39
+
40
+ continue
41
+
42
+ events = esim.generateEventFromCVImage(frame_log, current_ts_ns)
43
+ x, y, t, p = events.T
44
+ t = t.astype(dtype=np.float32) * 1e-6 # convert ns to milliseconds
45
+
46
+ last_timestamp = t[-1]
47
+
48
+ event_frame = np.zeros((height, width, 3), dtype=np.float32)
49
+
50
+ x = x.astype(dtype=np.int32)
51
+ y = y.astype(dtype=np.int32)
52
+ p = p.astype(dtype=np.int32)
53
+
54
+ print(idx, events.shape)
55
+
56
+ if last_timestamp <= 0:
57
+ continue
58
+
59
+ event_frame[y, x, 1 - p] = (last_timestamp - t) / (last_timestamp - t[0])
60
+
61
+ event_frame *= 255
62
+ event_frame = event_frame.astype(dtype=np.uint8)
63
+
64
+ stack = np.hstack([frame_bgr, event_frame])
65
+ cv2.imwrite(f"outputs/stack_{idx}.png", stack)
66
+
67
+ # cv2.imwrite("frame.png", frame_bgr)
68
+
69
+ # input(idx)
70
+ # # t, x, y, p = event
71
+ # # x, y = x.astype(dtype=np.int32), y.astype(dtype=np.int32)
72
+
73
+ # # events = np.hstack([x[..., None], y[..., None], t[..., None], p[..., None]])
74
+ # # event_labels = segmentation[y, x].astype(dtype=np.uint8)
75
+
76
+ # # write_frame = False
77
+ # # show_frame = False
78
+
79
+ # # if write_frame or show_frame:
80
+ # # ts, xs, ys, ps = event
81
+ # # h, w = frame_color.shape[:2]
82
+ # # event_bgr = np.zeros((h, w, 3), dtype=np.uint8)
83
+ # # for x, y, p in zip(xs, ys, ps):
84
+ # # event_bgr[y, x, 0 if p == -1 else 2] = 255
85
+
86
+ # # image_path = image_paths[frame_keys[frame_index]]
87
+ # # rgb_image = cv2.imread(image_path)
88
+
89
+
90
+ # cv2.imshow("Frame", frame)
91
+ # cv2.waitKey(1)
92
+
93
+ idx += 1
vis.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+
4
+ files = os.listdir("outputs")
5
+ n_files = len(files)
6
+
7
+ for i in range(1, n_files):
8
+ frame = cv2.imread(f"outputs/stack_{i}.png")
9
+ cv2.imshow("Frame", frame)
10
+ c = cv2.waitKey(1)
11
+
12
+ if c == ord('q'):
13
+ break