README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: Generalizable-HOI-Denoising
3
- emoji: 🐠
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 4.36.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Test
3
+ emoji: 💻
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.17.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -12,8 +12,6 @@ import shutil
12
  # from gradio_inter.predict_from_file import predict_from_file
13
  from gradio_inter.create_bash_file import create_bash_file
14
 
15
- from sample.reconstruct_data_taco import reconstruct_from_file
16
-
17
  def create_temp_file(path: str) -> str:
18
  temp_dir = tempfile.gettempdir()
19
  temp_folder = os.path.join(temp_dir, "denoising")
@@ -43,98 +41,16 @@ def predict(file_path: str):
43
 
44
  res_file_path = "/tmp/denoising/save/predicted_infos_seed_0_tag_20231104_017_jts_spatial_t_100__st_0.npy"
45
 
46
- saved_path = reconstruct_from_file(temp_file_path)
47
-
48
- return saved_path
49
-
50
- def create_demo():
51
-
52
- USAGE = """# GeneOH Diffusion: Towards Generalizable Hand-Object Interaction Denoising via Denoising Diffusion
53
- **[Project](https://meowuu7.github.io/GeneOH-Diffusion/) | [Paper](https://openreview.net/pdf?id=FvK2noilxT) | [Github](https://github.com/Meowuu7/GeneOH-Diffusion)**
54
- ## Input data format
55
- Currently, the demo accepts a `.pkl` file containing an hand-object sequence organized as the following format:
56
- ```python
57
- {
58
- "hand_pose": numpy.ndarray(seq_length, 48), # MANO pose at each frame
59
- "hand_trans": numpy.ndarray(seq_length, 3), # hand global translation at each frmae
60
- "hand_shape": numpy.ndarray(10), # MANO shape coefficients
61
- "hand_verts": numpy.ndarray(seq_length, 778, 3), # MANO hand vertices
62
- "hand_faces": numpy.ndarray(1538, 3), # MANO hand faces
63
- "obj_verts": numpy.ndarray(seq_length, num_obj_verts, 3), # object vertices at each frame
64
- "obj_faces": numpy.ndarray(num_obj_faces, 3), # object faces
65
- "obj_pose": numpy.ndarray(seq_length, 4, 4), # object pose at each frame
66
- }
67
- ```
68
- We provide an example [here](https://drive.google.com/file/d/17oqKMhQNpRqSdApyuuCmTrPkrFl0Cqp6/view?usp=sharing). **The demo is under developing and will support more data formats in the future.**
69
-
70
-
71
- ## To run the demo,
72
- 1. Upload a `pickle` file to the left box by draging your file or clicking the box to open the file explorer.
73
- 2. Clik the `Submit` button to run the demo.
74
- 3. The denoised sequence will be output as a `.npy` file and can be downloaded from the right box.
75
-
76
- Since the model runs on CPU currently, the speed is not very fast. For instance, it takes abount 1200s to process the [example](https://drive.google.com/file/d/17oqKMhQNpRqSdApyuuCmTrPkrFl0Cqp6/view?usp=sharing) mentioned above which contains 288 frames. Please be patient and wait for the result.
77
 
78
- To run the model faster, please visit our [github repo](https://github.com/Meowuu7/GeneOH-Diffusion), follow the instructions and run the model on your own server or local machine.
79
-
80
- ## Output data format
81
- The output is a `.npy` file containing the denoised sequence organized as the following format:
82
- ```python
83
- {
84
- "predicted_info": {
85
- "targets": numpy.ndarray(seq_length, num_mano_joints, 3), # input MANO joints
86
- "outputs": numpy.ndarray(seq_length, num_mano_joints, 3), # denoised MANO joints
87
- "obj_verts": numpy.ndarray(seq_length, num_obj_verts, 3), # object vertices at each frame
88
- "obj_faces": numpy.ndarray(num_obj_faces, 3), # object faces
89
- ... # others
90
- }
91
- "bf_ct_verts": numpy.ndarray(seq_length, 778, 3), # denoised MANO vertices
92
- "bf_ct_rot_var": numpy.ndarray(seq_length, 3), # denoised MANO global rotation coefficients
93
- "bf_ct_theta_var": numpy.ndarray(seq_length, 45), # denoised MANO global pose coefficients
94
- "bf_ct_beta_var": numpy.ndarray(1, 10), # denoised MANO shape coefficients
95
- "bf_ct_transl_var": numpy.ndarray(seq_length, 3), # denoised hand global translation
96
- }
97
- ```
98
- The corresponding output file of the [example](https://drive.google.com/file/d/17oqKMhQNpRqSdApyuuCmTrPkrFl0Cqp6/view?usp=sharing) mentioned above can be downloaded [here](https://drive.google.com/file/d/1Ah-qwV6LXlOyaBBe0qQRu1lN-BpKt2Y3/view?usp=sharing).
99
- """
100
-
101
-
102
- with gr.Blocks() as demo:
103
-
104
- gr.Markdown(USAGE)
105
 
