meow commited on
Commit
d204888
1 Parent(s): 26e9fd6
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pkl
2
+ *.pt
app.py CHANGED
@@ -38,7 +38,10 @@ def predict(file_path: str):
38
  temp_bash_file = create_bash_file(temp_file_path)
39
 
40
  os.system(f"bash {temp_bash_file}")
41
- return temp_file_path
 
 
 
42
 
43
 
44
  demo = gr.Interface(
 
38
  temp_bash_file = create_bash_file(temp_file_path)
39
 
40
  os.system(f"bash {temp_bash_file}")
41
+
42
+ res_file_path = "/tmp/denoising/save/predicted_infos_seed_0_tag_20231104_017_jts_spatial_t_100__st_0.npy"
43
+
44
+ return res_file_path
45
 
46
 
47
  demo = gr.Interface(
data_loaders/humanml/data/__pycache__/dataset_ours_single_seq.cpython-310.pyc CHANGED
Binary files a/data_loaders/humanml/data/__pycache__/dataset_ours_single_seq.cpython-310.pyc and b/data_loaders/humanml/data/__pycache__/dataset_ours_single_seq.cpython-310.pyc differ
 
data_loaders/humanml/data/__pycache__/dataset_ours_single_seq.cpython-38.pyc CHANGED
Binary files a/data_loaders/humanml/data/__pycache__/dataset_ours_single_seq.cpython-38.pyc and b/data_loaders/humanml/data/__pycache__/dataset_ours_single_seq.cpython-38.pyc differ
 
data_loaders/humanml/data/dataset_ours_single_seq.py CHANGED
@@ -6184,6 +6184,14 @@ class GRAB_Dataset_V19_Arctic_from_Pred(torch.utils.data.Dataset): # GRAB datass
6184
  )
6185
 
6186
 
 
 
 
 
 
 
 
 
6187
  ### Load field data from root folders ### ## obj root folder ##
6188
  self.obj_root_folder = "/data1/xueyi/GRAB_extracted/tools/object_meshes/contact_meshes_objs"
6189
  self.obj_params_folder = "/data1/xueyi/GRAB_extracted/tools/object_meshes/contact_meshes_params"
@@ -7113,16 +7121,22 @@ class GRAB_Dataset_V19_HHO(torch.utils.data.Dataset): # GRAB datasset #
7113
  self.start_idx = self.args.start_idx
7114
 
7115
  # load datas # grab path; grab sequences #
7116
- grab_path = "/data1/xueyi/GRAB_extracted"
7117
- ## grab contactmesh ## id2objmeshname
7118
- obj_mesh_path = os.path.join(grab_path, 'tools/object_meshes/contact_meshes')
7119
- id2objmeshname = []
7120
- obj_meshes = sorted(os.listdir(obj_mesh_path))
7121
- # objectmesh name #
7122
- id2objmeshname = [obj_meshes[i].split(".")[0] for i in range(len(obj_meshes))]
7123
- self.id2objmeshname = id2objmeshname
7124
-
 
 
 
7125
 
 
 
 
7126
 
7127
  self.aug_trans_T = 0.05
7128
  self.aug_rot_T = 0.3
@@ -7142,9 +7156,10 @@ class GRAB_Dataset_V19_HHO(torch.utils.data.Dataset): # GRAB datasset #
7142
 
7143
  ## predicted infos fn ##
7144
  self.data_folder = data_folder
7145
- self.subj_data_folder = '/data1/xueyi/GRAB_processed_wsubj'
7146
  # self.subj_corr_data_folder = args.subj_corr_data_folder
7147
- self.mano_path = "/data1/xueyi/mano_models/mano/models" ### mano_path
 
7148
  ## mano paths ##
7149
  self.aug = True
7150
  self.use_anchors = False
@@ -7152,19 +7167,19 @@ class GRAB_Dataset_V19_HHO(torch.utils.data.Dataset): # GRAB datasset #
7152
 
7153
  self.use_anchors = args.use_anchors
7154
 
7155
- self.grab_path = "/data1/xueyi/GRAB_extracted"
7156
- obj_mesh_path = os.path.join(self.grab_path, 'tools/object_meshes/contact_meshes')
7157
- id2objmesh = []
7158
- obj_meshes = sorted(os.listdir(obj_mesh_path))
7159
- for i, fn in enumerate(obj_meshes):
7160
- id2objmesh.append(os.path.join(obj_mesh_path, fn))
7161
- self.id2objmesh = id2objmesh
7162
- self.id2meshdata = {}
7163
 
7164
  ## obj root folder; ##
7165
- ### Load field data from root folders ###
7166
- self.obj_root_folder = "/data1/xueyi/GRAB_extracted/tools/object_meshes/contact_meshes_objs"
7167
- self.obj_params_folder = "/data1/xueyi/GRAB_extracted/tools/object_meshes/contact_meshes_params" # # and base points
7168
 
7169
  self.load_meta = True
7170
 
@@ -7204,36 +7219,6 @@ class GRAB_Dataset_V19_HHO(torch.utils.data.Dataset): # GRAB datasset #
7204
  self.predicted_hand_joints = outputs # nf x nnjoints x 3 #
7205
  self.predicted_hand_joints = torch.from_numpy(self.predicted_hand_joints).float()
7206
 
7207
- # if 'rhand_trans' in data:
7208
- # # outputs = data['outputs']
7209
- # self.predicted_hand_trans = data['rhand_trans'] # nframes x 3
7210
- # self.predicted_hand_rot = data['rhand_rot'] # nframes x 3
7211
- # self.predicted_hand_theta = data['rhand_theta']
7212
- # self.predicted_hand_beta = data['rhand_beta']
7213
- # self.predicted_hand_trans = torch.from_numpy(self.predicted_hand_trans).float() # nframes x 3
7214
- # self.predicted_hand_rot = torch.from_numpy(self.predicted_hand_rot).float() # nframes x 3
7215
- # self.predicted_hand_theta = torch.from_numpy(self.predicted_hand_theta).float() # nframes x 24
7216
- # self.predicted_hand_beta = torch.from_numpy(self.predicted_hand_beta).float() # 10,
7217
-
7218
- # self.predicted_hand_trans_opt = data_opt['rhand_trans'] # nframes x 3
7219
- # self.predicted_hand_rot_opt = data_opt['rhand_rot'] # nframes x 3
7220
- # self.predicted_hand_theta_opt = data_opt['rhand_theta']
7221
- # self.predicted_hand_beta_opt = data_opt['rhand_beta']
7222
- # self.predicted_hand_trans_opt = torch.from_numpy(self.predicted_hand_trans_opt).float() # nframes x 3
7223
- # self.predicted_hand_rot_opt = torch.from_numpy(self.predicted_hand_rot_opt).float() # nframes x 3
7224
- # self.predicted_hand_theta_opt = torch.from_numpy(self.predicted_hand_theta_opt).float() # nframes x 24
7225
- # self.predicted_hand_beta_opt = torch.from_numpy(self.predicted_hand_beta_opt).float() # 10,
7226
-
7227
- # self.predicted_hand_trans[9:] = self.predicted_hand_trans_opt[9:]
7228
- # self.predicted_hand_rot[9:] = self.predicted_hand_rot_opt[9:]
7229
- # self.predicted_hand_theta[ 9:] = self.predicted_hand_theta_opt[ 9:]
7230
- # # self.predicted_hand_trans[:, 9:] = self.predicted_hand_trans_opt[:, 9:]
7231
-
7232
- # else:
7233
- # self.predicted_hand_trans = None
7234
- # self.predicted_hand_rot = None
7235
- # self.predicted_hand_theta = None
7236
- # self.predicted_hand_beta = None
7237
 
7238
  else:
7239
  self.predicted_hand_joints = None
@@ -7276,25 +7261,25 @@ class GRAB_Dataset_V19_HHO(torch.utils.data.Dataset): # GRAB datasset #
7276
  )
7277
 
7278
 
7279
- ### Load field data from root folders ### ## obj root folder ##
7280
- self.obj_root_folder = "/data1/xueyi/GRAB_extracted/tools/object_meshes/contact_meshes_objs"
7281
- self.obj_params_folder = "/data1/xueyi/GRAB_extracted/tools/object_meshes/contact_meshes_params"
7282
 
7283
 
7284
- # anchor_load_driver, masking_load_driver #
7285
- # use_anchors, self.hand_palm_vertex_mask #
7286
- if self.use_anchors: # use anchors # anchor_load_driver, masking_load_driver #
7287
- # anchor_load_driver, masking_load_driver #
7288
- inpath = "/home/xueyi/sim/CPF/assets" # contact potential field; assets # ##
7289
- fvi, aw, _, _ = anchor_load_driver(inpath)
7290
- self.face_vertex_index = torch.from_numpy(fvi).long()
7291
- self.anchor_weight = torch.from_numpy(aw).float()
7292
 
7293
- anchor_path = os.path.join("/home/xueyi/sim/CPF/assets", "anchor")
7294
- palm_path = os.path.join("/home/xueyi/sim/CPF/assets", "hand_palm_full.txt")
7295
- hand_region_assignment, hand_palm_vertex_mask = masking_load_driver(anchor_path, palm_path)
7296
- # self.hand_palm_vertex_mask for hand palm mask #
7297
- self.hand_palm_vertex_mask = torch.from_numpy(hand_palm_vertex_mask).bool() ## the mask for hand palm to get hand anchors #
7298
 
7299
  files_clean = [self.seq_path]
7300
 
@@ -7325,57 +7310,6 @@ class GRAB_Dataset_V19_HHO(torch.utils.data.Dataset): # GRAB datasset #
7325
  # (288, 4, 4)
7326
  return clip_data
7327
 
7328
- # if f is None:
7329
- # cur_clip = self.clips[clip_idx]
7330
- # if len(cur_clip) > 3:
7331
- # return cur_clip
7332
- # f = cur_clip[2]
7333
- # clip_clean = np.load(f)
7334
- # # pert_folder_nm = self.split + '_pert'
7335
- # pert_folder_nm = self.split
7336
- # # if not self.use_pert:
7337
- # # pert_folder_nm = self.split
7338
- # # clip_pert = np.load(os.path.join(self.data_folder, pert_folder_nm, os.path.basename(f)))
7339
-
7340
-
7341
- # ##### load subj params #####
7342
- # pure_file_name = f.split("/")[-1].split(".")[0]
7343
- # pure_subj_params_fn = f"{pure_file_name}_subj.npy"
7344
-
7345
- # subj_params_fn = os.path.join(self.subj_data_folder, self.split, pure_subj_params_fn)
7346
- # subj_params = np.load(subj_params_fn, allow_pickle=True).item()
7347
- # rhand_transl = subj_params["rhand_transl"]
7348
- # rhand_betas = subj_params["rhand_betas"]
7349
- # # rhand_pose = clip_clean['f2'] ## rhand pose ##
7350
-
7351
- # object_global_orient = clip_clean['f5'] ## clip_len x 3 --> orientation
7352
- # object_trcansl = clip_clean['f6'] ## cliplen x 3 --> translation
7353
-
7354
- # object_idx = clip_clean['f7'][0].item()
7355
-
7356
- # pert_subj_params_fn = os.path.join(self.subj_data_folder, pert_folder_nm, pure_subj_params_fn)
7357
- # pert_subj_params = np.load(pert_subj_params_fn, allow_pickle=True).item()
7358
- # ##### load subj params #####
7359
-
7360
- # # meta data -> lenght of the current clip -> construct meta data from those saved meta data -> load file on the fly # clip file name -> yes...
7361
- # # print(f"rhand_transl: {rhand_transl.shape},rhand_betas: {rhand_betas.shape}, rhand_pose: {rhand_pose.shape} ")
7362
- # ### pert and clean pair for encoding and decoding ###
7363
-
7364
- # # maxx_clip_len =
7365
- # loaded_clip = (
7366
- # 0, rhand_transl.shape[0], clip_clean,
7367
- # [clip_clean['f9'], clip_clean['f11'], clip_clean['f10'], clip_clean['f1'], clip_clean['f2'], rhand_transl, rhand_betas, object_global_orient, object_trcansl, object_idx], pert_subj_params,
7368
- # )
7369
- # # self.clips[clip_idx] = loaded_clip
7370
-
7371
- # return loaded_clip
7372
-
7373
- # self.clips.append((self.len, self.len+clip_len, clip_pert,
7374
- # [clip_clean['f9'], clip_clean['f11'], clip_clean['f10'], clip_clean['f1'], clip_clean['f2'], rhand_transl, rhand_betas], pert_subj_params,
7375
- # # subj_corr_data, pert_subj_corr_data
7376
- # ))
7377
-
7378
-
7379
 
7380
  def get_idx_to_mesh_data(self, obj_id):
7381
  if obj_id not in self.id2meshdata:
@@ -7520,9 +7454,7 @@ class GRAB_Dataset_V19_HHO(torch.utils.data.Dataset): # GRAB datasset #
7520
  ### aug_global_orient_var, aug_pose_var, aug_transl_var ###
7521
  #### ==== get random augmented pose, rot, transl ==== ####
7522
  # rnd_aug_global_orient_var, rnd_aug_pose_var, rnd_aug_transl_var #
