xymeow7 commited on
Commit
074bdce
·
verified ·
1 Parent(s): 9b0f9d2

Update sample/reconstruct_data_taco.py

Browse files
Files changed (1) hide show
  1. sample/reconstruct_data_taco.py +81 -76
sample/reconstruct_data_taco.py CHANGED
@@ -51,15 +51,15 @@ def get_penetration_masks(obj_verts, obj_faces, hand_verts):
51
  def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=False, nn_hand_params=24, rt_vars=False, with_proj=False, obj_verts_trans=None, obj_faces=None, with_params_smoothing=False, dist_thres=0.005, with_ctx_mask=False):
52
  # obj_verts_trans, obj_faces
53
  joints = torch.from_numpy(joints).float() # # joints
54
- base_pts = torch.from_numpy(base_pts).float() # # base_pts
55
 
56
  if nn_hand_params < 45:
57
  use_pca = True
58
  else:
59
  use_pca = False
60
 
61
- tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float()
62
- tot_base_normals_trans = torch.from_numpy(tot_base_normals_trans).float()
63
  ### start optimization ###
64
  # setup MANO layer
65
  # mano_path = "/data1/xueyi/mano_models/mano/models"
@@ -139,48 +139,48 @@ def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans
139
  # )
140
 
141
  #
142
- dist_joints_to_base_pts = torch.sum(
143
- (joints.unsqueeze(-2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts #
144
- )
145
 
146
- nn_base_pts = dist_joints_to_base_pts.size(-1)
147
- nn_joints = dist_joints_to_base_pts.size(1)
148
 
149
- dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts #
150
- minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints #
151
 
152
- nk_contact_pts = 2
153
- minn_dist[:, :-5] = 1e9
154
- minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) #
155
- # joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0) ==
156
- minn_topk_mask = torch.zeros_like(minn_dist)
157
- # minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints #
158
- minn_topk_mask[:, -5: -3] = 1.
159
- basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0)
160
- minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
161
- # for seq 101
162
- # minn_dist_mask[31:, -5, :] = minn_dist_mask[30: 31, -5, :]
163
- minn_dist_mask = minn_dist_mask.float()
164
 
165
- ## tot base pts
166
- tot_base_pts_trans_disp = torch.sum(
167
- (tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1 # (nf - 1) x nn_base_pts displacement
168
- )
169
- ### tot base pts trans disp ###
170
- tot_base_pts_trans_disp = torch.sqrt(tot_base_pts_trans_disp).mean(dim=-1) # (nf - 1)
171
- # tot_base_pts_trans_disp_mov_thres = 1e-20
172
- tot_base_pts_trans_disp_mov_thres = 3e-4
173
- tot_base_pts_trans_disp_mask = tot_base_pts_trans_disp >= tot_base_pts_trans_disp_mov_thres
174
- tot_base_pts_trans_disp_mask = torch.cat(
175
- [tot_base_pts_trans_disp_mask, tot_base_pts_trans_disp_mask[-1:]], dim=0
176
- )
177
 
178
- attraction_mask_new = (tot_base_pts_trans_disp_mask.float().unsqueeze(-1).unsqueeze(-1) + minn_dist_mask.float()) > 1.5
179
 
180
 
181
 
182
- minn_topk_mask = (minn_dist_mask + minn_topk_mask.float().unsqueeze(-1)) > 1.5
183
- print(f"minn_dist_mask: {minn_dist_mask.size()}")
184
  s = 1.0
185
  # affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s)
186
 
@@ -231,21 +231,21 @@ def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans
231
  print('\tRotation Smoothness Loss: {}'.format(joints_pred_loss.item()))
232
 
233
  #
234
- print(tot_base_pts_trans.size())
235
- diff_base_pts_trans = torch.sum((tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1) # (nf - 1) x nn_base_pts
236
- print(f"diff_base_pts_trans: {diff_base_pts_trans.size()}")
237
- diff_base_pts_trans = diff_base_pts_trans.mean(dim=-1)
238
- diff_base_pts_trans_threshold = 1e-20
239
- diff_base_pts_trans_mask = diff_base_pts_trans > diff_base_pts_trans_threshold # (nf - 1) ### the mask of the tranformed base pts
240
- diff_base_pts_trans_mask = diff_base_pts_trans_mask.float()
241
- print(f"diff_base_pts_trans_mask: {diff_base_pts_trans_mask.size()}, diff_base_pts_trans: {diff_base_pts_trans.size()}")
242
- diff_last_frame_mask = torch.tensor([0,], dtype=torch.float32).to(diff_base_pts_trans_mask.device) + diff_base_pts_trans_mask[-1]
243
- diff_base_pts_trans_mask = torch.cat(
244
- [diff_base_pts_trans_mask, diff_last_frame_mask], dim=0 # nf tensor
245
- )
246
  # attraction_mask = (diff_base_pts_trans_mask.unsqueeze(-1).unsqueeze(-1) + minn_topk_mask.float()) > 1.5
247
- attraction_mask = minn_topk_mask.float()
248
- attraction_mask = attraction_mask.float()
249
 
250
  # the direction of the normal vector and the moving direction of the object point -> whether the point should be selected
251
  # the contact maps of the object should be like? #
@@ -1184,7 +1184,12 @@ def reconstruct_from_file(single_seq_path):
1184
  tot_data[cur_k].append(cur_data[cur_k][ :clip_ending_idxes[i_tag]])
1185
 
1186
  for cur_k in tot_data:
1187
- if cur_k in ["tot_base_pts", "tot_base_normals", "tot_obj_rot", "tot_obj_transl", "tot_obj_pcs", "tot_rhand_joints", "tot_gt_rhand_joints"]:
 
 
 
 
 
1188
  tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=1)
1189
  else:
1190
  tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=0)
