Spaces:
Runtime error
Runtime error
Update sample/reconstruct_data_taco.py
Browse files- sample/reconstruct_data_taco.py +81 -76
sample/reconstruct_data_taco.py
CHANGED
@@ -51,15 +51,15 @@ def get_penetration_masks(obj_verts, obj_faces, hand_verts):
|
|
51 |
def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=False, nn_hand_params=24, rt_vars=False, with_proj=False, obj_verts_trans=None, obj_faces=None, with_params_smoothing=False, dist_thres=0.005, with_ctx_mask=False):
|
52 |
# obj_verts_trans, obj_faces
|
53 |
joints = torch.from_numpy(joints).float() # # joints
|
54 |
-
base_pts = torch.from_numpy(base_pts).float() # # base_pts
|
55 |
|
56 |
if nn_hand_params < 45:
|
57 |
use_pca = True
|
58 |
else:
|
59 |
use_pca = False
|
60 |
|
61 |
-
tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float()
|
62 |
-
tot_base_normals_trans = torch.from_numpy(tot_base_normals_trans).float()
|
63 |
### start optimization ###
|
64 |
# setup MANO layer
|
65 |
# mano_path = "/data1/xueyi/mano_models/mano/models"
|
@@ -139,48 +139,48 @@ def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans
|
|
139 |
# )
|
140 |
|
141 |
#
|
142 |
-
dist_joints_to_base_pts = torch.sum(
|
143 |
-
|
144 |
-
)
|
145 |
|
146 |
-
nn_base_pts = dist_joints_to_base_pts.size(-1)
|
147 |
-
nn_joints = dist_joints_to_base_pts.size(1)
|
148 |
|
149 |
-
dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts #
|
150 |
-
minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints #
|
151 |
|
152 |
-
nk_contact_pts = 2
|
153 |
-
minn_dist[:, :-5] = 1e9
|
154 |
-
minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) #
|
155 |
-
# joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0) ==
|
156 |
-
minn_topk_mask = torch.zeros_like(minn_dist)
|
157 |
-
# minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints #
|
158 |
-
minn_topk_mask[:, -5: -3] = 1.
|
159 |
-
basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0)
|
160 |
-
minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
|
161 |
-
# for seq 101
|
162 |
-
# minn_dist_mask[31:, -5, :] = minn_dist_mask[30: 31, -5, :]
|
163 |
-
minn_dist_mask = minn_dist_mask.float()
|
164 |
|
165 |
-
## tot base pts
|
166 |
-
tot_base_pts_trans_disp = torch.sum(
|
167 |
-
|
168 |
-
)
|
169 |
-
### tot base pts trans disp ###
|
170 |
-
tot_base_pts_trans_disp = torch.sqrt(tot_base_pts_trans_disp).mean(dim=-1) # (nf - 1)
|
171 |
-
# tot_base_pts_trans_disp_mov_thres = 1e-20
|
172 |
-
tot_base_pts_trans_disp_mov_thres = 3e-4
|
173 |
-
tot_base_pts_trans_disp_mask = tot_base_pts_trans_disp >= tot_base_pts_trans_disp_mov_thres
|
174 |
-
tot_base_pts_trans_disp_mask = torch.cat(
|
175 |
-
|
176 |
-
)
|
177 |
|
178 |
-
attraction_mask_new = (tot_base_pts_trans_disp_mask.float().unsqueeze(-1).unsqueeze(-1) + minn_dist_mask.float()) > 1.5
|
179 |
|
180 |
|
181 |
|
182 |
-
minn_topk_mask = (minn_dist_mask + minn_topk_mask.float().unsqueeze(-1)) > 1.5
|
183 |
-
print(f"minn_dist_mask: {minn_dist_mask.size()}")
|
184 |
s = 1.0
|
185 |
# affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s)
|
186 |
|
@@ -231,21 +231,21 @@ def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans
|
|
231 |
print('\tRotation Smoothness Loss: {}'.format(joints_pred_loss.item()))
|
232 |
|
233 |
#
|
234 |
-
print(tot_base_pts_trans.size())
|
235 |
-
diff_base_pts_trans = torch.sum((tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1) # (nf - 1) x nn_base_pts
|
236 |
-
print(f"diff_base_pts_trans: {diff_base_pts_trans.size()}")
|
237 |
-
diff_base_pts_trans = diff_base_pts_trans.mean(dim=-1)
|
238 |
-
diff_base_pts_trans_threshold = 1e-20
|
239 |
-
diff_base_pts_trans_mask = diff_base_pts_trans > diff_base_pts_trans_threshold # (nf - 1) ### the mask of the tranformed base pts
|
240 |
-
diff_base_pts_trans_mask = diff_base_pts_trans_mask.float()
|
241 |
-
print(f"diff_base_pts_trans_mask: {diff_base_pts_trans_mask.size()}, diff_base_pts_trans: {diff_base_pts_trans.size()}")
|
242 |
-
diff_last_frame_mask = torch.tensor([0,], dtype=torch.float32).to(diff_base_pts_trans_mask.device) + diff_base_pts_trans_mask[-1]
|
243 |
-
diff_base_pts_trans_mask = torch.cat(
|
244 |
-
|
245 |
-
)
|
246 |
# attraction_mask = (diff_base_pts_trans_mask.unsqueeze(-1).unsqueeze(-1) + minn_topk_mask.float()) > 1.5
|
247 |
-
attraction_mask = minn_topk_mask.float()
|
248 |
-
attraction_mask = attraction_mask.float()
|
249 |
|
250 |
# the direction of the normal vector and the moving direction of the object point -> whether the point should be selected
|
251 |
# the contact maps of the object should be like? #
|
@@ -1184,7 +1184,12 @@ def reconstruct_from_file(single_seq_path):
|
|
1184 |
tot_data[cur_k].append(cur_data[cur_k][ :clip_ending_idxes[i_tag]])
|
1185 |
|
1186 |
for cur_k in tot_data:
|
1187 |
-
|
|
|
|
|
|
|
|
|
|
|
1188 |
tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=1)
|
1189 |
else:
|
1190 |
tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=0)
|
@@ -1202,8 +1207,8 @@ def reconstruct_from_file(single_seq_path):
|
|
1202 |
|
1203 |
targets = data['targets'] # # targets # #
|
1204 |
outputs = data['outputs'] #
|
1205 |
-
tot_base_pts = data["tot_base_pts"][0] # total base pts, total base normals #
|
1206 |
-
tot_base_normals = data['tot_base_normals'][0] # nn_base_normals #
|
1207 |
|
1208 |
|
1209 |
|
@@ -1213,22 +1218,22 @@ def reconstruct_from_file(single_seq_path):
|
|
1213 |
tot_obj_transl = data['tot_obj_transl'][0]
|
1214 |
print(f"tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
|
1215 |
|
1216 |
-
if len(tot_base_pts.shape) == 2:
|
1217 |
-
|
1218 |
-
|
1219 |
-
|
1220 |
|
1221 |
-
|
1222 |
-
|
1223 |
-
|
1224 |
-
else:
|
1225 |
-
|
1226 |
-
|
1227 |
-
|
1228 |
|
1229 |
-
|
1230 |
-
|
1231 |
-
|
1232 |
|
1233 |
|
1234 |
|
@@ -1237,7 +1242,7 @@ def reconstruct_from_file(single_seq_path):
|
|
1237 |
|
1238 |
targets = np.matmul(targets, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 #
|
1239 |
# denoise relative positions
|
1240 |
-
print(f"tot_base_pts: {tot_base_pts.shape}")
|
1241 |
|
1242 |
|
1243 |
#### obj_verts_trans, obj_faces ####
|
@@ -1264,7 +1269,7 @@ def reconstruct_from_file(single_seq_path):
|
|
1264 |
with_contact_opt = True
|
1265 |
with_ctx_mask = False
|
1266 |
|
1267 |
-
bf_ct_optimized_dict, bf_proj_optimized_dict, optimized_dict = get_optimized_hand_fr_joints_v4_anchors(outputs,
|
1268 |
|
1269 |
|
1270 |
|
@@ -1274,12 +1279,12 @@ def reconstruct_from_file(single_seq_path):
|
|
1274 |
optimized_sv_infos.update(bf_ct_optimized_dict)
|
1275 |
optimized_sv_infos.update(bf_proj_optimized_dict)
|
1276 |
optimized_sv_infos.update(optimized_dict)
|
1277 |
-
optimized_sv_infos.update(
|
1278 |
-
|
1279 |
-
|
1280 |
-
|
1281 |
-
|
1282 |
-
)
|
1283 |
|
1284 |
|
1285 |
optimized_sv_infos.update({'predicted_info': data})
|
|
|
51 |
def get_optimized_hand_fr_joints_v4_anchors(joints, base_pts, tot_base_pts_trans, tot_base_normals_trans, with_contact_opt=False, nn_hand_params=24, rt_vars=False, with_proj=False, obj_verts_trans=None, obj_faces=None, with_params_smoothing=False, dist_thres=0.005, with_ctx_mask=False):
|
52 |
# obj_verts_trans, obj_faces
|
53 |
joints = torch.from_numpy(joints).float() # # joints
|
54 |
+
# base_pts = torch.from_numpy(base_pts).float() # # base_pts
|
55 |
|
56 |
if nn_hand_params < 45:
|
57 |
use_pca = True
|
58 |
else:
|
59 |
use_pca = False
|
60 |
|
61 |
+
# tot_base_pts_trans = torch.from_numpy(tot_base_pts_trans).float()
|
62 |
+
# tot_base_normals_trans = torch.from_numpy(tot_base_normals_trans).float()
|
63 |
### start optimization ###
|
64 |
# setup MANO layer
|
65 |
# mano_path = "/data1/xueyi/mano_models/mano/models"
|
|
|
139 |
# )
|
140 |
|
141 |
#
|
142 |
+
# dist_joints_to_base_pts = torch.sum(
|
143 |
+
# (joints.unsqueeze(-2) - tot_base_pts_trans.unsqueeze(1)) ** 2, dim=-1 # nf x nnjoints x nnbasepts #
|
144 |
+
# )
|
145 |
|
146 |
+
# nn_base_pts = dist_joints_to_base_pts.size(-1)
|
147 |
+
# nn_joints = dist_joints_to_base_pts.size(1)
|
148 |
|
149 |
+
# dist_joints_to_base_pts = torch.sqrt(dist_joints_to_base_pts) # nf x nnjoints x nnbasepts #
|
150 |
+
# minn_dist, minn_dist_idx = torch.min(dist_joints_to_base_pts, dim=-1) # nf x nnjoints #
|
151 |
|
152 |
+
# nk_contact_pts = 2
|
153 |
+
# minn_dist[:, :-5] = 1e9
|
154 |
+
# minn_topk_dist, minn_topk_idx = torch.topk(minn_dist, k=nk_contact_pts, largest=False) #
|
155 |
+
# # joints_idx_rng_exp = torch.arange(nn_joints).unsqueeze(0) ==
|
156 |
+
# minn_topk_mask = torch.zeros_like(minn_dist)
|
157 |
+
# # minn_topk_mask[minn_topk_idx] = 1. # nf x nnjoints #
|
158 |
+
# minn_topk_mask[:, -5: -3] = 1.
|
159 |
+
# basepts_idx_range = torch.arange(nn_base_pts).unsqueeze(0).unsqueeze(0)
|
160 |
+
# minn_dist_mask = basepts_idx_range == minn_dist_idx.unsqueeze(-1) # nf x nnjoints x nnbasepts
|
161 |
+
# # for seq 101
|
162 |
+
# # minn_dist_mask[31:, -5, :] = minn_dist_mask[30: 31, -5, :]
|
163 |
+
# minn_dist_mask = minn_dist_mask.float()
|
164 |
|
165 |
+
# ## tot base pts
|
166 |
+
# tot_base_pts_trans_disp = torch.sum(
|
167 |
+
# (tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1 # (nf - 1) x nn_base_pts displacement
|
168 |
+
# )
|
169 |
+
# ### tot base pts trans disp ###
|
170 |
+
# tot_base_pts_trans_disp = torch.sqrt(tot_base_pts_trans_disp).mean(dim=-1) # (nf - 1)
|
171 |
+
# # tot_base_pts_trans_disp_mov_thres = 1e-20
|
172 |
+
# tot_base_pts_trans_disp_mov_thres = 3e-4
|
173 |
+
# tot_base_pts_trans_disp_mask = tot_base_pts_trans_disp >= tot_base_pts_trans_disp_mov_thres
|
174 |
+
# tot_base_pts_trans_disp_mask = torch.cat(
|
175 |
+
# [tot_base_pts_trans_disp_mask, tot_base_pts_trans_disp_mask[-1:]], dim=0
|
176 |
+
# )
|
177 |
|
178 |
+
# attraction_mask_new = (tot_base_pts_trans_disp_mask.float().unsqueeze(-1).unsqueeze(-1) + minn_dist_mask.float()) > 1.5
|
179 |
|
180 |
|
181 |
|
182 |
+
# minn_topk_mask = (minn_dist_mask + minn_topk_mask.float().unsqueeze(-1)) > 1.5
|
183 |
+
# print(f"minn_dist_mask: {minn_dist_mask.size()}")
|
184 |
s = 1.0
|
185 |
# affinity_scores = get_affinity_fr_dist(dist_joints_to_base_pts, s=s)
|
186 |
|
|
|
231 |
print('\tRotation Smoothness Loss: {}'.format(joints_pred_loss.item()))
|
232 |
|
233 |
#
|
234 |
+
# print(tot_base_pts_trans.size())
|
235 |
+
# diff_base_pts_trans = torch.sum((tot_base_pts_trans[1:, :, :] - tot_base_pts_trans[:-1, :, :]) ** 2, dim=-1) # (nf - 1) x nn_base_pts
|
236 |
+
# print(f"diff_base_pts_trans: {diff_base_pts_trans.size()}")
|
237 |
+
# diff_base_pts_trans = diff_base_pts_trans.mean(dim=-1)
|
238 |
+
# diff_base_pts_trans_threshold = 1e-20
|
239 |
+
# diff_base_pts_trans_mask = diff_base_pts_trans > diff_base_pts_trans_threshold # (nf - 1) ### the mask of the tranformed base pts
|
240 |
+
# diff_base_pts_trans_mask = diff_base_pts_trans_mask.float()
|
241 |
+
# print(f"diff_base_pts_trans_mask: {diff_base_pts_trans_mask.size()}, diff_base_pts_trans: {diff_base_pts_trans.size()}")
|
242 |
+
# diff_last_frame_mask = torch.tensor([0,], dtype=torch.float32).to(diff_base_pts_trans_mask.device) + diff_base_pts_trans_mask[-1]
|
243 |
+
# diff_base_pts_trans_mask = torch.cat(
|
244 |
+
# [diff_base_pts_trans_mask, diff_last_frame_mask], dim=0 # nf tensor
|
245 |
+
# )
|
246 |
# attraction_mask = (diff_base_pts_trans_mask.unsqueeze(-1).unsqueeze(-1) + minn_topk_mask.float()) > 1.5
|
247 |
+
# attraction_mask = minn_topk_mask.float()
|
248 |
+
# attraction_mask = attraction_mask.float()
|
249 |
|
250 |
# the direction of the normal vector and the moving direction of the object point -> whether the point should be selected
|
251 |
# the contact maps of the object should be like? #
|
|
|
1184 |
tot_data[cur_k].append(cur_data[cur_k][ :clip_ending_idxes[i_tag]])
|
1185 |
|
1186 |
for cur_k in tot_data:
|
1187 |
+
print(f"cur_k: {cur_k}")
|
1188 |
+
for aa in tot_data[cur_k]:
|
1189 |
+
print(aa.shape)
|
1190 |
+
if cur_k in ['tot_base_pts', 'tot_base_normals']:
|
1191 |
+
continue
|
1192 |
+
elif cur_k in ["tot_base_pts", "tot_base_normals", "tot_obj_rot", "tot_obj_transl", "tot_obj_pcs", "tot_rhand_joints", "tot_gt_rhand_joints"]:
|
1193 |
tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=1)
|
1194 |
else:
|
1195 |
tot_data[cur_k] = np.concatenate(tot_data[cur_k], axis=0)
|
|
|
1207 |
|
1208 |
targets = data['targets'] # # targets # #
|
1209 |
outputs = data['outputs'] #
|
1210 |
+
# tot_base_pts = data["tot_base_pts"][0] # total base pts, total base normals #
|
1211 |
+
# tot_base_normals = data['tot_base_normals'][0] # nn_base_normals #
|
1212 |
|
1213 |
|
1214 |
|
|
|
1218 |
tot_obj_transl = data['tot_obj_transl'][0]
|
1219 |
print(f"tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
|
1220 |
|
1221 |
+
# if len(tot_base_pts.shape) == 2:
|
1222 |
+
# # numpy array # # tot base pts #
|
1223 |
+
# tot_base_pts_trans = np.matmul(tot_base_pts.reshape(1, tot_base_pts.shape[0], 3), tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
|
1224 |
+
# tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot[0]) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])[0]
|
1225 |
|
1226 |
+
# tot_base_normals_trans = np.matmul( # #
|
1227 |
+
# tot_base_normals.reshape(1, tot_base_normals.shape[0], 3), tot_obj_rot
|
1228 |
+
# )
|
1229 |
+
# else:
|
1230 |
+
# print(f"tot_base_pts: {tot_base_pts.shape}, tot_obj_rot: {tot_obj_rot.shape}, tot_obj_transl: {tot_obj_transl.shape}")
|
1231 |
+
# tot_base_pts_trans = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
|
1232 |
+
# tot_base_pts = np.matmul(tot_base_pts, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1])
|
1233 |
|
1234 |
+
# tot_base_normals_trans = np.matmul(
|
1235 |
+
# tot_base_normals, tot_obj_rot
|
1236 |
+
# )
|
1237 |
|
1238 |
|
1239 |
|
|
|
1242 |
|
1243 |
targets = np.matmul(targets, tot_obj_rot) + tot_obj_transl.reshape(tot_obj_transl.shape[0], 1, tot_obj_transl.shape[1]) # ws x nn_verts x 3 #
|
1244 |
# denoise relative positions
|
1245 |
+
# print(f"tot_base_pts: {tot_base_pts.shape}")
|
1246 |
|
1247 |
|
1248 |
#### obj_verts_trans, obj_faces ####
|
|
|
1269 |
with_contact_opt = True
|
1270 |
with_ctx_mask = False
|
1271 |
|
1272 |
+
bf_ct_optimized_dict, bf_proj_optimized_dict, optimized_dict = get_optimized_hand_fr_joints_v4_anchors(outputs, None, None, None, with_contact_opt=with_contact_opt, nn_hand_params=nn_hand_params, rt_vars=True, with_proj=with_proj, obj_verts_trans=obj_verts_trans, obj_faces=obj_faces, with_params_smoothing=with_params_smoothing, dist_thres=dist_thres, with_ctx_mask=with_ctx_mask)
|
1273 |
|
1274 |
|
1275 |
|
|
|
1279 |
optimized_sv_infos.update(bf_ct_optimized_dict)
|
1280 |
optimized_sv_infos.update(bf_proj_optimized_dict)
|
1281 |
optimized_sv_infos.update(optimized_dict)
|
1282 |
+
# optimized_sv_infos.update(
|
1283 |
+
# {
|
1284 |
+
# 'tot_base_pts_trans': tot_base_pts_trans,
|
1285 |
+
# 'tot_base_normals_trans': tot_base_normals_trans
|
1286 |
+
# }
|
1287 |
+
# )
|
1288 |
|
1289 |
|
1290 |
optimized_sv_infos.update({'predicted_info': data})
|