7523
- aug_trans, aug_rot, aug_pose = 0.01, 0.05, 0.3
7524
- aug_trans, aug_rot, aug_pose = 0.001, 0.05, 0.3
7525
- aug_trans, aug_rot, aug_pose = 0.000, 0.05, 0.3
7526
  aug_trans, aug_rot, aug_pose = 0.000, 0.00, 0.00
7527
  # noise scale #
7528
  # aug_trans, aug_rot, aug_pose = 0.01, 0.05, 0.3 # scale 1 for the standard scale
@@ -7570,17 +7502,19 @@ class GRAB_Dataset_V19_HHO(torch.utils.data.Dataset): # GRAB datasset #
7570
  rhand_verts = rhand_verts * 0.001
7571
  rhand_joints = rhand_joints * 0.001
7572
 
 
 
7573
  offset_cur_to_raw = raw_hand_verts[0, 0] - rhand_verts[0, 0]
7574
  rhand_verts = rhand_verts + offset_cur_to_raw.unsqueeze(0)
7575
  rhand_joints = rhand_joints + offset_cur_to_raw.unsqueeze(0)
7576
 
7577
- # rhand_anchors, pert_rhand_anchors #
7578
- # rhand_anchors, canon_rhand_anchors #
7579
- # use_anchors, self.hand_palm_vertex_mask #
7580
- if self.use_anchors: # # rhand_anchors: bsz x nn_hand_anchors x 3 #
7581
- # rhand_anchors = rhand_verts[:, self.hand_palm_vertex_mask] # nf x nn_anchors x 3 --> for the anchor points ##
7582
- rhand_anchors = recover_anchor_batch(rhand_verts, self.face_vertex_index, self.anchor_weight.unsqueeze(0).repeat(self.window_size, 1, 1))
7583
- # print(f"rhand_anchors: {rhand_anchors.size()}") ### recover rhand verts here ###
7584
 
7585
 
7586
 
@@ -7600,13 +7534,16 @@ class GRAB_Dataset_V19_HHO(torch.utils.data.Dataset): # GRAB datasset #
7600
  )
7601
  pert_rhand_verts = pert_rhand_verts * 0.001 # verts
7602
  pert_rhand_joints = pert_rhand_joints * 0.001 # joints
 
 
 
7603
  pert_rhand_verts = pert_rhand_verts + offset_cur_to_raw.unsqueeze(0)
7604
  pert_rhand_joints = pert_rhand_joints + offset_cur_to_raw.unsqueeze(0)
7605
 
7606
- if self.use_anchors:
7607
- # pert_rhand_anchors = pert_rhand_verts[:, self.hand_palm_vertex_mask]
7608
- pert_rhand_anchors = recover_anchor_batch(pert_rhand_verts, self.face_vertex_index, self.anchor_weight.unsqueeze(0).repeat(self.window_size, 1, 1))
7609
- # print(f"rhand_anchors: {rhand_anchors.size()}") ### recover rhand verts here ###
7610
 
7611
  # use_canon_joints
7612
 
@@ -7617,6 +7554,8 @@ class GRAB_Dataset_V19_HHO(torch.utils.data.Dataset): # GRAB datasset #
7617
  canon_pert_rhand_verts = canon_pert_rhand_verts * 0.001 # verts
7618
  canon_pert_rhand_joints = canon_pert_rhand_joints * 0.001 # joints
7619
 
 
 
7620
  # if self.use_anchors:
7621
  # # canon_pert_rhand_anchors = canon_pert_rhand_verts[:, self.hand_palm_vertex_mask]
7622
  # canon_pert_rhand_anchors = recover_anchor_batch(canon_pert_rhand_verts, self.face_vertex_index, self.anchor_weight.unsqueeze(0).repeat(self.window_size, 1, 1))
@@ -8155,15 +8094,15 @@ class GRAB_Dataset_V19_From_Evaluated_Info(torch.utils.data.Dataset):
8155
  self.inst_normalization = args.inst_normalization
8156
 
8157
 
8158
- # load datas # grab path; grab sequences #
8159
- grab_path = "/data1/xueyi/GRAB_extracted"
8160
- ## grab contactmesh ## id2objmeshname
8161
- obj_mesh_path = os.path.join(grab_path, 'tools/object_meshes/contact_meshes')
8162
- id2objmeshname = []
8163
- obj_meshes = sorted(os.listdir(obj_mesh_path))
8164
- # objectmesh name #
8165
- id2objmeshname = [obj_meshes[i].split(".")[0] for i in range(len(obj_meshes))]
8166
- self.id2objmeshname = id2objmeshname
8167
 
8168
 
8169
  ## the predicted_info_fn
@@ -8234,14 +8173,14 @@ class GRAB_Dataset_V19_From_Evaluated_Info(torch.utils.data.Dataset):
8234
  self.use_anchors = False
8235
 
8236
 
8237
- self.grab_path = "/data1/xueyi/GRAB_extracted"
8238
- obj_mesh_path = os.path.join(self.grab_path, 'tools/object_meshes/contact_meshes')
8239
- id2objmesh = []
8240
- obj_meshes = sorted(os.listdir(obj_mesh_path))
8241
- for i, fn in enumerate(obj_meshes):
8242
- id2objmesh.append(os.path.join(obj_mesh_path, fn))
8243
- self.id2objmesh = id2objmesh
8244
- self.id2meshdata = {}
8245
 
8246
  ## obj root folder; ##
8247
  ### Load field data from root folders ###
 
6184
  )
6185
 
6186
 
6187
+ self.reversed_list = torch.tensor(
6188
+ [0, 5, 6, 7, 9, 10, 11, 17, 18, 19, 13, 14, 15, 1, 2, 3, 4, 8, 12, 16, 20], dtype=torch.long
6189
+ )
6190
+
6191
+ self.forward_list = torch.tensor(
6192
+ [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20], dtype=torch.long
6193
+ )
6194
+
6195
  ### Load field data from root folders ### ## obj root folder ##
6196
  self.obj_root_folder = "/data1/xueyi/GRAB_extracted/tools/object_meshes/contact_meshes_objs"
6197
  self.obj_params_folder = "/data1/xueyi/GRAB_extracted/tools/object_meshes/contact_meshes_params"
 
7121
  self.start_idx = self.args.start_idx
7122
 
7123
  # load datas # grab path; grab sequences #
7124
+ # grab_path = "/data1/xueyi/GRAB_extracted"
7125
+ # ## grab contactmesh ## id2objmeshname
7126
+ # obj_mesh_path = os.path.join(grab_path, 'tools/object_meshes/contact_meshes')
7127
+ # id2objmeshname = []
7128
+ # obj_meshes = sorted(os.listdir(obj_mesh_path))
7129
+ # # objectmesh name #
7130
+ # id2objmeshname = [obj_meshes[i].split(".")[0] for i in range(len(obj_meshes))]
7131
+ # self.id2objmeshname = id2objmeshname
7132
+
7133
+ self.reversed_list = torch.tensor(
7134
+ [0, 5, 6, 7, 9, 10, 11, 17, 18, 19, 13, 14, 15, 1, 2, 3, 4, 8, 12, 16, 20], dtype=torch.long
7135
+ )
7136
 
7137
+ self.forward_list = torch.tensor(
7138
+ [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20], dtype=torch.long
7139
+ )
7140
 
7141
  self.aug_trans_T = 0.05
7142
  self.aug_rot_T = 0.3
 
7156
 
7157
  ## predicted infos fn ##
7158
  self.data_folder = data_folder
7159
+ # self.subj_data_folder = '/data1/xueyi/GRAB_processed_wsubj'
7160
  # self.subj_corr_data_folder = args.subj_corr_data_folder
7161
+ # self.mano_path = "/data1/xueyi/mano_models/mano/models" ### mano_path
7162
+ self.mano_path = "./"
7163
  ## mano paths ##
7164
  self.aug = True
7165
  self.use_anchors = False
 
7167
 
7168
  self.use_anchors = args.use_anchors
7169
 
7170
+ # self.grab_path = "/data1/xueyi/GRAB_extracted"
7171
+ # obj_mesh_path = os.path.join(self.grab_path, 'tools/object_meshes/contact_meshes')
7172
+ # id2objmesh = []
7173
+ # obj_meshes = sorted(os.listdir(obj_mesh_path))
7174
+ # for i, fn in enumerate(obj_meshes):
7175
+ # id2objmesh.append(os.path.join(obj_mesh_path, fn))
7176
+ # self.id2objmesh = id2objmesh
7177
+ # self.id2meshdata = {}
7178
 
7179
  ## obj root folder; ##
7180
+ # ### Load field data from root folders ###
7181
+ # self.obj_root_folder = "/data1/xueyi/GRAB_extracted/tools/object_meshes/contact_meshes_objs"
7182
+ # self.obj_params_folder = "/data1/xueyi/GRAB_extracted/tools/object_meshes/contact_meshes_params" # # and base points
7183
 
7184
  self.load_meta = True
7185
 
 
7219
  self.predicted_hand_joints = outputs # nf x nnjoints x 3 #
7220
  self.predicted_hand_joints = torch.from_numpy(self.predicted_hand_joints).float()
7221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7222
 
7223
  else:
7224
  self.predicted_hand_joints = None
 
7261
  )
7262
 
7263
 
7264
+ # ### Load field data from root folders ### ## obj root folder ##
7265
+ # self.obj_root_folder = "/data1/xueyi/GRAB_extracted/tools/object_meshes/contact_meshes_objs"
7266
+ # self.obj_params_folder = "/data1/xueyi/GRAB_extracted/tools/object_meshes/contact_meshes_params"
7267
 
7268
 
7269
+ # # anchor_load_driver, masking_load_driver #
7270
+ # # use_anchors, self.hand_palm_vertex_mask #
7271
+ # if self.use_anchors: # use anchors # anchor_load_driver, masking_load_driver #
7272
+ # # anchor_load_driver, masking_load_driver #
7273
+ # inpath = "/home/xueyi/sim/CPF/assets" # contact potential field; assets # ##
7274
+ # fvi, aw, _, _ = anchor_load_driver(inpath)
7275
+ # self.face_vertex_index = torch.from_numpy(fvi).long()
7276
+ # self.anchor_weight = torch.from_numpy(aw).float()
7277
 
7278
+ # anchor_path = os.path.join("/home/xueyi/sim/CPF/assets", "anchor")
7279
+ # palm_path = os.path.join("/home/xueyi/sim/CPF/assets", "hand_palm_full.txt")
7280
+ # hand_region_assignment, hand_palm_vertex_mask = masking_load_driver(anchor_path, palm_path)
7281
+ # # self.hand_palm_vertex_mask for hand palm mask #
7282
+ # self.hand_palm_vertex_mask = torch.from_numpy(hand_palm_vertex_mask).bool() ## the mask for hand palm to get hand anchors #
7283
 
7284
  files_clean = [self.seq_path]
7285
 
 
7310
  # (288, 4, 4)
7311
  return clip_data
7312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7313
 
7314
  def get_idx_to_mesh_data(self, obj_id):
7315
  if obj_id not in self.id2meshdata:
 
7454
  ### aug_global_orient_var, aug_pose_var, aug_transl_var ###
7455
  #### ==== get random augmented pose, rot, transl ==== ####
7456
  # rnd_aug_global_orient_var, rnd_aug_pose_var, rnd_aug_transl_var #
7457
+
 
 
7458
  aug_trans, aug_rot, aug_pose = 0.000, 0.00, 0.00
7459
  # noise scale #
7460
  # aug_trans, aug_rot, aug_pose = 0.01, 0.05, 0.3 # scale 1 for the standard scale
 
7502
  rhand_verts = rhand_verts * 0.001
7503
  rhand_joints = rhand_joints * 0.001
7504
 
7505
+ rhand_joints = rhand_joints[:, self.reversed_list]
7506
+
7507
  offset_cur_to_raw = raw_hand_verts[0, 0] - rhand_verts[0, 0]
7508
  rhand_verts = rhand_verts + offset_cur_to_raw.unsqueeze(0)
7509
  rhand_joints = rhand_joints + offset_cur_to_raw.unsqueeze(0)
7510
 
7511
+ # # rhand_anchors, pert_rhand_anchors #
7512
+ # # rhand_anchors, canon_rhand_anchors #
7513
+ # # use_anchors, self.hand_palm_vertex_mask #
7514
+ # if self.use_anchors: # # rhand_anchors: bsz x nn_hand_anchors x 3 #
7515
+ # # rhand_anchors = rhand_verts[:, self.hand_palm_vertex_mask] # nf x nn_anchors x 3 --> for the anchor points ##
7516
+ # rhand_anchors = recover_anchor_batch(rhand_verts, self.face_vertex_index, self.anchor_weight.unsqueeze(0).repeat(self.window_size, 1, 1))
7517
+ # # print(f"rhand_anchors: {rhand_anchors.size()}") ### recover rhand verts here ###
7518
 
7519
 
7520
 
 
7534
  )
7535
  pert_rhand_verts = pert_rhand_verts * 0.001 # verts
7536
  pert_rhand_joints = pert_rhand_joints * 0.001 # joints
7537
+
7538
+ pert_rhand_joints = pert_rhand_joints[:, self.reversed_list]
7539
+
7540
  pert_rhand_verts = pert_rhand_verts + offset_cur_to_raw.unsqueeze(0)