106
- # # demo =
107
- # gr.Interface(
108
- # predict,
109
- # # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3),
110
- # gr.File(type="filepath"),
111
- # gr.File(type="filepath"),
112
- # cache_examples=False
113
- # )
114
 
115
-
116
- input_file = gr.File(type="filepath")
117
- output_file = gr.File(type="filepath")
118
-
119
- gr.Interface(
120
- predict,
121
- # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3),
122
- input_file,
123
- output_file,
124
- cache_examples=False
125
- )
126
-
127
- inputs = input_file
128
- outputs = output_file
129
- gr.Examples(
130
- examples=[os.path.join(os.path.dirname(__file__), "./gradio_inter/20231104_017.pkl"), os.path.join(os.path.dirname(__file__), "./gradio_inter/20231104_010.pkl")],
131
- inputs=inputs,
132
- fn=predict,
133
- outputs=outputs,
134
- )
135
-
136
- return demo
137
-
138
  if __name__ == "__main__":
139
- demo = create_demo()
140
  demo.launch()
 
12
  # from gradio_inter.predict_from_file import predict_from_file
13
  from gradio_inter.create_bash_file import create_bash_file
14
 
 
 
15
  def create_temp_file(path: str) -> str:
16
  temp_dir = tempfile.gettempdir()
17
  temp_folder = os.path.join(temp_dir, "denoising")
 
41
 
42
  res_file_path = "/tmp/denoising/save/predicted_infos_seed_0_tag_20231104_017_jts_spatial_t_100__st_0.npy"
43
 
44
+ return res_file_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ demo = gr.Interface(
48
+ predict,
49
+ # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3),
50
+ gr.File(type="filepath"),
51
+ gr.File(type="filepath"),
52
+ cache_examples=False
53
+ )
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  if __name__ == "__main__":
 
56
  demo.launch()
data_loaders/humanml/data/dataset_ours_single_seq.py CHANGED
@@ -7252,7 +7252,7 @@ class GRAB_Dataset_V19_HHO(torch.utils.data.Dataset): # GRAB datasset #
7252
  flat_hand_mean=True,
7253
  side='left',
7254
  # mano_root=self.mano_path, # mano_root #
7255
- mano_root=self.mano_path,
7256
  ncomps=45,
7257
  use_pca=False,
7258
  # center_idx=0
 
7252
  flat_hand_mean=True,
7253
  side='left',
7254
  # mano_root=self.mano_path, # mano_root #
7255
+ mano_root="/home/hlyang/HOI/HOI/manopth/mano/models",
7256
  ncomps=45,
7257
  use_pca=False,
7258
  # center_idx=0
gradio_inter/20231104_010.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2224d2863679dee3f24820538be025cf4c11bb9117b99797f83ca267741a9642
3
- size 2978083
 
 
 
 
gradio_inter/20231104_017.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6c432a154c16c0b135b73162e4d822516fbe3c1d36933f1a54b791cbfa9365ed
3
- size 3215727
 
 
 
 
gradio_inter/predict_from_file.py CHANGED
@@ -199,7 +199,7 @@ def main():
199
  st_idxes = list(range(0, num_ending_clearning_frames, nn_st_skip))
200
  if st_idxes[-1] + num_cleaning_frames < nn_frames:
201
  st_idxes.append(nn_frames - num_cleaning_frames)
202
- # st_idxes = [st_idxes[0]]
203
  print(f"st_idxes: {st_idxes}")
204
 
205
 
 
199
  st_idxes = list(range(0, num_ending_clearning_frames, nn_st_skip))
200
  if st_idxes[-1] + num_cleaning_frames < nn_frames:
201
  st_idxes.append(nn_frames - num_cleaning_frames)
202
+ st_idxes = [st_idxes[0]]
203
  print(f"st_idxes: {st_idxes}")
204
 
205
 