@@ -1202,8 +1207,8 @@ def reconstruct_from_file(single_seq_path):
1202
 
1203
  targets = data['targets'] # # targets # #
1204
  outputs = data['outputs'] #
1205
- tot_base_pts = data["tot_base_pts"][0] # total base pts, total base normals #
1206
- tot_base_normals = data['tot_base_normals'][0] # nn_base_normals #
1207
 
1208
 
1209
 
@@ -1213,22 +1218,22 @@ def reconstruct_from_file(single_seq_path):
1213
  tot_obj_transl = data['tot_obj_transl'][0]
1214
  print(f"tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
1215
 
1216
- if len(tot_base_pts.shape) == 2:
1217
- # numpy array # # tot base pts #
1218
- tot_base_pts_trans = np.matmul(tot_base_pts.reshape(1, tot_base_pts.shape[0], 3), tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1219
- tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot[0]) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])[0]
1220
 
1221
- tot_base_normals_trans = np.matmul( # #
1222
- tot_base_normals.reshape(1, tot_base_normals.shape[0], 3), tot_obj_rot
1223
- )
1224
- else:
1225
- print(f"tot_base_pts: {tot_base_pts.shape}, tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
1226
- tot_base_pts_trans = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1227
- tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1228
 
1229
- tot_base_normals_trans = np.matmul(
1230
- tot_base_normals, tot_obj_rot
1231
- )
1232
 
1233
 
1234
 
@@ -1237,7 +1242,7 @@ def reconstruct_from_file(single_seq_path):
1237
 
1238
  targets = np.matmul(targets, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 #
1239
  # denoise relative positions
1240
- print(f"tot_base_pts: {tot_base_pts.shape}")
1241
 
1242
 
1243
  #### obj_verts_trans, obj_faces ####
@@ -1264,7 +1269,7 @@ def reconstruct_from_file(single_seq_path):
1264
  with_contact_opt = True
1265
  with_ctx_mask = False
1266
 
1267
- bf_ct_optimized_dict, bf_proj_optimized_dict, optimized_dict = get_optimized_hand_fr_joints_v4_anchors(outputs, tot_base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=with_contact_opt, nn_hand_params=nn_hand_params, rt_vars=True, with_proj=with_proj, obj_verts_trans=obj_verts_trans, obj_faces=obj_faces, with_params_smoothing=with_params_smoothing, dist_thres=dist_thres, with_ctx_mask=with_ctx_mask)
1268
 
1269
 
1270
 
@@ -1274,12 +1279,12 @@ def reconstruct_from_file(single_seq_path):
1274
  optimized_sv_infos.update(bf_ct_optimized_dict)
1275
  optimized_sv_infos.update(bf_proj_optimized_dict)
1276
  optimized_sv_infos.update(optimized_dict)
1277
- optimized_sv_infos.update(
1278
- {
1279
- 'tot_base_pts_trans': tot_base_pts_trans,
1280
- 'tot_base_normals_trans': tot_base_normals_trans
1281
- }
1282
- )
1283
 
1284
 
1285
  optimized_sv_infos.update({'predicted_info': data})
 
51
  def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=False, nn_hand_params=24, rt_vars=False, with_proj=False, obj_verts_trans=None, obj_faces=None, with_params_smoothing=False, dist_thres=0.005, with_ctx_mask=False):
52
  # obj_verts_trans, obj_faces
53
  joints = torch.from_numpy(joints).float() # # joints
54
+ # base_pts = torch.from_numpy(base_pts).float() # # base_pts
55
 
56
  if nn_hand_params < 45:
57
  use_pca = True
58
  else:
59
  use_pca = False
60
 
61
+ # tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float()
62
+ # tot_base_normals_trans = torch.from_numpy(tot_base_normals_trans).float()
63
  ### start optimization ###
64
  # setup MANO layer
65
  # mano_path = "/data1/xueyi/mano_models/mano/models"
 
139
  # )
140
 
141
  #
142
+ # dist_joints_to_base_pts = torch.sum(
143
+ # (joints.unsqueeze(-2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts #
144
+ # )
145
 
146
+ # nn_base_pts = dist_joints_to_base_pts.size(-1)
147
+ # nn_joints = dist_joints_to_base_pts.size(1)
148
 
149
+ # dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts #
150
+ # minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints #
151
 
152
+ # nk_contact_pts = 2
153
+ # minn_dist[:, :-5] = 1e9
154
+ # minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) #
155
+ # # joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0) ==
156
+ # minn_topk_mask = torch.zeros_like(minn_dist)
157
+ # # minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints #
158
+ # minn_topk_mask[:, -5: -3] = 1.
159
+ # basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0)
160
+ # minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
161
+ # # for seq 101
162
+ # # minn_dist_mask[31:, -5, :] = minn_dist_mask[30: 31, -5, :]
163
+ # minn_dist_mask = minn_dist_mask.float()
164
 
165
+ # ## tot base pts
166
+ # tot_base_pts_trans_disp = torch.sum(
167
+ # (tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1 # (nf - 1) x nn_base_pts displacement
168
+ # )
169
+ # ### tot base pts trans disp ###
170
+ # tot_base_pts_trans_disp = torch.sqrt(tot_base_pts_trans_disp).mean(dim=-1) # (nf - 1)
171
+ # # tot_base_pts_trans_disp_mov_thres = 1e-20
172
+ # tot_base_pts_trans_disp_mov_thres = 3e-4
173
+ # tot_base_pts_trans_disp_mask = tot_base_pts_trans_disp >= tot_base_pts_trans_disp_mov_thres
174
+ # tot_base_pts_trans_disp_mask = torch.cat(
175
+ # [tot_base_pts_trans_disp_mask, tot_base_pts_trans_disp_mask[-1:]], dim=0
176
+ # )
177
 
178
+ # attraction_mask_new = (tot_base_pts_trans_disp_mask.float().unsqueeze(-1).unsqueeze(-1) + minn_dist_mask.float()) > 1.5
179
 
180
 
181
 
182
+ # minn_topk_mask = (minn_dist_mask + minn_topk_mask.float().unsqueeze(-1)) > 1.5
183
+ # print(f"minn_dist_mask: {minn_dist_mask.size()}")
184
  s = 1.0
185
  # affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s)
186
 
 
231
  print('\tRotation Smoothness Loss: {}'.format(joints_pred_loss.item()))
232
 
233
  #
234
+ # print(tot_base_pts_trans.size())
235
+ # diff_base_pts_trans = torch.sum((tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1) # (nf - 1) x nn_base_pts
236
+ # print(f"diff_base_pts_trans: {diff_base_pts_trans.size()}")
237
+ # diff_base_pts_trans = diff_base_pts_trans.mean(dim=-1)
238
+ # diff_base_pts_trans_threshold = 1e-20
239
+ # diff_base_pts_trans_mask = diff_base_pts_trans > diff_base_pts_trans_threshold # (nf - 1) ### the mask of the tranformed base pts
240
+ # diff_base_pts_trans_mask = diff_base_pts_trans_mask.float()
241
+ # print(f"diff_base_pts_trans_mask: {diff_base_pts_trans_mask.size()}, diff_base_pts_trans: {diff_base_pts_trans.size()}")
242
+ # diff_last_frame_mask = torch.tensor([0,], dtype=torch.float32).to(diff_base_pts_trans_mask.device) + diff_base_pts_trans_mask[-1]
243
+ # diff_base_pts_trans_mask = torch.cat(
244
+ # [diff_base_pts_trans_mask, diff_last_frame_mask], dim=0 # nf tensor
245
+ # )
246
  # attraction_mask = (diff_base_pts_trans_mask.unsqueeze(-1).unsqueeze(-1) + minn_topk_mask.float()) > 1.5
247
+ # attraction_mask = minn_topk_mask.float()
248
+ # attraction_mask = attraction_mask.float()
249
 
250
  # the direction of the normal vector and the moving direction of the object point -> whether the point should be selected
251
  # the contact maps of the object should be like? #
 
1184
  tot_data[cur_k].append(cur_data[cur_k][ :clip_ending_idxes[i_tag]])
1185
 
1186
  for cur_k in tot_data:
1187
+ print(f"cur_k: {cur_k}")
1188
+ for aa in tot_data[cur_k]:
1189
+ print(aa.shape)
1190
+ if cur_k in ['tot_base_pts', 'tot_base_normals']:
1191
+ continue
1192
+ elif cur_k in ["tot_base_pts", "tot_base_normals", "tot_obj_rot", "tot_obj_transl", "tot_obj_pcs", "tot_rhand_joints", "tot_gt_rhand_joints"]:
1193
  tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=1)