7541
  pert_rhand_joints = pert_rhand_joints + offset_cur_to_raw.unsqueeze(0)
7542
 
7543
+ # if self.use_anchors:
7544
+ # # pert_rhand_anchors = pert_rhand_verts[:, self.hand_palm_vertex_mask]
7545
+ # pert_rhand_anchors = recover_anchor_batch(pert_rhand_verts, self.face_vertex_index, self.anchor_weight.unsqueeze(0).repeat(self.window_size, 1, 1))
7546
+ # # print(f"rhand_anchors: {rhand_anchors.size()}") ### recover rhand verts here ###
7547
 
7548
  # use_canon_joints
7549
 
 
7554
  canon_pert_rhand_verts = canon_pert_rhand_verts * 0.001 # verts
7555
  canon_pert_rhand_joints = canon_pert_rhand_joints * 0.001 # joints
7556
 
7557
+ canon_pert_rhand_joints = canon_pert_rhand_joints[:, self.reversed_list]
7558
+
7559
  # if self.use_anchors:
7560
  # # canon_pert_rhand_anchors = canon_pert_rhand_verts[:, self.hand_palm_vertex_mask]
7561
  # canon_pert_rhand_anchors = recover_anchor_batch(canon_pert_rhand_verts, self.face_vertex_index, self.anchor_weight.unsqueeze(0).repeat(self.window_size, 1, 1))
 
8094
  self.inst_normalization = args.inst_normalization
8095
 
8096
 
8097
+ # # load datas # grab path; grab sequences #
8098
+ # grab_path = "/data1/xueyi/GRAB_extracted"
8099
+ # ## grab contactmesh ## id2objmeshname
8100
+ # obj_mesh_path = os.path.join(grab_path, 'tools/object_meshes/contact_meshes')
8101
+ # id2objmeshname = []
8102
+ # obj_meshes = sorted(os.listdir(obj_mesh_path))
8103
+ # # objectmesh name #
8104
+ # id2objmeshname = [obj_meshes[i].split(".")[0] for i in range(len(obj_meshes))]
8105
+ # self.id2objmeshname = id2objmeshname
8106
 
8107
 
8108
  ## the predicted_info_fn
 
8173
  self.use_anchors = False
8174
 
8175
 
8176
+ # self.grab_path = "/data1/xueyi/GRAB_extracted"
8177
+ # obj_mesh_path = os.path.join(self.grab_path, 'tools/object_meshes/contact_meshes')
8178
+ # id2objmesh = []
8179
+ # obj_meshes = sorted(os.listdir(obj_mesh_path))
8180
+ # for i, fn in enumerate(obj_meshes):
8181
+ # id2objmesh.append(os.path.join(obj_mesh_path, fn))
8182
+ # self.id2objmesh = id2objmesh
8183
+ # self.id2meshdata = {}
8184
 
8185
  ## obj root folder; ##
8186
  ### Load field data from root folders ###
gradio_inter/__pycache__/predict_from_file.cpython-310.pyc CHANGED
Binary files a/gradio_inter/__pycache__/predict_from_file.cpython-310.pyc and b/gradio_inter/__pycache__/predict_from_file.cpython-310.pyc differ
 
gradio_inter/__pycache__/predict_from_file.cpython-38.pyc CHANGED
Binary files a/gradio_inter/__pycache__/predict_from_file.cpython-38.pyc and b/gradio_inter/__pycache__/predict_from_file.cpython-38.pyc differ
 
gradio_inter/predict_from_file.py CHANGED
@@ -144,7 +144,14 @@ import pickle as pkl
144
 
145
 
146
  def main():
147
-
 
 
 
 
 
 
 
148
  args = train_args()
149
 
150
  fixseed(args.seed)
@@ -192,6 +199,7 @@ def main():
192
  st_idxes = list(range(0, num_ending_clearning_frames, nn_st_skip))
193
  if st_idxes[-1] + num_cleaning_frames < nn_frames:
194
  st_idxes.append(nn_frames - num_cleaning_frames)
 
195
  print(f"st_idxes: {st_idxes}")
196
 
197
 
@@ -303,6 +311,11 @@ def main():
303
  cur_ins_targets = cur_targets[cur_ins_rel_idx] / data_scale_factor
304
 
305
  cur_ins_outputs = cur_outputs[cur_ins_rel_idx] / data_scale_factor
 
 
 
 
 
306
  cur_ins_pert_verts = cur_pert_verts[cur_ins_rel_idx, ...]
307
  cur_ins_verts = cur_verts[cur_ins_rel_idx, ...]
308
 
 
144
 
145
 
146
  def main():
147
+
148
+ forward_map = torch.tensor(
149
+ [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20], dtype=torch.long
150
+ )
151
+ reversed_map = torch.tensor(
152
+ [0, 5, 6, 7, 9, 10, 11, 17, 18, 19, 13, 14, 15, 1, 2, 3, 4, 8, 12, 16, 20], dtype=torch.long
153
+ )
154
+
155
  args = train_args()
156
 
157
  fixseed(args.seed)
 
199
  st_idxes = list(range(0, num_ending_clearning_frames, nn_st_skip))
200
  if st_idxes[-1] + num_cleaning_frames < nn_frames:
201
  st_idxes.append(nn_frames - num_cleaning_frames)
202
+ st_idxes = [st_idxes[0]]
203
  print(f"st_idxes: {st_idxes}")
204
 
205
 
 
311
  cur_ins_targets = cur_targets[cur_ins_rel_idx] / data_scale_factor
312
 
313
  cur_ins_outputs = cur_outputs[cur_ins_rel_idx] / data_scale_factor
314
+
315
+ cur_ins_targets = cur_ins_targets[..., forward_map, :]
316
+ cur_ins_outputs = cur_ins_outputs[..., forward_map, :]
317
+
318
+
319
  cur_ins_pert_verts = cur_pert_verts[cur_ins_rel_idx, ...]
320
  cur_ins_verts = cur_verts[cur_ins_rel_idx, ...]
321
 
model/__pycache__/mdm_ours.cpython-38.pyc CHANGED
Binary files a/model/__pycache__/mdm_ours.cpython-38.pyc and b/model/__pycache__/mdm_ours.cpython-38.pyc differ
 
requirements.txt CHANGED
@@ -7,9 +7,9 @@ torch==1.12.1
7
  blobfile==2.0.1
8
  manopth @ git+https://github.com/hassony2/manopth.git
9
  numpy==1.23.1
10
- psutil
11
  scikit-learn
12
- scipy
13
  tensorboard
14
  tensorboardx
15
  tqdm
 
7
  blobfile==2.0.1
8
  manopth @ git+https://github.com/hassony2/manopth.git
9
  numpy==1.23.1
10
+ psutil==5.9.2
11
  scikit-learn
12
+ scipy==1.9.3
13
  tensorboard
14
  tensorboardx
15
  tqdm
sample/__pycache__/reconstruct_data.cpython-38.pyc CHANGED
Binary files a/sample/__pycache__/reconstruct_data.cpython-38.pyc and b/sample/__pycache__/reconstruct_data.cpython-38.pyc differ
 
sample/__pycache__/reconstruct_data_taco.cpython-310.pyc ADDED
Binary file (13.8 kB). View file
 
sample/__pycache__/reconstruct_data_taco.cpython-38.pyc CHANGED
Binary files a/sample/__pycache__/reconstruct_data_taco.cpython-38.pyc and b/sample/__pycache__/reconstruct_data_taco.cpython-38.pyc differ
 
sample/predict_taco.py CHANGED
@@ -163,8 +163,8 @@ def main():
163
  with open(args_path, 'w') as fw:
164
  json.dump(vars(args), fw, indent=4, sort_keys=True)
165
 
166
- tot_hho_seq_paths = [args.single_seq_path]
167
- tot_hho_seq_tags = ["test"]
168
 
169
  data_dict = pkl.load(open(args.single_seq_path, 'rb'))
170
  data_hand_verts = data_dict['hand_verts']
@@ -180,9 +180,14 @@ def main():
180
  st_idxes = list(range(0, num_ending_clearning_frames, nn_st_skip))
181
  if st_idxes[-1] + num_cleaning_frames < nn_frames:
182
  st_idxes.append(nn_frames - num_cleaning_frames)
 
 
183
  print(f"st_idxes: {st_idxes}")
184
 
185
- for cur_seed in range(0, 122, 11):
 
 
 
186
  args.seed = cur_seed
187
 
188
 
 
163
  with open(args_path, 'w') as fw:
164
  json.dump(vars(args), fw, indent=4, sort_keys=True)
165
 
166
+ # tot_hho_seq_paths = [args.single_seq_path]
167
+ # tot_hho_seq_tags = ["test"]
168
 
169
  data_dict = pkl.load(open(args.single_seq_path, 'rb'))
170
  data_hand_verts = data_dict['hand_verts']
 
180
  st_idxes = list(range(0, num_ending_clearning_frames, nn_st_skip))
181
  if st_idxes[-1] + num_cleaning_frames < nn_frames:
182
  st_idxes.append(nn_frames - num_cleaning_frames)
183
+
184
+ st_idxes = [st_idxes[0]]
185
  print(f"st_idxes: {st_idxes}")
186
 
187
+ tot_seeds = [0]
188
+
189
+ # for cur_seed in range(0, 122, 11):
190
+ for cur_seed in tot_seeds:
191
  args.seed = cur_seed
192
 
193
 
sample/reconstruct_data_taco.py CHANGED
@@ -9,14 +9,14 @@ import numpy as np
9
  import os, argparse, copy, json
10
  import pickle as pkl
11
  from scipy.spatial.transform import Rotation as R
12
- from psbody.mesh import Mesh
13
  from manopth.manolayer import ManoLayer
14
 
15
  import trimesh
16
  from utils import *
17
- import utils
18
  import utils.model_util as model_util
19
- from utils.anchor_utils import masking_load_driver, anchor_load_driver, recover_anchor_batch
20
 
21
  from utils.parser_util import train_args
22
 
@@ -50,41 +50,42 @@ def get_penetration_masks(obj_verts, obj_faces, hand_verts):
50
  # interpolatio for unsmooth hand parameters? #
51
  def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=False, nn_hand_params=24, rt_vars=False, with_proj=False, obj_verts_trans=None, obj_faces=None, with_params_smoothing=False, dist_thres=0.005, with_ctx_mask=False):
52
  # obj_verts_trans, obj_faces
53
- joints = torch.from_numpy(joints).float().cuda() # joints
54
- base_pts = torch.from_numpy(base_pts).float().cuda() # base_pts
55
 
56
  if nn_hand_params < 45:
57
  use_pca = True
58
  else:
59
  use_pca = False
60
 
61
- tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float().cuda()
62
- tot_base_normals_trans = torch.from_numpy(tot_base_normals_trans).float().cuda()
63
  ### start optimization ###
64
  # setup MANO layer
65
- mano_path = "/data1/xueyi/mano_models/mano/models"
 
66
  nn_hand_params = 24
67
  use_pca = True
68
- if not use_left:
69
- mano_layer = ManoLayer(
70
- flat_hand_mean=False,
71
- side='right',
72
- mano_root=mano_path, # mano_root #
73
- ncomps=nn_hand_params, # hand params #
74
- use_pca=use_pca, # pca for pca #
75
- root_rot_mode='axisang',
76
- joint_rot_mode='axisang'
77
- ).cuda()
78
- else:
79
- mano_layer = ManoLayer(
80
- flat_hand_mean=True,
81
- side='left',
82
- mano_root=mano_path, # mano_root #
83
- ncomps=nn_hand_params, # hand params #
84
- use_pca=use_pca, # pca for pca #
85
- root_rot_mode='axisang',
86
- joint_rot_mode='axisang'
87
- ).cuda()
88
 
89
  # mano_layer = ManoLayer(
90
  # flat_hand_mean=False,
@@ -94,31 +95,31 @@ def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans
94
  # use_pca=use_pca, # pca for pca #
95
  # root_rot_mode='axisang',
96
  # joint_rot_mode='axisang'
97
- # ).cuda()
98
  nn_frames = joints.size(0)
99
 
100
 
101
  # anchor_load_driver, masking_load_driver #
102
- inpath = "/home/xueyi/sim/CPF/assets" # contact potential field; assets # ##
103
- fvi, aw, _, _ = anchor_load_driver(inpath)
104
- face_vertex_index = torch.from_numpy(fvi).long().cuda()
105
- anchor_weight = torch.from_numpy(aw).float().cuda()
106
 
107
- anchor_path = os.path.join("/home/xueyi/sim/CPF/assets", "anchor")
108
- palm_path = os.path.join("/home/xueyi/sim/CPF/assets", "hand_palm_full.txt")
109
- hand_region_assignment, hand_palm_vertex_mask = masking_load_driver(anchor_path, palm_path)
110
- # self.hand_palm_vertex_mask for hand palm mask #
111
- hand_palm_vertex_mask = torch.from_numpy(hand_palm_vertex_mask).bool().cuda() ## the mask for hand palm to get hand anchors #
112
 
113
 
114
 
115
 
116
  # initialize variables
117
- beta_var = torch.randn([1, 10]).cuda()
118
  # first 3 global orientation