sample/reconstruct_data_taco.py CHANGED
@@ -51,15 +51,15 @@ def get_penetration_masks(obj_verts, obj_faces, hand_verts):
51
  def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=False, nn_hand_params=24, rt_vars=False, with_proj=False, obj_verts_trans=None, obj_faces=None, with_params_smoothing=False, dist_thres=0.005, with_ctx_mask=False):
52
  # obj_verts_trans, obj_faces
53
  joints = torch.from_numpy(joints).float() # # joints
54
- # base_pts = torch.from_numpy(base_pts).float() # # base_pts
55
 
56
  if nn_hand_params < 45:
57
  use_pca = True
58
  else:
59
  use_pca = False
60
 
61
- # tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float()
62
- # tot_base_normals_trans = torch.from_numpy(tot_base_normals_trans).float()
63
  ### start optimization ###
64
  # setup MANO layer
65
  # mano_path = "/data1/xueyi/mano_models/mano/models"
@@ -139,48 +139,48 @@ def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans
139
  # )
140
 
141
  #
142
- # dist_joints_to_base_pts = torch.sum(
143
- # (joints.unsqueeze(-2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts #
144
- # )
145
 
146
- # nn_base_pts = dist_joints_to_base_pts.size(-1)
147
- # nn_joints = dist_joints_to_base_pts.size(1)
148
 
149
- # dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts #
150
- # minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints #
151
 
152
- # nk_contact_pts = 2
153
- # minn_dist[:, :-5] = 1e9
154
- # minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) #
155
- # # joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0) ==
156
- # minn_topk_mask = torch.zeros_like(minn_dist)
157
- # # minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints #
158
- # minn_topk_mask[:, -5: -3] = 1.
159
- # basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0)
160
- # minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
161
- # # for seq 101
162
- # # minn_dist_mask[31:, -5, :] = minn_dist_mask[30: 31, -5, :]
163
- # minn_dist_mask = minn_dist_mask.float()
164
 
165
- # ## tot base pts
166
- # tot_base_pts_trans_disp = torch.sum(
167
- # (tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1 # (nf - 1) x nn_base_pts displacement
168
- # )
169
- # ### tot base pts trans disp ###
170
- # tot_base_pts_trans_disp = torch.sqrt(tot_base_pts_trans_disp).mean(dim=-1) # (nf - 1)
171
- # # tot_base_pts_trans_disp_mov_thres = 1e-20
172
- # tot_base_pts_trans_disp_mov_thres = 3e-4
173
- # tot_base_pts_trans_disp_mask = tot_base_pts_trans_disp >= tot_base_pts_trans_disp_mov_thres
174
- # tot_base_pts_trans_disp_mask = torch.cat(
175
- # [tot_base_pts_trans_disp_mask, tot_base_pts_trans_disp_mask[-1:]], dim=0
176
- # )
177
 
178
- # attraction_mask_new = (tot_base_pts_trans_disp_mask.float().unsqueeze(-1).unsqueeze(-1) + minn_dist_mask.float()) > 1.5
179
 
180
 
181
 
182
- # minn_topk_mask = (minn_dist_mask + minn_topk_mask.float().unsqueeze(-1)) > 1.5
183
- # print(f"minn_dist_mask: {minn_dist_mask.size()}")
184
  s = 1.0
185
  # affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s)
186
 
@@ -231,21 +231,21 @@ def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans
231
  print('\tRotation Smoothness Loss: {}'.format(joints_pred_loss.item()))
232
 
233
  #
234
- # print(tot_base_pts_trans.size())
235
- # diff_base_pts_trans = torch.sum((tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1) # (nf - 1) x nn_base_pts
236
- # print(f"diff_base_pts_trans: {diff_base_pts_trans.size()}")
237
- # diff_base_pts_trans = diff_base_pts_trans.mean(dim=-1)
238
- # diff_base_pts_trans_threshold = 1e-20
239
- # diff_base_pts_trans_mask = diff_base_pts_trans > diff_base_pts_trans_threshold # (nf - 1) ### the mask of the tranformed base pts
240
- # diff_base_pts_trans_mask = diff_base_pts_trans_mask.float()
241
- # print(f"diff_base_pts_trans_mask: {diff_base_pts_trans_mask.size()}, diff_base_pts_trans: {diff_base_pts_trans.size()}")
242
- # diff_last_frame_mask = torch.tensor([0,], dtype=torch.float32).to(diff_base_pts_trans_mask.device) + diff_base_pts_trans_mask[-1]
243
- # diff_base_pts_trans_mask = torch.cat(
244
- # [diff_base_pts_trans_mask, diff_last_frame_mask], dim=0 # nf tensor
245
- # )
246
  # attraction_mask = (diff_base_pts_trans_mask.unsqueeze(-1).unsqueeze(-1) + minn_topk_mask.float()) > 1.5
