vr-hmr / debug_unity_data.py
zirobtc's picture
Upload folder using huggingface_hub
7e120dd
#!/usr/bin/env python3
"""
Diagnostic script to check Unity dataset coordinate consistency.
Checks:
- Rotation consistency between `smpl_params_c`, `smpl_params_w`, and `T_w2c`
- `cam_angvel` matches the convention used in preprocessing
- SMPL forward-kinematics consistency between camera/world parameters
Run with the gvhmr env python:
/root/miniconda3/envs/gvhmr/bin/python debug_unity_data.py
"""
import torch
import numpy as np
from pathlib import Path
from scipy.spatial.transform import Rotation as R
def axis_angle_to_matrix(aa):
"""Convert axis-angle to rotation matrix (numpy)."""
return R.from_rotvec(aa).as_matrix()
def check_single_sequence(pt_path):
"""Check a single .pt file for coordinate consistency."""
print(f"\n{'='*80}")
print(f"Checking: {pt_path.name}")
print(f"{'='*80}")
data = torch.load(pt_path, map_location="cpu")
# Extract key data
smpl_c = data["smpl_params_c"]
smpl_w = data["smpl_params_w"]
T_w2c = data["T_w2c"].numpy()
# Check first frame
idx = 0
print(f"\n[Frame {idx}]")
# Ground truth
go_c_gt = smpl_c["global_orient"][idx].numpy() # (3,) axis-angle
go_w_gt = smpl_w["global_orient"][idx].numpy() # (3,) axis-angle
# Convert to matrices
R_c_gt = axis_angle_to_matrix(go_c_gt) # Pelvis in camera frame
R_w_gt = axis_angle_to_matrix(go_w_gt) # Pelvis in world frame
R_w2c = T_w2c[idx, :3, :3] # World to camera
R_c2w = R_w2c.T # Camera to world
# Verify: R_w = R_c2w @ R_c
R_w_reconstructed = R_c2w @ R_c_gt
# Compare
R_diff = R_w_reconstructed @ R_w_gt.T
angle_err_deg = np.linalg.norm(R.from_matrix(R_diff).as_rotvec()) * 180.0 / np.pi
print(f"Ground truth global_orient_c (axis-angle): {go_c_gt}")
print(f"Ground truth global_orient_w (axis-angle): {go_w_gt}")
print(f"\nReconstruction test: R_w = R_c2w @ R_c")
print(f" Rotation error: {angle_err_deg:.4f}°")
if angle_err_deg > 1.0:
print(f" ❌ ERROR: Rotation mismatch > 1°!")
print(f" R_w (ground truth):\n{R_w_gt}")
print(f" R_w (reconstructed):\n{R_w_reconstructed}")
else:
print(f" ✅ OK: Rotations are consistent")
# Check cam_angvel computation (should match preprocess convention)
print(f"\n[Camera Angular Velocity Check]")
cam_ok = True
if "cam_angvel" in data:
cam_angvel = data["cam_angvel"] # (L, 6) - 6D rotation
print(f" cam_angvel shape: {cam_angvel.shape}")
print(f" cam_angvel[0]: {cam_angvel[0].numpy()}")
# Manually compute cam_angvel and compare.
# Convention (see `tools/demo/process_dataset.py:compute_velocity`):
# cam_angvel[0] = [1,0,0, 0,1,0] (identity, rotation6d)
# cam_angvel[i] = rot6d(R_i @ R_{i-1}^T)
from genmo.utils.rotation_conversions import matrix_to_rotation_6d
R_w2c_t = torch.from_numpy(T_w2c[:, :3, :3]).float()
L = int(R_w2c_t.shape[0])
cam_angvel_manual = torch.zeros((L, 6), dtype=torch.float32)
cam_angvel_manual[0] = cam_angvel_manual.new_tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0])
if L > 1:
R_diff_manual = R_w2c_t[1:] @ R_w2c_t[:-1].transpose(-1, -2)
cam_angvel_manual[1:] = matrix_to_rotation_6d(R_diff_manual)
diff = (cam_angvel - cam_angvel_manual).abs().max()
print(f" Manual vs stored cam_angvel max diff: {diff:.6f}")
if diff > 1e-4:
print(f" ❌ WARNING: cam_angvel mismatch!")
cam_ok = False
else:
print(f" ✅ OK: cam_angvel matches manual computation")
# Check SMPL forward kinematics consistency
print(f"\n[SMPL FK Check]")
fk_ok = True
try:
from third_party.GVHMR.hmr4d.utils.smplx_utils import make_smplx
smplx_model = make_smplx("supermotion").eval()
with torch.no_grad():
# Incam SMPL
out_c = smplx_model(
global_orient=smpl_c["global_orient"][idx:idx+1],
body_pose=smpl_c["body_pose"][idx:idx+1],
betas=smpl_c["betas"][idx:idx+1],
transl=smpl_c["transl"][idx:idx+1]
)
joints_c = out_c.joints[0, :22].numpy() # (22, 3)
# Global SMPL
out_w = smplx_model(
global_orient=smpl_w["global_orient"][idx:idx+1],
body_pose=smpl_w["body_pose"][idx:idx+1],
betas=smpl_w["betas"][idx:idx+1],
transl=smpl_w["transl"][idx:idx+1]
)
joints_w = out_w.joints[0, :22].numpy() # (22, 3)
# Transform camera->world using T_w2c (world->camera):
# x_c = R_w2c x_w + t_w2c
# => x_w = R_w2c^T (x_c - t_w2c)
t_w2c = T_w2c[idx, :3, 3]
joints_c2w = (R_c2w @ (joints_c - t_w2c).T).T
# Compare
joint_err = np.linalg.norm(joints_c2w - joints_w, axis=-1).mean()
print(f" Mean joint error (incam→world vs world GT): {joint_err:.4f}m")
if joint_err > 0.05:
print(f" ❌ ERROR: Joint mismatch > 5cm!")
fk_ok = False
else:
print(f" ✅ OK: SMPL joints are consistent")
except Exception as e:
print(f" ⚠️ Could not run SMPL FK check: {e}")
fk_ok = False
# Consider the clip consistent only if all checks are reasonable.
ok_rot = angle_err_deg < 1.0
return ok_rot and cam_ok and fk_ok
def main():
dataset_root = Path("./processed_dataset")
feat_dir = dataset_root / "genmo_features"
if not feat_dir.exists():
print(f"Error: {feat_dir} not found!")
return
pt_files = sorted(list(feat_dir.glob("*.pt")))
print(f"Found {len(pt_files)} sequences")
if len(pt_files) == 0:
print("No .pt files found!")
return
# Check first 3 sequences
num_check = min(3, len(pt_files))
all_ok = True
for i in range(num_check):
ok = check_single_sequence(pt_files[i])
all_ok = all_ok and ok
print(f"\n{'='*80}")
if all_ok:
print("✅ All checks passed! Data appears consistent.")
print("\nIf training still has high loss, the issue is likely:")
print(" 1. Model architecture/hyperparameters")
print(" 2. Normalization statistics mismatch")
print(" 3. Sequence length handling during training")
else:
print("❌ Data consistency issues found!")
print("\nThis explains the high training loss.")
print("You need to fix the coordinate system in process_dataset.py")
print(f"{'='*80}\n")
if __name__ == "__main__":
main()