119
- rot_var = torch.randn([nn_frames, 3]).cuda()
120
- theta_var = torch.randn([nn_frames, nn_hand_params]).cuda()
121
- transl_var = torch.randn([nn_frames, 3]).cuda()
122
 
123
  # 3 + 45 + 3 = 51 for computing #
124
  # transl_var = tot_rhand_transl.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous()
@@ -151,11 +152,11 @@ def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans
151
  nk_contact_pts = 2
152
  minn_dist[:, :-5] = 1e9
153
  minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) #
154
- # joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0).cuda() ==
155
  minn_topk_mask = torch.zeros_like(minn_dist)
156
  # minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints #
157
  minn_topk_mask[:, -5: -3] = 1.
158
- basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0).cuda()
159
  minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
160
  # for seq 101
161
  # minn_dist_mask[31:, -5, :] = minn_dist_mask[30: 31, -5, :]
@@ -349,7 +350,7 @@ def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans
349
  else:
350
  smoothed_theta_var.append(cur_theta_var.detach().clone())
351
  smoothed_theta_var = torch.stack(smoothed_theta_var, dim=0) # smoothed_theta_var: nf x nn_theta_dim
352
- theta_var = torch.randn_like(smoothed_theta_var).cuda()
353
  theta_var.data = smoothed_theta_var.data
354
  # theta_var = smoothed_theta_var.clone()
355
  theta_var.requires_grad_() # for
@@ -395,258 +396,260 @@ def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans
395
 
396
 
397
  window_size = hand_verts.size(0)
398
- if with_contact_opt:
399
- num_iters = 2000
400
- num_iters = 1000 # seq 77 # if with contact opt #
401
- # num_iters = 500 # seq 77
402
- ori_theta_var = theta_var.detach().clone()
403
-
404
- # tot_base_pts_trans # nf x nn_base_pts x 3
405
- disp_base_pts_trans = tot_base_pts_trans[1:] - tot_base_pts_trans[:-1] # (nf - 1) x nn_base_pts x 3
406
- disp_base_pts_trans = torch.cat( # nf x nn_base_pts x 3
407
- [disp_base_pts_trans, disp_base_pts_trans[-1:]], dim=0
408
- )
409
-
410
- rhand_anchors = recover_anchor_batch(hand_verts.detach(), face_vertex_index, anchor_weight.unsqueeze(0).repeat(window_size, 1, 1))
411
-
412
- dist_joints_to_base_pts = torch.sum(
413
- (rhand_anchors.unsqueeze(-2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts #
414
- )
415
-
416
- nn_base_pts = dist_joints_to_base_pts.size(-1)
417
- nn_joints = dist_joints_to_base_pts.size(1)
418
-
419
- dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts #
420
- minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints #
421
-
422
- nk_contact_pts = 2
423
- # minn_dist[:, :-5] = 1e9
424
- # minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) #
425
- # joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0).cuda() ==
426
- # minn_topk_mask = torch.zeros_like(minn_dist)
427
- # minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints #
428
- # minn_topk_mask[:, -5: -3] = 1.
429
- basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0).cuda()
430
- minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
431
- # for seq 101
432
- # minn_dist_mask[31:, -5, :] = minn_dist_mask[30: 31, -5, :]
433
- minn_dist_mask = minn_dist_mask.float()
434
-
435
- #
436
- # for seq 47
437
- # if cat_nm in ["Scissors"]:
438
- # # minn_dist_mask[:] = minn_dist_mask[11:12, :, :]
439
- # # minn_dist_mask[:11] = False
440
-
441
- # # if i_test_seq == 24:
442
- # # minn_dist_mask[20:] = minn_dist_mask[20:21, :, :]
443
- # # else:
444
-
445
- # # minn_dist_mask[:] = minn_dist_mask[11:12, :, :]
446
- # minn_dist_mask[:] = minn_dist_mask[20:21, :, :]
447
-
448
- attraction_mask_new = (tot_base_pts_trans_disp_mask.float().unsqueeze(-1).unsqueeze(-1) + minn_dist_mask.float()) > 1.5
449
-
450
-
451
-
452
- # joints: nf x nn_jts_pts x 3; nf x nn_base_pts x 3
453
- dist_joints_to_base_pts_trans = torch.sum(
454
- (rhand_anchors.unsqueeze(2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nn_jts_pts x nn_base_pts
455
- )
456
- minn_dist_joints_to_base_pts, minn_dist_idxes = torch.min(dist_joints_to_base_pts_trans, dim=-1) # nf x nn_jts_pts # nf x nn_jts_pts #
457
- nearest_base_normals = model_util.batched_index_select_ours(tot_base_normals_trans, indices=minn_dist_idxes, dim=1) # nf x nn_base_pts x 3 --> nf x nn_jts_pts x 3 # # nf x nn_jts_pts x 3 #
458
- nearest_base_pts_trans = model_util.batched_index_select_ours(disp_base_pts_trans, indices=minn_dist_idxes, dim=1) # nf x nn_jts_ts x 3 #
459
- dot_nearest_base_normals_trans = torch.sum(
460
- nearest_base_normals * nearest_base_pts_trans, dim=-1 # nf x nn_jts
461
- )
462
- trans_normals_mask = dot_nearest_base_normals_trans < 0. # nf x nn_jts # nf x nn_jts #
463
- nearest_dist = torch.sqrt(minn_dist_joints_to_base_pts)
464
-
465
- # dist_thres
466
- nearest_dist_mask = nearest_dist < dist_thres # hoi seq
467
- # nearest_dist_mask = nearest_dist < 0.005 # hoi seq
468
-
469
- # nearest_dist_mask = nearest_dist < 0.03 # hoi seq
470
- # nearest_dist_mask = nearest_dist < 0.5 # hoi seq # seq 47
471
- # nearest_dist_mask = nearest_dist < 0.1 # hoi seq # seq 47
472
-
473
-
474
-
475
- # nearest_dist_mask = nearest_dist < 0.1
476
- k_attr = 100.
477
- joint_attraction_k = torch.exp(-1. * k_attr * nearest_dist)
478
- # attraction_mask_new_new = (attraction_mask_new.float() + trans_normals_mask.float().unsqueeze(-1) + nearest_dist_mask.float().unsqueeze(-1)) > 2.5
479
-
480
- attraction_mask_new_new = (attraction_mask_new.float() + nearest_dist_mask.float().unsqueeze(-1)) > 1.5
481
-
482
- # if cat_nm in ["ToyCar", "Pliers", "Bottle", "Mug", "Scissors"]:
483
- anchor_masks = [2, 3, 4, 9, 10, 11, 15, 16, 17, 22, 23, 24, 29, 30, 31]
484
- anchor_nmasks = [iid for iid in range(attraction_mask_new_new.size(1)) if iid not in anchor_masks]
485
- anchor_nmasks = torch.tensor(anchor_nmasks, dtype=torch.long).cuda()
486
- attraction_mask_new_new[:, anchor_nmasks, :] = False # scissors
487
-
488
- # # for seq 47
489
- # elif cat_nm in ["Scissors"]:
490
- # anchor_masks = [2, 3, 4, 15, 16, 17, 22, 23, 24]
491
- # anchor_nmasks = [iid for iid in range(attraction_mask_new_new.size(1)) if iid not in anchor_masks]
492
- # anchor_nmasks = torch.tensor(anchor_nmasks, dtype=torch.long).cuda()
493
- # attraction_mask_new_new[:, anchor_nmasks, :] = False
494
-
495
- # anchor_masks = torch.array([2, 3, 4, 15, 16, 17, 22, 23, 24], dtype=torch.long).cuda()
496
- # anchor_masks = torch.arange(attraction_mask_new_new.size(1)).unsqueeze(0).unsqueeze(-1).cuda() !=
497
-
498
- # [2, 3, 4]
499
- # [9, 10, 11]
500
- # [15, 16, 17]
501
- # [22, 23, 24]
502
- # seq 47: [2, 3, 4, 15, 16, 17, 22, 23, 24]
503
- # motion planning? #
504
-
505
-
506
- transl_var_ori = transl_var.clone().detach()
507
- # transl_var, theta_var, rot_var, beta_var #
508
- # opt = optim.Adam([rot_var, transl_var, theta_var], lr=learning_rate)
509
- # opt = optim.Adam([transl_var, theta_var], lr=learning_rate)
510
- opt = optim.Adam([transl_var, theta_var, rot_var], lr=learning_rate)
511
- # opt = optim.Adam([theta_var, rot_var], lr=learning_rate)
512
- scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5)
513
- for i in range(num_iters):
514
- opt.zero_grad()
515
- # mano_layer
516
- hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
517
- beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
518
- hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
519
- hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
 
 
520
 
521
- joints_pred_loss = torch.sum(
522
- (hand_joints - joints) ** 2, dim=-1
523
- ).mean()
524
 
525
- #
526
- rhand_anchors = recover_anchor_batch(hand_verts, face_vertex_index, anchor_weight.unsqueeze(0).repeat(window_size, 1, 1))
527
 
528
 
529
- # dist_joints_to_base_pts_sqr = torch.sum(
530
- # (hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1
531
- # ) # nf x nnb x 3 ---- nf x nnj x 1 x 3
532
- dist_joints_to_base_pts_sqr = torch.sum(
533
- (rhand_anchors.unsqueeze(2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1
534
- )
535
- # attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr
536
- attaction_loss = 0.5 * dist_joints_to_base_pts_sqr
537
- # attaction_loss = attaction_loss
538
- # attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :])
539
 
540
- # attaction_loss = torch.mean(attaction_loss * attraction_mask)
541
- # attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
542
 
543
- # seq 80
544
- # attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
545
 
546
- # seq 70
547
- # attaction_loss = torch.mean(attaction_loss[10:, -5:-3, :] * minn_dist_mask[10:, -5:-3, :]) # + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
548
 
549
- # new version relying on new mask #
550
- # attaction_loss = torch.mean(attaction_loss[:, -5:-3, :] * attraction_mask_new[:, -5:-3, :])
551
- ### original version ###
552
- # attaction_loss = torch.mean(attaction_loss[20:, -3:, :] * attraction_mask_new[20:, -3:, :])
553
 
554
- # attaction_loss = torch.mean(attaction_loss[:, -5:, :] * attraction_mask_new_new[:, -5:, :] * joint_attraction_k[:, -5:].unsqueeze(-1))
555
 
556
 
557
 
558
- attaction_loss = torch.mean(attaction_loss[:, :, :] * attraction_mask_new_new[:, :, :] * joint_attraction_k[:, :].unsqueeze(-1))
559
 
560
 
561
- # seq mug
562
- # attaction_loss = torch.mean(attaction_loss[4:, -5:-4, :] * minn_dist_mask[4:, -5:-4, :]) # + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
563
 
564
 
565
- # opt.zero_grad()
566
- pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1])
567
- # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
568
- shape_prior_loss = torch.mean(beta_var**2)
569
- pose_prior_loss = torch.mean(theta_var**2)
570
- joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1])
571
- # =0.05
572
- # # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100.
573
- # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200.
574
 
575
- # loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200.
576
 
577
- # loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + attaction_loss * 10000 # + joints_smoothness_loss * 200.
578
 
579
- theta_smoothness_loss = F.mse_loss(theta_var, ori_theta_var)
580
 
581
- transl_smoothness_loss = F.mse_loss(transl_var_ori, transl_var)
582
- # loss = attaction_loss * 1000. + theta_smoothness_loss * 0.00001
583
 
584
- # attraction loss, joint prediction loss, joints smoothness loss #
585
- # loss = attaction_loss * 1000. + joints_pred_loss
586
- ### general ###
587
- # loss = attaction_loss * 1000. + joints_pred_loss * 0.01 + joints_smoothness_loss * 0.5 # + pose_prior_loss * 0.00005 # + shape_prior_loss * 0.001 # + pose_smoothness_loss * 0.5
588
 
589
- # tune for seq 140
590
- loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.5 # + pose_prior_loss * 0.00005 # + shape_prior_loss * 0.001 # + pose_smoothness_loss * 0.5
591
- # ToyCar #
592
- loss = attaction_loss * 10000. + joints_pred_loss * 0.01 + joints_smoothness_loss * 0.5 # + pose_prior_loss * 0.00005 # + shape_prior_loss * 0.001 # + pose_smoothness_loss * 0.5
593
 
594
- # if cat_nm in ["Scissors"]:
595
- # # scissors and
596
- # # if dist_thres < 0.05:
597
- # # loss = transl_smoothness_loss * 0.5 + attaction_loss * 10000. + joints_pred_loss * 0.01 + joints_smoothness_loss * 0.5
598
- # # else:
599
- # # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.05
600
- # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.005
601
- # # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.005
602
-
603
- # if dist_thres < 0.05:
604
- # # loss = transl_smoothness_loss * 0.5 + attaction_loss * 10000. + joints_pred_loss * 0.01 + joints_smoothness_loss * 0.5
605
- # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.5
606
-
607
- # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + pose_smoothness_loss * 0.0005 + joints_smoothness_loss * 0.5
608
-
609
- # nearest_dist_shape, _ = torch.min(nearest_dist, dim=-1)
610
- # nearest_dist_shape_mask = nearest_dist_shape > 0.01
611
- # transl_smoothness_loss_v2 = torch.sum((transl_var_ori - transl_var) ** 2, dim=-1)
612
- # transl_smoothness_loss_v2 = torch.mean(transl_smoothness_loss_v2[nearest_dist_shape_mask])
613
- # loss = transl_smoothness_loss * 0.5 + attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.5
614
-
615
- # else:
616
- # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.05
617
- # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.005
618
 