247
- # attraction_mask = minn_topk_mask.float()
248
- # attraction_mask = attraction_mask.float()
249
 
250
  # the direction of the normal vector and the moving direction of the object point -> whether the point should be selected
251
  # the contact maps of the object should be like? #
@@ -1121,7 +1121,7 @@ def reconstruct_from_file(single_seq_path):
1121
  if st_idxes[-1] + num_cleaning_frames < nn_frames:
1122
  st_idxes.append(nn_frames - num_cleaning_frames)
1123
 
1124
- # st_idxes = [st_idxes[0]]
1125
  print(f"st_idxes: {st_idxes}")
1126
 
1127
 
@@ -1184,12 +1184,7 @@ def reconstruct_from_file(single_seq_path):
1184
  tot_data[cur_k].append(cur_data[cur_k][ :clip_ending_idxes[i_tag]])
1185
 
1186
  for cur_k in tot_data:
1187
- print(f"cur_k: {cur_k}")
1188
- for aa in tot_data[cur_k]:
1189
- print(aa.shape)
1190
- if cur_k in ['tot_base_pts', 'tot_base_normals']:
1191
- continue
1192
- elif cur_k in ["tot_base_pts", "tot_base_normals", "tot_obj_rot", "tot_obj_transl", "tot_obj_pcs", "tot_rhand_joints", "tot_gt_rhand_joints"]:
1193
  tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=1)
1194
  else:
1195
  tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=0)
@@ -1207,8 +1202,8 @@ def reconstruct_from_file(single_seq_path):
1207
 
1208
  targets = data['targets'] # # targets # #
1209
  outputs = data['outputs'] #
1210
- # tot_base_pts = data["tot_base_pts"][0] # total base pts, total base normals #
1211
- # tot_base_normals = data['tot_base_normals'][0] # nn_base_normals #
1212
 
1213
 
1214
 
@@ -1218,22 +1213,22 @@ def reconstruct_from_file(single_seq_path):
1218
  tot_obj_transl = data['tot_obj_transl'][0]