1194
  else:
1195
  tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=0)
 
1207
 
1208
  targets = data['targets'] # # targets # #
1209
  outputs = data['outputs'] #
1210
+ # tot_base_pts = data["tot_base_pts"][0] # total base pts, total base normals #
1211
+ # tot_base_normals = data['tot_base_normals'][0] # nn_base_normals #
1212
 
1213
 
1214
 
 
1218
  tot_obj_transl = data['tot_obj_transl'][0]
1219
  print(f"tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
1220
 
1221
+ # if len(tot_base_pts.shape) == 2:
1222
+ # # numpy array # # tot base pts #
1223
+ # tot_base_pts_trans = np.matmul(tot_base_pts.reshape(1, tot_base_pts.shape[0], 3), tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1224
+ # tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot[0]) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])[0]
1225
 
1226
+ # tot_base_normals_trans = np.matmul( # #
1227
+ # tot_base_normals.reshape(1, tot_base_normals.shape[0], 3), tot_obj_rot
1228
+ # )
1229
+ # else:
1230
+ # print(f"tot_base_pts: {tot_base_pts.shape}, tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
1231
+ # tot_base_pts_trans = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1232
+ # tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1233
 
1234
+ # tot_base_normals_trans = np.matmul(
1235
+ # tot_base_normals, tot_obj_rot
1236
+ # )
1237
 
1238
 
1239
 
 
1242
 
1243
  targets = np.matmul(targets, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 #
1244
  # denoise relative positions
1245
+ # print(f"tot_base_pts: {tot_base_pts.shape}")
1246
 
1247
 
1248
  #### obj_verts_trans, obj_faces ####
 
1269
  with_contact_opt = True
1270
  with_ctx_mask = False
1271
 
1272
+ bf_ct_optimized_dict, bf_proj_optimized_dict, optimized_dict = get_optimized_hand_fr_joints_v4_anchors(outputs, None, None, None, with_contact_opt=with_contact_opt, nn_hand_params=nn_hand_params, rt_vars=True, with_proj=with_proj, obj_verts_trans=obj_verts_trans, obj_faces=obj_faces, with_params_smoothing=with_params_smoothing, dist_thres=dist_thres, with_ctx_mask=with_ctx_mask)
1273
 
1274
 
1275
 
 
1279
  optimized_sv_infos.update(bf_ct_optimized_dict)
1280
  optimized_sv_infos.update(bf_proj_optimized_dict)
1281
  optimized_sv_infos.update(optimized_dict)
1282
+ # optimized_sv_infos.update(
1283
+ # {
1284
+ # 'tot_base_pts_trans': tot_base_pts_trans,
1285
+ # 'tot_base_normals_trans': tot_base_normals_trans
1286
+ # }
1287
+ # )
1288
 
1289
 
1290
  optimized_sv_infos.update({'predicted_info': data})