619
- # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.005
620
 
621
- # loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001
622
- # loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
623
 
624
- # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0
625
 
626
- # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0
627
 
628
- # loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0
629
- # loss = joints_pred_loss * 30 + attaction_loss * 0.001
630
 
631
- opt.zero_grad()
632
- loss.backward()
633
- opt.step()
634
- scheduler.step()
635
 
636
- print('Iter {}: {}'.format(i, loss.item()), flush=True)
637
- print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
638
- print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
639
- print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
640
- print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
641
- print('\tAttraction Loss: {}'.format(attaction_loss.item()))
642
- print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item()))
643
- # theta_smoothness_loss
644
- print('\tTheta Smoothness Loss: {}'.format(theta_smoothness_loss.item()))
645
- # transl_smoothness_loss
646
- print('\tTransl Smoothness Loss: {}'.format(transl_smoothness_loss.item()))
647
 
648
- rhand_anchors_np = rhand_anchors.detach().cpu().numpy()
649
- np.save("out_anchors.npy", rhand_anchors_np)
650
  ### optimized dict before projection ###
651
  bf_proj_optimized_dict = {
652
  'bf_ctx_mask_verts': hand_verts.detach().cpu().numpy(),
@@ -659,48 +662,48 @@ def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans
659
 
660
 
661
 
662
- ###
663
- if with_ctx_mask:
664
- tot_penetration_masks_bf_contact_opt_frame_nmask_np = tot_penetration_masks_bf_contact_opt_frame_nmask.detach().cpu().numpy()
665
- hand_verts = hand_verts.detach()
666
- hand_joints = hand_joints.detach()
667
- rot_var = rot_var.detach()
668
- theta_var = theta_var.detach()
669
- beta_var = beta_var.detach()
670
- transl_var = transl_var.detach()
671
- hand_verts = hand_verts.detach()
672
-
673
- # tot_base_pts_trans_disp_mask ### total penetration masks
674
- tot_base_pts_trans_disp_mask_n = (1. - tot_base_pts_trans_disp_mask.float()) > 0.5 ### that object do not move
675
- tot_penetration_masks_bf_contact_opt_frame_nmask = (tot_penetration_masks_bf_contact_opt_frame_nmask.float() + tot_base_pts_trans_disp_mask_n.float()) > 1.5
676
- hand_verts[tot_penetration_masks_bf_contact_opt_frame_nmask] = bf_ct_verts_th[tot_penetration_masks_bf_contact_opt_frame_nmask]
677
- hand_joints[tot_penetration_masks_bf_contact_opt_frame_nmask] = bf_ct_joints_th[tot_penetration_masks_bf_contact_opt_frame_nmask]
678
- rot_var[tot_penetration_masks_bf_contact_opt_frame_nmask] = bf_ct_rot_var_th[tot_penetration_masks_bf_contact_opt_frame_nmask]
679
- theta_var[tot_penetration_masks_bf_contact_opt_frame_nmask] = bf_ct_theta_var_th[tot_penetration_masks_bf_contact_opt_frame_nmask]
680
- # beta_var[tot_penetration_masks_bf_contact_opt_frame_nmask] = bf_ct_beta_var_th[tot_penetration_masks_bf_contact_opt_frame_nmask]
681
- transl_var[tot_penetration_masks_bf_contact_opt_frame_nmask] = bf_ct_transl_var_th[tot_penetration_masks_bf_contact_opt_frame_nmask]
682
-
683
- rot_var_tmp = torch.randn_like(rot_var)
684
- theta_var_tmp = torch.randn_like(theta_var)
685
- transl_var_tmp = torch.randn_like(transl_var)
686
-
687
- rot_var_tmp.data = rot_var.data.clone()
688
- theta_var_tmp.data = theta_var.data.clone()
689
- transl_var_tmp.data = transl_var.data.clone()
690
-
691
- rot_var = torch.randn_like(rot_var)
692
- theta_var = torch.randn_like(theta_var)
693
- transl_var = torch.randn_like(transl_var)
694
-
695
- rot_var.data = rot_var_tmp.data.clone()
696
- theta_var.data = theta_var_tmp.data.clone()
697
- transl_var.data = transl_var_tmp.data.clone()
698
-
699
-
700
- rot_var = rot_var.requires_grad_()
701
- theta_var = theta_var.requires_grad_()
702
- beta_var = beta_var.requires_grad_()
703
- transl_var = transl_var.requires_grad_()
704
 
705
 
706
 
@@ -722,142 +725,142 @@ def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans
722
  'bf_proj_transl_var': bf_proj_transl_var,
723
  } )
724
 
725
- if with_proj:
726
- num_iters = 2000
727
- num_iters = 1000 # seq 77 # if with contact opt #
728
- # num_iters = 500 # seq 77
729
- ori_theta_var = theta_var.detach().clone()
730
-
731
- nearest_base_pts=None
732
- nearest_base_normals=None
733
-
734
- # obj_verts_trans, obj_faces
735
- if cat_nm in ["Mug"]:
736
- # tot_penetration_masks = None
737
- tot_penetration_masks = get_penetration_masks(obj_verts_trans, obj_faces, hand_verts)
738
- else:
739
- tot_penetration_masks = get_penetration_masks(obj_verts_trans, obj_faces, hand_verts)
740
- # tot_penetration_masks = None
741
-
742
- # opt = optim.Adam([rot_var, transl_var, theta_var], lr=learning_rate)
743
- # opt = optim.Adam([transl_var, theta_var], lr=learning_rate)
744
- opt = optim.Adam([transl_var, theta_var, rot_var], lr=learning_rate)
745
- scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5)
746
- for i in range(num_iters):
747
- opt.zero_grad()
748
- # mano_layer
749
- hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
750
- beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
751
- hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
752
- hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
753
 
754
- # tot_base_pts_trans, tot_base_normals_trans #
755
- # obj_verts_trans, ### nearest base pts ####
756
- proj_loss, nearest_base_pts, nearest_base_normals = get_proj_losses(hand_verts, tot_base_pts_trans, tot_base_normals_trans, tot_penetration_masks, nearest_base_pts=nearest_base_pts, nearest_base_normals=nearest_base_normals)
757
 
758
- # rhand_anchors = recover_anchor_batch(hand_verts, face_vertex_index, anchor_weight.unsqueeze(0).repeat(window_size, 1, 1))
759
 
760
 
761
- # hand_joints = rhand_anchors
762
 
763
- joints_pred_loss = torch.sum(
764
- (hand_joints - joints) ** 2, dim=-1
765
- ).mean()
766
 
767
 
768
 
769
 
770
- # # dist_joints_to_base_pts_sqr = torch.sum(
771
- # # (hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1
772
- # # ) # nf x nnb x 3 ---- nf x nnj x 1 x 3
773
- # dist_joints_to_base_pts_sqr = torch.sum(
774
- # (rhand_anchors.unsqueeze(2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1
775
- # )
776
- # # attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr
777
- # attaction_loss = 0.5 * dist_joints_to_base_pts_sqr
778
- # # attaction_loss = attaction_loss
779
- # # attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :])
780
 
781
- # # attaction_loss = torch.mean(attaction_loss * attraction_mask)
782
- # # attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
783
 
784
- # # seq 80
785
- # # attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
786
 
787
- # # seq 70
788
- # # attaction_loss = torch.mean(attaction_loss[10:, -5:-3, :] * minn_dist_mask[10:, -5:-3, :]) # + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
789
 
790
- # # new version relying on new mask #
791
- # # attaction_loss = torch.mean(attaction_loss[:, -5:-3, :] * attraction_mask_new[:, -5:-3, :])
792
- # ### original version ###
793
- # # attaction_loss = torch.mean(attaction_loss[20:, -3:, :] * attraction_mask_new[20:, -3:, :])
794
 
795
- # # attaction_loss = torch.mean(attaction_loss[:, -5:, :] * attraction_mask_new_new[:, -5:, :] * joint_attraction_k[:, -5:].unsqueeze(-1))
796
 
797
 
798
 
799
- # attaction_loss = torch.mean(attaction_loss[:, :, :] * attraction_mask_new_new[:, :, :] * joint_attraction_k[:, :].unsqueeze(-1))
800
 
801
 
802
- # seq mug
803
- # attaction_loss = torch.mean(attaction_loss[4:, -5:-4, :] * minn_dist_mask[4:, -5:-4, :]) # + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
804
 
805
 
806
- # opt.zero_grad()
807
- pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1])
808
- # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
809
- shape_prior_loss = torch.mean(beta_var**2)
810
- pose_prior_loss = torch.mean(theta_var**2)
811
- joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1])
812
- # =0.05
813
- # # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100.
814
- # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200.
815
 
816
- # loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200.
817
 
818
- # loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + attaction_loss * 10000 # + joints_smoothness_loss * 200.
819
 
820
- theta_smoothness_loss = F.mse_loss(theta_var, ori_theta_var)
821
- # loss = attaction_loss * 1000. + theta_smoothness_loss * 0.00001
822
 
823
- # attraction loss, joint prediction loss, joints smoothness loss #
824
- # loss = attaction_loss * 1000. + joints_pred_loss
825
- ### general ###
826
- # loss = attaction_loss * 1000. + joints_pred_loss * 0.01 + joints_smoothness_loss * 0.5 # + pose_prior_loss * 0.00005 # + shape_prior_loss * 0.001 # + pose_smoothness_loss * 0.5
827
 
828
- # tune for seq 140
829
- # loss = proj_loss * 1. + joints_pred_loss * 0.05 + joints_smoothness_loss * 0.5 # + pose_prior_loss * 0.00005 # + shape_prior_loss * 0.001 # + pose_smoothness_loss * 0.5
830
- loss = proj_loss * 1. + joints_pred_loss * 1.0 + joints_smoothness_loss * 0.5
831
- if cat_nm in ["Pliers"]:
832
- # loss = proj_loss * 1. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.05
833
- loss = proj_loss * 1. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.05
834
-
835
- elif cat_nm in ["Bottle"]:
836
- loss = proj_loss * 1. + joints_pred_loss * 0.01 + joints_smoothness_loss * 0.05
837
- # loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001
838
- # loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
839
 
840
- # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0
841
 
842
- # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0
843
 
844
- # loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0
845
- # loss = joints_pred_loss * 30 + attaction_loss * 0.001
846
 
847
- opt.zero_grad()
848
- loss.backward()
849
- opt.step()
850
- scheduler.step()
851
 
852
- print('Iter {}: {}'.format(i, loss.item()), flush=True)
853
- print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
854
- print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
855
- print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
856
- print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
857
- print('\tproj_loss Loss: {}'.format(proj_loss.item()))
858
- print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item()))
859
- # theta_smoothness_loss
860
- print('\tTheta Smoothness Loss: {}'.format(theta_smoothness_loss.item()))
861
 
862
  ### ### verts and joints before contact opt ### ###
863
  # bf_ct_verts, bf_ct_joints #
@@ -1091,12 +1094,17 @@ def get_proj_losses(hand_verts, base_pts, base_normals, tot_penetration_masks, n
1091
 
1092
 
1093
  ## optimization ##
1094
- if __name__=='__main__':
 
 
1095
 
1096
- args = train_args()
1097
 
1098
-
1099
- single_seq_path = args.single_seq_path
 
 
 
 
1100
 
1101
  print(f"Reconstructing meshes for sequence: {single_seq_path}")
1102
 
@@ -1112,20 +1120,27 @@ if __name__=='__main__':
1112
  st_idxes = list(range(0, num_ending_clearning_frames, nn_st_skip))
1113
  if st_idxes[-1] + num_cleaning_frames < nn_frames:
1114
  st_idxes.append(nn_frames - num_cleaning_frames)
 
 
1115
  print(f"st_idxes: {st_idxes}")
1116
 
 
 
 
1117
  clip_ending_idxes = [st_idxes[cur_ts + 1] - st_idxes[cur_ts] for cur_ts in range(len(st_idxes) - 1)]
1118
 
1119
 
1120
- test_tags = [f"{args.test_tag}_st_{cur_st}" for cur_st in st_idxes]
1121
 
1122
 
1123
- pred_infos_sv_folder = args.save_dir
1124
- new_pred_infos_sv_folder = args.save_dir
1125
 
1126
 
1127
 
1128
- tot_rnd_seeds = range(0, 121, 11)
 
 
1129
 
1130
  # st_idxx = 0
1131
 
@@ -1227,8 +1242,8 @@ if __name__=='__main__':
1227
 
1228
  #### obj_verts_trans, obj_faces ####
1229
  obj_verts_trans = np.matmul(obj_verts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1230
- obj_faces = data['template_obj_fs']
1231
- print(f"obj_verts_trans: {obj_verts_trans.shape}, obj_faces: {obj_faces.shape}")
1232
 
1233
 
1234
 
@@ -1267,10 +1282,12 @@ if __name__=='__main__':
1267
  )
1268
 
1269
 
 
 
1270
  optimized_sv_infos_sv_fn_nm = f"optimized_infos_sv_dict_seed_{seed}_tag_{test_tags[0]}_ntag_{len(test_tags)}.npy"
1271
 
1272
  optimized_sv_infos_sv_fn = os.path.join(new_pred_infos_sv_folder, optimized_sv_infos_sv_fn_nm)
1273
  np.save(optimized_sv_infos_sv_fn, optimized_sv_infos)
1274
  print(f"optimized infos saved to {optimized_sv_infos_sv_fn}")
1275
 
1276
-
 
9
  import os, argparse, copy, json
10
  import pickle as pkl
11
  from scipy.spatial.transform import Rotation as R
12
+ # from psbody.mesh import Mesh
13
  from manopth.manolayer import ManoLayer
14
 
15
  import trimesh
16
  from utils import *
17
+ # import utils
18
  import utils.model_util as model_util
19
+ # from utils.anchor_utils import masking_load_driver, anchor_load_driver, recover_anchor_batch
20
 
21
  from utils.parser_util import train_args
22
 
 
50
  # interpolatio for unsmooth hand parameters? #
51
  def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=False, nn_hand_params=24, rt_vars=False, with_proj=False, obj_verts_trans=None, obj_faces=None, with_params_smoothing=False, dist_thres=0.005, with_ctx_mask=False):