1219
  print(f"tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
1220
 
1221
- # if len(tot_base_pts.shape) == 2:
1222
- # # numpy array # # tot base pts #
1223
- # tot_base_pts_trans = np.matmul(tot_base_pts.reshape(1, tot_base_pts.shape[0], 3), tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1224
- # tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot[0]) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])[0]
1225
 
1226
- # tot_base_normals_trans = np.matmul( # #
1227
- # tot_base_normals.reshape(1, tot_base_normals.shape[0], 3), tot_obj_rot
1228
- # )
1229
- # else:
1230
- # print(f"tot_base_pts: {tot_base_pts.shape}, tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
1231
- # tot_base_pts_trans = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1232
- # tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1233
 
1234
- # tot_base_normals_trans = np.matmul(
1235
- # tot_base_normals, tot_obj_rot
1236
- # )
1237
 
1238
 
1239
 
@@ -1242,7 +1237,7 @@ def reconstruct_from_file(single_seq_path):
1242
 
1243
  targets = np.matmul(targets, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 #
1244
  # denoise relative positions
1245
- # print(f"tot_base_pts: {tot_base_pts.shape}")
1246
 
1247
 
1248
  #### obj_verts_trans, obj_faces ####
@@ -1269,7 +1264,7 @@ def reconstruct_from_file(single_seq_path):
1269
  with_contact_opt = True
1270
  with_ctx_mask = False
1271
 
1272
- bf_ct_optimized_dict, bf_proj_optimized_dict, optimized_dict = get_optimized_hand_fr_joints_v4_anchors(outputs, None, None, None, with_contact_opt=with_contact_opt, nn_hand_params=nn_hand_params, rt_vars=True, with_proj=with_proj, obj_verts_trans=obj_verts_trans, obj_faces=obj_faces, with_params_smoothing=with_params_smoothing, dist_thres=dist_thres, with_ctx_mask=with_ctx_mask)
1273
 
1274
 
1275
 
@@ -1279,12 +1274,12 @@ def reconstruct_from_file(single_seq_path):
1279
  optimized_sv_infos.update(bf_ct_optimized_dict)
1280
  optimized_sv_infos.update(bf_proj_optimized_dict)
1281
  optimized_sv_infos.update(optimized_dict)
1282
- # optimized_sv_infos.update(
1283
- # {
1284
- # 'tot_base_pts_trans': tot_base_pts_trans,
1285
- # 'tot_base_normals_trans': tot_base_normals_trans
1286
- # }
1287
- # )
1288
 
1289
 
1290
  optimized_sv_infos.update({'predicted_info': data})
 
51
  def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=False, nn_hand_params=24, rt_vars=False, with_proj=False, obj_verts_trans=None, obj_faces=None, with_params_smoothing=False, dist_thres=0.005, with_ctx_mask=False):
52
  # obj_verts_trans, obj_faces
53
  joints = torch.from_numpy(joints).float() # # joints
54
+ base_pts = torch.from_numpy(base_pts).float() # # base_pts
55
 
56
  if nn_hand_params < 45:
57
  use_pca = True
58
  else:
59
  use_pca = False
60
 
61
+ tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float()
62
+ tot_base_normals_trans = torch.from_numpy(tot_base_normals_trans).float()
63
  ### start optimization ###
64
  # setup MANO layer
65
  # mano_path = "/data1/xueyi/mano_models/mano/models"
 
139
  # )
140
 
141
  #
142
+ dist_joints_to_base_pts = torch.sum(
143
+ (joints.unsqueeze(-2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts #
144
+ )
145
 
146
+ nn_base_pts = dist_joints_to_base_pts.size(-1)
147
+ nn_joints = dist_joints_to_base_pts.size(1)
148
 
149
+ dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts #
150
+ minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints #
151
 
152
+ nk_contact_pts = 2
153
+ minn_dist[:, :-5] = 1e9
154
+ minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) #
155
+ # joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0) ==
156
+ minn_topk_mask = torch.zeros_like(minn_dist)
157
+ # minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints #
158
+ minn_topk_mask[:, -5: -3] = 1.
159
+ basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0)
160
+ minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
161
+ # for seq 101
162
+ # minn_dist_mask[31:, -5, :] = minn_dist_mask[30: 31, -5, :]
163
+ minn_dist_mask = minn_dist_mask.float()
164
 
165
+ ## tot base pts
166
+ tot_base_pts_trans_disp = torch.sum(
167
+ (tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1 # (nf - 1) x nn_base_pts displacement
168
+ )
169
+ ### tot base pts trans disp ###
170
+ tot_base_pts_trans_disp = torch.sqrt(tot_base_pts_trans_disp).mean(dim=-1) # (nf - 1)
171
+ # tot_base_pts_trans_disp_mov_thres = 1e-20
172
+ tot_base_pts_trans_disp_mov_thres = 3e-4
173
+ tot_base_pts_trans_disp_mask = tot_base_pts_trans_disp >= tot_base_pts_trans_disp_mov_thres
174
+ tot_base_pts_trans_disp_mask = torch.cat(
175
+ [tot_base_pts_trans_disp_mask, tot_base_pts_trans_disp_mask[-1:]], dim=0
176
+ )
177
 
178
+ attraction_mask_new = (tot_base_pts_trans_disp_mask.float().unsqueeze(-1).unsqueeze(-1) + minn_dist_mask.float()) > 1.5
179
 
180
 
181
 
182
+ minn_topk_mask = (minn_dist_mask + minn_topk_mask.float().unsqueeze(-1)) > 1.5
183
+ print(f"minn_dist_mask: {minn_dist_mask.size()}")
184
  s = 1.0
185
  # affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s)
186
 
 
231
  print('\tRotation Smoothness Loss: {}'.format(joints_pred_loss.item()))
232
 
233
  #
234
+ print(tot_base_pts_trans.size())
235
+ diff_base_pts_trans = torch.sum((tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1) # (nf - 1) x nn_base_pts
236
+ print(f"diff_base_pts_trans: {diff_base_pts_trans.size()}")
237
+ diff_base_pts_trans = diff_base_pts_trans.mean(dim=-1)
238
+ diff_base_pts_trans_threshold = 1e-20
239
+ diff_base_pts_trans_mask = diff_base_pts_trans > diff_base_pts_trans_threshold # (nf - 1) ### the mask of the tranformed base pts
240
+ diff_base_pts_trans_mask = diff_base_pts_trans_mask.float()
241
+ print(f"diff_base_pts_trans_mask: {diff_base_pts_trans_mask.size()}, diff_base_pts_trans: {diff_base_pts_trans.size()}")
242
+ diff_last_frame_mask = torch.tensor([0,], dtype=torch.float32).to(diff_base_pts_trans_mask.device) + diff_base_pts_trans_mask[-1]
243
+ diff_base_pts_trans_mask = torch.cat(
244
+ [diff_base_pts_trans_mask, diff_last_frame_mask], dim=0 # nf tensor
245
+ )
246
  # attraction_mask = (diff_base_pts_trans_mask.unsqueeze(-1).unsqueeze(-1) + minn_topk_mask.float()) > 1.5
247
+ attraction_mask = minn_topk_mask.float()
248
+ attraction_mask = attraction_mask.float()
249
 
250
  # the direction of the normal vector and the moving direction of the object point -> whether the point should be selected
251
  # the contact maps of the object should be like? #
 
1121
  if st_idxes[-1] + num_cleaning_frames < nn_frames:
1122
  st_idxes.append(nn_frames - num_cleaning_frames)
1123
 
1124
+ st_idxes = [st_idxes[0]]
1125
  print(f"st_idxes: {st_idxes}")
1126
 
1127
 
 
1184
  tot_data[cur_k].append(cur_data[cur_k][ :clip_ending_idxes[i_tag]])
1185
 
1186
  for cur_k in tot_data:
1187
+ if cur_k in ["tot_base_pts", "tot_base_normals", "tot_obj_rot", "tot_obj_transl", "tot_obj_pcs", "tot_rhand_joints", "tot_gt_rhand_joints"]:
 
 
 
 
 
1188
  tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=1)
1189
  else:
1190
  tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=0)
 
1202
 
1203
  targets = data['targets'] # # targets # #
1204
  outputs = data['outputs'] #
1205
+ tot_base_pts = data["tot_base_pts"][0] # total base pts, total base normals #
1206
+ tot_base_normals = data['tot_base_normals'][0] # nn_base_normals #
1207
 
1208
 
1209
 
 
1213
  tot_obj_transl = data['tot_obj_transl'][0]
1214
  print(f"tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
1215
 
1216
+ if len(tot_base_pts.shape) == 2:
1217
+ # numpy array # # tot base pts #
1218
+ tot_base_pts_trans = np.matmul(tot_base_pts.reshape(1, tot_base_pts.shape[0], 3), tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1219
+ tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot[0]) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])[0]
1220
 