52
  # obj_verts_trans, obj_faces
53
+ joints = torch.from_numpy(joints).float() # # joints
54
+ base_pts = torch.from_numpy(base_pts).float() # # base_pts
55
 
56
  if nn_hand_params < 45:
57
  use_pca = True
58
  else:
59
  use_pca = False
60
 
61
+ tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float()
62
+ tot_base_normals_trans = torch.from_numpy(tot_base_normals_trans).float()
63
  ### start optimization ###
64
  # setup MANO layer
65
+ # mano_path = "/data1/xueyi/mano_models/mano/models"
66
+ mano_path = "./"
67
  nn_hand_params = 24
68
  use_pca = True
69
+ # if not use_left:
70
+ # mano_layer = ManoLayer(
71
+ # flat_hand_mean=False,
72
+ # side='right',
73
+ # mano_root=mano_path, # mano_root #
74
+ # ncomps=nn_hand_params, # hand params #
75
+ # use_pca=use_pca, # pca for pca #
76
+ # root_rot_mode='axisang',
77
+ # joint_rot_mode='axisang'
78
+ # )
79
+ # else:
80
+ mano_layer = ManoLayer(
81
+ flat_hand_mean=True,
82
+ side='left',
83
+ mano_root=mano_path, # mano_root #
84
+ ncomps=nn_hand_params, # hand params #
85
+ use_pca=use_pca, # pca for pca #
86
+ root_rot_mode='axisang',
87
+ joint_rot_mode='axisang'
88
+ )
89
 
90
  # mano_layer = ManoLayer(
91
  # flat_hand_mean=False,
 
95
  # use_pca=use_pca, # pca for pca #
96
  # root_rot_mode='axisang',
97
  # joint_rot_mode='axisang'
98
+ # )
99
  nn_frames = joints.size(0)
100
 
101
 
102
  # anchor_load_driver, masking_load_driver #
103
+ # inpath = "/home/xueyi/sim/CPF/assets" # contact potential field; assets # ##
104
+ # fvi, aw, _, _ = anchor_load_driver(inpath)
105
+ # face_vertex_index = torch.from_numpy(fvi).long()
106
+ # anchor_weight = torch.from_numpy(aw).float()
107
 
108
+ # anchor_path = os.path.join("/home/xueyi/sim/CPF/assets", "anchor")
109
+ # palm_path = os.path.join("/home/xueyi/sim/CPF/assets", "hand_palm_full.txt")
110
+ # hand_region_assignment, hand_palm_vertex_mask = masking_load_driver(anchor_path, palm_path)
111
+ # # self.hand_palm_vertex_mask for hand palm mask #
112
+ # hand_palm_vertex_mask = torch.from_numpy(hand_palm_vertex_mask).bool() ## the mask for hand palm to get hand anchors #
113
 
114
 
115
 
116
 
117
  # initialize variables
118
+ beta_var = torch.randn([1, 10])
119
  # first 3 global orientation
120
+ rot_var = torch.randn([nn_frames, 3])
121
+ theta_var = torch.randn([nn_frames, nn_hand_params])
122
+ transl_var = torch.randn([nn_frames, 3])
123
 
124
  # 3 + 45 + 3 = 51 for computing #
125
  # transl_var = tot_rhand_transl.unsqueeze(0).repeat(args.num_init, 1, 1).contiguous().to(device).view(args.num_init * num_frames, 3).contiguous()
 
152
  nk_contact_pts = 2
153
  minn_dist[:, :-5] = 1e9
154
  minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) #
155
+ # joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0) ==
156
  minn_topk_mask = torch.zeros_like(minn_dist)
157
  # minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints #
158
  minn_topk_mask[:, -5: -3] = 1.
159
+ basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0)
160
  minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
161
  # for seq 101
162
  # minn_dist_mask[31:, -5, :] = minn_dist_mask[30: 31, -5, :]
 
350
  else:
351
  smoothed_theta_var.append(cur_theta_var.detach().clone())
352
  smoothed_theta_var = torch.stack(smoothed_theta_var, dim=0) # smoothed_theta_var: nf x nn_theta_dim
353
+ theta_var = torch.randn_like(smoothed_theta_var)
354
  theta_var.data = smoothed_theta_var.data
355
  # theta_var = smoothed_theta_var.clone()
356
  theta_var.requires_grad_() # for
 
396
 
397
 
398
  window_size = hand_verts.size(0)
399
+
400
+
401
+ # if with_contact_opt:
402
+ # num_iters = 2000
403
+ # num_iters = 1000 # seq 77 # if with contact opt #
404
+ # # num_iters = 500 # seq 77
405
+ # ori_theta_var = theta_var.detach().clone()
406
+
407
+ # # tot_base_pts_trans # nf x nn_base_pts x 3
408
+ # disp_base_pts_trans = tot_base_pts_trans[1:] - tot_base_pts_trans[:-1] # (nf - 1) x nn_base_pts x 3
409
+ # disp_base_pts_trans = torch.cat( # nf x nn_base_pts x 3
410
+ # [disp_base_pts_trans, disp_base_pts_trans[-1:]], dim=0
411
+ # )
412
+
413
+ # rhand_anchors = recover_anchor_batch(hand_verts.detach(), face_vertex_index, anchor_weight.unsqueeze(0).repeat(window_size, 1, 1))
414
+
415
+ # dist_joints_to_base_pts = torch.sum(
416
+ # (rhand_anchors.unsqueeze(-2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts #
417
+ # )
418
+
419
+ # nn_base_pts = dist_joints_to_base_pts.size(-1)
420
+ # nn_joints = dist_joints_to_base_pts.size(1)
421
+
422
+ # dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts #
423
+ # minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints #
424
+
425
+ # nk_contact_pts = 2
426
+ # # minn_dist[:, :-5] = 1e9
427
+ # # minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) #
428
+ # # joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0) ==
429
+ # # minn_topk_mask = torch.zeros_like(minn_dist)
430
+ # # minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints #
431
+ # # minn_topk_mask[:, -5: -3] = 1.
432
+ # basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0)
433
+ # minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
434
+ # # for seq 101
435
+ # # minn_dist_mask[31:, -5, :] = minn_dist_mask[30: 31, -5, :]
436
+ # minn_dist_mask = minn_dist_mask.float()
437
+
438
+ # #
439
+ # # for seq 47
440
+ # # if cat_nm in ["Scissors"]:
441
+ # # # minn_dist_mask[:] = minn_dist_mask[11:12, :, :]
442
+ # # # minn_dist_mask[:11] = False
443
+
444
+ # # # if i_test_seq == 24:
445
+ # # # minn_dist_mask[20:] = minn_dist_mask[20:21, :, :]
446
+ # # # else:
447
+
448
+ # # # minn_dist_mask[:] = minn_dist_mask[11:12, :, :]
449
+ # # minn_dist_mask[:] = minn_dist_mask[20:21, :, :]
450
+
451
+ # attraction_mask_new = (tot_base_pts_trans_disp_mask.float().unsqueeze(-1).unsqueeze(-1) + minn_dist_mask.float()) > 1.5
452
+
453
+
454
+
455
+ # # joints: nf x nn_jts_pts x 3; nf x nn_base_pts x 3
456
+ # dist_joints_to_base_pts_trans = torch.sum(
457
+ # (rhand_anchors.unsqueeze(2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nn_jts_pts x nn_base_pts
458
+ # )
459
+ # minn_dist_joints_to_base_pts, minn_dist_idxes = torch.min(dist_joints_to_base_pts_trans, dim=-1) # nf x nn_jts_pts # nf x nn_jts_pts #
460
+ # nearest_base_normals = model_util.batched_index_select_ours(tot_base_normals_trans, indices=minn_dist_idxes, dim=1) # nf x nn_base_pts x 3 --> nf x nn_jts_pts x 3 # # nf x nn_jts_pts x 3 #
461
+ # nearest_base_pts_trans = model_util.batched_index_select_ours(disp_base_pts_trans, indices=minn_dist_idxes, dim=1) # nf x nn_jts_ts x 3 #
462
+ # dot_nearest_base_normals_trans = torch.sum(
463
+ # nearest_base_normals * nearest_base_pts_trans, dim=-1 # nf x nn_jts
464
+ # )
465
+ # trans_normals_mask = dot_nearest_base_normals_trans < 0. # nf x nn_jts # nf x nn_jts #
466
+ # nearest_dist = torch.sqrt(minn_dist_joints_to_base_pts)
467
+
468
+ # # dist_thres
469
+ # nearest_dist_mask = nearest_dist < dist_thres # hoi seq
470
+ # # nearest_dist_mask = nearest_dist < 0.005 # hoi seq
471
+
472
+ # # nearest_dist_mask = nearest_dist < 0.03 # hoi seq
473
+ # # nearest_dist_mask = nearest_dist < 0.5 # hoi seq # seq 47
474
+ # # nearest_dist_mask = nearest_dist < 0.1 # hoi seq # seq 47
475
+
476
+
477
+
478
+ # # nearest_dist_mask = nearest_dist < 0.1
479
+ # k_attr = 100.
480
+ # joint_attraction_k = torch.exp(-1. * k_attr * nearest_dist)
481
+ # # attraction_mask_new_new = (attraction_mask_new.float() + trans_normals_mask.float().unsqueeze(-1) + nearest_dist_mask.float().unsqueeze(-1)) > 2.5
482
+
483
+ # attraction_mask_new_new = (attraction_mask_new.float() + nearest_dist_mask.float().unsqueeze(-1)) > 1.5
484
+
485
+ # # if cat_nm in ["ToyCar", "Pliers", "Bottle", "Mug", "Scissors"]:
486
+ # anchor_masks = [2, 3, 4, 9, 10, 11, 15, 16, 17, 22, 23, 24, 29, 30, 31]
487
+ # anchor_nmasks = [iid for iid in range(attraction_mask_new_new.size(1)) if iid not in anchor_masks]
488
+ # anchor_nmasks = torch.tensor(anchor_nmasks, dtype=torch.long)
489
+ # attraction_mask_new_new[:, anchor_nmasks, :] = False # scissors
490
+
491
+ # # # for seq 47
492
+ # # elif cat_nm in ["Scissors"]:
493
+ # # anchor_masks = [2, 3, 4, 15, 16, 17, 22, 23, 24]
494
+ # # anchor_nmasks = [iid for iid in range(attraction_mask_new_new.size(1)) if iid not in anchor_masks]
495
+ # # anchor_nmasks = torch.tensor(anchor_nmasks, dtype=torch.long)
496
+ # # attraction_mask_new_new[:, anchor_nmasks, :] = False
497
+
498
+ # # anchor_masks = torch.array([2, 3, 4, 15, 16, 17, 22, 23, 24], dtype=torch.long)
499
+ # # anchor_masks = torch.arange(attraction_mask_new_new.size(1)).unsqueeze(0).unsqueeze(-1) !=
500
+
501
+ # # [2, 3, 4]
502
+ # # [9, 10, 11]
503
+ # # [15, 16, 17]
504
+ # # [22, 23, 24]
505
+ # # seq 47: [2, 3, 4, 15, 16, 17, 22, 23, 24]
506
+ # # motion planning? #
507
+
508
+
509
+ # transl_var_ori = transl_var.clone().detach()
510
+ # # transl_var, theta_var, rot_var, beta_var #
511
+ # # opt = optim.Adam([rot_var, transl_var, theta_var], lr=learning_rate)
512
+ # # opt = optim.Adam([transl_var, theta_var], lr=learning_rate)
513
+ # opt = optim.Adam([transl_var, theta_var, rot_var], lr=learning_rate)
514
+ # # opt = optim.Adam([theta_var, rot_var], lr=learning_rate)
515
+ # scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5)
516
+ # for i in range(num_iters):
517
+ # opt.zero_grad()
518
+ # # mano_layer
519
+ # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
520
+ # beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
521
+ # hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
522
+ # hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
523
 
524
+ # joints_pred_loss = torch.sum(
525
+ # (hand_joints - joints) ** 2, dim=-1
526
+ # ).mean()
527
 
528
+ # #
529
+ # rhand_anchors = recover_anchor_batch(hand_verts, face_vertex_index, anchor_weight.unsqueeze(0).repeat(window_size, 1, 1))
530
 
531
 
532
+ # # dist_joints_to_base_pts_sqr = torch.sum(
533
+ # # (hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1
534
+ # # ) # nf x nnb x 3 ---- nf x nnj x 1 x 3
535
+ # dist_joints_to_base_pts_sqr = torch.sum(
536
+ # (rhand_anchors.unsqueeze(2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1
537
+ # )
538
+ # # attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr
539
+ # attaction_loss = 0.5 * dist_joints_to_base_pts_sqr
540
+ # # attaction_loss = attaction_loss
541
+ # # attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :])
542
 
543
+ # # attaction_loss = torch.mean(attaction_loss * attraction_mask)
544
+ # # attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
545
 
546
+ # # seq 80
547
+ # # attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
548
 
549
+ # # seq 70
550
+ # # attaction_loss = torch.mean(attaction_loss[10:, -5:-3, :] * minn_dist_mask[10:, -5:-3, :]) # + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
551
 
552
+ # # new version relying on new mask #
553
+ # # attaction_loss = torch.mean(attaction_loss[:, -5:-3, :] * attraction_mask_new[:, -5:-3, :])
554
+ # ### original version ###
555
+ # # attaction_loss = torch.mean(attaction_loss[20:, -3:, :] * attraction_mask_new[20:, -3:, :])
556
 
557
+ # # attaction_loss = torch.mean(attaction_loss[:, -5:, :] * attraction_mask_new_new[:, -5:, :] * joint_attraction_k[:, -5:].unsqueeze(-1))
558
 
559
 
560
 
561
+ # attaction_loss = torch.mean(attaction_loss[:, :, :] * attraction_mask_new_new[:, :, :] * joint_attraction_k[:, :].unsqueeze(-1))
562
 
563
 
564
+ # # seq mug
565
+ # # attaction_loss = torch.mean(attaction_loss[4:, -5:-4, :] * minn_dist_mask[4:, -5:-4, :]) # + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
566
 
567
 
568
+ # # opt.zero_grad()
569
+ # pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1])
570
+ # # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
571
+ # shape_prior_loss = torch.mean(beta_var**2)
572
+ # pose_prior_loss = torch.mean(theta_var**2)
573
+ # joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1])
574
+ # # =0.05
575
+ # # # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100.
576
+ # # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200.
577
 
578
+ # # loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200.
579
 
580
+ # # loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + attaction_loss * 10000 # + joints_smoothness_loss * 200.
581
 
582
+ # theta_smoothness_loss = F.mse_loss(theta_var, ori_theta_var)
583
 
584
+ # transl_smoothness_loss = F.mse_loss(transl_var_ori, transl_var)
585
+ # # loss = attaction_loss * 1000. + theta_smoothness_loss * 0.00001
586
 
587
+ # # attraction loss, joint prediction loss, joints smoothness loss #
588
+ # # loss = attaction_loss * 1000. + joints_pred_loss
589
+ # ### general ###
590
+ # # loss = attaction_loss * 1000. + joints_pred_loss * 0.01 + joints_smoothness_loss * 0.5 # + pose_prior_loss * 0.00005 # + shape_prior_loss * 0.001 # + pose_smoothness_loss * 0.5
591
 
592
+ # # tune for seq 140
593
+ # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.5 # + pose_prior_loss * 0.00005 # + shape_prior_loss * 0.001 # + pose_smoothness_loss * 0.5
594
+ # # ToyCar #
595
+ # loss = attaction_loss * 10000. + joints_pred_loss * 0.01 + joints_smoothness_loss * 0.5 # + pose_prior_loss * 0.00005 # + shape_prior_loss * 0.001 # + pose_smoothness_loss * 0.5
596
 
597
+ # # if cat_nm in ["Scissors"]:
598
+ # # # scissors and
599
+ # # # if dist_thres < 0.05:
600
+ # # # loss = transl_smoothness_loss * 0.5 + attaction_loss * 10000. + joints_pred_loss * 0.01 + joints_smoothness_loss * 0.5
601
+ # # # else:
602
+ # # # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.05
603
+ # # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.005
604
+ # # # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.005
605
+
606
+ # # if dist_thres < 0.05:
607
+ # # # loss = transl_smoothness_loss * 0.5 + attaction_loss * 10000. + joints_pred_loss * 0.01 + joints_smoothness_loss * 0.5
608
+ # # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.5
609
+
610
+ # # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + pose_smoothness_loss * 0.0005 + joints_smoothness_loss * 0.5
611
+
612
+ # # nearest_dist_shape, _ = torch.min(nearest_dist, dim=-1)
613
+ # # nearest_dist_shape_mask = nearest_dist_shape > 0.01
614
+ # # transl_smoothness_loss_v2 = torch.sum((transl_var_ori - transl_var) ** 2, dim=-1)
615
+ # # transl_smoothness_loss_v2 = torch.mean(transl_smoothness_loss_v2[nearest_dist_shape_mask])
616
+ # # loss = transl_smoothness_loss * 0.5 + attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.5
617
+
618
+ # # else:
619
+ # # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.05
620
+ # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.005
621
 
622
+ # # # loss = attaction_loss * 10000. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.005
623
 
624
+ # # loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001
625
+ # # loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
626
 
627
+ # # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0
628
 
629
+ # # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0
630
 
631
+ # # loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0
632
+ # # loss = joints_pred_loss * 30 + attaction_loss * 0.001
633
 
634
+ # opt.zero_grad()
635
+ # loss.backward()
636
+ # opt.step()
637
+ # scheduler.step()
638
 
639
+ # print('Iter {}: {}'.format(i, loss.item()), flush=True)
640
+ # print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
641
+ # print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
642
+ # print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
643
+ # print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
644
+ # print('\tAttraction Loss: {}'.format(attaction_loss.item()))
645
+ # print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item()))
646
+ # # theta_smoothness_loss
647
+ # print('\tTheta Smoothness Loss: {}'.format(theta_smoothness_loss.item()))
648
+ # # transl_smoothness_loss
649
+ # print('\tTransl Smoothness Loss: {}'.format(transl_smoothness_loss.item()))
650
 
651
+ # rhand_anchors_np = rhand_anchors.detach().cpu().numpy()
652
+ # np.save("out_anchors.npy", rhand_anchors_np)
653
  ### optimized dict before projection ###
654
  bf_proj_optimized_dict = {
655
  'bf_ctx_mask_verts': hand_verts.detach().cpu().numpy(),
 
662
 
663
 
664
 
665
+ # ###
666
+ # if with_ctx_mask:
667
+ # tot_penetration_masks_bf_contact_opt_frame_nmask_np = tot_penetration_masks_bf_contact_opt_frame_nmask.detach().cpu().numpy()
668
+ # hand_verts = hand_verts.detach()
669
+ # hand_joints = hand_joints.detach()
670
+ # rot_var = rot_var.detach()
671
+ # theta_var = theta_var.detach()
672
+ # beta_var = beta_var.detach()
673
+ # transl_var = transl_var.detach()
674
+ # hand_verts = hand_verts.detach()
675
+
676
+ # # tot_base_pts_trans_disp_mask ### total penetration masks
677
+ # tot_base_pts_trans_disp_mask_n = (1. - tot_base_pts_trans_disp_mask.float()) > 0.5 ### that object do not move
678
+ # tot_penetration_masks_bf_contact_opt_frame_nmask = (tot_penetration_masks_bf_contact_opt_frame_nmask.float() + tot_base_pts_trans_disp_mask_n.float()) > 1.5
679
+ # hand_verts[tot_penetration_masks_bf_contact_opt_frame_nmask] = bf_ct_verts_th[tot_penetration_masks_bf_contact_opt_frame_nmask]
680
+ # hand_joints[tot_penetration_masks_bf_contact_opt_frame_nmask] = bf_ct_joints_th[tot_penetration_masks_bf_contact_opt_frame_nmask]
681
+ # rot_var[tot_penetration_masks_bf_contact_opt_frame_nmask] = bf_ct_rot_var_th[tot_penetration_masks_bf_contact_opt_frame_nmask]
682
+ # theta_var[tot_penetration_masks_bf_contact_opt_frame_nmask] = bf_ct_theta_var_th[tot_penetration_masks_bf_contact_opt_frame_nmask]
683
+ # # beta_var[tot_penetration_masks_bf_contact_opt_frame_nmask] = bf_ct_beta_var_th[tot_penetration_masks_bf_contact_opt_frame_nmask]
684
+ # transl_var[tot_penetration_masks_bf_contact_opt_frame_nmask] = bf_ct_transl_var_th[tot_penetration_masks_bf_contact_opt_frame_nmask]
685
+
686
+ # rot_var_tmp = torch.randn_like(rot_var)
687
+ # theta_var_tmp = torch.randn_like(theta_var)
688
+ # transl_var_tmp = torch.randn_like(transl_var)
689
+
690
+ # rot_var_tmp.data = rot_var.data.clone()
691
+ # theta_var_tmp.data = theta_var.data.clone()
692
+ # transl_var_tmp.data = transl_var.data.clone()
693
+
694
+ # rot_var = torch.randn_like(rot_var)
695
+ # theta_var = torch.randn_like(theta_var)
696
+ # transl_var = torch.randn_like(transl_var)
697
+
698
+ # rot_var.data = rot_var_tmp.data.clone()
699
+ # theta_var.data = theta_var_tmp.data.clone()
700
+ # transl_var.data = transl_var_tmp.data.clone()
701
+
702
+
703
+ # rot_var = rot_var.requires_grad_()
704
+ # theta_var = theta_var.requires_grad_()
705
+ # beta_var = beta_var.requires_grad_()
706
+ # transl_var = transl_var.requires_grad_()
707
 
708
 
709
 
 
725
  'bf_proj_transl_var': bf_proj_transl_var,
726
  } )
727
 
728
+ # if with_proj:
729
+ # num_iters = 2000
730
+ # num_iters = 1000 # seq 77 # if with contact opt #
731
+ # # num_iters = 500 # seq 77
732
+ # ori_theta_var = theta_var.detach().clone()
733
+
734
+ # nearest_base_pts=None
735
+ # nearest_base_normals=None
736
+
737
+ # # obj_verts_trans, obj_faces
738
+ # if cat_nm in ["Mug"]:
739
+ # # tot_penetration_masks = None
740
+ # tot_penetration_masks = get_penetration_masks(obj_verts_trans, obj_faces, hand_verts)
741
+ # else:
742
+ # tot_penetration_masks = get_penetration_masks(obj_verts_trans, obj_faces, hand_verts)
743
+ # # tot_penetration_masks = None
744
+
745
+ # # opt = optim.Adam([rot_var, transl_var, theta_var], lr=learning_rate)
746
+ # # opt = optim.Adam([transl_var, theta_var], lr=learning_rate)
747
+ # opt = optim.Adam([transl_var, theta_var, rot_var], lr=learning_rate)
748
+ # scheduler = optim.lr_scheduler.StepLR(opt, step_size=num_iters, gamma=0.5)
749
+ # for i in range(num_iters):
750
+ # opt.zero_grad()
751
+ # # mano_layer
752
+ # hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
753
+ # beta_var.unsqueeze(1).repeat(1, nn_frames, 1).view(-1, 10), transl_var)
754
+ # hand_verts = hand_verts.view( nn_frames, 778, 3) * 0.001
755
+ # hand_joints = hand_joints.view(nn_frames, -1, 3) * 0.001
756
 
757
+ # # tot_base_pts_trans, tot_base_normals_trans #
758
+ # # obj_verts_trans, ### nearest base pts ####
759
+ # proj_loss, nearest_base_pts, nearest_base_normals = get_proj_losses(hand_verts, tot_base_pts_trans, tot_base_normals_trans, tot_penetration_masks, nearest_base_pts=nearest_base_pts, nearest_base_normals=nearest_base_normals)
760
 
761
+ # # rhand_anchors = recover_anchor_batch(hand_verts, face_vertex_index, anchor_weight.unsqueeze(0).repeat(window_size, 1, 1))
762
 
763
 
764
+ # # hand_joints = rhand_anchors
765
 
766
+ # joints_pred_loss = torch.sum(
767
+ # (hand_joints - joints) ** 2, dim=-1
768
+ # ).mean()
769
 
770
 
771
 
772
 
773
+ # # # dist_joints_to_base_pts_sqr = torch.sum(
774
+ # # # (hand_joints.unsqueeze(2) - base_pts.unsqueeze(0).unsqueeze(1)) ** 2, dim=-1
775
+ # # # ) # nf x nnb x 3 ---- nf x nnj x 1 x 3
776
+ # # dist_joints_to_base_pts_sqr = torch.sum(
777
+ # # (rhand_anchors.unsqueeze(2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1
778
+ # # )
779
+ # # # attaction_loss = 0.5 * affinity_scores * dist_joints_to_base_pts_sqr
780
+ # # attaction_loss = 0.5 * dist_joints_to_base_pts_sqr
781
+ # # # attaction_loss = attaction_loss
782
+ # # # attaction_loss = torch.mean(attaction_loss[..., -5:, :] * minn_dist_mask[..., -5:, :])
783
 