1221
+ tot_base_normals_trans = np.matmul( # #
1222
+ tot_base_normals.reshape(1, tot_base_normals.shape[0], 3), tot_obj_rot
1223
+ )
1224
+ else:
1225
+ print(f"tot_base_pts: {tot_base_pts.shape}, tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
1226
+ tot_base_pts_trans = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1227
+ tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1228
 
1229
+ tot_base_normals_trans = np.matmul(
1230
+ tot_base_normals, tot_obj_rot
1231
+ )
1232
 
1233
 
1234
 
 
1237
 
1238
  targets = np.matmul(targets, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 #
1239
  # denoise relative positions
1240
+ print(f"tot_base_pts: {tot_base_pts.shape}")
1241
 
1242
 
1243
  #### obj_verts_trans, obj_faces ####
 
1264
  with_contact_opt = True
1265
  with_ctx_mask = False
1266
 
1267
+ bf_ct_optimized_dict, bf_proj_optimized_dict, optimized_dict = get_optimized_hand_fr_joints_v4_anchors(outputs, tot_base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=with_contact_opt, nn_hand_params=nn_hand_params, rt_vars=True, with_proj=with_proj, obj_verts_trans=obj_verts_trans, obj_faces=obj_faces, with_params_smoothing=with_params_smoothing, dist_thres=dist_thres, with_ctx_mask=with_ctx_mask)
1268
 
1269
 
1270
 
 
1274
  optimized_sv_infos.update(bf_ct_optimized_dict)
1275
  optimized_sv_infos.update(bf_proj_optimized_dict)
1276
  optimized_sv_infos.update(optimized_dict)
1277
+ optimized_sv_infos.update(
1278
+ {
1279
+ 'tot_base_pts_trans': tot_base_pts_trans,
1280
+ 'tot_base_normals_trans': tot_base_normals_trans
1281
+ }
1282
+ )
1283
 
1284
 
1285
  optimized_sv_infos.update({'predicted_info': data})