784
+ # # # attaction_loss = torch.mean(attaction_loss * attraction_mask)
785
+ # # # attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
786
 
787
+ # # # seq 80
788
+ # # # attaction_loss = torch.mean(attaction_loss[46:, -5:-3, :] * minn_dist_mask[46:, -5:-3, :]) + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
789
 
790
+ # # # seq 70
791
+ # # # attaction_loss = torch.mean(attaction_loss[10:, -5:-3, :] * minn_dist_mask[10:, -5:-3, :]) # + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
792
 
793
+ # # # new version relying on new mask #
794
+ # # # attaction_loss = torch.mean(attaction_loss[:, -5:-3, :] * attraction_mask_new[:, -5:-3, :])
795
+ # # ### original version ###
796
+ # # # attaction_loss = torch.mean(attaction_loss[20:, -3:, :] * attraction_mask_new[20:, -3:, :])
797
 
798
+ # # # attaction_loss = torch.mean(attaction_loss[:, -5:, :] * attraction_mask_new_new[:, -5:, :] * joint_attraction_k[:, -5:].unsqueeze(-1))
799
 
800
 
801
 
802
+ # # attaction_loss = torch.mean(attaction_loss[:, :, :] * attraction_mask_new_new[:, :, :] * joint_attraction_k[:, :].unsqueeze(-1))
803
 
804
 
805
+ # # seq mug
806
+ # # attaction_loss = torch.mean(attaction_loss[4:, -5:-4, :] * minn_dist_mask[4:, -5:-4, :]) # + torch.mean(attaction_loss[:40, -5:-3, :] * minn_dist_mask[:40, -5:-3, :])
807
 
808
 
809
+ # # opt.zero_grad()
810
+ # pose_smoothness_loss = F.mse_loss(theta_var.view(nn_frames, -1)[1:], theta_var.view(nn_frames, -1)[:-1])
811
+ # # joints_smoothness_loss = joint_acc_loss(hand_verts, J_regressor.to(device))
812
+ # shape_prior_loss = torch.mean(beta_var**2)
813
+ # pose_prior_loss = torch.mean(theta_var**2)
814
+ # joints_smoothness_loss = F.mse_loss(hand_joints.view(nn_frames, -1, 3)[1:], hand_joints.view(nn_frames, -1, 3)[:-1])
815
+ # # =0.05
816
+ # # # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001 + joints_smoothness_loss * 100.
817
+ # # loss = joints_pred_loss * 30 + pose_smoothness_loss * 0.000001 + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001 + joints_smoothness_loss * 200.
818
 
819
+ # # loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.05 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + joints_smoothness_loss * 200.
820
 
821
+ # # loss = joints_pred_loss * 5000 + pose_smoothness_loss * 0.03 + shape_prior_loss * 0.0002 + pose_prior_loss * 0.0005 # + attaction_loss * 10000 # + joints_smoothness_loss * 200.
822
 
823
+ # theta_smoothness_loss = F.mse_loss(theta_var, ori_theta_var)
824
+ # # loss = attaction_loss * 1000. + theta_smoothness_loss * 0.00001
825
 
826
+ # # attraction loss, joint prediction loss, joints smoothness loss #
827
+ # # loss = attaction_loss * 1000. + joints_pred_loss
828
+ # ### general ###
829
+ # # loss = attaction_loss * 1000. + joints_pred_loss * 0.01 + joints_smoothness_loss * 0.5 # + pose_prior_loss * 0.00005 # + shape_prior_loss * 0.001 # + pose_smoothness_loss * 0.5
830
 
831
+ # # tune for seq 140
832
+ # # loss = proj_loss * 1. + joints_pred_loss * 0.05 + joints_smoothness_loss * 0.5 # + pose_prior_loss * 0.00005 # + shape_prior_loss * 0.001 # + pose_smoothness_loss * 0.5
833
+ # loss = proj_loss * 1. + joints_pred_loss * 1.0 + joints_smoothness_loss * 0.5
834
+ # if cat_nm in ["Pliers"]:
835
+ # # loss = proj_loss * 1. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.05
836
+ # loss = proj_loss * 1. + joints_pred_loss * 0.0001 + joints_smoothness_loss * 0.05
837
+
838
+ # elif cat_nm in ["Bottle"]:
839
+ # loss = proj_loss * 1. + joints_pred_loss * 0.01 + joints_smoothness_loss * 0.05
840
+ # # loss = joints_pred_loss * 20 + joints_smoothness_loss * 200. + shape_prior_loss * 0.0001 + pose_prior_loss * 0.00001
841
+ # # loss = joints_pred_loss * 30 # + shape_prior_loss * 0.001 + pose_prior_loss * 0.0001
842
 
843
+ # # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 0.001 + joints_smoothness_loss * 1.0
844
 
845
+ # # loss = joints_pred_loss * 20 + pose_smoothness_loss * 0.5 + attaction_loss * 2000000. # + joints_smoothness_loss * 10.0
846
 
847
+ # # loss = joints_pred_loss * 20 # + pose_smoothness_loss * 0.5 + attaction_loss * 100. + joints_smoothness_loss * 10.0
848
+ # # loss = joints_pred_loss * 30 + attaction_loss * 0.001
849
 
850
+ # opt.zero_grad()
851
+ # loss.backward()
852
+ # opt.step()
853
+ # scheduler.step()
854
 
855
+ # print('Iter {}: {}'.format(i, loss.item()), flush=True)
856
+ # print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
857
+ # print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
858
+ # print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
859
+ # print('\tJoints Prediction Loss: {}'.format(joints_pred_loss.item()))
860
+ # print('\tproj_loss Loss: {}'.format(proj_loss.item()))
861
+ # print('\tJoint Smoothness Loss: {}'.format(joints_smoothness_loss.item()))
862
+ # # theta_smoothness_loss
863
+ # print('\tTheta Smoothness Loss: {}'.format(theta_smoothness_loss.item()))
864
 
865
  ### ### verts and joints before contact opt ### ###
866
  # bf_ct_verts, bf_ct_joints #
 
1094
 
1095
 
1096
  ## optimization ##
1097
+ # if __name__=='__main__':
1098
+ def reconstruct_from_file(single_seq_path):
1099
+ # args = train_args()
1100
 
 
1101
 
1102
+ # use_left = True
1103
+ test_tag = "20231104_017_jts_spatial_t_100_"
1104
+ save_dir = "/tmp/denoising/save"
1105
+ # single_seq_path =
1106
+
1107
+ single_seq_path = single_seq_path
1108
 
1109
  print(f"Reconstructing meshes for sequence: {single_seq_path}")
1110
 
 
1120
  st_idxes = list(range(0, num_ending_clearning_frames, nn_st_skip))
1121
  if st_idxes[-1] + num_cleaning_frames < nn_frames:
1122
  st_idxes.append(nn_frames - num_cleaning_frames)
1123
+
1124
+ st_idxes = [st_idxes[0]]
1125
  print(f"st_idxes: {st_idxes}")
1126
 
1127
+
1128
+
1129
+
1130
  clip_ending_idxes = [st_idxes[cur_ts + 1] - st_idxes[cur_ts] for cur_ts in range(len(st_idxes) - 1)]
1131
 
1132
 
1133
+ test_tags = [f"{test_tag}_st_{cur_st}" for cur_st in st_idxes]
1134
 
1135
 
1136
+ pred_infos_sv_folder = save_dir
1137
+ new_pred_infos_sv_folder = save_dir
1138
 
1139
 
1140
 
1141
+ # tot_rnd_seeds = range(0, 121, 11)
1142
+
1143
+ tot_rnd_seeds = [0]
1144
 
1145
  # st_idxx = 0
1146
 
 
1242
 
1243
  #### obj_verts_trans, obj_faces ####
1244
  obj_verts_trans = np.matmul(obj_verts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
1245
+ obj_faces = None
1246
+ # print(f"obj_verts_trans: {obj_verts_trans.shape}, obj_faces: {obj_faces.shape}")
1247
 
1248
 
1249
 
 
1282
  )
1283
 
1284
 
1285
+ optimized_sv_infos.update({'predicted_info': data})
1286
+
1287
  optimized_sv_infos_sv_fn_nm = f"optimized_infos_sv_dict_seed_{seed}_tag_{test_tags[0]}_ntag_{len(test_tags)}.npy"
1288
 
1289
  optimized_sv_infos_sv_fn = os.path.join(new_pred_infos_sv_folder, optimized_sv_infos_sv_fn_nm)
1290
  np.save(optimized_sv_infos_sv_fn, optimized_sv_infos)
1291
  print(f"optimized infos saved to {optimized_sv_infos_sv_fn}")
1292
 
1293
+ return optimized_sv_infos_sv_fn
scripts/val_gradio/predict_taco_rndseed_spatial.sh CHANGED
@@ -94,7 +94,7 @@ export diff_hand_params=""
94
  export diff_basejtsrel=""
95
  export diff_realbasejtsrel="--diff_realbasejtsrel"
96
  # export model_path=./ckpts/model000519000.pt
97
- export model_path=/home/xueyi/sim/Generalizable-HOI-Denoising/ckpt/model001039000.pt
98
 
99
 
100
 
@@ -120,7 +120,7 @@ export use_reverse=""
120
  export seed=0
121
 
122
 
123
- export save_dir=./data/taco/result
124
 
125
 
126
  export cuda_ids=2
 
94
  export diff_basejtsrel=""
95
  export diff_realbasejtsrel="--diff_realbasejtsrel"
96
  # export model_path=./ckpts/model000519000.pt
97
+ export model_path=./model001039000.pt
98
 
99
 
100
 
 
120
  export seed=0
121
 
122
 
123
+ export save_dir=/tmp/denoising
124
 
125
 
126
  export cuda_ids=2
scripts/val_gradio/predict_taco_rndseed_spatial_temp.sh CHANGED
@@ -94,7 +94,7 @@ export diff_hand_params=""
94
  export diff_basejtsrel=""
95
  export diff_realbasejtsrel="--diff_realbasejtsrel"
96
  # export model_path=./ckpts/model000519000.pt
97
- export model_path=/home/xueyi/sim/Generalizable-HOI-Denoising/ckpt/model001039000.pt
98
 
99
 
100
 
@@ -120,7 +120,7 @@ export use_reverse=""
120
  export seed=0
121
 
122
 
123
- export save_dir=./data/taco/result
124
 
125
 
126
  export cuda_ids=2
 
94
  export diff_basejtsrel=""
95
  export diff_realbasejtsrel="--diff_realbasejtsrel"
96
  # export model_path=./ckpts/model000519000.pt
97
+ export model_path=./model001039000.pt
98
 
99
 
100
 
 
120
  export seed=0
121
 
122
 
123
+ export save_dir=/tmp/denoising
124
 
125
 
126
  export cuda_ids=2
test_predict_from_file.py CHANGED
@@ -12,6 +12,8 @@ import shutil
12
  # from gradio_inter.predict_from_file import predict_from_file
13
  from gradio_inter.create_bash_file import create_bash_file
14
 
 
 
15
  def create_temp_file(path: str) -> str:
16
  temp_dir = tempfile.gettempdir()
17
  temp_folder = os.path.join(temp_dir, "denoising")
@@ -36,11 +38,15 @@ def predict(file_path: str):
36
  temp_file_path = create_temp_file(file_path)
37
  # predict_from_file
38
  print(f"temp_path: {temp_file_path}")
 
39
  temp_bash_file = create_bash_file(temp_file_path)
40
  print(f"temp_bash_file: {temp_bash_file}")
41
- os.system(f"bash {temp_bash_file}")
42
 
43
-
 
 
 
 
44
  # demo = gr.Interface(
45
  # predict,
46
  # # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3),
 
12
  # from gradio_inter.predict_from_file import predict_from_file
13
  from gradio_inter.create_bash_file import create_bash_file
14
 
15
+ from sample.reconstruct_data_taco import reconstruct_from_file
16
+
17
  def create_temp_file(path: str) -> str:
18
  temp_dir = tempfile.gettempdir()
19
  temp_folder = os.path.join(temp_dir, "denoising")
 
38
  temp_file_path = create_temp_file(file_path)
39
  # predict_from_file
40
  print(f"temp_path: {temp_file_path}")
41
+
42
  temp_bash_file = create_bash_file(temp_file_path)
43
  print(f"temp_bash_file: {temp_bash_file}")
 
44
 
45
+ # os.system(f"bash {temp_bash_file}")
46
+
47
+ saved_path = reconstruct_from_file(temp_file_path)
48
+ print(saved_path)
49
+
50
  # demo = gr.Interface(
51
  # predict,
52
  # # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3),
train/__pycache__/training_loop_ours.cpython-38.pyc CHANGED
Binary files a/train/__pycache__/training_loop_ours.cpython-38.pyc and b/train/__pycache__/training_loop_ours.cpython-38.pyc differ
 
utils/__pycache__/dist_util.cpython-38.pyc CHANGED
Binary files a/utils/__pycache__/dist_util.cpython-38.pyc and b/utils/__pycache__/dist_util.cpython-38.pyc differ