Spaces:
Sleeping
Sleeping
Add phantom project with submodules and dependencies
Browse filesBinary files tracked with Git LFS
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +13 -0
- app.py +398 -80
- phantom +0 -1
- phantom/.gitignore +11 -0
- phantom/.gitmodules +15 -0
- phantom/LICENSE +21 -0
- phantom/README.md +168 -0
- phantom/configs/default.yaml +30 -0
- phantom/configs/epic.yaml +31 -0
- phantom/configs/sam2_hiera_l.yaml +117 -0
- phantom/data/__init__.py +0 -0
- phantom/docs/teaser_masquerade.png +3 -0
- phantom/docs/teaser_phantom.png +3 -0
- phantom/install.sh +67 -0
- phantom/phantom/__init__.py +0 -0
- phantom/phantom/camera/__init__.py +0 -0
- phantom/phantom/camera/camera_extrinsics.json +42 -0
- phantom/phantom/camera/camera_extrinsics_ego_bimanual_shoulders.json +52 -0
- phantom/phantom/camera/camera_intrinsics_HD1080.json +48 -0
- phantom/phantom/camera/camera_intrinsics_epic.json +48 -0
- phantom/phantom/detectors/detector_detectron2.py +121 -0
- phantom/phantom/detectors/detector_dino.py +108 -0
- phantom/phantom/detectors/detector_hamer.py +447 -0
- phantom/phantom/detectors/detector_sam2.py +240 -0
- phantom/phantom/hand.py +805 -0
- phantom/phantom/process_data.py +243 -0
- phantom/phantom/processors/__init__.py +0 -0
- phantom/phantom/processors/action_processor.py +478 -0
- phantom/phantom/processors/base_processor.py +209 -0
- phantom/phantom/processors/bbox_processor.py +851 -0
- phantom/phantom/processors/hand_processor.py +675 -0
- phantom/phantom/processors/handinpaint_processor.py +485 -0
- phantom/phantom/processors/paths.py +219 -0
- phantom/phantom/processors/phantom_data.py +340 -0
- phantom/phantom/processors/robotinpaint_processor.py +785 -0
- phantom/phantom/processors/segmentation_processor.py +1056 -0
- phantom/phantom/processors/smoothing_processor.py +303 -0
- phantom/phantom/twin_bimanual_robot.py +597 -0
- phantom/phantom/twin_robot.py +490 -0
- phantom/phantom/utils/__init__.py +0 -0
- phantom/phantom/utils/bbox_utils.py +38 -0
- phantom/phantom/utils/data_utils.py +38 -0
- phantom/phantom/utils/image_utils.py +103 -0
- phantom/phantom/utils/pcd_utils.py +210 -0
- phantom/phantom/utils/transform_utils.py +43 -0
- phantom/setup.py +7 -0
- phantom/submodules/phantom-E2FGVI/.gitignore +136 -0
- phantom/submodules/phantom-E2FGVI/E2FGVI/__init__.py +0 -0
- phantom/submodules/phantom-E2FGVI/E2FGVI/configs/train_e2fgvi.json +41 -0
- phantom/submodules/phantom-E2FGVI/E2FGVI/configs/train_e2fgvi_hq.json +41 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
phantom/submodules/phantom-robosuite/robosuite/models/assets/textures/plywood-4k.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.hdf5 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.obj filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
*.mtl filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
*.stl filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
*.dae filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
*.hdr filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
*.msh filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -1,109 +1,427 @@
|
|
| 1 |
"""
|
| 2 |
-
Phantom Video Processor - Hugging Face Space
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
import spaces
|
| 7 |
import subprocess
|
| 8 |
import sys
|
|
|
|
|
|
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
|
| 11 |
-
# ==========
|
| 12 |
-
|
| 13 |
PHANTOM_DIR = Path("/home/user/app/phantom")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
if Path("/tmp/.phantom_ready").exists():
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# 运行setup.sh
|
| 26 |
setup_script = Path("/home/user/app/setup.sh")
|
| 27 |
-
if setup_script.exists():
|
| 28 |
-
|
| 29 |
-
result = subprocess.run(
|
| 30 |
-
["bash", str(setup_script)],
|
| 31 |
-
check=True,
|
| 32 |
-
capture_output=True,
|
| 33 |
-
text=True
|
| 34 |
-
)
|
| 35 |
-
print(result.stdout)
|
| 36 |
-
print("✅ 环境配置完成")
|
| 37 |
-
return True
|
| 38 |
-
except subprocess.CalledProcessError as e:
|
| 39 |
-
print(f"❌ 配置失败: {e.stderr}")
|
| 40 |
-
return False
|
| 41 |
-
else:
|
| 42 |
-
print("⚠️ setup.sh不存在")
|
| 43 |
-
return False
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
-
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
import torch
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
|
| 62 |
-
# 检查
|
| 63 |
if torch.cuda.is_available():
|
| 64 |
gpu = torch.cuda.get_device_name(0)
|
| 65 |
-
|
|
|
|
| 66 |
else:
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
)
|
| 95 |
-
btn = gr.Button("开始处理", variant="primary")
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
)
|
| 107 |
|
|
|
|
| 108 |
if __name__ == "__main__":
|
| 109 |
demo.queue().launch()
|
|
|
|
| 1 |
"""
|
| 2 |
+
Phantom Video Processor - Hugging Face Space Demo
|
| 3 |
+
将人类手部视频转换为机器人演示数据
|
| 4 |
"""
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import spaces
|
| 8 |
import subprocess
|
| 9 |
import sys
|
| 10 |
+
import os
|
| 11 |
+
import shutil
|
| 12 |
+
import tempfile
|
| 13 |
from pathlib import Path
|
| 14 |
|
| 15 |
+
# ========== 路径配置 ==========
|
|
|
|
| 16 |
PHANTOM_DIR = Path("/home/user/app/phantom")
|
| 17 |
+
DATA_RAW_DIR = PHANTOM_DIR / "data" / "raw"
|
| 18 |
+
DATA_PROCESSED_DIR = PHANTOM_DIR / "data" / "processed"
|
| 19 |
+
MANO_DIR = PHANTOM_DIR / "submodules" / "phantom-hamer" / "_DATA" / "data" / "mano"
|
| 20 |
+
|
| 21 |
+
# 添加 Phantom 到 Python 路径
|
| 22 |
+
if PHANTOM_DIR.exists():
|
| 23 |
+
sys.path.insert(0, str(PHANTOM_DIR))
|
| 24 |
+
sys.path.insert(0, str(PHANTOM_DIR / "phantom"))
|
| 25 |
+
|
| 26 |
+
# ========== 环境检测 ==========
|
| 27 |
+
def check_environment():
|
| 28 |
+
"""检查环境状态"""
|
| 29 |
+
status = {
|
| 30 |
+
"phantom_installed": Path("/tmp/.phantom_ready").exists(),
|
| 31 |
+
"mano_ready": (MANO_DIR / "MANO_LEFT.pkl").exists() and (MANO_DIR / "MANO_RIGHT.pkl").exists(),
|
| 32 |
+
"sample_data": (DATA_RAW_DIR / "pick_and_place").exists(),
|
| 33 |
+
"cuda_available": False,
|
| 34 |
+
"gpu_name": None
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
import torch
|
| 39 |
+
status["cuda_available"] = torch.cuda.is_available()
|
| 40 |
+
if status["cuda_available"]:
|
| 41 |
+
status["gpu_name"] = torch.cuda.get_device_name(0)
|
| 42 |
+
except:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
return status
|
| 46 |
+
|
| 47 |
+
def get_status_text():
|
| 48 |
+
"""获取状态文本"""
|
| 49 |
+
status = check_environment()
|
| 50 |
+
lines = []
|
| 51 |
+
lines.append("=" * 40)
|
| 52 |
+
lines.append("环境状态")
|
| 53 |
+
lines.append("=" * 40)
|
| 54 |
+
lines.append(f"Phantom 安装: {'✅' if status['phantom_installed'] else '❌ 首次运行需初始化'}")
|
| 55 |
+
lines.append(f"MANO 模型: {'✅' if status['mano_ready'] else '❌ 请上传 MANO 模型文件'}")
|
| 56 |
+
lines.append(f"示例数据: {'✅' if status['sample_data'] else '⏳ 将自动下载'}")
|
| 57 |
+
lines.append(f"CUDA: {'✅ ' + (status['gpu_name'] or '') if status['cuda_available'] else '⏳ GPU 将在处理时分配'}")
|
| 58 |
+
lines.append("=" * 40)
|
| 59 |
+
return "\n".join(lines)
|
| 60 |
+
|
| 61 |
+
# ========== MANO 模型上传 ==========
|
| 62 |
+
def upload_mano_files(left_file, right_file):
|
| 63 |
+
"""上传 MANO 模型文件"""
|
| 64 |
+
MANO_DIR.mkdir(parents=True, exist_ok=True)
|
| 65 |
+
|
| 66 |
+
messages = []
|
| 67 |
|
| 68 |
+
if left_file is not None:
|
| 69 |
+
dest = MANO_DIR / "MANO_LEFT.pkl"
|
| 70 |
+
shutil.copy(left_file.name, dest)
|
| 71 |
+
messages.append(f"✅ MANO_LEFT.pkl 已保存")
|
| 72 |
+
|
| 73 |
+
if right_file is not None:
|
| 74 |
+
dest = MANO_DIR / "MANO_RIGHT.pkl"
|
| 75 |
+
shutil.copy(right_file.name, dest)
|
| 76 |
+
messages.append(f"✅ MANO_RIGHT.pkl 已保存")
|
| 77 |
+
|
| 78 |
+
if not messages:
|
| 79 |
+
return "⚠️ 请选择文件上传"
|
| 80 |
+
|
| 81 |
+
return "\n".join(messages) + "\n\n" + get_status_text()
|
| 82 |
+
|
| 83 |
+
# ========== 初始化环境 ==========
|
| 84 |
+
def initialize_environment(progress=gr.Progress()):
|
| 85 |
+
"""初始化 Phantom 环境"""
|
| 86 |
if Path("/tmp/.phantom_ready").exists():
|
| 87 |
+
return "✅ 环境已就绪\n\n" + get_status_text()
|
| 88 |
+
|
| 89 |
+
progress(0, desc="开始初始化...")
|
| 90 |
+
|
|
|
|
|
|
|
| 91 |
setup_script = Path("/home/user/app/setup.sh")
|
| 92 |
+
if not setup_script.exists():
|
| 93 |
+
return "❌ setup.sh 不存在"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
try:
|
| 96 |
+
# 运行 setup.sh
|
| 97 |
+
progress(0.1, desc="运行安装脚本...")
|
| 98 |
+
process = subprocess.Popen(
|
| 99 |
+
["bash", str(setup_script)],
|
| 100 |
+
stdout=subprocess.PIPE,
|
| 101 |
+
stderr=subprocess.STDOUT,
|
| 102 |
+
text=True,
|
| 103 |
+
bufsize=1
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
output_lines = []
|
| 107 |
+
for line in iter(process.stdout.readline, ''):
|
| 108 |
+
output_lines.append(line.strip())
|
| 109 |
+
if len(output_lines) > 50:
|
| 110 |
+
output_lines = output_lines[-50:] # 保留最后 50 行
|
| 111 |
+
|
| 112 |
+
process.wait()
|
| 113 |
|
| 114 |
+
if process.returncode == 0:
|
| 115 |
+
progress(1.0, desc="完成!")
|
| 116 |
+
return "✅ 初始化完成!\n\n" + "\n".join(output_lines[-20:]) + "\n\n" + get_status_text()
|
| 117 |
+
else:
|
| 118 |
+
return f"❌ 初始化失败 (返回码: {process.returncode})\n\n" + "\n".join(output_lines[-30:])
|
| 119 |
|
| 120 |
+
except Exception as e:
|
| 121 |
+
return f"❌ 初始化错误: {str(e)}"
|
| 122 |
|
| 123 |
+
# ========== 视频处理 ==========
|
| 124 |
+
@spaces.GPU(duration=300)
|
| 125 |
+
def process_video(
|
| 126 |
+
video_file,
|
| 127 |
+
robot_type,
|
| 128 |
+
target_hand,
|
| 129 |
+
processing_mode,
|
| 130 |
+
use_sample_data,
|
| 131 |
+
progress=gr.Progress()
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
处理视频 - 将人类手部转换为机器人
|
| 135 |
+
"""
|
| 136 |
import torch
|
| 137 |
|
| 138 |
+
# 状态信息
|
| 139 |
+
status_lines = []
|
| 140 |
|
| 141 |
+
# GPU 检查
|
| 142 |
if torch.cuda.is_available():
|
| 143 |
gpu = torch.cuda.get_device_name(0)
|
| 144 |
+
status_lines.append(f"✅ GPU: {gpu}")
|
| 145 |
+
status_lines.append(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 146 |
else:
|
| 147 |
+
status_lines.append("❌ GPU 不可用")
|
| 148 |
+
return None, None, "\n".join(status_lines)
|
| 149 |
+
|
| 150 |
+
# 检查环境
|
| 151 |
+
if not Path("/tmp/.phantom_ready").exists():
|
| 152 |
+
status_lines.append("❌ 请先点击「初始化环境」按钮")
|
| 153 |
+
return None, None, "\n".join(status_lines)
|
| 154 |
+
|
| 155 |
+
# 检查 MANO
|
| 156 |
+
if not (MANO_DIR / "MANO_LEFT.pkl").exists():
|
| 157 |
+
status_lines.append("❌ 请先上传 MANO 模型文件")
|
| 158 |
+
return None, None, "\n".join(status_lines)
|
| 159 |
+
|
| 160 |
+
progress(0.1, desc="准备处理...")
|
| 161 |
+
|
| 162 |
+
# 确定输入数据
|
| 163 |
+
if use_sample_data:
|
| 164 |
+
demo_name = "pick_and_place"
|
| 165 |
+
data_root = str(DATA_RAW_DIR)
|
| 166 |
+
status_lines.append(f"📂 使用示例数据: {demo_name}")
|
| 167 |
+
else:
|
| 168 |
+
if video_file is None:
|
| 169 |
+
status_lines.append("❌ 请上传视频或选择使用示例数据")
|
| 170 |
+
return None, None, "\n".join(status_lines)
|
| 171 |
+
|
| 172 |
+
# 创建临时目录存放上传的视频
|
| 173 |
+
demo_name = "user_upload"
|
| 174 |
+
user_data_dir = DATA_RAW_DIR / demo_name / "0"
|
| 175 |
+
user_data_dir.mkdir(parents=True, exist_ok=True)
|
| 176 |
+
|
| 177 |
+
# 复制视频到正确位置
|
| 178 |
+
video_dest = user_data_dir / "video.mkv"
|
| 179 |
+
shutil.copy(video_file, video_dest)
|
| 180 |
+
data_root = str(DATA_RAW_DIR)
|
| 181 |
+
status_lines.append(f"📂 处理上传视频: {video_file}")
|
| 182 |
+
|
| 183 |
+
status_lines.append(f"🤖 机器人类型: {robot_type}")
|
| 184 |
+
status_lines.append(f"✋ 目标手部: {target_hand}")
|
| 185 |
+
status_lines.append(f"⚙️ 处理模式: {processing_mode}")
|
| 186 |
+
status_lines.append("-" * 40)
|
| 187 |
+
|
| 188 |
+
progress(0.2, desc="开始处理...")
|
| 189 |
+
|
| 190 |
+
# 构建处理命令
|
| 191 |
+
cmd = [
|
| 192 |
+
sys.executable,
|
| 193 |
+
str(PHANTOM_DIR / "phantom" / "process_data.py"),
|
| 194 |
+
f"demo_name={demo_name}",
|
| 195 |
+
f"data_root_dir={data_root}",
|
| 196 |
+
f"processed_data_root_dir={str(DATA_PROCESSED_DIR)}",
|
| 197 |
+
f"mode={processing_mode}",
|
| 198 |
+
f"robot={robot_type}",
|
| 199 |
+
f"target_hand={target_hand}",
|
| 200 |
+
"bimanual_setup=single_arm",
|
| 201 |
+
"demo_num=0", # 只处理第一个 demo
|
| 202 |
+
]
|
| 203 |
+
|
| 204 |
+
status_lines.append(f"命令: {' '.join(cmd)}")
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
# 运行处理
|
| 208 |
+
progress(0.3, desc="处理中...")
|
| 209 |
+
|
| 210 |
+
process = subprocess.Popen(
|
| 211 |
+
cmd,
|
| 212 |
+
stdout=subprocess.PIPE,
|
| 213 |
+
stderr=subprocess.STDOUT,
|
| 214 |
+
text=True,
|
| 215 |
+
cwd=str(PHANTOM_DIR / "phantom"),
|
| 216 |
+
env={**os.environ, "PYTHONPATH": str(PHANTOM_DIR)}
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
output_lines = []
|
| 220 |
+
for line in iter(process.stdout.readline, ''):
|
| 221 |
+
line = line.strip()
|
| 222 |
+
if line:
|
| 223 |
+
output_lines.append(line)
|
| 224 |
+
# 更新进度
|
| 225 |
+
if "BBOX" in line:
|
| 226 |
+
progress(0.4, desc="检测边界框...")
|
| 227 |
+
elif "HAND2D" in line:
|
| 228 |
+
progress(0.5, desc="提取2D手部姿态...")
|
| 229 |
+
elif "SEGMENTATION" in line:
|
| 230 |
+
progress(0.6, desc="分割手臂...")
|
| 231 |
+
elif "ACTION" in line:
|
| 232 |
+
progress(0.7, desc="提取动作...")
|
| 233 |
+
elif "INPAINT" in line:
|
| 234 |
+
progress(0.8, desc="视频修复...")
|
| 235 |
+
elif "ROBOT" in line:
|
| 236 |
+
progress(0.9, desc="叠加机器人...")
|
| 237 |
+
|
| 238 |
+
process.wait()
|
| 239 |
+
|
| 240 |
+
progress(1.0, desc="完成!")
|
| 241 |
+
|
| 242 |
+
# 添加处理输出
|
| 243 |
+
status_lines.append("-" * 40)
|
| 244 |
+
status_lines.append("处理日志 (最后 20 行):")
|
| 245 |
+
status_lines.extend(output_lines[-20:])
|
| 246 |
+
|
| 247 |
+
# 查找输出文件
|
| 248 |
+
output_video = None
|
| 249 |
+
output_data = None
|
| 250 |
+
|
| 251 |
+
processed_dir = DATA_PROCESSED_DIR / demo_name / "0"
|
| 252 |
+
|
| 253 |
+
# 查找生成的视频
|
| 254 |
+
video_pattern = f"video_overlay_{robot_type}_single_arm.mkv"
|
| 255 |
+
for f in processed_dir.glob("**/*.mkv"):
|
| 256 |
+
if robot_type.lower() in f.name.lower():
|
| 257 |
+
output_video = str(f)
|
| 258 |
+
break
|
| 259 |
+
|
| 260 |
+
# 查找训练数据
|
| 261 |
+
for f in processed_dir.glob("**/training_data*.npz"):
|
| 262 |
+
output_data = str(f)
|
| 263 |
+
break
|
| 264 |
+
|
| 265 |
+
if output_video:
|
| 266 |
+
status_lines.append(f"\n✅ 输出视频: {output_video}")
|
| 267 |
+
if output_data:
|
| 268 |
+
status_lines.append(f"✅ 训练数据: {output_data}")
|
| 269 |
+
|
| 270 |
+
if process.returncode == 0:
|
| 271 |
+
status_lines.insert(0, "✅ 处理完成!")
|
| 272 |
+
else:
|
| 273 |
+
status_lines.insert(0, f"⚠️ 处理完成但有警告 (返回码: {process.returncode})")
|
| 274 |
+
|
| 275 |
+
return output_video, output_data, "\n".join(status_lines)
|
| 276 |
+
|
| 277 |
+
except Exception as e:
|
| 278 |
+
import traceback
|
| 279 |
+
status_lines.append(f"\n❌ 处理错误: {str(e)}")
|
| 280 |
+
status_lines.append(traceback.format_exc())
|
| 281 |
+
return None, None, "\n".join(status_lines)
|
| 282 |
+
|
| 283 |
+
# ========== Gradio 界面 ==========
|
| 284 |
+
with gr.Blocks(
|
| 285 |
+
title="Phantom - 机器人视频生成器",
|
| 286 |
+
theme=gr.themes.Soft()
|
| 287 |
+
) as demo:
|
| 288 |
+
|
| 289 |
+
gr.Markdown("""
|
| 290 |
+
# 🤖 Phantom - 将人类视频转换为机器人演示
|
| 291 |
+
|
| 292 |
+
**论文**: [Phantom: Training Robots Without Robots Using Only Human Videos](https://phantom-human-videos.github.io/)
|
| 293 |
+
|
| 294 |
+
将人类手部操作视频自动转换为机器人演示数据,用于训练机器人策略。
|
| 295 |
+
""")
|
| 296 |
+
|
| 297 |
+
with gr.Tabs():
|
| 298 |
+
# ========== 环境设置 Tab ==========
|
| 299 |
+
with gr.TabItem("1️⃣ 环境设置"):
|
| 300 |
+
gr.Markdown("""
|
| 301 |
+
### 首次使用需要完成以下步骤:
|
| 302 |
+
|
| 303 |
+
1. **初始化环境** - 安装依赖和下载模型 (首次约 5-10 分钟)
|
| 304 |
+
2. **上传 MANO 模型** - 需要从官网注册下载
|
| 305 |
+
""")
|
| 306 |
+
|
| 307 |
+
with gr.Row():
|
| 308 |
+
with gr.Column():
|
| 309 |
+
init_btn = gr.Button("🔧 初始化环境", variant="primary", size="lg")
|
| 310 |
+
init_output = gr.Textbox(
|
| 311 |
+
label="初始化状态",
|
| 312 |
+
lines=15,
|
| 313 |
+
value=get_status_text()
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
with gr.Column():
|
| 317 |
+
gr.Markdown("""
|
| 318 |
+
### MANO 模型下载
|
| 319 |
+
|
| 320 |
+
1. 访问 [MANO 官网](https://mano.is.tue.mpg.de/)
|
| 321 |
+
2. 注册账号并下载模型
|
| 322 |
+
3. 上传 `MANO_LEFT.pkl` 和 `MANO_RIGHT.pkl`
|
| 323 |
+
""")
|
| 324 |
+
|
| 325 |
+
mano_left = gr.File(label="MANO_LEFT.pkl", file_types=[".pkl"])
|
| 326 |
+
mano_right = gr.File(label="MANO_RIGHT.pkl", file_types=[".pkl"])
|
| 327 |
+
upload_btn = gr.Button("📤 上传 MANO 模型")
|
| 328 |
+
upload_output = gr.Textbox(label="上传状态", lines=5)
|
| 329 |
+
|
| 330 |
+
init_btn.click(fn=initialize_environment, outputs=init_output)
|
| 331 |
+
upload_btn.click(fn=upload_mano_files, inputs=[mano_left, mano_right], outputs=upload_output)
|
| 332 |
+
|
| 333 |
+
# ========== 视频处理 Tab ==========
|
| 334 |
+
with gr.TabItem("2️⃣ 视频处理"):
|
| 335 |
+
with gr.Row():
|
| 336 |
+
with gr.Column():
|
| 337 |
+
gr.Markdown("### 输入设置")
|
| 338 |
+
|
| 339 |
+
use_sample = gr.Checkbox(
|
| 340 |
+
label="使用示例数据 (pick_and_place)",
|
| 341 |
+
value=True,
|
| 342 |
+
info="推荐首次使用时勾选,使用预置的示例视频"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
video_input = gr.Video(
|
| 346 |
+
label="或上传自己的视频",
|
| 347 |
+
interactive=True
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
robot_type = gr.Dropdown(
|
| 351 |
+
choices=["Panda", "Kinova3", "UR5e", "IIWA", "Jaco"],
|
| 352 |
+
value="Panda",
|
| 353 |
+
label="机器人类型"
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
target_hand = gr.Radio(
|
| 357 |
+
choices=["left", "right"],
|
| 358 |
+
value="left",
|
| 359 |
+
label="目标手部"
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
processing_mode = gr.Dropdown(
|
| 363 |
+
choices=[
|
| 364 |
+
"bbox",
|
| 365 |
+
"hand2d",
|
| 366 |
+
"arm_segmentation",
|
| 367 |
+
"hand_inpaint",
|
| 368 |
+
"robot_inpaint",
|
| 369 |
+
"all"
|
| 370 |
+
],
|
| 371 |
+
value="bbox",
|
| 372 |
+
label="处理模式",
|
| 373 |
+
info="建议逐步运行: bbox -> hand2d -> arm_segmentation -> hand_inpaint -> robot_inpaint"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
process_btn = gr.Button("🚀 开始处理", variant="primary", size="lg")
|
| 377 |
+
|
| 378 |
+
with gr.Column():
|
| 379 |
+
gr.Markdown("### 输出结果")
|
| 380 |
+
|
| 381 |
+
video_output = gr.Video(label="生成的机器人视频")
|
| 382 |
+
data_output = gr.File(label="训练数据 (NPZ)")
|
| 383 |
+
status_output = gr.Textbox(label="处理状态", lines=20)
|
| 384 |
+
|
| 385 |
+
process_btn.click(
|
| 386 |
+
fn=process_video,
|
| 387 |
+
inputs=[video_input, robot_type, target_hand, processing_mode, use_sample],
|
| 388 |
+
outputs=[video_output, data_output, status_output]
|
| 389 |
)
|
|
|
|
| 390 |
|
| 391 |
+
# ========== 说明 Tab ==========
|
| 392 |
+
with gr.TabItem("📖 说明"):
|
| 393 |
+
gr.Markdown("""
|
| 394 |
+
## 处理流程
|
| 395 |
+
|
| 396 |
+
Phantom 将人类手部视频转换为机器人演示数据,处理步骤:
|
| 397 |
+
|
| 398 |
+
| 步骤 | 模式 | 描述 |
|
| 399 |
+
|------|------|------|
|
| 400 |
+
| 1 | `bbox` | 检测手部边界框 |
|
| 401 |
+
| 2 | `hand2d` | 提取 2D 手部姿态 |
|
| 402 |
+
| 3 | `arm_segmentation` | 分割人类手臂 |
|
| 403 |
+
| 4 | `hand_inpaint` | 移除手臂并修复背景 |
|
| 404 |
+
| 5 | `robot_inpaint` | 叠加虚拟机器人 |
|
| 405 |
+
|
| 406 |
+
## 输入要求
|
| 407 |
+
|
| 408 |
+
- **视频格式**: MKV, MP4 等常见格式
|
| 409 |
+
- **分辨率**: 推荐 1080p
|
| 410 |
+
- **内容**: 单手操作视频,手部需清晰可见
|
| 411 |
+
|
| 412 |
+
## GPU Zero 限制
|
| 413 |
+
|
| 414 |
+
- 单次处理时间限制: 300 秒
|
| 415 |
+
- 建议逐步运行各处理模式
|
| 416 |
+
- 复杂视频可能需要多次处理
|
| 417 |
+
|
| 418 |
+
## 参考资料
|
| 419 |
|
| 420 |
+
- [Phantom 论文](https://arxiv.org/abs/2503.00779)
|
| 421 |
+
- [GitHub 仓库](https://github.com/MarionLepert/phantom)
|
| 422 |
+
- [MANO 手部模型](https://mano.is.tue.mpg.de/)
|
| 423 |
+
""")
|
|
|
|
| 424 |
|
| 425 |
+
# 启动
|
| 426 |
if __name__ == "__main__":
|
| 427 |
demo.queue().launch()
|
phantom
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
Subproject commit a8bb81c1bbe6ade129a1f6f0906482f510354a5e
|
|
|
|
|
|
phantom/.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.egg-info
|
| 2 |
+
**/_DATA/*
|
| 3 |
+
data/raw/*
|
| 4 |
+
!data/raw/.gitkeep
|
| 5 |
+
data/processed/*
|
| 6 |
+
!data/processed/.gitkeep
|
| 7 |
+
**/__pycache__/*
|
| 8 |
+
*.pyc
|
| 9 |
+
*.pth
|
| 10 |
+
outputs/*
|
| 11 |
+
phantom/outputs/*
|
phantom/.gitmodules
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "submodules/phantom-E2FGVI"]
|
| 2 |
+
path = submodules/phantom-E2FGVI
|
| 3 |
+
url = git@github.com:MarionLepert/phantom-E2FGVI.git
|
| 4 |
+
[submodule "submodules/sam2"]
|
| 5 |
+
path = submodules/sam2
|
| 6 |
+
url = git@github.com:facebookresearch/sam2.git
|
| 7 |
+
[submodule "submodules/phantom-robosuite"]
|
| 8 |
+
path = submodules/phantom-robosuite
|
| 9 |
+
url = git@github.com:MarionLepert/phantom-robosuite.git
|
| 10 |
+
[submodule "submodules/phantom-robomimic"]
|
| 11 |
+
path = submodules/phantom-robomimic
|
| 12 |
+
url = git@github.com:MarionLepert/phantom-robomimic.git
|
| 13 |
+
[submodule "submodules/phantom-hamer"]
|
| 14 |
+
path = submodules/phantom-hamer
|
| 15 |
+
url = git@github.com:MarionLepert/phantom-hamer.git
|
phantom/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Stanford Interactive Perception and Robot Learning Lab
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
phantom/README.md
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Code for Phantom and Masquerade
|
| 2 |
+
[](https://www.python.org)
|
| 3 |
+
[](https://opensource.org/licenses/MIT)
|
| 4 |
+
<hr style="border: 2px solid gray;"></hr>
|
| 5 |
+
|
| 6 |
+
This repository contains the code used to process human videos in [Phantom: Training Robots Without Robots Using Only Human Videos](https://phantom-human-videos.github.io/) and [Masquerade: Learning from In-the-wild Human Videos using Data-Editing](https://masquerade-robot.github.io/).
|
| 7 |
+
|
| 8 |
+
<table>
|
| 9 |
+
<tr>
|
| 10 |
+
<td align="center" width="50%">
|
| 11 |
+
<h3><a href="https://phantom-human-videos.github.io/">Phantom: Training Robots Without Robots Using Only Human Videos</a></h3>
|
| 12 |
+
<p><em><a href=https://marionlepert.github.io/>Marion Lepert</a></em>, <em><a href=https://jiayingfang.github.io/>Jiaying Fang</a></em>, <em><a href=https://web.stanford.edu/~bohg/>Jeannette Bohg</a></em></p>
|
| 13 |
+
<a href="https://phantom-human-videos.github.io/">
|
| 14 |
+
<img src="docs/teaser_phantom.png" alt="Phantom Teaser" width="90%">
|
| 15 |
+
</a>
|
| 16 |
+
</td>
|
| 17 |
+
<td align="center" width="50%">
|
| 18 |
+
<h3><a href="https://masquerade-robot.github.io/">Masquerade: Learning from In-the-wild Human Videos using Data-Editing</a></h3>
|
| 19 |
+
<p><em><a href=https://marionlepert.github.io/>Marion Lepert*</a></em>, <em><a href=https://jiayingfang.github.io/>Jiaying Fang*</a></em>, <em><a href=https://web.stanford.edu/~bohg/>Jeannette Bohg</a></em></p>
|
| 20 |
+
<img src="docs/teaser_masquerade.png" alt="Masquerade Teaser" width="90%">
|
| 21 |
+
</td>
|
| 22 |
+
</tr>
|
| 23 |
+
</table>
|
| 24 |
+
|
| 25 |
+
Both projects use data editing to convert human videos into “robotized” demonstrations. They share much of the same codebase, with some differences in the processing pipeline:
|
| 26 |
+
|
| 27 |
+
**Phantom**
|
| 28 |
+
* Input: RGBD videos with a single left hand visible in every frame.
|
| 29 |
+
* Data editing: inpaint the single human arm, overlay a rendered robot arm in the same pose.
|
| 30 |
+
* Action labels: extract full 3D end-effector pose (position, orientation, gripper)
|
| 31 |
+
|
| 32 |
+
**Masquerade**
|
| 33 |
+
* Input: RGB videos from [Epic Kitchens](https://epic-kitchens.github.io/2025); one or both hands may be visible, sometimes occluded.
|
| 34 |
+
* Data editing: segment and inpaint both arms, overlay a bimanual robot whose effectors follow the estimated poses (with a 3-4cm error along the depth direction due to lack of depth data)
|
| 35 |
+
* Action labels: use 2D projected waypoints as auxiliary supervision only (not full 3D actions)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
## Installation
|
| 40 |
+
1. Clone this repo recursively
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
git clone --recursive git@github.com:MarionLepert/phantom.git
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
2. Run the following script from the root directory to install the required conda environment.
|
| 47 |
+
```bash
|
| 48 |
+
./install.sh
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
3. Download the MANO hand models. To do so, go to the [MANO website](https://mano.is.tue.mpg.de/) and register to be able to download the models. Download the left and right hand models and move MANO_LEFT.pkl and MANO_RIGHT.pkl inside the `$ROOT_DIR/submodules/phantom-hamer/_DATA/data/mano/` folder.
|
| 52 |
+
|
| 53 |
+
## Getting Started
|
| 54 |
+
Process **Phantom** sample data (manually collected in-lab videos)
|
| 55 |
+
```bash
|
| 56 |
+
conda activate phantom
|
| 57 |
+
|
| 58 |
+
python process_data.py demo_name=pick_and_place data_root_dir=../data/raw processed_data_root_dir=../data/processed mode=all
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
Process **Masquerade** sample data ([Epic Kitchens](https://epic-kitchens.github.io/2025) video)
|
| 62 |
+
```bash
|
| 63 |
+
conda activate phantom
|
| 64 |
+
|
| 65 |
+
python process_data.py demo_name=epic data_root_dir=../data/raw processed_data_root_dir=../data/processed mode=all --config-name=epic
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
## Codebase Overview
|
| 70 |
+
|
| 71 |
+
### Process data
|
| 72 |
+
Each video is processed using the following steps:
|
| 73 |
+
|
| 74 |
+
1. **Extract human hand bounding boxes**: `bbox_processor.py`
|
| 75 |
+
* `mode=bbox`
|
| 76 |
+
|
| 77 |
+
2. **Extract 2d human hand poses**: `hand_processor.py`
|
| 78 |
+
* `mode=hand2d`: extract the 2d hand pose
|
| 79 |
+
|
| 80 |
+
3. **Extract human and arm segmentation masks**: `segmentation_processor.py`
|
| 81 |
+
* `mode=hand_segmentation`: used for depth alignment in hand pose refinement (only works for hand3d)
|
| 82 |
+
* `mode=arm_segmentation`: needed in all cases to inpaint the human
|
| 83 |
+
|
| 84 |
+
2. **Extract 3d human hand poses**: `hand_processor.py`
|
| 85 |
+
* `mode=hand3d`: extract the 3d hand pose (note: requires depth, and was only tested on the left hand)
|
| 86 |
+
|
| 87 |
+
4. **Retarget human actions to robot actions**: `action_processor.py`
|
| 88 |
+
* `mode=action`
|
| 89 |
+
|
| 90 |
+
5. **Smooth human poses**: `smoothing_processor.py`
|
| 91 |
+
* `mode=smoothing`
|
| 92 |
+
|
| 93 |
+
6. **Remove hand from videos using inpainting**: `handinpaint_processor.py`
|
| 94 |
+
* `mode=hand_inpaint`
|
| 95 |
+
* Inpainting method [E2FGVI](https://arxiv.org/pdf/2204.02663) is used.
|
| 96 |
+
|
| 97 |
+
7. **Overlay virtual robot on video**: `robotinpaint_processor.py`
|
| 98 |
+
* `mode=robot_inpaint`: overlay a single robot (default) or bimanual (epic mode) robot on the image
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
### Config reference (see configuration files in `configs/`)
|
| 102 |
+
|
| 103 |
+
| Flag | Type | Required | Choices | Description |
|
| 104 |
+
|------|------|----------|---------|-------------|
|
| 105 |
+
| `--demo_name` | `str` | ✅ | - | Name of the demonstration/dataset to process |
|
| 106 |
+
| `--mode` | `str` (multiple) | ✅ | `bbox`, `hand2d`, `hand3d`, `hand_segmentation`, `arm_segmentation`, `action`, `smoothing`, `hand_inpaint`, `robot_inpaint`, `all` | Processing modes to run (can specify multiple with e.g. `'mode=[bbox,hand2d]'`) |
|
| 107 |
+
| `--robot_name` | `str` | ✅ | `Panda`, `Kinova3`, `UR5e`, `IIWA`, `Jaco` | Type of robot to use for overlays |
|
| 108 |
+
| `--gripper_name` | `str` | ❌ | `Robotiq85` | Type of gripper to use |
|
| 109 |
+
| `--data_root_dir` | `str` | ❌ | - | Root directory containing raw video data |
|
| 110 |
+
| `--processed_data_root_dir` | `str` | ❌ | - | Root directory to save processed data |
|
| 111 |
+
| `--epic` | `bool` | ❌ | - | Use Epic-Kitchens dataset processing mode |
|
| 112 |
+
| `--bimanual_setup` | `str` | ❌ | `single_arm`, `shoulders` | Bimanual setup configuration to use (shoulders corresponds to the bimanual hardware configuration used in Masquerade) |
|
| 113 |
+
| `--target_hand` | `str` | ❌ | `left`, `right`, `both` | Which hand(s) to target for processing |
|
| 114 |
+
| `--camera_intrinsics` | `str` | ❌ | - | Path to camera intrinsics file |
|
| 115 |
+
| `--camera_extrinsics` | `str` | ❌ | - | Path to camera extrinsics file |
|
| 116 |
+
| `--input_resolution` | `int` | ❌ | - | Resolution of input videos |
|
| 117 |
+
| `--output_resolution` | `int` | ❌ | - | Resolution of output videos |
|
| 118 |
+
| `--depth_for_overlay` | `bool` | ❌ | - | Use depth information for overlays |
|
| 119 |
+
| `--demo_num` | `str` | ❌ | - | Process a single demo number instead of all demos |
|
| 120 |
+
| `--debug_cameras` | `str` (multiple) | ❌ | - | Additional camera names to include for debugging |
|
| 121 |
+
| `--constrained_hand` | `bool` | ❌ | - | Use constrained hand processing |
|
| 122 |
+
| `--render` | `bool` | ❌ | - | Render the robot overlay on the video |
|
| 123 |
+
|
| 124 |
+
**Note** Please specify `--bimanual_setup single_arm` along with `--target_hand left` or `--target_hand right` if you are using single arm. For bimanual setups, use `--bimanual_setup shoulders`.
|
| 125 |
+
|
| 126 |
+
### Camera details
|
| 127 |
+
* **Phantom**: a Zed2 camera was used to capture the sample data at HD1080 resolution.
|
| 128 |
+
* **Masquerade**: We used Epic-Kitchens videos and used the camera intrinsics provided in the dataset. To use videos captured with a different camera resolution, update the camera intrinsics and extrinsics files in `$ROOT_DIR/phantom/camera/`.
|
| 129 |
+
|
| 130 |
+
### Train policy
|
| 131 |
+
After processing the video data, the edited data can be used to train a policy. The following files should be used:
|
| 132 |
+
|
| 133 |
+
* Observations
|
| 134 |
+
* Phantom Samples: extract RGB images from `data/processed/pick_and_place/*/video_overlay_Panda_single_arm.mkv`
|
| 135 |
+
* Epic (In-the-wild Data) Samples: extract RGB images from `data/processed/epic/*/video_overlay_Kinova3_shoulders.mkv`
|
| 136 |
+
|
| 137 |
+
* Actions
|
| 138 |
+
* Phantom Samples: All data stored in `data/processed/pick_and_place/*/inpaint_processor/training_data_single_arm.npz`
|
| 139 |
+
* Epic (In-the-wild Data) Samples: All data stored in `data/processed/epic/*/inpaint_processor/training_data_shoulders.npz`
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
In Phantom, [Diffusion Policy](https://github.com/real-stanford/diffusion_policy) was used for policy training.
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
## Citation
|
| 146 |
+
```bibtex
|
| 147 |
+
@article{lepert2025phantomtrainingrobotsrobots,
|
| 148 |
+
title={Phantom: Training Robots Without Robots Using Only Human Videos},
|
| 149 |
+
author={Marion Lepert and Jiaying Fang and Jeannette Bohg},
|
| 150 |
+
year={2025},
|
| 151 |
+
eprint={2503.00779},
|
| 152 |
+
archivePrefix={arXiv},
|
| 153 |
+
primaryClass={cs.RO},
|
| 154 |
+
url={https://arxiv.org/abs/2503.00779},
|
| 155 |
+
}
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
```bibtex
|
| 159 |
+
@misc{lepert2025masqueradelearninginthewildhuman,
|
| 160 |
+
title={Masquerade: Learning from In-the-wild Human Videos using Data-Editing},
|
| 161 |
+
author={Marion Lepert and Jiaying Fang and Jeannette Bohg},
|
| 162 |
+
year={2025},
|
| 163 |
+
eprint={2508.09976},
|
| 164 |
+
archivePrefix={arXiv},
|
| 165 |
+
primaryClass={cs.RO},
|
| 166 |
+
url={https://arxiv.org/abs/2508.09976},
|
| 167 |
+
}
|
| 168 |
+
```
|
phantom/configs/default.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default configuration (PHANTOM paper settings)
|
| 2 |
+
debug: false
|
| 3 |
+
verbose: false
|
| 4 |
+
skip_existing: false
|
| 5 |
+
n_processes: 1
|
| 6 |
+
data_root_dir: "../data/raw_data/"
|
| 7 |
+
processed_data_root_dir: "../data/processed_data/"
|
| 8 |
+
demo_name: ""
|
| 9 |
+
|
| 10 |
+
# Processing settings
|
| 11 |
+
mode: ["bbox"] # Default processing mode - must be one of: bbox, hand2d, hand3d, hand_segmentation, arm_segmentation, action, smoothing, hand_inpaint, robot_inpaint, all
|
| 12 |
+
demo_num: null # Process specific demo number (null = process all)
|
| 13 |
+
|
| 14 |
+
# Additional settings
|
| 15 |
+
debug_cameras: []
|
| 16 |
+
|
| 17 |
+
# PHANTOM paper configuration (default)
|
| 18 |
+
input_resolution: 1080
|
| 19 |
+
output_resolution: 240
|
| 20 |
+
robot: "Panda"
|
| 21 |
+
gripper: "Robotiq85"
|
| 22 |
+
square: true
|
| 23 |
+
epic: false
|
| 24 |
+
bimanual_setup: "single_arm"
|
| 25 |
+
target_hand: "left"
|
| 26 |
+
constrained_hand: true
|
| 27 |
+
depth_for_overlay: true
|
| 28 |
+
render: false
|
| 29 |
+
camera_intrinsics: "camera/camera_intrinsics_HD1080.json"
|
| 30 |
+
camera_extrinsics: "camera/camera_extrinsics.json"
|
phantom/configs/epic.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default configuration (PHANTOM paper settings)
|
| 2 |
+
debug: false
|
| 3 |
+
verbose: false
|
| 4 |
+
skip_existing: false
|
| 5 |
+
n_processes: 1
|
| 6 |
+
data_root_dir: "../data/raw_data/"
|
| 7 |
+
processed_data_root_dir: "../data/processed_data/"
|
| 8 |
+
demo_name: ""
|
| 9 |
+
|
| 10 |
+
# Processing settings
|
| 11 |
+
mode: ["bbox"] # Default processing mode
|
| 12 |
+
demo_num: null # Process specific demo number (null = process all videos in the root folder)
|
| 13 |
+
|
| 14 |
+
# Additional settings
|
| 15 |
+
debug_cameras: [] # Add other robomimic cameras like sideview, etc. Warning: this significantly slows down the processing time
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# EPIC-KITCHENS configuration override
|
| 19 |
+
input_resolution: 256
|
| 20 |
+
output_resolution: 256
|
| 21 |
+
robot: "Kinova3"
|
| 22 |
+
gripper: "Robotiq85"
|
| 23 |
+
square: false
|
| 24 |
+
epic: true
|
| 25 |
+
bimanual_setup: "shoulders"
|
| 26 |
+
target_hand: "both"
|
| 27 |
+
constrained_hand: false
|
| 28 |
+
depth_for_overlay: false
|
| 29 |
+
render: false
|
| 30 |
+
camera_intrinsics: "camera/camera_intrinsics_epic.json"
|
| 31 |
+
camera_extrinsics: "camera/camera_extrinsics_ego_bimanual_shoulders.json"
|
phantom/configs/sam2_hiera_l.yaml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _global_
|
| 2 |
+
|
| 3 |
+
# Model
|
| 4 |
+
model:
|
| 5 |
+
_target_: sam2.modeling.sam2_base.SAM2Base
|
| 6 |
+
image_encoder:
|
| 7 |
+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
|
| 8 |
+
scalp: 1
|
| 9 |
+
trunk:
|
| 10 |
+
_target_: sam2.modeling.backbones.hieradet.Hiera
|
| 11 |
+
embed_dim: 144
|
| 12 |
+
num_heads: 2
|
| 13 |
+
stages: [2, 6, 36, 4]
|
| 14 |
+
global_att_blocks: [23, 33, 43]
|
| 15 |
+
window_pos_embed_bkg_spatial_size: [7, 7]
|
| 16 |
+
window_spec: [8, 4, 16, 8]
|
| 17 |
+
neck:
|
| 18 |
+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
|
| 19 |
+
position_encoding:
|
| 20 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 21 |
+
num_pos_feats: 256
|
| 22 |
+
normalize: true
|
| 23 |
+
scale: null
|
| 24 |
+
temperature: 10000
|
| 25 |
+
d_model: 256
|
| 26 |
+
backbone_channel_list: [1152, 576, 288, 144]
|
| 27 |
+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
|
| 28 |
+
fpn_interp_model: nearest
|
| 29 |
+
|
| 30 |
+
memory_attention:
|
| 31 |
+
_target_: sam2.modeling.memory_attention.MemoryAttention
|
| 32 |
+
d_model: 256
|
| 33 |
+
pos_enc_at_input: true
|
| 34 |
+
layer:
|
| 35 |
+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
|
| 36 |
+
activation: relu
|
| 37 |
+
dim_feedforward: 2048
|
| 38 |
+
dropout: 0.1
|
| 39 |
+
pos_enc_at_attn: false
|
| 40 |
+
self_attention:
|
| 41 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 42 |
+
rope_theta: 10000.0
|
| 43 |
+
feat_sizes: [64, 64]
|
| 44 |
+
embedding_dim: 256
|
| 45 |
+
num_heads: 1
|
| 46 |
+
downsample_rate: 1
|
| 47 |
+
dropout: 0.1
|
| 48 |
+
d_model: 256
|
| 49 |
+
pos_enc_at_cross_attn_keys: true
|
| 50 |
+
pos_enc_at_cross_attn_queries: false
|
| 51 |
+
cross_attention:
|
| 52 |
+
_target_: sam2.modeling.sam.transformer.RoPEAttention
|
| 53 |
+
rope_theta: 10000.0
|
| 54 |
+
feat_sizes: [64, 64]
|
| 55 |
+
rope_k_repeat: True
|
| 56 |
+
embedding_dim: 256
|
| 57 |
+
num_heads: 1
|
| 58 |
+
downsample_rate: 1
|
| 59 |
+
dropout: 0.1
|
| 60 |
+
kv_in_dim: 64
|
| 61 |
+
num_layers: 4
|
| 62 |
+
|
| 63 |
+
memory_encoder:
|
| 64 |
+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
|
| 65 |
+
out_dim: 64
|
| 66 |
+
position_encoding:
|
| 67 |
+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
|
| 68 |
+
num_pos_feats: 64
|
| 69 |
+
normalize: true
|
| 70 |
+
scale: null
|
| 71 |
+
temperature: 10000
|
| 72 |
+
mask_downsampler:
|
| 73 |
+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
|
| 74 |
+
kernel_size: 3
|
| 75 |
+
stride: 2
|
| 76 |
+
padding: 1
|
| 77 |
+
fuser:
|
| 78 |
+
_target_: sam2.modeling.memory_encoder.Fuser
|
| 79 |
+
layer:
|
| 80 |
+
_target_: sam2.modeling.memory_encoder.CXBlock
|
| 81 |
+
dim: 256
|
| 82 |
+
kernel_size: 7
|
| 83 |
+
padding: 3
|
| 84 |
+
layer_scale_init_value: 1e-6
|
| 85 |
+
use_dwconv: True # depth-wise convs
|
| 86 |
+
num_layers: 2
|
| 87 |
+
|
| 88 |
+
num_maskmem: 7
|
| 89 |
+
image_size: 1024
|
| 90 |
+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
|
| 91 |
+
sigmoid_scale_for_mem_enc: 20.0
|
| 92 |
+
sigmoid_bias_for_mem_enc: -10.0
|
| 93 |
+
use_mask_input_as_output_without_sam: true
|
| 94 |
+
# Memory
|
| 95 |
+
directly_add_no_mem_embed: true
|
| 96 |
+
# use high-resolution feature map in the SAM mask decoder
|
| 97 |
+
use_high_res_features_in_sam: true
|
| 98 |
+
# output 3 masks on the first click on initial conditioning frames
|
| 99 |
+
multimask_output_in_sam: true
|
| 100 |
+
# SAM heads
|
| 101 |
+
iou_prediction_use_sigmoid: True
|
| 102 |
+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
|
| 103 |
+
use_obj_ptrs_in_encoder: true
|
| 104 |
+
add_tpos_enc_to_obj_ptrs: false
|
| 105 |
+
only_obj_ptrs_in_the_past_for_eval: true
|
| 106 |
+
# object occlusion prediction
|
| 107 |
+
pred_obj_scores: true
|
| 108 |
+
pred_obj_scores_mlp: true
|
| 109 |
+
fixed_no_obj_ptr: true
|
| 110 |
+
# multimask tracking settings
|
| 111 |
+
multimask_output_for_tracking: true
|
| 112 |
+
use_multimask_token_for_obj_ptr: true
|
| 113 |
+
multimask_min_pt_num: 0
|
| 114 |
+
multimask_max_pt_num: 1
|
| 115 |
+
use_mlp_for_obj_ptr_proj: true
|
| 116 |
+
# Compilation flag
|
| 117 |
+
compile_image_encoder: False
|
phantom/data/__init__.py
ADDED
|
File without changes
|
phantom/docs/teaser_masquerade.png
ADDED
|
Git LFS Details
|
phantom/docs/teaser_phantom.png
ADDED
|
Git LFS Details
|
phantom/install.sh
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
eval "$(conda shell.bash hook)"
|
| 2 |
+
# ######################## Phantom Env ###############################
|
| 3 |
+
conda create -n phantom python=3.10 -y
|
| 4 |
+
conda activate phantom
|
| 5 |
+
conda install nvidia/label/cuda-12.1.0::cuda-toolkit -c nvidia/label/cuda-12.1.0 -y
|
| 6 |
+
|
| 7 |
+
# Install SAM2
|
| 8 |
+
cd submodules/sam2
|
| 9 |
+
pip install -v -e ".[notebooks]"
|
| 10 |
+
cd ../..
|
| 11 |
+
|
| 12 |
+
# Install Hamer
|
| 13 |
+
cd submodules/phantom-hamer
|
| 14 |
+
pip install -e .\[all\]
|
| 15 |
+
pip install -v -e third-party/ViTPose
|
| 16 |
+
wget https://www.cs.utexas.edu/~pavlakos/hamer/data/hamer_demo_data.tar.gz
|
| 17 |
+
tar --warning=no-unknown-keyword --exclude=".*" -xvf hamer_demo_data.tar.gz
|
| 18 |
+
cd ../..
|
| 19 |
+
|
| 20 |
+
# Install mmcv
|
| 21 |
+
pip install --index-url https://download.pytorch.org/whl/cu121 torch==2.1.0 torchvision==0.16.0
|
| 22 |
+
pip install mmcv==1.3.9
|
| 23 |
+
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.1/index.html
|
| 24 |
+
pip install numpy==1.26.4
|
| 25 |
+
|
| 26 |
+
# Install phantom-robosuite
|
| 27 |
+
cd submodules/phantom-robosuite
|
| 28 |
+
pip install -e .
|
| 29 |
+
cd ../..
|
| 30 |
+
|
| 31 |
+
# Install phantom-robomimic
|
| 32 |
+
cd submodules/phantom-robomimic
|
| 33 |
+
pip install -e .
|
| 34 |
+
cd ../..
|
| 35 |
+
|
| 36 |
+
# Install additional packages
|
| 37 |
+
pip install joblib mediapy open3d pandas
|
| 38 |
+
pip install transformers==4.42.4
|
| 39 |
+
pip install PyOpenGL==3.1.4
|
| 40 |
+
pip install Rtree
|
| 41 |
+
pip install git+https://github.com/epic-kitchens/epic-kitchens-100-hand-object-bboxes.git
|
| 42 |
+
pip install protobuf==3.20.0
|
| 43 |
+
pip install hydra-core==1.3.2
|
| 44 |
+
pip install omegaconf==2.3.0
|
| 45 |
+
|
| 46 |
+
# Download E2FGVI weights
|
| 47 |
+
cd submodules/phantom-E2FGVI/E2FGVI/release_model/
|
| 48 |
+
pip install gdown
|
| 49 |
+
gdown --fuzzy https://drive.google.com/file/d/10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3/view?usp=sharing
|
| 50 |
+
cd ../..
|
| 51 |
+
|
| 52 |
+
# Install phantom-E2FGVI
|
| 53 |
+
pip install -e .
|
| 54 |
+
cd ../..
|
| 55 |
+
|
| 56 |
+
# Install phantom
|
| 57 |
+
pip install -e .
|
| 58 |
+
|
| 59 |
+
# Download sample data
|
| 60 |
+
cd data/raw
|
| 61 |
+
wget https://download.cs.stanford.edu/juno/phantom/pick_and_place.zip
|
| 62 |
+
unzip pick_and_place.zip
|
| 63 |
+
rm pick_and_place.zip
|
| 64 |
+
wget https://download.cs.stanford.edu/juno/phantom/epic.zip
|
| 65 |
+
unzip epic.zip
|
| 66 |
+
rm epic.zip
|
| 67 |
+
cd ../..
|
phantom/phantom/__init__.py
ADDED
|
File without changes
|
phantom/phantom/camera/__init__.py
ADDED
|
File without changes
|
phantom/phantom/camera/camera_extrinsics.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"camera_base_ori": [
|
| 4 |
+
[
|
| 5 |
+
0.9842690634302423,
|
| 6 |
+
-0.053375086066005106,
|
| 7 |
+
0.1684206369825258
|
| 8 |
+
],
|
| 9 |
+
[
|
| 10 |
+
-0.1763762231197722,
|
| 11 |
+
-0.35235905397979306,
|
| 12 |
+
0.9190944048336218
|
| 13 |
+
],
|
| 14 |
+
[
|
| 15 |
+
0.010287793357058851,
|
| 16 |
+
-0.934341584895969,
|
| 17 |
+
-0.3562302121408726
|
| 18 |
+
]
|
| 19 |
+
],
|
| 20 |
+
"camera_base_ori_rotvec": [
|
| 21 |
+
-1.930138005212092,
|
| 22 |
+
0.16467696378244215,
|
| 23 |
+
-0.12809137765065973
|
| 24 |
+
],
|
| 25 |
+
"camera_base_pos": [
|
| 26 |
+
0.3407932803063093,
|
| 27 |
+
-0.40868423448040403,
|
| 28 |
+
0.39911982578151795
|
| 29 |
+
],
|
| 30 |
+
"camera_base_quat": [
|
| 31 |
+
0.8204965462375373,
|
| 32 |
+
-0.07000374049084156,
|
| 33 |
+
0.054451304871138306,
|
| 34 |
+
-0.564729979129313
|
| 35 |
+
],
|
| 36 |
+
"p_marker_ee": [
|
| 37 |
+
-0.01874144739551215,
|
| 38 |
+
0.029611448317719172,
|
| 39 |
+
-0.013687685723932594
|
| 40 |
+
]
|
| 41 |
+
}
|
| 42 |
+
]
|
phantom/phantom/camera/camera_extrinsics_ego_bimanual_shoulders.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"num_marker_seen": 114,
|
| 4 |
+
"stage2_retry": 11,
|
| 5 |
+
"pixel_error": 2.1157278874907863,
|
| 6 |
+
"proj_func": "hand_marker_proj_world_camera",
|
| 7 |
+
"intrinsics": {
|
| 8 |
+
"fx": 731.4708862304688,
|
| 9 |
+
"fy": 731.4708862304688,
|
| 10 |
+
"ppx": 646.266357421875,
|
| 11 |
+
"ppy": 355.9967956542969
|
| 12 |
+
},
|
| 13 |
+
"camera_base_ori": [
|
| 14 |
+
[
|
| 15 |
+
-0.7220417114840215,
|
| 16 |
+
0.37764981440725887,
|
| 17 |
+
0.579686453658689
|
| 18 |
+
],
|
| 19 |
+
[
|
| 20 |
+
0.020370475586732495,
|
| 21 |
+
0.8491206965938227,
|
| 22 |
+
-0.527805917303316
|
| 23 |
+
],
|
| 24 |
+
[
|
| 25 |
+
-0.6915495720493177,
|
| 26 |
+
-0.3692893991088662,
|
| 27 |
+
-0.6207934673498243
|
| 28 |
+
]
|
| 29 |
+
],
|
| 30 |
+
"camera_base_ori_rotvec": [
|
| 31 |
+
0.2877344548443808,
|
| 32 |
+
2.3075097094104504,
|
| 33 |
+
-0.6485227972051454
|
| 34 |
+
],
|
| 35 |
+
"camera_base_pos": [
|
| 36 |
+
-0.5123627783256401,
|
| 37 |
+
-0.11387480700266536,
|
| 38 |
+
0.3151264229148423
|
| 39 |
+
],
|
| 40 |
+
"p_marker_ee": [
|
| 41 |
+
-0.041990731174163416,
|
| 42 |
+
-0.02636865486252487,
|
| 43 |
+
-0.01442948433864288
|
| 44 |
+
],
|
| 45 |
+
"camera_base_quat": [
|
| 46 |
+
0.11139014686225811,
|
| 47 |
+
0.8933022830245745,
|
| 48 |
+
-0.25106152012025673,
|
| 49 |
+
0.35576871621882866
|
| 50 |
+
]
|
| 51 |
+
}
|
| 52 |
+
]
|
phantom/phantom/camera/camera_intrinsics_HD1080.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"left": {
|
| 3 |
+
"fx": 1057.7322998046875,
|
| 4 |
+
"fy": 1057.7322998046875,
|
| 5 |
+
"cx": 972.5150756835938,
|
| 6 |
+
"cy": 552.568359375,
|
| 7 |
+
"disto": [
|
| 8 |
+
0.0,
|
| 9 |
+
0.0,
|
| 10 |
+
0.0,
|
| 11 |
+
0.0,
|
| 12 |
+
0.0,
|
| 13 |
+
0.0,
|
| 14 |
+
0.0,
|
| 15 |
+
0.0,
|
| 16 |
+
0.0,
|
| 17 |
+
0.0,
|
| 18 |
+
0.0,
|
| 19 |
+
0.0
|
| 20 |
+
],
|
| 21 |
+
"v_fov": 54.09259796142578,
|
| 22 |
+
"h_fov": 84.45639038085938,
|
| 23 |
+
"d_fov": 92.32276916503906
|
| 24 |
+
},
|
| 25 |
+
"right": {
|
| 26 |
+
"fx": 1057.7322998046875,
|
| 27 |
+
"fy": 1057.7322998046875,
|
| 28 |
+
"cx": 972.5150756835938,
|
| 29 |
+
"cy": 552.568359375,
|
| 30 |
+
"disto": [
|
| 31 |
+
0.0,
|
| 32 |
+
0.0,
|
| 33 |
+
0.0,
|
| 34 |
+
0.0,
|
| 35 |
+
0.0,
|
| 36 |
+
0.0,
|
| 37 |
+
0.0,
|
| 38 |
+
0.0,
|
| 39 |
+
0.0,
|
| 40 |
+
0.0,
|
| 41 |
+
0.0,
|
| 42 |
+
0.0
|
| 43 |
+
],
|
| 44 |
+
"v_fov": 54.09259796142578,
|
| 45 |
+
"h_fov": 84.45639038085938,
|
| 46 |
+
"d_fov": 92.32276916503906
|
| 47 |
+
}
|
| 48 |
+
}
|
phantom/phantom/camera/camera_intrinsics_epic.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"left": {
|
| 3 |
+
"fx": 248.7892127911359,
|
| 4 |
+
"fy": 248.7892127911359,
|
| 5 |
+
"cx": 228,
|
| 6 |
+
"cy": 128,
|
| 7 |
+
"disto": [
|
| 8 |
+
0.0,
|
| 9 |
+
0.0,
|
| 10 |
+
0.0,
|
| 11 |
+
0.0,
|
| 12 |
+
0.0,
|
| 13 |
+
0.0,
|
| 14 |
+
0.0,
|
| 15 |
+
0.0,
|
| 16 |
+
0.0,
|
| 17 |
+
0.0,
|
| 18 |
+
0.0,
|
| 19 |
+
0.0
|
| 20 |
+
],
|
| 21 |
+
"v_fov": 54.6,
|
| 22 |
+
"h_fov": 83.21271514892578,
|
| 23 |
+
"d_fov": 91.07240295410156
|
| 24 |
+
},
|
| 25 |
+
"right": {
|
| 26 |
+
"fx": 248.7892127911359,
|
| 27 |
+
"fy": 248.7892127911359,
|
| 28 |
+
"cx": 228,
|
| 29 |
+
"cy": 128,
|
| 30 |
+
"disto": [
|
| 31 |
+
0.0,
|
| 32 |
+
0.0,
|
| 33 |
+
0.0,
|
| 34 |
+
0.0,
|
| 35 |
+
0.0,
|
| 36 |
+
0.0,
|
| 37 |
+
0.0,
|
| 38 |
+
0.0,
|
| 39 |
+
0.0,
|
| 40 |
+
0.0,
|
| 41 |
+
0.0,
|
| 42 |
+
0.0
|
| 43 |
+
],
|
| 44 |
+
"v_fov": 54.6,
|
| 45 |
+
"h_fov": 83.21271514892578,
|
| 46 |
+
"d_fov": 91.07240295410156
|
| 47 |
+
}
|
| 48 |
+
}
|
phantom/phantom/detectors/detector_detectron2.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wrapper around detectron2 for object detection
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Tuple
|
| 8 |
+
import cv2
|
| 9 |
+
import logging
|
| 10 |
+
import mediapy as media
|
| 11 |
+
import requests
|
| 12 |
+
import hamer # type: ignore
|
| 13 |
+
from hamer.utils.utils_detectron2 import DefaultPredictor_Lazy # type: ignore
|
| 14 |
+
from detectron2.config import LazyConfig # type: ignore
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
def download_detectron_ckpt(root_dir: str, ckpt_path: str) -> None:
|
| 19 |
+
url = "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl"
|
| 20 |
+
save_path = Path(root_dir, ckpt_path)
|
| 21 |
+
save_path.parent.mkdir(exist_ok=True, parents=True)
|
| 22 |
+
response = requests.get(url, stream=True)
|
| 23 |
+
if response.status_code == 200:
|
| 24 |
+
with open(save_path, "wb") as file:
|
| 25 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 26 |
+
file.write(chunk)
|
| 27 |
+
logger.info(f"File downloaded successfully and saved to {save_path}")
|
| 28 |
+
else:
|
| 29 |
+
logger.info(f"Failed to download the file. Status code: {response.status_code}")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DetectorDetectron2:
|
| 33 |
+
def __init__(self, root_dir: str):
|
| 34 |
+
cfg_path = (Path(hamer.__file__).parent / "configs" / "cascade_mask_rcnn_vitdet_h_75ep.py")
|
| 35 |
+
detectron2_cfg = LazyConfig.load(str(cfg_path))
|
| 36 |
+
|
| 37 |
+
detectron2_cfg.train.init_checkpoint = os.path.join(
|
| 38 |
+
root_dir, "_DATA/detectron_ckpts/model_final_f05665.pkl"
|
| 39 |
+
)
|
| 40 |
+
if not os.path.exists(detectron2_cfg.train.init_checkpoint):
|
| 41 |
+
download_detectron_ckpt(
|
| 42 |
+
root_dir, "_DATA/detectron_ckpts/model_final_f05665.pkl"
|
| 43 |
+
)
|
| 44 |
+
for predictor in detectron2_cfg.model.roi_heads.box_predictors:
|
| 45 |
+
predictor.test_score_thresh = 0.25
|
| 46 |
+
self.detectron2 = DefaultPredictor_Lazy(detectron2_cfg)
|
| 47 |
+
|
| 48 |
+
def get_bboxes(self, img: np.ndarray, visualize: bool=False,
|
| 49 |
+
visualize_wait: bool=True) -> Tuple[np.ndarray, np.ndarray]:
|
| 50 |
+
""" Get bounding boxes and scores for the detected hand in the image """
|
| 51 |
+
det_out = self.detectron2(img)
|
| 52 |
+
|
| 53 |
+
det_instances = det_out["instances"]
|
| 54 |
+
valid_idx = (det_instances.pred_classes == 0) & (det_instances.scores > 0.5)
|
| 55 |
+
pred_bboxes = det_instances.pred_boxes.tensor[valid_idx].cpu().numpy()
|
| 56 |
+
pred_scores = det_instances.scores[valid_idx].cpu().numpy()
|
| 57 |
+
|
| 58 |
+
if visualize:
|
| 59 |
+
img_rgb = img.copy()
|
| 60 |
+
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
| 61 |
+
for bbox, score in zip(pred_bboxes, pred_scores):
|
| 62 |
+
cv2.rectangle(
|
| 63 |
+
img_bgr,
|
| 64 |
+
(int(bbox[0]), int(bbox[1])),
|
| 65 |
+
(int(bbox[2]), int(bbox[3])),
|
| 66 |
+
(0, 255, 0),
|
| 67 |
+
2,
|
| 68 |
+
)
|
| 69 |
+
cv2.putText(img_bgr,
|
| 70 |
+
f"{score:.4f}",
|
| 71 |
+
(int(bbox[0]), int(bbox[1])),
|
| 72 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 73 |
+
1,
|
| 74 |
+
(0, 255, 0),
|
| 75 |
+
2,
|
| 76 |
+
cv2.LINE_AA)
|
| 77 |
+
|
| 78 |
+
cv2.imshow(f"Detected bounding boxes", img_bgr)
|
| 79 |
+
if visualize_wait:
|
| 80 |
+
cv2.waitKey(0)
|
| 81 |
+
else:
|
| 82 |
+
cv2.waitKey(1)
|
| 83 |
+
|
| 84 |
+
return pred_bboxes, pred_scores
|
| 85 |
+
|
| 86 |
+
def get_best_bbox(self, img: np.ndarray, visualize: bool=False,
|
| 87 |
+
visualize_wait: bool=True) -> Tuple[np.ndarray, float]:
|
| 88 |
+
""" Get the best bounding box and score for the detected hand in the image """
|
| 89 |
+
bboxes, scores = self.get_bboxes(img)
|
| 90 |
+
if len(bboxes) == 0:
|
| 91 |
+
logger.info("No bbox found with Detectron")
|
| 92 |
+
return np.array([]), 0
|
| 93 |
+
best_idx = scores.argmax()
|
| 94 |
+
best_bbox, best_score = bboxes[best_idx], scores[best_idx]
|
| 95 |
+
|
| 96 |
+
if visualize:
|
| 97 |
+
img_rgb = img.copy()
|
| 98 |
+
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
| 99 |
+
cv2.rectangle(
|
| 100 |
+
img_bgr,
|
| 101 |
+
(int(best_bbox[0]), int(best_bbox[1])),
|
| 102 |
+
(int(best_bbox[2]), int(best_bbox[3])),
|
| 103 |
+
(0, 255, 0),
|
| 104 |
+
2,
|
| 105 |
+
)
|
| 106 |
+
cv2.putText(img_bgr,
|
| 107 |
+
f"{best_score:.4f}",
|
| 108 |
+
(int(best_bbox[0]), int(best_bbox[1])),
|
| 109 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 110 |
+
1,
|
| 111 |
+
(0, 255, 0),
|
| 112 |
+
2,
|
| 113 |
+
cv2.LINE_AA)
|
| 114 |
+
|
| 115 |
+
cv2.imshow(f"Best detected bounding box", img_bgr)
|
| 116 |
+
if visualize_wait:
|
| 117 |
+
cv2.waitKey(0)
|
| 118 |
+
else:
|
| 119 |
+
cv2.waitKey(1)
|
| 120 |
+
|
| 121 |
+
return best_bbox, best_score
|
phantom/phantom/detectors/detector_dino.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wrapper around DINO-V2 for object detection
|
| 3 |
+
"""
|
| 4 |
+
from typing import Sequence, Tuple, Optional
|
| 5 |
+
import numpy as np
|
| 6 |
+
from transformers import pipeline # type: ignore
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import cv2
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from phantom.utils.image_utils import DetectionResult
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
class DetectorDino:
|
| 16 |
+
def __init__(self, detector_id: str):
|
| 17 |
+
self.detector = pipeline(
|
| 18 |
+
model=detector_id,
|
| 19 |
+
task="zero-shot-object-detection",
|
| 20 |
+
device="cuda",
|
| 21 |
+
batch_size=4,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def get_bboxes(self, frame: np.ndarray, object_name: str, threshold: float = 0.4,
|
| 25 |
+
visualize: bool = False, pause_visualization: bool = True) -> Tuple[np.ndarray, np.ndarray]:
|
| 26 |
+
"""
|
| 27 |
+
Detect objects in a frame and return their bounding boxes and confidence scores.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
frame: Input image as numpy array in RGB format
|
| 31 |
+
object_name: Target object category to detect
|
| 32 |
+
threshold: Confidence threshold for detection (0.0-1.0)
|
| 33 |
+
visualize: If True, displays detection results visually
|
| 34 |
+
pause_visualization: If True, waits for key press when visualizing
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Tuple of (bounding_boxes, confidence_scores) as numpy arrays
|
| 38 |
+
Empty arrays if no objects detected
|
| 39 |
+
"""
|
| 40 |
+
img_pil = Image.fromarray(frame)
|
| 41 |
+
labels = [f"{object_name}."]
|
| 42 |
+
results = self.detector(img_pil, candidate_labels=labels, threshold=threshold)
|
| 43 |
+
results = [DetectionResult.from_dict(result) for result in results]
|
| 44 |
+
if not results:
|
| 45 |
+
return np.array([]), np.array([])
|
| 46 |
+
bboxes = np.array([np.array(result.box.xyxy) for result in results])
|
| 47 |
+
scores = np.array([result.score for result in results])
|
| 48 |
+
|
| 49 |
+
if visualize:
|
| 50 |
+
img_rgb = frame.copy()
|
| 51 |
+
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
| 52 |
+
for bbox, score in zip(bboxes, scores):
|
| 53 |
+
cv2.rectangle(
|
| 54 |
+
img_bgr,
|
| 55 |
+
(int(bbox[0]), int(bbox[1])),
|
| 56 |
+
(int(bbox[2]), int(bbox[3])),
|
| 57 |
+
(0, 255, 0),
|
| 58 |
+
2,
|
| 59 |
+
)
|
| 60 |
+
cv2.putText(img_bgr,
|
| 61 |
+
f"{score:.4f}",
|
| 62 |
+
(int(bbox[0]), int(bbox[1])),
|
| 63 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 64 |
+
1,
|
| 65 |
+
(0, 255, 0),
|
| 66 |
+
2,
|
| 67 |
+
cv2.LINE_AA)
|
| 68 |
+
cv2.imshow("Detection", img_bgr)
|
| 69 |
+
if pause_visualization:
|
| 70 |
+
cv2.waitKey(0)
|
| 71 |
+
else:
|
| 72 |
+
cv2.waitKey(1)
|
| 73 |
+
return bboxes, scores
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_best_bbox(self, frame: np.ndarray, object_name: str, threshold: float = 0.4,
|
| 77 |
+
visualize: bool = False, pause_visualization: bool = True) -> Optional[np.ndarray]:
|
| 78 |
+
bboxes, scores = self.get_bboxes(frame, object_name, threshold)
|
| 79 |
+
if len(bboxes) == 0:
|
| 80 |
+
return None
|
| 81 |
+
best_idx = np.array(scores).argmax()
|
| 82 |
+
best_bbox, best_score = bboxes[best_idx], scores[best_idx]
|
| 83 |
+
|
| 84 |
+
if visualize:
|
| 85 |
+
img_rgb = frame.copy()
|
| 86 |
+
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
| 87 |
+
cv2.rectangle(
|
| 88 |
+
img_bgr,
|
| 89 |
+
(best_bbox[0], best_bbox[1]),
|
| 90 |
+
(best_bbox[2], best_bbox[3]),
|
| 91 |
+
(0, 255, 0),
|
| 92 |
+
2,
|
| 93 |
+
)
|
| 94 |
+
cv2.putText(img_bgr,
|
| 95 |
+
f"{best_score:.4f}",
|
| 96 |
+
(int(best_bbox[0]), int(best_bbox[1])),
|
| 97 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 98 |
+
1,
|
| 99 |
+
(0, 255, 0),
|
| 100 |
+
2,
|
| 101 |
+
cv2.LINE_AA)
|
| 102 |
+
cv2.imshow("Detection", img_bgr)
|
| 103 |
+
if pause_visualization:
|
| 104 |
+
cv2.waitKey(0)
|
| 105 |
+
else:
|
| 106 |
+
cv2.waitKey(1)
|
| 107 |
+
return best_bbox
|
| 108 |
+
|
phantom/phantom/detectors/detector_hamer.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wrapper around HaMeR for hand pose estimation
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
import numpy as np
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
import torch
|
| 12 |
+
from hamer.utils import recursive_to # type: ignore
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
|
| 15 |
+
from hamer.models import HAMER, DEFAULT_CHECKPOINT # type: ignore
|
| 16 |
+
import sys
|
| 17 |
+
import os
|
| 18 |
+
# Add the phantom-hamer directory to Python path for vitpose_model import
|
| 19 |
+
hamer_path = os.path.join(os.path.dirname(__file__), '..', '..', 'submodules', 'phantom-hamer')
|
| 20 |
+
if hamer_path not in sys.path:
|
| 21 |
+
sys.path.insert(0, hamer_path)
|
| 22 |
+
from vitpose_model import ViTPoseModel # type: ignore
|
| 23 |
+
from hamer.datasets.vitdet_dataset import ViTDetDataset # type: ignore
|
| 24 |
+
from hamer.utils.renderer import cam_crop_to_full # type: ignore
|
| 25 |
+
from hamer.utils.geometry import perspective_projection # type: ignore
|
| 26 |
+
from hamer.configs import get_config # type: ignore
|
| 27 |
+
from yacs.config import CfgNode as CN # type: ignore
|
| 28 |
+
|
| 29 |
+
from phantom.utils.data_utils import get_parent_folder_of_package
|
| 30 |
+
|
| 31 |
+
logger = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
THUMB_VERTEX = 756
|
| 34 |
+
INDEX_FINGER_VERTEX = 350
|
| 35 |
+
|
| 36 |
+
class DetectorHamer:
|
| 37 |
+
"""
|
| 38 |
+
Detector using the HaMeR model for 3D hand pose estimation.
|
| 39 |
+
|
| 40 |
+
The detection pipeline consists of:
|
| 41 |
+
- Initial hand detection using general object detectors
|
| 42 |
+
- Hand type classification (left/right) using ViTPose
|
| 43 |
+
- 3D pose estimation using HaMeR
|
| 44 |
+
- MANO parameters estimation for mesh reconstruction
|
| 45 |
+
|
| 46 |
+
Dependencies:
|
| 47 |
+
- HaMeR model for 3D pose estimation
|
| 48 |
+
- ViTPose for keypoint detection
|
| 49 |
+
- DINO and Detectron2 for initial hand detection
|
| 50 |
+
"""
|
| 51 |
+
def __init__(self):
|
| 52 |
+
root_dir = get_parent_folder_of_package("hamer")
|
| 53 |
+
checkpoint_path = Path(root_dir, DEFAULT_CHECKPOINT)
|
| 54 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 55 |
+
|
| 56 |
+
self.rescale_factor = 2.0 # Factor for padding the box
|
| 57 |
+
self.batch_size = 1 # Batch size for inference
|
| 58 |
+
|
| 59 |
+
self.model, self.model_cfg = self.load_hamer_model(checkpoint_path, root_dir)
|
| 60 |
+
self.model.to(self.device)
|
| 61 |
+
self.model.eval()
|
| 62 |
+
|
| 63 |
+
root_dir = "../submodules/phantom-hamer/"
|
| 64 |
+
vit_dir = os.path.join(root_dir, "third-party/ViTPose/")
|
| 65 |
+
self.cpm = ViTPoseModel(device=self.device, root_dir=root_dir, vit_dir=vit_dir)
|
| 66 |
+
|
| 67 |
+
self.faces_right = self.model.mano.faces
|
| 68 |
+
self.faces_left = self.faces_right[:,[0,2,1]]
|
| 69 |
+
|
| 70 |
+
def detect_hand_keypoints(self,
|
| 71 |
+
img: np.ndarray,
|
| 72 |
+
hand_side: str,
|
| 73 |
+
visualize: bool=False,
|
| 74 |
+
visualize_3d: bool=False,
|
| 75 |
+
pause_visualization: bool=True,
|
| 76 |
+
bboxes: Optional[np.ndarray]=None,
|
| 77 |
+
is_right: Optional[np.ndarray]=None,
|
| 78 |
+
kpts_2d_only: Optional[bool]=False,
|
| 79 |
+
camera_params: Optional[dict]=None) -> Optional[dict]:
|
| 80 |
+
"""
|
| 81 |
+
Detect hand keypoints in the input image.
|
| 82 |
+
|
| 83 |
+
The method performs the following steps:
|
| 84 |
+
1. Detect hand bounding boxes using object detectors
|
| 85 |
+
2. Optionally refine boxes using ViTPose to determine hand type (left/right)
|
| 86 |
+
3. Run HaMeR model to estimate 3D hand pose
|
| 87 |
+
4. Project 3D keypoints back to 2D for visualization
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
img: Input RGB image as numpy array
|
| 91 |
+
hand_side: Target hand side to detect (left or right)
|
| 92 |
+
visualize: If True, displays detection results in a window
|
| 93 |
+
visualize_3d: If True, shows 3D visualization of keypoints and mesh
|
| 94 |
+
pause_visualization: If True, waits for key press when visualizing
|
| 95 |
+
bboxes: Bounding boxes of the hands
|
| 96 |
+
is_right: Whether the hand is right
|
| 97 |
+
kpts_2d_only: If True, only cares about 2D keypoints, i.e., use default
|
| 98 |
+
focal length in HaMeR instead of real camera intrinsics
|
| 99 |
+
camera_params: Optional camera intrinsics (fx, fy, cx, cy)
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Dictionary containing:
|
| 103 |
+
- annotated_img: Image with keypoints drawn
|
| 104 |
+
- success: Whether detection was successful (21 keypoints found)
|
| 105 |
+
- kpts_3d: 3D keypoints in camera space
|
| 106 |
+
- kpts_2d: 2D keypoints projected onto image
|
| 107 |
+
- verts: 3D mesh vertices
|
| 108 |
+
- T_cam_pred: Camera transformation matrix
|
| 109 |
+
- Various camera parameters and MANO pose parameters
|
| 110 |
+
"""
|
| 111 |
+
if not kpts_2d_only:
|
| 112 |
+
scaled_focal_length, camera_center = self.get_image_params(img, camera_params)
|
| 113 |
+
else:
|
| 114 |
+
scaled_focal_length, camera_center = self.get_image_params(img, camera_params=None)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
dataset = ViTDetDataset(self.model_cfg, img, bboxes, is_right, rescale_factor=self.rescale_factor)
|
| 118 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0)
|
| 119 |
+
|
| 120 |
+
list_2d_kpts, list_3d_kpts, list_verts = [], [], []
|
| 121 |
+
T_cam_pred_all: list[torch.Tensor] = []
|
| 122 |
+
list_global_orient = []
|
| 123 |
+
kpts_2d_hamer = None
|
| 124 |
+
for batch in dataloader:
|
| 125 |
+
batch = recursive_to(batch, "cuda")
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
out = self.model(batch)
|
| 128 |
+
|
| 129 |
+
batch_T_cam_pred_all = DetectorHamer.get_all_T_cam_pred(batch, out, scaled_focal_length)
|
| 130 |
+
|
| 131 |
+
for idx in range(len(batch_T_cam_pred_all)):
|
| 132 |
+
kpts_3d = out["pred_keypoints_3d"][idx].detach().cpu().numpy() # [21, 3]
|
| 133 |
+
verts = out["pred_vertices"][idx].detach().cpu().numpy() # [778, 3]
|
| 134 |
+
is_right = batch["right"][idx].cpu().numpy()
|
| 135 |
+
global_orient = out["pred_mano_params"]["global_orient"][idx].detach().cpu().numpy()
|
| 136 |
+
hand_pose = out["pred_mano_params"]["hand_pose"][idx].detach().cpu().numpy()
|
| 137 |
+
list_global_orient.append(global_orient)
|
| 138 |
+
|
| 139 |
+
if hand_side == "left":
|
| 140 |
+
kpts_3d, verts = DetectorHamer.convert_right_hand_keypoints_to_left_hand(kpts_3d, verts)
|
| 141 |
+
|
| 142 |
+
T_cam_pred = batch_T_cam_pred_all[idx]
|
| 143 |
+
|
| 144 |
+
img_w, img_h = batch["img_size"][idx].float()
|
| 145 |
+
|
| 146 |
+
kpts_2d_hamer = DetectorHamer.project_3d_kpt_to_2d(kpts_3d, img_w, img_h, scaled_focal_length,
|
| 147 |
+
camera_center, T_cam_pred)
|
| 148 |
+
|
| 149 |
+
# Keep T_cam_pred as tensor
|
| 150 |
+
list_2d_kpts.append(kpts_2d_hamer)
|
| 151 |
+
list_3d_kpts.append(kpts_3d + T_cam_pred.cpu().numpy())
|
| 152 |
+
list_verts.append(verts + T_cam_pred.cpu().numpy())
|
| 153 |
+
|
| 154 |
+
T_cam_pred_all += batch_T_cam_pred_all
|
| 155 |
+
|
| 156 |
+
annotated_img = DetectorHamer.visualize_2d_kpt_on_img(
|
| 157 |
+
kpts_2d=list_2d_kpts[0],
|
| 158 |
+
img=img,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
if visualize:
|
| 162 |
+
if bboxes is not None:
|
| 163 |
+
cv2.rectangle(annotated_img, (int(bboxes[0][0]), int(bboxes[0][1])), (int(bboxes[0][2]), int(bboxes[0][3])), (0, 255, 0), 2)
|
| 164 |
+
cv2.imshow("Annotated Image", annotated_img)
|
| 165 |
+
cv2.waitKey(0 if pause_visualization else 1)
|
| 166 |
+
|
| 167 |
+
if visualize_3d:
|
| 168 |
+
DetectorHamer.visualize_keypoints_3d(annotated_img, list_3d_kpts[0], list_verts[0])
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
return {
|
| 172 |
+
"annotated_img": annotated_img,
|
| 173 |
+
"success": len(list_2d_kpts[0]) == 21,
|
| 174 |
+
"kpts_3d": list_3d_kpts[0],
|
| 175 |
+
"kpts_2d": np.rint(list_2d_kpts[0]).astype(np.int32),
|
| 176 |
+
"verts": list_verts[0],
|
| 177 |
+
"T_cam_pred": T_cam_pred_all[0],
|
| 178 |
+
"scaled_focal_length": scaled_focal_length,
|
| 179 |
+
"camera_center": camera_center,
|
| 180 |
+
"img_w": img_w,
|
| 181 |
+
"img_h": img_h,
|
| 182 |
+
"global_orient": list_global_orient[0],
|
| 183 |
+
"hand_pose": hand_pose,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
def get_image_params(self, img: np.ndarray, camera_params: Optional[dict]) -> Tuple[float, torch.Tensor]:
|
| 187 |
+
"""
|
| 188 |
+
Get the scaled focal length and camera center.
|
| 189 |
+
"""
|
| 190 |
+
img_w = img.shape[1]
|
| 191 |
+
img_h = img.shape[0]
|
| 192 |
+
if camera_params is not None:
|
| 193 |
+
scaled_focal_length = camera_params["fx"]
|
| 194 |
+
cx = camera_params["cx"]
|
| 195 |
+
cy = camera_params["cy"]
|
| 196 |
+
camera_center = torch.tensor([img_w-cx, img_h-cy])
|
| 197 |
+
else:
|
| 198 |
+
scaled_focal_length = (self.model_cfg.EXTRA.FOCAL_LENGTH / self.model_cfg.MODEL.IMAGE_SIZE
|
| 199 |
+
* max(img_w, img_h))
|
| 200 |
+
camera_center = torch.tensor([img_w, img_h], dtype=torch.float).reshape(1, 2) / 2.0
|
| 201 |
+
return scaled_focal_length, camera_center
|
| 202 |
+
|
| 203 |
+
@staticmethod
|
| 204 |
+
def convert_right_hand_keypoints_to_left_hand(kpts, verts):
|
| 205 |
+
"""
|
| 206 |
+
Convert right hand keypoints/vertices to left hand by mirroring across the Y-Z plane.
|
| 207 |
+
|
| 208 |
+
This is done by flipping the X coordinates of both keypoints and vertices.
|
| 209 |
+
The MANO model internally uses right hand, so this conversion is needed
|
| 210 |
+
when processing left hands.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
kpts: 3D keypoints [21, 3]
|
| 214 |
+
verts: 3D mesh vertices [778, 3]
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Transformed keypoints and vertices
|
| 218 |
+
"""
|
| 219 |
+
kpts[:,0] = -kpts[:,0]
|
| 220 |
+
verts[:,0] = -verts[:,0]
|
| 221 |
+
return kpts, verts
|
| 222 |
+
|
| 223 |
+
@staticmethod
|
| 224 |
+
def visualize_keypoints_3d(annotated_img: np.ndarray, kpts_3d: np.ndarray, verts: np.ndarray) -> None:
|
| 225 |
+
nfingers = len(kpts_3d) - 1
|
| 226 |
+
npts_per_finger = 4
|
| 227 |
+
list_fingers = [np.vstack([kpts_3d[0], kpts_3d[i:i + npts_per_finger]]) for i in range(1, nfingers, npts_per_finger)]
|
| 228 |
+
finger_colors_bgr = [(0, 255, 0), (0, 0, 255), (255, 0, 0), (255, 0, 255), (0, 255, 255)]
|
| 229 |
+
finger_colors_rgb = [(color[2], color[1], color[0]) for color in finger_colors_bgr]
|
| 230 |
+
fig, axs = plt.subplots(1,2, figsize=(20, 10))
|
| 231 |
+
axs[0] = fig.add_subplot(111, projection='3d')
|
| 232 |
+
for finger_idx, finger_pts in enumerate(list_fingers):
|
| 233 |
+
for i in range(len(finger_pts) - 1):
|
| 234 |
+
color = finger_colors_rgb[finger_idx]
|
| 235 |
+
axs[0].plot(
|
| 236 |
+
[finger_pts[i][0], finger_pts[i + 1][0]],
|
| 237 |
+
[finger_pts[i][1], finger_pts[i + 1][1]],
|
| 238 |
+
[finger_pts[i][2], finger_pts[i + 1][2]],
|
| 239 |
+
color=np.array(color)/255.0,
|
| 240 |
+
)
|
| 241 |
+
axs[0].scatter(kpts_3d[:, 0], kpts_3d[:, 1], kpts_3d[:, 2])
|
| 242 |
+
axs[0].scatter(verts[:, 0], verts[:, 1], verts[:, 2])
|
| 243 |
+
annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
|
| 244 |
+
axs[1].imshow(annotated_img_rgb)
|
| 245 |
+
|
| 246 |
+
fig = plt.figure()
|
| 247 |
+
ax = fig.add_subplot(111)
|
| 248 |
+
ax.imshow(annotated_img_rgb)
|
| 249 |
+
|
| 250 |
+
plt.show()
|
| 251 |
+
|
| 252 |
+
@staticmethod
|
| 253 |
+
def get_all_T_cam_pred(batch: dict, out: dict, scaled_focal_length: float) -> torch.Tensor:
|
| 254 |
+
"""
|
| 255 |
+
Get the camera transformation matrix
|
| 256 |
+
"""
|
| 257 |
+
multiplier = 2 * batch["right"] - 1
|
| 258 |
+
pred_cam = out["pred_cam"]
|
| 259 |
+
pred_cam[:, 1] = multiplier * pred_cam[:, 1]
|
| 260 |
+
box_center = batch["box_center"].float()
|
| 261 |
+
box_size = batch["box_size"].float()
|
| 262 |
+
# NOTE: FOR HaMeR, they are using the img_size as (W, H)
|
| 263 |
+
W_H_shapes = batch["img_size"].float()
|
| 264 |
+
|
| 265 |
+
multiplier = 2 * batch["right"] - 1
|
| 266 |
+
T_cam_pred_all = cam_crop_to_full(
|
| 267 |
+
pred_cam, box_center, box_size, W_H_shapes, scaled_focal_length
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
return T_cam_pred_all
|
| 271 |
+
|
| 272 |
+
@staticmethod
|
| 273 |
+
def visualize_2d_kpt_on_img(kpts_2d: np.ndarray, img: np.ndarray) -> np.ndarray:
|
| 274 |
+
"""
|
| 275 |
+
Plot 2D hand keypoints on the image with finger connections.
|
| 276 |
+
|
| 277 |
+
Each finger is drawn with a different color:
|
| 278 |
+
- Thumb: Green
|
| 279 |
+
- Index: Blue
|
| 280 |
+
- Middle: Red
|
| 281 |
+
- Ring: Magenta
|
| 282 |
+
- Pinky: Cyan
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
kpts_2d: 2D keypoints as integers [21, 2]
|
| 286 |
+
img: Input RGB image
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
Image with keypoints and connections drawn (BGR format)
|
| 290 |
+
"""
|
| 291 |
+
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 292 |
+
pts = kpts_2d.astype(np.int32)
|
| 293 |
+
nfingers = len(pts) - 1
|
| 294 |
+
npts_per_finger = 4
|
| 295 |
+
list_fingers = [np.vstack([pts[0], pts[i:i + npts_per_finger]]) for i in range(1, nfingers, npts_per_finger)]
|
| 296 |
+
finger_colors = [(0, 255, 0), (0, 0, 255), (255, 0, 0), (255, 0, 255), (0, 255, 255)]
|
| 297 |
+
thickness = 5 if img_bgr.shape[0] > 1000 else 2
|
| 298 |
+
for finger_idx, finger_pts in enumerate(list_fingers):
|
| 299 |
+
for i in range(len(finger_pts) - 1):
|
| 300 |
+
color = finger_colors[finger_idx]
|
| 301 |
+
cv2.line(
|
| 302 |
+
img_bgr,
|
| 303 |
+
tuple(finger_pts[i]),
|
| 304 |
+
tuple(finger_pts[i + 1]),
|
| 305 |
+
color,
|
| 306 |
+
thickness=thickness,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
cv2.line(img_bgr, [1787, 1522], [1656,1400], (255,0,0), thickness=thickness)
|
| 310 |
+
|
| 311 |
+
for pt in pts:
|
| 312 |
+
cv2.circle(img_bgr, (pt[0], pt[1]), radius=thickness, color=(0,0,0), thickness=thickness-1)
|
| 313 |
+
|
| 314 |
+
return img_bgr
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
@staticmethod
|
| 318 |
+
def project_3d_kpt_to_2d(kpts_3d: torch.Tensor, img_w: int, img_h: int, scaled_focal_length: float,
|
| 319 |
+
camera_center: torch.Tensor, T_cam: Optional[torch.Tensor] = None,) -> np.ndarray:
|
| 320 |
+
"""
|
| 321 |
+
Project 3D keypoints to 2D image coordinates using perspective projection.
|
| 322 |
+
"""
|
| 323 |
+
batch_size = 1
|
| 324 |
+
|
| 325 |
+
rotation = torch.eye(3).unsqueeze(0)
|
| 326 |
+
assert T_cam is not None
|
| 327 |
+
|
| 328 |
+
T_cam = T_cam.cpu()
|
| 329 |
+
kpts_3d = torch.tensor(kpts_3d).cpu()
|
| 330 |
+
|
| 331 |
+
T_cam = T_cam.clone().cuda()
|
| 332 |
+
kpts_3d = kpts_3d.clone().cuda()
|
| 333 |
+
rotation = rotation.cuda()
|
| 334 |
+
|
| 335 |
+
scaled_focal_length_full = torch.tensor([scaled_focal_length, scaled_focal_length]).reshape(1, 2)
|
| 336 |
+
|
| 337 |
+
# IMPORTANT: The perspective_projection function assumes T_cam has not been added to kpts_3d already!
|
| 338 |
+
kpts_2d = perspective_projection(
|
| 339 |
+
kpts_3d.reshape(batch_size, -1, 3),
|
| 340 |
+
rotation=rotation.repeat(batch_size, 1, 1),
|
| 341 |
+
translation=T_cam.reshape(batch_size, -1),
|
| 342 |
+
focal_length=scaled_focal_length_full.repeat(batch_size, 1),
|
| 343 |
+
camera_center=camera_center.repeat(batch_size, 1),
|
| 344 |
+
).reshape(batch_size, -1, 2)
|
| 345 |
+
kpts_2d = kpts_2d[0].cpu().numpy()
|
| 346 |
+
|
| 347 |
+
return np.rint(kpts_2d).astype(np.int32)
|
| 348 |
+
|
| 349 |
+
@staticmethod
|
| 350 |
+
def annotate_bboxes_on_img(img: np.ndarray, debug_bboxes: dict) -> np.ndarray:
|
| 351 |
+
"""
|
| 352 |
+
Annotate bounding boxes on the image.
|
| 353 |
+
|
| 354 |
+
:param img: Input image (numpy array)
|
| 355 |
+
:param debug_bboxes: Dictionary containing different sets of bounding boxes and optional scores
|
| 356 |
+
:return: Annotated image
|
| 357 |
+
"""
|
| 358 |
+
color_dict = {
|
| 359 |
+
"dino_bboxes": (0, 255, 0),
|
| 360 |
+
"det_bboxes": (0, 0, 255),
|
| 361 |
+
"refined_bboxes": (255, 0, 0),
|
| 362 |
+
"filtered_bboxes": (255, 255, 0),
|
| 363 |
+
}
|
| 364 |
+
corner_dict = {
|
| 365 |
+
"dino_bboxes": "top_left",
|
| 366 |
+
"det_bboxes": "top_right",
|
| 367 |
+
"refined_bboxes": "bottom_left",
|
| 368 |
+
"filtered_bboxes": "bottom_right",
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
def draw_bbox_and_label(bbox, label, color, label_pos, include_label=True):
|
| 372 |
+
""" Helper function to draw the bounding box and add label """
|
| 373 |
+
cv2.rectangle(
|
| 374 |
+
img,
|
| 375 |
+
(int(bbox[0]), int(bbox[1])),
|
| 376 |
+
(int(bbox[2]), int(bbox[3])),
|
| 377 |
+
color,
|
| 378 |
+
2,
|
| 379 |
+
)
|
| 380 |
+
if include_label:
|
| 381 |
+
cv2.putText(
|
| 382 |
+
img, label, label_pos,
|
| 383 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2, cv2.LINE_AA
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
label_pos_dict = {
|
| 387 |
+
"top_left": lambda bbox: (int(bbox[0]), int(bbox[1]) - 10),
|
| 388 |
+
"bottom_right": lambda bbox: (int(bbox[2]) - 150, int(bbox[3]) - 10),
|
| 389 |
+
"top_right": lambda bbox: (int(bbox[2]) - 150, int(bbox[1]) - 10),
|
| 390 |
+
"bottom_left": lambda bbox: (int(bbox[0]), int(bbox[3]) - 10),
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
for key, value in debug_bboxes.items():
|
| 394 |
+
# Unpack bboxes and scores
|
| 395 |
+
if key in ["dino_bboxes", "det_bboxes"]:
|
| 396 |
+
bboxes, scores = value
|
| 397 |
+
else:
|
| 398 |
+
bboxes = value
|
| 399 |
+
scores = [None] * len(bboxes)
|
| 400 |
+
|
| 401 |
+
color = color_dict.get(key, (0, 0, 0))
|
| 402 |
+
label_pos_fn = label_pos_dict[corner_dict.get(key, "top_left")]
|
| 403 |
+
|
| 404 |
+
# Draw each bounding box and its label
|
| 405 |
+
for idx, bbox in enumerate(bboxes):
|
| 406 |
+
score_text = f" {scores[idx]:.3f}" if scores[idx] is not None else ""
|
| 407 |
+
label = key.split("_")[0] + score_text
|
| 408 |
+
|
| 409 |
+
# Draw bounding box and label on the image
|
| 410 |
+
label_pos = label_pos_fn(bbox)
|
| 411 |
+
if key in ["dino_bboxes", "det_bboxes"] or idx == 0:
|
| 412 |
+
draw_bbox_and_label(bbox, label, color, label_pos)
|
| 413 |
+
return img
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
@staticmethod
|
| 417 |
+
def load_hamer_model(checkpoint_path: str, root_dir: Optional[str] = None) -> Tuple[HAMER, CN]:
|
| 418 |
+
"""
|
| 419 |
+
Load the HaMeR model from the checkpoint path.
|
| 420 |
+
"""
|
| 421 |
+
model_cfg_path = str(Path(checkpoint_path).parent.parent / "model_config.yaml")
|
| 422 |
+
model_cfg = get_config(model_cfg_path, update_cachedir=True)
|
| 423 |
+
# update model and params path
|
| 424 |
+
if root_dir:
|
| 425 |
+
model_cfg.defrost()
|
| 426 |
+
model_cfg.MANO.DATA_DIR = os.path.join(root_dir, model_cfg.MANO.DATA_DIR)
|
| 427 |
+
model_cfg.MANO.MODEL_PATH = os.path.join(root_dir, model_cfg.MANO.MODEL_PATH.replace("./", ""))
|
| 428 |
+
model_cfg.MANO.MEAN_PARAMS = os.path.join(root_dir, model_cfg.MANO.MEAN_PARAMS.replace("./", ""))
|
| 429 |
+
model_cfg.freeze()
|
| 430 |
+
|
| 431 |
+
# Override some config values, to crop bbox correctly
|
| 432 |
+
if (model_cfg.MODEL.BACKBONE.TYPE == "vit") and ("BBOX_SHAPE" not in model_cfg.MODEL):
|
| 433 |
+
model_cfg.defrost()
|
| 434 |
+
assert (
|
| 435 |
+
model_cfg.MODEL.IMAGE_SIZE == 256
|
| 436 |
+
), f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
|
| 437 |
+
model_cfg.MODEL.BBOX_SHAPE = [192, 256]
|
| 438 |
+
model_cfg.freeze()
|
| 439 |
+
|
| 440 |
+
# Update config to be compatible with demo
|
| 441 |
+
if "PRETRAINED_WEIGHTS" in model_cfg.MODEL.BACKBONE:
|
| 442 |
+
model_cfg.defrost()
|
| 443 |
+
model_cfg.MODEL.BACKBONE.pop("PRETRAINED_WEIGHTS")
|
| 444 |
+
model_cfg.freeze()
|
| 445 |
+
|
| 446 |
+
model = HAMER.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg)
|
| 447 |
+
return model, model_cfg
|
phantom/phantom/detectors/detector_sam2.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wrapper around SAM2 for object segmentation
|
| 3 |
+
"""
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pdb
|
| 6 |
+
import os
|
| 7 |
+
import logging
|
| 8 |
+
import requests
|
| 9 |
+
from typing import Tuple, Optional
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
from matplotlib.axes import Axes
|
| 13 |
+
import cv2
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import torch
|
| 16 |
+
from sam2.build_sam import build_sam2 # type: ignore
|
| 17 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor # type: ignore
|
| 18 |
+
from sam2.build_sam import build_sam2_video_predictor # type: ignore
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
def download_sam2_ckpt(ckpt_path: str) -> None:
|
| 23 |
+
url = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"
|
| 24 |
+
save_path = Path(ckpt_path)
|
| 25 |
+
save_path.parent.mkdir(exist_ok=True, parents=True)
|
| 26 |
+
response = requests.get(url, stream=True)
|
| 27 |
+
if response.status_code == 200:
|
| 28 |
+
with open(save_path, "wb") as file:
|
| 29 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 30 |
+
file.write(chunk)
|
| 31 |
+
logger.info(f"File downloaded successfully and saved to {save_path}")
|
| 32 |
+
else:
|
| 33 |
+
logger.info(f"Failed to download the file. Status code: {response.status_code}")
|
| 34 |
+
|
| 35 |
+
class DetectorSam2:
|
| 36 |
+
"""
|
| 37 |
+
A detector that uses the SAM2 model for object segmentation in images and videos.
|
| 38 |
+
"""
|
| 39 |
+
def __init__(self):
|
| 40 |
+
checkpoint = "../submodules/sam2/checkpoints/sam2_hiera_large.pt"
|
| 41 |
+
model_cfg = "sam2_hiera_l.yaml"
|
| 42 |
+
|
| 43 |
+
if not os.path.exists(checkpoint):
|
| 44 |
+
download_sam2_ckpt(checkpoint)
|
| 45 |
+
self.device = "cuda"
|
| 46 |
+
|
| 47 |
+
self.video_predictor = build_sam2_video_predictor(model_cfg, checkpoint, device=self.device)
|
| 48 |
+
|
| 49 |
+
def segment_video(self, video_dir: Path, bbox: np.ndarray, points: np.ndarray,
|
| 50 |
+
indices: int, reverse: bool=False, output_bboxes: Optional[np.ndarray]=None):
|
| 51 |
+
"""
|
| 52 |
+
Segment an object across video frames using SAM2's video tracking capabilities.
|
| 53 |
+
|
| 54 |
+
Parameters:
|
| 55 |
+
video_dir: Directory containing video frames as image files
|
| 56 |
+
bbox: Bounding box coordinates [x0, y0, x1, y1] for the object to track
|
| 57 |
+
points: Point(s) on the object to track
|
| 58 |
+
start_idx: Frame index to start tracking from
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
video_segments: Dictionary mapping frame indices to segmentation masks
|
| 62 |
+
list_annotated_imgs: Array of frames with the segmented object masked out
|
| 63 |
+
"""
|
| 64 |
+
frame_names = os.listdir(video_dir)
|
| 65 |
+
frame_names = sorted(frame_names)
|
| 66 |
+
with torch.inference_mode(), torch.autocast(self.device, dtype=torch.bfloat16):
|
| 67 |
+
state = self.video_predictor.init_state(video_path=str(video_dir))
|
| 68 |
+
self.video_predictor.reset_state(state)
|
| 69 |
+
|
| 70 |
+
for point, idx in zip(points, indices):
|
| 71 |
+
try:
|
| 72 |
+
if bbox is None or np.all(bbox) == 0:
|
| 73 |
+
self.video_predictor.add_new_points_or_box(
|
| 74 |
+
state,
|
| 75 |
+
frame_idx=int(idx),
|
| 76 |
+
obj_id=0,
|
| 77 |
+
points=np.array(point),
|
| 78 |
+
labels=np.ones(len(point)),
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
self.video_predictor.add_new_points_or_box(
|
| 82 |
+
state,
|
| 83 |
+
frame_idx=int(idx),
|
| 84 |
+
obj_id=0,
|
| 85 |
+
box=np.array(bbox),
|
| 86 |
+
points=np.array(point),
|
| 87 |
+
labels=np.ones(len(point)),
|
| 88 |
+
)
|
| 89 |
+
except Exception as e:
|
| 90 |
+
print("Error in adding new points or box:", e)
|
| 91 |
+
pdb.set_trace()
|
| 92 |
+
|
| 93 |
+
video_segments = {}
|
| 94 |
+
for (
|
| 95 |
+
out_frame_idx,
|
| 96 |
+
out_obj_ids,
|
| 97 |
+
out_mask_logits,
|
| 98 |
+
) in self.video_predictor.propagate_in_video(state, reverse=reverse):
|
| 99 |
+
video_segments[out_frame_idx] = {
|
| 100 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
| 101 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
frame_indices = list(video_segments.keys())
|
| 105 |
+
frame_indices.sort()
|
| 106 |
+
list_annotated_imgs = {}
|
| 107 |
+
for out_frame_idx in frame_indices:
|
| 108 |
+
img = Image.open(os.path.join(video_dir, frame_names[out_frame_idx]))
|
| 109 |
+
img_arr = np.array(img)
|
| 110 |
+
mask = video_segments[out_frame_idx][0]
|
| 111 |
+
if output_bboxes is not None:
|
| 112 |
+
# Crop the mask to the bounding box
|
| 113 |
+
output_bbox = output_bboxes[out_frame_idx].astype(np.int32)
|
| 114 |
+
if output_bbox.sum() > 0:
|
| 115 |
+
bbox_mask = np.zeros_like(mask)
|
| 116 |
+
bbox_mask = self._crop_mask_to_bbox(mask, output_bbox)
|
| 117 |
+
mask = mask * bbox_mask
|
| 118 |
+
img_arr[mask[0]] = (0, 0, 0)
|
| 119 |
+
list_annotated_imgs[out_frame_idx] = img_arr
|
| 120 |
+
|
| 121 |
+
if output_bboxes is not None:
|
| 122 |
+
for out_frame_idx in frame_indices:
|
| 123 |
+
output_bbox = output_bboxes[out_frame_idx].astype(np.int32)
|
| 124 |
+
mask = video_segments[out_frame_idx][0]
|
| 125 |
+
mask_ori = mask.copy()
|
| 126 |
+
if output_bbox.sum() > 0:
|
| 127 |
+
bbox_mask = np.zeros_like(mask)
|
| 128 |
+
bbox_mask = self._crop_mask_to_bbox(mask, output_bbox)
|
| 129 |
+
mask = mask * bbox_mask
|
| 130 |
+
video_segments[out_frame_idx] = {
|
| 131 |
+
0: mask
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
# Fix gpu memory leak
|
| 135 |
+
torch.cuda.empty_cache()
|
| 136 |
+
|
| 137 |
+
return video_segments, list_annotated_imgs
|
| 138 |
+
|
| 139 |
+
def _crop_mask_to_bbox(self, mask: np.ndarray, bbox: np.ndarray) -> np.ndarray:
|
| 140 |
+
"""
|
| 141 |
+
Crop a mask to a bounding box.
|
| 142 |
+
"""
|
| 143 |
+
margin = 20
|
| 144 |
+
bbox = bbox.astype(np.int32)
|
| 145 |
+
x0, y0, x1, y1 = bbox
|
| 146 |
+
x0 = max(0, x0 - margin)
|
| 147 |
+
x1 = min(mask.shape[2], x1 + margin)
|
| 148 |
+
y0 = max(0, y0 - margin)
|
| 149 |
+
y1 = min(mask.shape[1], y1 + margin)
|
| 150 |
+
bbox_mask = np.zeros_like(mask)
|
| 151 |
+
bbox_mask[:, y0:y1, x0:x1] = 1
|
| 152 |
+
return bbox_mask
|
| 153 |
+
|
| 154 |
+
def segment_video_from_mask(self, video_dir: str, mask: np.ndarray, frame_idx: int, reverse=False):
|
| 155 |
+
"""
|
| 156 |
+
Propagate a segmentation mask through video frames (forward or backward).
|
| 157 |
+
|
| 158 |
+
Parameters:
|
| 159 |
+
video_dir: Directory containing video frames
|
| 160 |
+
mask: Initial segmentation mask to propagate
|
| 161 |
+
frame_idx: Frame index where the mask is defined
|
| 162 |
+
reverse: If True, propagate backward in time; if False, propagate forward
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
frame_indices: List of frame indices where masks were generated
|
| 166 |
+
video_segments: Dictionary mapping frame indices to segmentation masks
|
| 167 |
+
"""
|
| 168 |
+
with torch.inference_mode(), torch.autocast(self.device, dtype=torch.bfloat16):
|
| 169 |
+
state = self.video_predictor.init_state(video_path=video_dir)
|
| 170 |
+
self.video_predictor.reset_state(state)
|
| 171 |
+
|
| 172 |
+
self.video_predictor.add_new_mask(state, frame_idx, 0, mask)
|
| 173 |
+
|
| 174 |
+
video_segments = {}
|
| 175 |
+
mask_prob = {}
|
| 176 |
+
for (
|
| 177 |
+
out_frame_idx,
|
| 178 |
+
out_obj_ids,
|
| 179 |
+
out_mask_logits,
|
| 180 |
+
) in self.video_predictor.propagate_in_video(state, reverse=reverse):
|
| 181 |
+
mask_prob[out_frame_idx] = torch.mean(torch.sigmoid(out_mask_logits))
|
| 182 |
+
video_segments[out_frame_idx] = {
|
| 183 |
+
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
|
| 184 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
frame_indices = list(video_segments.keys())
|
| 188 |
+
frame_indices.sort()
|
| 189 |
+
return frame_indices, video_segments
|
| 190 |
+
|
| 191 |
+
@staticmethod
|
| 192 |
+
def show_mask(mask: np.ndarray, ax: Axes, random_color: bool=False, borders: bool = True) -> None:
|
| 193 |
+
if random_color:
|
| 194 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
| 195 |
+
else:
|
| 196 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
| 197 |
+
h, w = mask.shape[-2:]
|
| 198 |
+
mask = mask.astype(np.uint8)
|
| 199 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
| 200 |
+
if borders:
|
| 201 |
+
contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
| 202 |
+
# Try to smooth contours
|
| 203 |
+
contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
|
| 204 |
+
mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
|
| 205 |
+
ax.imshow(mask_image)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@staticmethod
|
| 209 |
+
def show_masks(image: np.ndarray, masks: np.ndarray, scores: np.ndarray, point_coords: Optional[np.ndarray]=None,
|
| 210 |
+
box_coords: Optional[np.ndarray]=None, input_labels: Optional[np.ndarray]=None, borders: bool=True) -> None:
|
| 211 |
+
n_masks = len(masks)
|
| 212 |
+
fig, axs = plt.subplots(1, n_masks, figsize=(10*n_masks, 10))
|
| 213 |
+
for i, (mask, score) in enumerate(zip(masks, scores)):
|
| 214 |
+
axs[i].imshow(image)
|
| 215 |
+
DetectorSam2.show_mask(mask, axs[i], borders=borders)
|
| 216 |
+
if point_coords is not None:
|
| 217 |
+
assert input_labels is not None
|
| 218 |
+
DetectorSam2.show_points(point_coords, input_labels, axs[i])
|
| 219 |
+
if box_coords is not None:
|
| 220 |
+
DetectorSam2.show_box(box_coords, axs[i])
|
| 221 |
+
if len(scores) > 1:
|
| 222 |
+
axs[i].set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
|
| 223 |
+
axs[i].axis('off')
|
| 224 |
+
plt.show()
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def show_box(box: np.ndarray, ax: Axes) -> None:
|
| 228 |
+
x0, y0 = box[0], box[1]
|
| 229 |
+
w, h = box[2] - box[0], box[3] - box[1]
|
| 230 |
+
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@staticmethod
|
| 234 |
+
def show_points(coords: np.ndarray, labels: np.ndarray, ax: Axes, marker_size: int=375) -> None:
|
| 235 |
+
pos_points = coords[labels==1]
|
| 236 |
+
neg_points = coords[labels==0]
|
| 237 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*',
|
| 238 |
+
s=marker_size, edgecolor='white', linewidth=1.25)
|
| 239 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*',
|
| 240 |
+
s=marker_size, edgecolor='white', linewidth=1.25)
|
phantom/phantom/hand.py
ADDED
|
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hand Model Module
|
| 3 |
+
|
| 4 |
+
This module provides hand modeling for action processors. It converts detected hand
|
| 5 |
+
keypoints into kinematic models that can be used for robot control
|
| 6 |
+
|
| 7 |
+
Key Components:
|
| 8 |
+
- HandModel: Base class for unconstrained hand kinematic modeling
|
| 9 |
+
- PhysicallyConstrainedHandModel: Extended class with constrained joint and velocity limits
|
| 10 |
+
- Grasp point and orientation calculation for robot end-effector control
|
| 11 |
+
|
| 12 |
+
The hand model follows the MediaPipe hand landmark convention with 21 keypoints:
|
| 13 |
+
- Wrist (1 point)
|
| 14 |
+
- Thumb (4 points: MCP, PIP, DIP, TIP)
|
| 15 |
+
- Index finger (4 points: MCP, PIP, DIP, TIP)
|
| 16 |
+
- Middle finger (4 points: MCP, PIP, DIP, TIP)
|
| 17 |
+
- Ring finger (4 points: MCP, PIP, DIP, TIP)
|
| 18 |
+
- Pinky finger (4 points: MCP, PIP, DIP, TIP)
|
| 19 |
+
|
| 20 |
+
Coordinate System:
|
| 21 |
+
- All calculations performed in robot coordinate frame
|
| 22 |
+
- Grasp orientations aligned with robot end-effector conventions
|
| 23 |
+
- Joint rotations represented as rotation matrices and Euler angles
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from typing import Optional, List, Dict, Tuple, Union, Any
|
| 27 |
+
import numpy as np
|
| 28 |
+
import pdb
|
| 29 |
+
import torch
|
| 30 |
+
from scipy.spatial.transform import Rotation
|
| 31 |
+
import logging
|
| 32 |
+
|
| 33 |
+
from phantom.utils.transform_utils import *
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
class HandModel:
|
| 37 |
+
"""
|
| 38 |
+
Base class for hand kinematic modeling and trajectory analysis.
|
| 39 |
+
|
| 40 |
+
This class provides a kinematic representation of a human hand using 21 keypoints
|
| 41 |
+
from hand pose estimation. It calculates joint rotations, tracks hand motion over
|
| 42 |
+
time, and computes grasp points and orientations suitable for robot control.
|
| 43 |
+
|
| 44 |
+
The kinematic structure follows a tree topology with the wrist as the root,
|
| 45 |
+
and each finger as a separate chain. Joint rotations are calculated relative
|
| 46 |
+
to parent joints using vector alignment methods.
|
| 47 |
+
|
| 48 |
+
Key Features:
|
| 49 |
+
- 21-point hand keypoint processing
|
| 50 |
+
- Joint rotation calculation using vector alignment
|
| 51 |
+
- Grasp point computation from thumb-index / thumb-middle finger positioning
|
| 52 |
+
- End-effector orientation calculation for robot control
|
| 53 |
+
|
| 54 |
+
Attributes:
|
| 55 |
+
robot_name (str): Name of the target robot for coordinate frame alignment
|
| 56 |
+
kinematic_tree (List[Tuple[int, int]]): Parent-child relationships for hand joints
|
| 57 |
+
joint_to_neighbors_mapping (Dict[int, Tuple[int, int, int]]): Mapping of joints to their neighbors
|
| 58 |
+
vertex_positions (List[np.ndarray]): Time series of hand keypoint positions
|
| 59 |
+
joint_rotations (List[List[np.ndarray]]): Time series of joint rotation matrices
|
| 60 |
+
grasp_points (List[np.ndarray]): Time series of computed grasp points
|
| 61 |
+
grasp_oris (List[np.ndarray]): Time series of grasp orientation matrices
|
| 62 |
+
timestamps (List[float]): Time stamps for each frame
|
| 63 |
+
num_joints (int): Total number of joints in the hand model
|
| 64 |
+
joint_rotations_xyz (List[List[np.ndarray]]): Time series of Euler angle representations
|
| 65 |
+
"""
|
| 66 |
+
def __init__(self, robot_name: str) -> None:
|
| 67 |
+
"""
|
| 68 |
+
Initialize the hand model with kinematic structure.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
robot_name: Name of the target robot for coordinate alignment
|
| 72 |
+
"""
|
| 73 |
+
self.robot_name: str = robot_name
|
| 74 |
+
|
| 75 |
+
# Define the kinematic tree structure for hand joints
|
| 76 |
+
# Format: (joint_index, parent_index) where -1 indicates root (wrist)
|
| 77 |
+
self.kinematic_tree: List[Tuple[int, int]] = [
|
| 78 |
+
(0, -1), # wrist base (root of the kinematic tree)
|
| 79 |
+
|
| 80 |
+
# Thumb chain (4 joints)
|
| 81 |
+
(1, 0), # thumb mcp
|
| 82 |
+
(2, 1), # thumb pip
|
| 83 |
+
(3, 2), # thumb dip
|
| 84 |
+
(4, 3), # thumb tip
|
| 85 |
+
|
| 86 |
+
# Index finger chain (4 joints)
|
| 87 |
+
(5, 0), # index mcp
|
| 88 |
+
(6, 5), # index pip
|
| 89 |
+
(7, 6), # index dip
|
| 90 |
+
(8, 7), # index tip
|
| 91 |
+
|
| 92 |
+
# Middle finger chain (4 joints)
|
| 93 |
+
(9, 0), # middle mcp
|
| 94 |
+
(10, 9), # middle pip
|
| 95 |
+
(11, 10), # middle dip
|
| 96 |
+
(12, 11), # middle tip
|
| 97 |
+
|
| 98 |
+
# Ring finger chain (4 joints)
|
| 99 |
+
(13, 0), # ring mcp
|
| 100 |
+
(14, 13), # ring pip
|
| 101 |
+
(15, 14), # ring dip
|
| 102 |
+
(16, 15), # ring tip
|
| 103 |
+
|
| 104 |
+
# Pinky finger chain (4 joints)
|
| 105 |
+
(17, 0), # pinky mcp
|
| 106 |
+
(18, 17), # pinky pip
|
| 107 |
+
(19, 18), # pinky dip
|
| 108 |
+
(20, 19), # pinky tip
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
# Mapping from joint index to (current_vertex, child_vertex, parent_vertex)
|
| 112 |
+
# This defines the local coordinate system for each joint rotation calculation
|
| 113 |
+
self.joint_to_neighbors_mapping: Dict[int, Tuple[int, int, int]] = {
|
| 114 |
+
# Thumb joint mappings
|
| 115 |
+
0: (0, 1, -1), # wrist to thumb mcp (no parent)
|
| 116 |
+
1: (1, 2, 0), # thumb mcp to pip (parent: wrist)
|
| 117 |
+
2: (2, 3, 1), # thumb pip to dip (parent: thumb mcp)
|
| 118 |
+
3: (3, 4, 2), # thumb dip to tip (parent: thumb pip)
|
| 119 |
+
|
| 120 |
+
# Index finger joint mappings
|
| 121 |
+
4: (0, 5, -1), # wrist to index mcp (no parent)
|
| 122 |
+
5: (5, 6, 0), # index mcp to pip (parent: wrist)
|
| 123 |
+
6: (6, 7, 5), # index pip to dip (parent: index mcp)
|
| 124 |
+
7: (7, 8, 6), # index dip to tip (parent: index pip)
|
| 125 |
+
|
| 126 |
+
# Middle finger joint mappings
|
| 127 |
+
8: (0, 9, -1), # wrist to middle mcp (no parent)
|
| 128 |
+
9: (9, 10, 0), # middle mcp to pip (parent: wrist)
|
| 129 |
+
10: (10, 11, 9), # middle pip to dip (parent: middle mcp)
|
| 130 |
+
11: (11, 12, 10),# middle dip to tip (parent: middle pip)
|
| 131 |
+
|
| 132 |
+
# Ring finger joint mappings
|
| 133 |
+
12: (0, 13, -1), # wrist to ring mcp (no parent)
|
| 134 |
+
13: (13, 14, 0),# ring mcp to pip (parent: wrist)
|
| 135 |
+
14: (14, 15, 13),# ring pip to dip (parent: ring mcp)
|
| 136 |
+
15: (15, 16, 14),# ring dip to tip (parent: ring pip)
|
| 137 |
+
|
| 138 |
+
# Pinky finger joint mappings
|
| 139 |
+
16: (0, 17, -1), # wrist to pinky mcp (no parent)
|
| 140 |
+
17: (17, 18, 0),# pinky mcp to pip (parent: wrist)
|
| 141 |
+
18: (18, 19, 17),# pinky pip to dip (parent: pinky mcp)
|
| 142 |
+
19: (19, 20, 18),# pinky dip to tip (parent: pinky pip)
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
self.num_joints: int = len(self.joint_to_neighbors_mapping)
|
| 146 |
+
|
| 147 |
+
# Time series data storage
|
| 148 |
+
self.vertex_positions: List[np.ndarray] = [] # List of (21, 3) arrays for each timestep
|
| 149 |
+
self.joint_rotations: List[List[np.ndarray]] = [] # List of rotation matrices for each joint
|
| 150 |
+
self.joint_rotations_xyz: List[List[np.ndarray]] = [] # List of Euler angle representations
|
| 151 |
+
self.grasp_points: List[np.ndarray] = [] # List of computed grasp points (3D positions)
|
| 152 |
+
self.grasp_oris: List[np.ndarray] = [] # List of grasp orientation matrices (3x3)
|
| 153 |
+
self.timestamps: List[float] = [] # List of timestamps for temporal analysis
|
| 154 |
+
|
| 155 |
+
def calculate_joint_rotation(self, current_pos: np.ndarray, child_pos: np.ndarray, parent_pos: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
|
| 156 |
+
"""
|
| 157 |
+
Calculate the rotation matrix for a single joint using vector alignment.
|
| 158 |
+
|
| 159 |
+
This method computes the rotation that aligns the previous direction vector
|
| 160 |
+
with the current direction vector. For root joints (no parent), it uses
|
| 161 |
+
a default upward direction as the reference.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
current_pos: 3D position of the current joint
|
| 165 |
+
child_pos: 3D position of the child joint
|
| 166 |
+
parent_pos: 3D position of the parent joint
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Tuple containing:
|
| 170 |
+
- rotation_matrix: 3x3 rotation matrix
|
| 171 |
+
- euler_angles: Rotation as XYZ Euler angles
|
| 172 |
+
"""
|
| 173 |
+
# Calculate current direction vector (current -> child)
|
| 174 |
+
current_dir = child_pos - current_pos
|
| 175 |
+
current_norm = np.linalg.norm(current_dir)
|
| 176 |
+
if current_norm < 1e-10:
|
| 177 |
+
return np.eye(3), np.array([0,0,0])
|
| 178 |
+
current_dir /= current_norm
|
| 179 |
+
|
| 180 |
+
# Calculate previous direction vector (parent -> current, or default up)
|
| 181 |
+
prev_dir = np.array([0.0, 0.0, 1.0]) if parent_pos is None else current_pos - parent_pos
|
| 182 |
+
prev_norm = np.linalg.norm(prev_dir)
|
| 183 |
+
if prev_norm < 1e-10:
|
| 184 |
+
return np.eye(3), np.array([0,0,0])
|
| 185 |
+
prev_dir /= prev_norm
|
| 186 |
+
|
| 187 |
+
# Check if vectors are already aligned (no rotation needed)
|
| 188 |
+
if np.abs((np.abs(np.dot(current_dir, prev_dir)) - 1)) < 1e-8:
|
| 189 |
+
return np.eye(3), np.array([0,0,0])
|
| 190 |
+
|
| 191 |
+
# Calculate rotation that aligns prev_dir with current_dir
|
| 192 |
+
rotation, _ = Rotation.align_vectors([current_dir], [prev_dir])
|
| 193 |
+
return rotation.as_matrix(), rotation.as_euler('xyz')
|
| 194 |
+
|
| 195 |
+
def calculate_frame_rotations(self, vertices: np.ndarray) -> Tuple[List[np.ndarray], List[np.ndarray]]:
|
| 196 |
+
"""
|
| 197 |
+
Calculate rotation matrices for all joints in a single frame.
|
| 198 |
+
|
| 199 |
+
This method processes all joints in the hand and computes their rotations
|
| 200 |
+
based on the kinematic structure and current vertex positions.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
vertices: Hand keypoints, shape (21, 3)
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
Tuple containing:
|
| 207 |
+
- rotation_matrices: List of 3x3 rotation matrices
|
| 208 |
+
- euler_angles: List of XYZ Euler angle arrays
|
| 209 |
+
"""
|
| 210 |
+
rotations, rotations_xyz = zip(*[
|
| 211 |
+
self.calculate_joint_rotation(vertices[m[0]], vertices[m[1]],
|
| 212 |
+
None if m[2] == -1 else vertices[m[2]])
|
| 213 |
+
for m in self.joint_to_neighbors_mapping.values()
|
| 214 |
+
])
|
| 215 |
+
return list(rotations), list(rotations_xyz)
|
| 216 |
+
|
| 217 |
+
def calculate_angular_velocity(self, joint_idx: int, t1: int, t2: int) -> np.ndarray:
|
| 218 |
+
"""
|
| 219 |
+
Calculate angular velocity for a specific joint between two time frames.
|
| 220 |
+
|
| 221 |
+
Angular velocity is computed as the rotation vector difference divided
|
| 222 |
+
by the time difference between frames.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
joint_idx: Index of the joint
|
| 226 |
+
t1: Index of the first time frame
|
| 227 |
+
t2: Index of the second time frame
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
Angular velocity vector (3,) in rad/s
|
| 231 |
+
"""
|
| 232 |
+
dt = self.timestamps[t2] - self.timestamps[t1]
|
| 233 |
+
if dt == 0:
|
| 234 |
+
return np.zeros(3)
|
| 235 |
+
|
| 236 |
+
# Get rotation matrices for the two time frames
|
| 237 |
+
R1, R2 = self.joint_rotations[t1][joint_idx], self.joint_rotations[t2][joint_idx]
|
| 238 |
+
|
| 239 |
+
# Calculate relative rotation and convert to angular velocity
|
| 240 |
+
R_relative = Rotation.from_matrix(R2) * Rotation.from_matrix(R1).inv()
|
| 241 |
+
return R_relative.as_rotvec() / dt
|
| 242 |
+
|
| 243 |
+
def calculate_frame_angular_velocities(self, current_frame_idx: int) -> np.ndarray:
|
| 244 |
+
"""
|
| 245 |
+
Calculate angular velocities for all joints at the current frame.
|
| 246 |
+
|
| 247 |
+
This method computes the angular velocity vectors for all joints by
|
| 248 |
+
comparing rotations with the previous frame. Returns zeros for the
|
| 249 |
+
first frame since no previous frame exists.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
current_frame_idx: Index of the current frame. Must be > 0.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
Array of angular velocity vectors (shape: num_joints x 3)
|
| 256 |
+
Each row contains [wx, wy, wz] for one joint.
|
| 257 |
+
Returns zeros if current_frame_idx == 0.
|
| 258 |
+
"""
|
| 259 |
+
if current_frame_idx == 0:
|
| 260 |
+
return np.zeros((self.num_joints, 3))
|
| 261 |
+
|
| 262 |
+
prev_frame_idx = current_frame_idx - 1
|
| 263 |
+
|
| 264 |
+
# Vectorized calculation for all joints
|
| 265 |
+
velocities = np.array([
|
| 266 |
+
self.calculate_angular_velocity(joint_idx, prev_frame_idx, current_frame_idx)
|
| 267 |
+
for joint_idx in range(self.num_joints)
|
| 268 |
+
])
|
| 269 |
+
|
| 270 |
+
return velocities
|
| 271 |
+
|
| 272 |
+
def calculate_grasp_plane(self, vertices: np.ndarray) -> np.ndarray:
|
| 273 |
+
"""
|
| 274 |
+
Calculate the plane that best fits through a set of hand vertices.
|
| 275 |
+
|
| 276 |
+
This method uses Singular Value Decomposition (SVD) to find the plane.
|
| 277 |
+
The plane is typically fitted through thumb and index finger points.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
vertices: Set of 3D points to fit plane through, shape (N, 3)
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
Plane coefficients [a, b, c, d] for ax + by + cz + d = 0
|
| 284 |
+
"""
|
| 285 |
+
# Create augmented matrix with homogeneous coordinates for plane fitting
|
| 286 |
+
A = np.c_[vertices[:, 0], vertices[:, 1], vertices[:, 2], np.ones(vertices.shape[0])]
|
| 287 |
+
|
| 288 |
+
# Right-hand side is zeros for the plane equation ax + by + cz + d = 0
|
| 289 |
+
b = np.zeros(vertices.shape[0])
|
| 290 |
+
|
| 291 |
+
# Use SVD to solve the least squares problem
|
| 292 |
+
U, S, Vt = np.linalg.svd(A)
|
| 293 |
+
|
| 294 |
+
# Plane coefficients are in the last row of Vt (smallest singular value)
|
| 295 |
+
plane_coeffs = Vt[-1, :]
|
| 296 |
+
|
| 297 |
+
# Normalize coefficients for easier interpretation (unit normal vector)
|
| 298 |
+
plane_coeffs = plane_coeffs / np.linalg.norm(plane_coeffs[:3])
|
| 299 |
+
|
| 300 |
+
return plane_coeffs # [a, b, c, d]
|
| 301 |
+
|
| 302 |
+
def calculate_grasp_point(self, grasp_plane: np.ndarray, vertices: np.ndarray) -> np.ndarray:
|
| 303 |
+
"""
|
| 304 |
+
Calculate the optimal grasp point for robot end-effector positioning.
|
| 305 |
+
|
| 306 |
+
The grasp point is computed as the midpoint between projected thumb tip
|
| 307 |
+
and index finger tip on the grasp plane. This provides a stable reference
|
| 308 |
+
point for robot grasping operations.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
grasp_plane: Plane coefficients [a, b, c, d]
|
| 312 |
+
vertices: Hand keypoints, shape (21, 3)
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
3D grasp point coordinates
|
| 316 |
+
"""
|
| 317 |
+
# Project fingertips onto the grasp plane
|
| 318 |
+
thumb_pt = project_point_to_plane(vertices[4], grasp_plane)
|
| 319 |
+
index_pt = project_point_to_plane(vertices[8], grasp_plane)
|
| 320 |
+
|
| 321 |
+
# Compute midpoint as the grasp reference
|
| 322 |
+
hand_ee_pt = np.mean([thumb_pt, index_pt], axis=0)
|
| 323 |
+
return hand_ee_pt
|
| 324 |
+
|
| 325 |
+
def add_frame(self, vertices: np.ndarray, timestamp: float, hand_detected: bool = True) -> None:
|
| 326 |
+
"""
|
| 327 |
+
Add a new frame of vertex positions and calculate corresponding data.
|
| 328 |
+
|
| 329 |
+
This is the main method for processing hand data over time. It computes
|
| 330 |
+
grasp points, orientations, and stores all relevant information for
|
| 331 |
+
the current timestep.
|
| 332 |
+
|
| 333 |
+
Args:
|
| 334 |
+
vertices: Array of 21 3D vertex positions
|
| 335 |
+
timestamp: Time of the frame in seconds
|
| 336 |
+
hand_detected: Whether hand was successfully detected
|
| 337 |
+
"""
|
| 338 |
+
if len(vertices) != 21:
|
| 339 |
+
raise ValueError(f"Expected 21 vertices, got {len(vertices)}")
|
| 340 |
+
|
| 341 |
+
# Handle frames without hand detection
|
| 342 |
+
if not hand_detected:
|
| 343 |
+
self.vertex_positions.append(np.zeros((21, 3)))
|
| 344 |
+
self.grasp_points.append(np.zeros(3))
|
| 345 |
+
self.grasp_oris.append(np.eye(3))
|
| 346 |
+
self.timestamps.append(timestamp)
|
| 347 |
+
return
|
| 348 |
+
|
| 349 |
+
# Extract key finger tip positions
|
| 350 |
+
thumb_tip = vertices[4]
|
| 351 |
+
index_tip = vertices[8]
|
| 352 |
+
middle_tip = vertices[12]
|
| 353 |
+
|
| 354 |
+
# Calculate grasp point as midpoint between thumb and middle finger tips
|
| 355 |
+
control_point = (thumb_tip + middle_tip) / 2
|
| 356 |
+
grasp_pt = control_point
|
| 357 |
+
|
| 358 |
+
# Calculate gripper orientation from thumb-index finger configuration
|
| 359 |
+
gripper_ori, _ = HandModel.get_gripper_orientation(thumb_tip, index_tip, vertices)
|
| 360 |
+
|
| 361 |
+
# Apply 90-degree rotation to align with robot gripper convention
|
| 362 |
+
rot_90_deg = Rotation.from_euler('Z', 90, degrees=True).as_matrix()
|
| 363 |
+
grasp_ori = gripper_ori @ rot_90_deg
|
| 364 |
+
|
| 365 |
+
# Store all frame data
|
| 366 |
+
self.vertex_positions.append(vertices)
|
| 367 |
+
self.grasp_points.append(grasp_pt)
|
| 368 |
+
self.grasp_oris.append(grasp_ori)
|
| 369 |
+
self.timestamps.append(timestamp)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def get_joint_data(self, joint_idx: int) -> Dict[str, Union[List[float], List[np.ndarray]]]:
|
| 373 |
+
"""
|
| 374 |
+
Get all trajectory data for a specific joint across all frames.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
joint_idx: Index of the joint
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
Dictionary containing joint trajectory data with keys:
|
| 381 |
+
- 'timestamps': List of time stamps
|
| 382 |
+
- 'rotations': List of rotation matrices for this joint
|
| 383 |
+
"""
|
| 384 |
+
return {
|
| 385 |
+
'timestamps': self.timestamps,
|
| 386 |
+
'rotations': [frame[joint_idx] for frame in self.joint_rotations],
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
@staticmethod
|
| 390 |
+
def get_parallel_plane(a: float, b: float, c: float, d: float, dist: float) -> Tuple[float, float, float, float]:
|
| 391 |
+
"""
|
| 392 |
+
Calculate coefficients of a plane parallel to the given plane at specified distance.
|
| 393 |
+
|
| 394 |
+
This utility method is useful for creating offset grasp planes that account
|
| 395 |
+
for gripper thickness or provide clearance during grasping operations.
|
| 396 |
+
|
| 397 |
+
Parameters:
|
| 398 |
+
a, b, c, d: Coefficients of the original plane ax + by + cz + d = 0
|
| 399 |
+
dist: Distance between planes (positive moves in normal direction)
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
(a, b, c, d_new) coefficients of the parallel plane
|
| 403 |
+
"""
|
| 404 |
+
# Calculate the magnitude of the normal vector
|
| 405 |
+
normal_magnitude = np.sqrt(a**2 + b**2 + c**2)
|
| 406 |
+
|
| 407 |
+
# Parallel plane has same normal direction, only d changes
|
| 408 |
+
d_new = d - dist * normal_magnitude
|
| 409 |
+
|
| 410 |
+
return (a, b, c, d_new)
|
| 411 |
+
|
| 412 |
+
@staticmethod
|
| 413 |
+
def get_gripper_orientation(thumb_tip: np.ndarray, index_tip: np.ndarray, vertices: np.ndarray, grasp_plane: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
|
| 414 |
+
"""
|
| 415 |
+
Compute robot gripper orientation matrix from hand keypoints and fingertip positions.
|
| 416 |
+
|
| 417 |
+
This method calculates a coordinate frame suitable for robot gripper control
|
| 418 |
+
based on the relative positions of thumb, index finger, and wrist. The resulting
|
| 419 |
+
orientation matrix can be directly used for robot end-effector control.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
thumb_tip: 3D position of thumb tip
|
| 423 |
+
index_tip: 3D position of index finger tip
|
| 424 |
+
vertices: All hand keypoints, shape (21, 3)
|
| 425 |
+
grasp_plane: Plane coefficients [a,b,c,d]
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
Tuple containing:
|
| 429 |
+
- gripper_orientation: 3x3 rotation matrix
|
| 430 |
+
- z_axis: Z-axis direction vector of the gripper frame
|
| 431 |
+
"""
|
| 432 |
+
# Calculate gripper opening direction (thumb to index finger)
|
| 433 |
+
gripper_direction = thumb_tip - index_tip
|
| 434 |
+
|
| 435 |
+
# Calculate gripper reference point (midpoint of fingertips)
|
| 436 |
+
midpoint = (thumb_tip + index_tip) / 2
|
| 437 |
+
|
| 438 |
+
if grasp_plane is None:
|
| 439 |
+
# Use palm geometry when no plane is provided
|
| 440 |
+
palm_axis = vertices[5] - midpoint # index MCP to midpoint
|
| 441 |
+
x_axis = gripper_direction / max(np.linalg.norm(gripper_direction), 1e-10)
|
| 442 |
+
z_axis = -palm_axis / max(np.linalg.norm(palm_axis), 1e-10)
|
| 443 |
+
else:
|
| 444 |
+
# Use grasp plane for orientation calculation
|
| 445 |
+
palm_axis = project_point_to_plane(vertices[0], grasp_plane) - project_point_to_plane(vertices[1], grasp_plane)
|
| 446 |
+
z_axis = -palm_axis / max(np.linalg.norm(palm_axis), 1e-10)
|
| 447 |
+
x_axis = np.cross(grasp_plane[:3], z_axis)
|
| 448 |
+
x_axis /= max(np.linalg.norm(x_axis), 1e-10)
|
| 449 |
+
|
| 450 |
+
# Compute y-axis
|
| 451 |
+
y_axis = np.cross(z_axis, x_axis)
|
| 452 |
+
y_axis /= max(np.linalg.norm(y_axis), 1e-10)
|
| 453 |
+
|
| 454 |
+
# Ensure orthogonality by recalculating z_axis
|
| 455 |
+
z_axis = np.cross(x_axis, y_axis)
|
| 456 |
+
z_axis /= max(np.linalg.norm(z_axis), 1e-10)
|
| 457 |
+
|
| 458 |
+
# Check orientation consistency with palm direction
|
| 459 |
+
if type(palm_axis) == torch.Tensor:
|
| 460 |
+
palm_axis = palm_axis.cpu().numpy()
|
| 461 |
+
if z_axis @ palm_axis > 0:
|
| 462 |
+
x_axis, y_axis, z_axis = -x_axis, -y_axis, -z_axis
|
| 463 |
+
|
| 464 |
+
# Construct orientation matrix
|
| 465 |
+
gripper_ori = np.column_stack([x_axis, y_axis, z_axis])
|
| 466 |
+
|
| 467 |
+
# Ensure proper handedness (right-handed coordinate system)
|
| 468 |
+
if np.linalg.det(gripper_ori) < 0:
|
| 469 |
+
x_axis = -x_axis # Flip one axis to fix handedness
|
| 470 |
+
gripper_ori = np.column_stack([x_axis, y_axis, z_axis])
|
| 471 |
+
|
| 472 |
+
# Verify determinant for debugging
|
| 473 |
+
det = np.linalg.det(gripper_ori)
|
| 474 |
+
if det < 0.9:
|
| 475 |
+
pdb.set_trace()
|
| 476 |
+
|
| 477 |
+
return gripper_ori, z_axis
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class PhysicallyConstrainedHandModel(HandModel):
|
| 481 |
+
"""
|
| 482 |
+
Extended hand model with physical constraints and realistic joint limits.
|
| 483 |
+
|
| 484 |
+
This class builds upon the base HandModel by adding realistic constraints
|
| 485 |
+
that enforce physically plausible hand poses and motion. It includes:
|
| 486 |
+
- Joint angle limits based on human hand anatomy
|
| 487 |
+
- Angular velocity constraints for smooth motion
|
| 488 |
+
- Pose reconstruction with constraint enforcement
|
| 489 |
+
- Enhanced grasp point calculation with plane-based refinement
|
| 490 |
+
|
| 491 |
+
Constrained hand model is used in Phantom
|
| 492 |
+
|
| 493 |
+
Key Constraints:
|
| 494 |
+
- Anatomically correct joint limits for each finger joint
|
| 495 |
+
- Velocity limiting to prevent jerky motions
|
| 496 |
+
- Iterative pose refinement with constraint satisfaction
|
| 497 |
+
- More robust grasp plane calculation and orientation alignment
|
| 498 |
+
|
| 499 |
+
Attributes:
|
| 500 |
+
joint_limits (Dict[int, Tuple[float, ...]]): Joint angle limits for each joint in radians
|
| 501 |
+
max_angular_velocity (float): Maximum allowed angular velocity in rad/s
|
| 502 |
+
"""
|
| 503 |
+
def __init__(self, robot_name: str) -> None:
|
| 504 |
+
"""
|
| 505 |
+
Initialize the physically constrained hand model.
|
| 506 |
+
|
| 507 |
+
Args:
|
| 508 |
+
robot_name: Name of the target robot for coordinate alignment
|
| 509 |
+
"""
|
| 510 |
+
super().__init__(robot_name)
|
| 511 |
+
|
| 512 |
+
# Define joint rotation limits (in radians) for each joint
|
| 513 |
+
# Format: (min_x, max_x, min_y, max_y, min_z, max_z) for XYZ Euler angles
|
| 514 |
+
small_angle = np.pi/40 # Small constraint for fine motor control
|
| 515 |
+
|
| 516 |
+
self.joint_limits: Dict[int, Tuple[float, float, float, float, float, float]] = {
|
| 517 |
+
# Thumb joints - more flexible due to opposable nature
|
| 518 |
+
0: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # wrist to thumb mcp
|
| 519 |
+
1: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # thumb mcp to pip
|
| 520 |
+
2: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # thumb pip to dip
|
| 521 |
+
3: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # thumb dip to tip
|
| 522 |
+
|
| 523 |
+
# Index finger joints - moderate constraints
|
| 524 |
+
4: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # wrist to index mcp
|
| 525 |
+
5: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # index mcp to pip
|
| 526 |
+
6: (-small_angle, small_angle, -np.pi/8, np.pi/8, -small_angle, small_angle), # index pip to dip
|
| 527 |
+
7: (-small_angle, small_angle, -np.pi/8, np.pi/8, -small_angle, small_angle), # index dip to tip
|
| 528 |
+
|
| 529 |
+
# Middle finger joints - tighter constraints for stability
|
| 530 |
+
8: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # wrist to middle mcp
|
| 531 |
+
9: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # middle mcp to pip
|
| 532 |
+
10: (-np.pi, np.pi, -np.pi, np.pi, -np.pi/4, np.pi/4), # middle pip to dip
|
| 533 |
+
11: (-np.pi, np.pi, -np.pi, np.pi, -np.pi/4, np.pi/4), # middle dip to tip
|
| 534 |
+
|
| 535 |
+
# Ring finger joints - similar to middle finger
|
| 536 |
+
12: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # wrist to ring mcp
|
| 537 |
+
13: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # ring mcp to pip
|
| 538 |
+
14: (-np.pi, np.pi, -np.pi, np.pi, -np.pi/4, np.pi/4), # ring pip to dip
|
| 539 |
+
15: (-np.pi, np.pi, -np.pi, np.pi, -np.pi/4, np.pi/4), # ring dip to tip
|
| 540 |
+
|
| 541 |
+
# Pinky finger joints - most constrained due to size
|
| 542 |
+
16: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # wrist to pinky mcp
|
| 543 |
+
17: (-np.pi, np.pi, -np.pi, np.pi, -np.pi, np.pi), # pinky mcp to pip
|
| 544 |
+
18: (-np.pi, np.pi, -np.pi, np.pi, -np.pi/4, np.pi/4), # pinky pip to dip
|
| 545 |
+
19: (-np.pi, np.pi, -np.pi, np.pi, -np.pi/4, np.pi/4), # pinky dip to tip
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
# Maximum angular velocity constraint (2π rad/s = 360°/s)
|
| 549 |
+
self.max_angular_velocity: float = np.pi * 2
|
| 550 |
+
|
| 551 |
+
def reconstruct_vertices(self, input_vertices: np.ndarray, rotations: List[np.ndarray]) -> np.ndarray:
|
| 552 |
+
"""
|
| 553 |
+
Reconstruct vertex positions from base vertex and constrained rotations.
|
| 554 |
+
|
| 555 |
+
This method applies the kinematic chain to reconstruct hand vertex positions
|
| 556 |
+
while respecting the calculated bone lengths from the input vertices.
|
| 557 |
+
This ensures consistent hand proportions while applying constraints.
|
| 558 |
+
|
| 559 |
+
Args:
|
| 560 |
+
input_vertices: Original vertex positions, shape (21, 3)
|
| 561 |
+
rotations: List of constrained rotation matrices
|
| 562 |
+
|
| 563 |
+
Returns:
|
| 564 |
+
Reconstructed vertex positions, shape (21, 3)
|
| 565 |
+
"""
|
| 566 |
+
vertices = np.zeros((21, 3))
|
| 567 |
+
vertices[0] = input_vertices[0] # Wrist position remains fixed
|
| 568 |
+
|
| 569 |
+
# Calculate bone lengths from original vertices to maintain proportions
|
| 570 |
+
bone_lengths: Dict[Tuple[int, int], float] = {}
|
| 571 |
+
min_bone_length = 1e-6 # Minimum length to avoid numerical issues
|
| 572 |
+
|
| 573 |
+
# Extract bone lengths from the kinematic chain
|
| 574 |
+
for current in range(self.num_joints):
|
| 575 |
+
mapping = self.joint_to_neighbors_mapping[current]
|
| 576 |
+
current_vertex = mapping[0]
|
| 577 |
+
child_vertex = mapping[1]
|
| 578 |
+
parent_vertex = mapping[2]
|
| 579 |
+
|
| 580 |
+
# Calculate bone length for current->child connection
|
| 581 |
+
if child_vertex != -2:
|
| 582 |
+
length = np.linalg.norm(input_vertices[child_vertex] - input_vertices[current_vertex])
|
| 583 |
+
bone_lengths[(current_vertex, child_vertex)] = max(length, min_bone_length)
|
| 584 |
+
|
| 585 |
+
# Reconstruct positions following the kinematic chain
|
| 586 |
+
for current in range(self.num_joints):
|
| 587 |
+
mapping = self.joint_to_neighbors_mapping[current]
|
| 588 |
+
current_vertex = mapping[0]
|
| 589 |
+
child_vertex = mapping[1]
|
| 590 |
+
parent_vertex = mapping[2]
|
| 591 |
+
|
| 592 |
+
if child_vertex == -2:
|
| 593 |
+
continue
|
| 594 |
+
|
| 595 |
+
# Get positions and rotation for this joint
|
| 596 |
+
parent_pos = vertices[parent_vertex]
|
| 597 |
+
current_pos = vertices[current_vertex]
|
| 598 |
+
rotation = rotations[current]
|
| 599 |
+
|
| 600 |
+
# Determine reference direction for rotation application
|
| 601 |
+
if parent_vertex == -1:
|
| 602 |
+
# Root joints use upward direction as reference
|
| 603 |
+
prev_dir = np.array([0, 0, 1])
|
| 604 |
+
else:
|
| 605 |
+
# Use direction from parent to current vertex
|
| 606 |
+
prev_dir = vertices[current_vertex] - vertices[parent_vertex]
|
| 607 |
+
prev_dir = prev_dir / np.linalg.norm(prev_dir)
|
| 608 |
+
|
| 609 |
+
# Apply rotation to get new direction
|
| 610 |
+
current_dir = rotation @ prev_dir
|
| 611 |
+
|
| 612 |
+
# Position child vertex using calculated bone length
|
| 613 |
+
bone_length = bone_lengths[(current_vertex, child_vertex)]
|
| 614 |
+
vertices[child_vertex] = current_pos + current_dir * bone_length
|
| 615 |
+
|
| 616 |
+
return vertices
|
| 617 |
+
|
| 618 |
+
def constrain_rotation(self, rotation_matrix: np.ndarray, joint_idx: int) -> np.ndarray:
|
| 619 |
+
"""
|
| 620 |
+
Apply joint angle constraints to a rotation matrix.
|
| 621 |
+
|
| 622 |
+
This method converts the rotation to Euler angles, clips them to the
|
| 623 |
+
joint limits, and converts back to a rotation matrix. This ensures
|
| 624 |
+
all joint angles remain within anatomically realistic ranges.
|
| 625 |
+
|
| 626 |
+
Args:
|
| 627 |
+
rotation_matrix: 3x3 rotation matrix to constrain
|
| 628 |
+
joint_idx: Index of the joint for limit lookup
|
| 629 |
+
|
| 630 |
+
Returns:
|
| 631 |
+
Constrained 3x3 rotation matrix
|
| 632 |
+
"""
|
| 633 |
+
try:
|
| 634 |
+
# Convert rotation matrix to Euler angles
|
| 635 |
+
rot = Rotation.from_matrix(rotation_matrix)
|
| 636 |
+
euler = rot.as_euler('xyz')
|
| 637 |
+
|
| 638 |
+
# Get joint limits for this joint
|
| 639 |
+
limits = self.joint_limits[joint_idx]
|
| 640 |
+
|
| 641 |
+
# Clip Euler angles to the specified limits
|
| 642 |
+
constrained_euler = np.clip(euler,
|
| 643 |
+
[limits[0], limits[2], limits[4]], # min limits
|
| 644 |
+
[limits[1], limits[3], limits[5]]) # max limits
|
| 645 |
+
|
| 646 |
+
# Convert back to rotation matrix if any clipping occurred
|
| 647 |
+
if not np.allclose(euler, constrained_euler):
|
| 648 |
+
return Rotation.from_euler('xyz', constrained_euler).as_matrix()
|
| 649 |
+
return rotation_matrix
|
| 650 |
+
|
| 651 |
+
except ValueError:
|
| 652 |
+
logger.error("Error constraining rotation")
|
| 653 |
+
# Return identity matrix if rotation is invalid
|
| 654 |
+
return np.eye(3)
|
| 655 |
+
|
| 656 |
+
def constrain_velocity(self, velocity: np.ndarray) -> np.ndarray:
|
| 657 |
+
"""
|
| 658 |
+
Apply angular velocity constraints to limit motion speed.
|
| 659 |
+
|
| 660 |
+
This method ensures that joint angular velocities don't exceed the
|
| 661 |
+
maximum allowed velocity, preventing jerky or unrealistic motions.
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
velocity: Angular velocity vector to constrain
|
| 665 |
+
|
| 666 |
+
Returns:
|
| 667 |
+
Constrained angular velocity vector
|
| 668 |
+
"""
|
| 669 |
+
velocity_magnitude = np.linalg.norm(velocity)
|
| 670 |
+
if velocity_magnitude > self.max_angular_velocity:
|
| 671 |
+
# Scale velocity to maximum while preserving direction
|
| 672 |
+
return velocity * (self.max_angular_velocity / velocity_magnitude)
|
| 673 |
+
return velocity
|
| 674 |
+
|
| 675 |
+
def add_frame(self, vertices: np.ndarray, timestamp: float, finger_pts: Any) -> None:
|
| 676 |
+
"""
|
| 677 |
+
Add a new frame with physical constraints applied.
|
| 678 |
+
|
| 679 |
+
This method extends the base add_frame functionality by applying
|
| 680 |
+
joint limits, velocity constraints, and enhanced grasp calculations.
|
| 681 |
+
The result is a more realistic and stable hand model suitable for
|
| 682 |
+
robot control applications.
|
| 683 |
+
|
| 684 |
+
Args:
|
| 685 |
+
vertices: Hand keypoints, shape (21, 3)
|
| 686 |
+
timestamp: Time of the frame in seconds
|
| 687 |
+
finger_pts: Additional finger point data (currently unused)
|
| 688 |
+
"""
|
| 689 |
+
# Calculate initial rotations from raw vertex positions
|
| 690 |
+
rotations, rotations_xyz = self.calculate_frame_rotations(vertices)
|
| 691 |
+
|
| 692 |
+
# Apply joint angle constraints to all rotations
|
| 693 |
+
constrained_rotations: List[np.ndarray] = []
|
| 694 |
+
for joint_idx, rotation in enumerate(rotations):
|
| 695 |
+
constrained_rot = self.constrain_rotation(rotation, joint_idx)
|
| 696 |
+
constrained_rotations.append(constrained_rot)
|
| 697 |
+
|
| 698 |
+
# Apply velocity constraints if this is not the first frame
|
| 699 |
+
if len(self.timestamps) > 0:
|
| 700 |
+
dt = timestamp - self.timestamps[-1]
|
| 701 |
+
for joint_idx in range(self.num_joints):
|
| 702 |
+
# Calculate angular velocity for this joint
|
| 703 |
+
prev_rot = Rotation.from_matrix(self.joint_rotations[-1][joint_idx])
|
| 704 |
+
curr_rot = Rotation.from_matrix(constrained_rotations[joint_idx])
|
| 705 |
+
rel_rot = curr_rot * prev_rot.inv()
|
| 706 |
+
velocity = rel_rot.as_rotvec() / dt
|
| 707 |
+
|
| 708 |
+
# Apply velocity constraint if needed
|
| 709 |
+
if np.linalg.norm(velocity) > self.max_angular_velocity:
|
| 710 |
+
# Constrain velocity and reconstruct rotation
|
| 711 |
+
constrained_velocity = self.constrain_velocity(velocity)
|
| 712 |
+
delta_rot = Rotation.from_rotvec(constrained_velocity * dt)
|
| 713 |
+
new_rot = delta_rot * prev_rot
|
| 714 |
+
constrained_rotations[joint_idx] = new_rot.as_matrix()
|
| 715 |
+
|
| 716 |
+
# Reconstruct vertices with constrained rotations
|
| 717 |
+
constrained_vertices = self.reconstruct_vertices(vertices, constrained_rotations)
|
| 718 |
+
|
| 719 |
+
# Extract key points for grasp calculation
|
| 720 |
+
thumb_tip = constrained_vertices[4]
|
| 721 |
+
index_tip = constrained_vertices[8]
|
| 722 |
+
|
| 723 |
+
# Calculate grasp plane using thumb and index finger regions
|
| 724 |
+
grasp_plane = self.calculate_grasp_plane(constrained_vertices[3:9])
|
| 725 |
+
|
| 726 |
+
# Organize fingers for direction analysis
|
| 727 |
+
n_fingers = len(constrained_vertices) - 1
|
| 728 |
+
npts_per_finger = 4
|
| 729 |
+
list_fingers = [np.vstack([constrained_vertices[0], constrained_vertices[i:i + npts_per_finger]])
|
| 730 |
+
for i in range(1, n_fingers, npts_per_finger)]
|
| 731 |
+
|
| 732 |
+
# Calculate finger direction vector for plane orientation
|
| 733 |
+
dir_vec = list_fingers[1][1] - list_fingers[-1][1] # index to pinky MCP
|
| 734 |
+
dir_vec = dir_vec / np.linalg.norm(dir_vec)
|
| 735 |
+
|
| 736 |
+
# Ensure consistent plane orientation (normal pointing away from palm)
|
| 737 |
+
if np.dot(dir_vec, grasp_plane[:3]) > 0:
|
| 738 |
+
grasp_plane = -grasp_plane
|
| 739 |
+
|
| 740 |
+
# Create slightly offset plane for grasp point calculation
|
| 741 |
+
shifted_grasp_plane = self.get_parallel_plane(*grasp_plane, 0.01)
|
| 742 |
+
grasp_pt = self.calculate_grasp_point(shifted_grasp_plane, constrained_vertices)
|
| 743 |
+
|
| 744 |
+
# Calculate gripper orientation using the grasp plane
|
| 745 |
+
gripper_ori, _ = HandModel.get_gripper_orientation(thumb_tip, index_tip, constrained_vertices, grasp_plane)
|
| 746 |
+
|
| 747 |
+
# Apply coordinate frame transformations for robot compatibility
|
| 748 |
+
rot_90_deg = Rotation.from_euler('Z', 90, degrees=True).as_matrix()
|
| 749 |
+
grasp_ori = gripper_ori @ rot_90_deg
|
| 750 |
+
|
| 751 |
+
# Apply pitch adjustment
|
| 752 |
+
angle = -np.pi/18 * 1.0 # -10 degrees
|
| 753 |
+
grasp_ori = Rotation.from_rotvec(angle * np.array([1, 0, 0])).apply(grasp_ori)
|
| 754 |
+
|
| 755 |
+
# Offset grasp point along gripper Z-axis for clearance
|
| 756 |
+
unit_vectors = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
| 757 |
+
transformed_vectors = unit_vectors @ grasp_ori.T
|
| 758 |
+
grasp_pt = grasp_pt - transformed_vectors[2] * 0.015 # 1.5cm offset
|
| 759 |
+
|
| 760 |
+
# Store all frame data
|
| 761 |
+
self.joint_rotations.append(constrained_rotations)
|
| 762 |
+
self.joint_rotations_xyz.append(rotations_xyz)
|
| 763 |
+
self.vertex_positions.append(constrained_vertices)
|
| 764 |
+
self.grasp_points.append(grasp_pt)
|
| 765 |
+
self.grasp_oris.append(grasp_ori)
|
| 766 |
+
self.timestamps.append(timestamp)
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
def get_list_finger_pts_from_skeleton(skeleton_pts: np.ndarray) -> Dict[str, np.ndarray]:
|
| 770 |
+
"""
|
| 771 |
+
Organize hand skeleton points into finger-specific groups.
|
| 772 |
+
|
| 773 |
+
This utility function takes the 21-point hand skeleton and organizes
|
| 774 |
+
it into a dictionary with separate arrays for each finger. This makes
|
| 775 |
+
it easier to perform finger-specific calculations and analysis.
|
| 776 |
+
|
| 777 |
+
Args:
|
| 778 |
+
skeleton_pts: Hand skeleton points, shape (21, 3)
|
| 779 |
+
Points are ordered as: wrist, thumb(4), index(4), middle(4), ring(4), pinky(4)
|
| 780 |
+
|
| 781 |
+
Returns:
|
| 782 |
+
Dictionary with finger names as keys and point arrays as values:
|
| 783 |
+
- "thumb": Wrist + 4 thumb points, shape (5, 3)
|
| 784 |
+
- "index": Wrist + 4 index points, shape (5, 3)
|
| 785 |
+
- "middle": Wrist + 4 middle points, shape (5, 3)
|
| 786 |
+
- "ring": Wrist + 4 ring points, shape (5, 3)
|
| 787 |
+
- "pinky": Wrist + 4 pinky points, shape (5, 3)
|
| 788 |
+
"""
|
| 789 |
+
n_fingers = len(skeleton_pts) - 1 # Exclude wrist point
|
| 790 |
+
npts_per_finger = 4 # MCP, PIP, DIP, TIP for each finger
|
| 791 |
+
|
| 792 |
+
# Create finger arrays by combining wrist with each finger's points
|
| 793 |
+
list_fingers = [
|
| 794 |
+
np.vstack([skeleton_pts[0], skeleton_pts[i : i + npts_per_finger]])
|
| 795 |
+
for i in range(1, n_fingers, npts_per_finger)
|
| 796 |
+
]
|
| 797 |
+
|
| 798 |
+
# Return organized finger dictionary
|
| 799 |
+
return {
|
| 800 |
+
"thumb": list_fingers[0],
|
| 801 |
+
"index": list_fingers[1],
|
| 802 |
+
"middle": list_fingers[2],
|
| 803 |
+
"ring": list_fingers[3],
|
| 804 |
+
"pinky": list_fingers[4]
|
| 805 |
+
}
|
phantom/phantom/process_data.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from enum import Enum
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from joblib import Parallel, delayed # type: ignore
|
| 5 |
+
import hydra
|
| 6 |
+
from omegaconf import DictConfig
|
| 7 |
+
|
| 8 |
+
from phantom.processors.base_processor import BaseProcessor
|
| 9 |
+
|
| 10 |
+
logging.basicConfig(level=logging.WARNING, format="%(name)s - %(levelname)s - %(message)s")
|
| 11 |
+
|
| 12 |
+
class ProcessingMode(Enum):
|
| 13 |
+
"""Enumeration of valid processing modes."""
|
| 14 |
+
BBOX = "bbox"
|
| 15 |
+
HAND2D = "hand2d"
|
| 16 |
+
HAND3D = "hand3d"
|
| 17 |
+
HAND_SEGMENTATION = "hand_segmentation"
|
| 18 |
+
ARM_SEGMENTATION = "arm_segmentation"
|
| 19 |
+
ACTION = "action"
|
| 20 |
+
SMOOTHING = "smoothing"
|
| 21 |
+
HAND_INPAINT = "hand_inpaint"
|
| 22 |
+
ROBOT_INPAINT = "robot_inpaint"
|
| 23 |
+
ALL = "all"
|
| 24 |
+
|
| 25 |
+
PROCESSING_ORDER = [
|
| 26 |
+
"bbox",
|
| 27 |
+
"hand2d",
|
| 28 |
+
"arm_segmentation",
|
| 29 |
+
"hand_segmentation",
|
| 30 |
+
"hand3d",
|
| 31 |
+
"action",
|
| 32 |
+
"smoothing",
|
| 33 |
+
"hand_inpaint",
|
| 34 |
+
"robot_inpaint",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
PROCESSING_ORDER_EPIC = [
|
| 38 |
+
"bbox",
|
| 39 |
+
"hand2d",
|
| 40 |
+
"arm_segmentation",
|
| 41 |
+
"action",
|
| 42 |
+
"smoothing",
|
| 43 |
+
"hand_inpaint",
|
| 44 |
+
"robot_inpaint",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
def process_one_demo(data_sub_folder: str, cfg: DictConfig, processor_classes: dict) -> None:
|
| 48 |
+
# Choose processing order based on epic flag
|
| 49 |
+
processing_order = PROCESSING_ORDER_EPIC if cfg.epic else PROCESSING_ORDER
|
| 50 |
+
|
| 51 |
+
# Handle both string and list modes
|
| 52 |
+
if isinstance(cfg.mode, str):
|
| 53 |
+
# Handle comma-separated string format
|
| 54 |
+
if ',' in cfg.mode:
|
| 55 |
+
selected_modes = []
|
| 56 |
+
for mode in cfg.mode.split(','):
|
| 57 |
+
mode = mode.strip()
|
| 58 |
+
if mode == "all":
|
| 59 |
+
selected_modes.extend(processing_order)
|
| 60 |
+
elif mode in processing_order:
|
| 61 |
+
selected_modes.append(mode)
|
| 62 |
+
else:
|
| 63 |
+
selected_modes = [m for m in processing_order if m in cfg.mode or "all" in cfg.mode]
|
| 64 |
+
else:
|
| 65 |
+
# For list of modes, use the order provided by user
|
| 66 |
+
selected_modes = []
|
| 67 |
+
for mode in cfg.mode:
|
| 68 |
+
if mode == "all":
|
| 69 |
+
selected_modes.extend(processing_order)
|
| 70 |
+
elif mode in processing_order:
|
| 71 |
+
selected_modes.append(mode)
|
| 72 |
+
|
| 73 |
+
for mode in selected_modes:
|
| 74 |
+
print(f"----------------- {mode.upper()} PROCESSOR -----------------")
|
| 75 |
+
processor_cls = processor_classes[mode]
|
| 76 |
+
processor = processor_cls(cfg)
|
| 77 |
+
try:
|
| 78 |
+
processor.process_one_demo(data_sub_folder)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"Error in {mode} processing: {e}")
|
| 81 |
+
if cfg.debug:
|
| 82 |
+
raise
|
| 83 |
+
|
| 84 |
+
def process_all_demos(cfg: DictConfig, processor_classes: dict) -> None:
|
| 85 |
+
# Choose processing order based on epic flag
|
| 86 |
+
processing_order = PROCESSING_ORDER_EPIC if cfg.epic else PROCESSING_ORDER
|
| 87 |
+
|
| 88 |
+
# Handle both string and list modes
|
| 89 |
+
if isinstance(cfg.mode, str):
|
| 90 |
+
# Handle comma-separated string format
|
| 91 |
+
if ',' in cfg.mode:
|
| 92 |
+
selected_modes = []
|
| 93 |
+
for mode in cfg.mode.split(','):
|
| 94 |
+
mode = mode.strip()
|
| 95 |
+
if mode == "all":
|
| 96 |
+
selected_modes.extend(processing_order)
|
| 97 |
+
elif mode in processing_order:
|
| 98 |
+
selected_modes.append(mode)
|
| 99 |
+
else:
|
| 100 |
+
selected_modes = [m for m in processing_order if m in cfg.mode or "all" in cfg.mode]
|
| 101 |
+
else:
|
| 102 |
+
# For list of modes, use the order provided by user
|
| 103 |
+
selected_modes = []
|
| 104 |
+
for mode in cfg.mode:
|
| 105 |
+
if mode == "all":
|
| 106 |
+
selected_modes.extend(processing_order)
|
| 107 |
+
elif mode in processing_order:
|
| 108 |
+
selected_modes.append(mode)
|
| 109 |
+
|
| 110 |
+
base_processor = BaseProcessor(cfg)
|
| 111 |
+
all_data_folders = base_processor.all_data_folders.copy()
|
| 112 |
+
for mode in selected_modes:
|
| 113 |
+
print(f"----------------- {mode.upper()} PROCESSOR -----------------")
|
| 114 |
+
processor_cls = processor_classes[mode]
|
| 115 |
+
processor = processor_cls(cfg)
|
| 116 |
+
for data_sub_folder in tqdm(all_data_folders):
|
| 117 |
+
try:
|
| 118 |
+
processor.process_one_demo(data_sub_folder)
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"Error in {mode} processing: {e}")
|
| 121 |
+
if cfg.debug:
|
| 122 |
+
raise
|
| 123 |
+
|
| 124 |
+
def process_all_demos_parallel(cfg: DictConfig, processor_classes: dict) -> None:
|
| 125 |
+
# Choose processing order based on epic flag
|
| 126 |
+
processing_order = PROCESSING_ORDER_EPIC if cfg.epic else PROCESSING_ORDER
|
| 127 |
+
|
| 128 |
+
# Handle both string and list modes
|
| 129 |
+
if isinstance(cfg.mode, str):
|
| 130 |
+
# Handle comma-separated string format
|
| 131 |
+
if ',' in cfg.mode:
|
| 132 |
+
selected_modes = []
|
| 133 |
+
for mode in cfg.mode.split(','):
|
| 134 |
+
mode = mode.strip()
|
| 135 |
+
if mode == "all":
|
| 136 |
+
selected_modes.extend(processing_order)
|
| 137 |
+
elif mode in processing_order:
|
| 138 |
+
selected_modes.append(mode)
|
| 139 |
+
else:
|
| 140 |
+
selected_modes = [m for m in processing_order if m in cfg.mode or "all" in cfg.mode]
|
| 141 |
+
else:
|
| 142 |
+
# For list of modes, use the order provided by user
|
| 143 |
+
selected_modes = []
|
| 144 |
+
for mode in cfg.mode:
|
| 145 |
+
if mode == "all":
|
| 146 |
+
selected_modes.extend(processing_order)
|
| 147 |
+
elif mode in processing_order:
|
| 148 |
+
selected_modes.append(mode)
|
| 149 |
+
|
| 150 |
+
base_processor = BaseProcessor(cfg)
|
| 151 |
+
all_data_folders = base_processor.all_data_folders.copy()
|
| 152 |
+
for mode in selected_modes:
|
| 153 |
+
print(f"----------------- {mode.upper()} PROCESSOR -----------------")
|
| 154 |
+
processor_cls = processor_classes[mode]
|
| 155 |
+
processor = processor_cls(cfg)
|
| 156 |
+
Parallel(n_jobs=cfg.n_processes)(
|
| 157 |
+
delayed(processor.process_one_demo)(data_sub_folder) for data_sub_folder in all_data_folders
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
def get_processor_classes(cfg: DictConfig) -> dict:
|
| 161 |
+
"""Initialize the processor classes"""
|
| 162 |
+
from phantom.processors.bbox_processor import BBoxProcessor
|
| 163 |
+
from phantom.processors.segmentation_processor import HandSegmentationProcessor, ArmSegmentationProcessor
|
| 164 |
+
from phantom.processors.hand_processor import Hand2DProcessor, Hand3DProcessor
|
| 165 |
+
from phantom.processors.action_processor import ActionProcessor
|
| 166 |
+
from phantom.processors.smoothing_processor import SmoothingProcessor
|
| 167 |
+
from phantom.processors.robotinpaint_processor import RobotInpaintProcessor
|
| 168 |
+
from phantom.processors.handinpaint_processor import HandInpaintProcessor
|
| 169 |
+
|
| 170 |
+
return {
|
| 171 |
+
"bbox": BBoxProcessor,
|
| 172 |
+
"hand2d": Hand2DProcessor,
|
| 173 |
+
"hand3d": Hand3DProcessor,
|
| 174 |
+
"hand_segmentation": HandSegmentationProcessor,
|
| 175 |
+
"arm_segmentation": ArmSegmentationProcessor,
|
| 176 |
+
"action": ActionProcessor,
|
| 177 |
+
"smoothing": SmoothingProcessor,
|
| 178 |
+
"robot_inpaint": RobotInpaintProcessor,
|
| 179 |
+
"hand_inpaint": HandInpaintProcessor,
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
def validate_mode(cfg: DictConfig) -> None:
|
| 183 |
+
"""
|
| 184 |
+
Validate that the mode parameter contains only valid processing modes.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
cfg: Configuration object containing mode parameter
|
| 188 |
+
|
| 189 |
+
Raises:
|
| 190 |
+
ValueError: If mode contains invalid options
|
| 191 |
+
"""
|
| 192 |
+
if isinstance(cfg.mode, str):
|
| 193 |
+
# Handle comma-separated string format
|
| 194 |
+
if ',' in cfg.mode:
|
| 195 |
+
modes = [mode.strip() for mode in cfg.mode.split(',')]
|
| 196 |
+
else:
|
| 197 |
+
modes = [cfg.mode]
|
| 198 |
+
else:
|
| 199 |
+
modes = cfg.mode
|
| 200 |
+
|
| 201 |
+
# Get valid modes from enum
|
| 202 |
+
valid_modes = {mode.value for mode in ProcessingMode}
|
| 203 |
+
invalid_modes = [mode for mode in modes if mode not in valid_modes]
|
| 204 |
+
|
| 205 |
+
if invalid_modes:
|
| 206 |
+
valid_mode_list = [mode.value for mode in ProcessingMode]
|
| 207 |
+
raise ValueError(
|
| 208 |
+
f"Invalid mode(s): {invalid_modes}. "
|
| 209 |
+
f"Valid modes are: {valid_mode_list}"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
def main(cfg: DictConfig):
|
| 213 |
+
# Validate mode parameter
|
| 214 |
+
validate_mode(cfg)
|
| 215 |
+
|
| 216 |
+
# Get processor classes
|
| 217 |
+
processor_classes = get_processor_classes(cfg)
|
| 218 |
+
|
| 219 |
+
if cfg.n_processes > 1:
|
| 220 |
+
process_all_demos_parallel(cfg, processor_classes)
|
| 221 |
+
elif cfg.demo_num is not None:
|
| 222 |
+
process_one_demo(cfg.demo_num, cfg, processor_classes)
|
| 223 |
+
else:
|
| 224 |
+
process_all_demos(cfg, processor_classes)
|
| 225 |
+
|
| 226 |
+
@hydra.main(version_base=None, config_path="../configs", config_name="default")
|
| 227 |
+
def hydra_main(cfg: DictConfig):
|
| 228 |
+
"""
|
| 229 |
+
Main entry point using Hydra configuration.
|
| 230 |
+
|
| 231 |
+
Example usage:
|
| 232 |
+
- Process all demos with bbox: python process_data.py mode=bbox
|
| 233 |
+
- Process single demo: python process_data.py mode=bbox demo_num=0
|
| 234 |
+
- Use EPIC dataset: python process_data.py dataset=epic mode=bbox
|
| 235 |
+
- Parallel processing: python process_data.py mode=bbox n_processes=4
|
| 236 |
+
- Process multiple modes sequentially: python process_data.py mode=bbox,hand3d
|
| 237 |
+
- Process with custom order: python process_data.py mode=hand3d,bbox,action
|
| 238 |
+
- Process with bracket notation (use quotes): python process_data.py "mode=[bbox,hand3d]"
|
| 239 |
+
"""
|
| 240 |
+
main(cfg)
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
hydra_main()
|
phantom/phantom/processors/__init__.py
ADDED
|
File without changes
|
phantom/phantom/processors/action_processor.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Action Processor Module
|
| 3 |
+
|
| 4 |
+
This module processes hand motion capture data and converts it into robot-executable actions.
|
| 5 |
+
It handles both single-arm and bimanual robotic setups, converting detected hand keypoints
|
| 6 |
+
into end-effector positions, orientations, and gripper widths that can be used for robot control.
|
| 7 |
+
|
| 8 |
+
Key Features:
|
| 9 |
+
- Converts hand keypoints from camera frame to robot frame
|
| 10 |
+
- Supports both unconstrained and physically constrained hand models
|
| 11 |
+
- Handles missing hand detections with interpolation
|
| 12 |
+
- Processes bimanual data with union-based frame selection
|
| 13 |
+
- Generates neutral poses when no hand data is available
|
| 14 |
+
|
| 15 |
+
The processor follows this pipeline:
|
| 16 |
+
1. Load hand sequence data (keypoints, detection flags)
|
| 17 |
+
2. Convert keypoints to robot coordinate frame
|
| 18 |
+
3. Apply hand model constraints (optional)
|
| 19 |
+
4. Extract end-effector poses and gripper states
|
| 20 |
+
5. Refine actions to handle missing detections
|
| 21 |
+
6. Save processed actions for robot execution
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import numpy as np
|
| 26 |
+
from typing import Tuple, Optional
|
| 27 |
+
from dataclasses import dataclass
|
| 28 |
+
import logging
|
| 29 |
+
from scipy.spatial.transform import Rotation
|
| 30 |
+
|
| 31 |
+
from phantom.processors.base_processor import BaseProcessor
|
| 32 |
+
from phantom.processors.phantom_data import HandSequence
|
| 33 |
+
from phantom.processors.paths import Paths
|
| 34 |
+
from phantom.hand import HandModel, PhysicallyConstrainedHandModel, get_list_finger_pts_from_skeleton
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class EEActions:
|
| 40 |
+
"""
|
| 41 |
+
Container for bimanual end-effector action data.
|
| 42 |
+
|
| 43 |
+
This dataclass holds the processed robot actions for a sequence of timesteps,
|
| 44 |
+
including 3D positions, 3D orientations, and gripper opening widths.
|
| 45 |
+
|
| 46 |
+
Attributes:
|
| 47 |
+
ee_pts (np.ndarray): End-effector positions, shape (N, 3) in robot frame coordinates
|
| 48 |
+
ee_oris (np.ndarray): End-effector orientations as rotation matrices, shape (N, 3, 3)
|
| 49 |
+
ee_widths (np.ndarray): Gripper opening widths in meters, shape (N,)
|
| 50 |
+
"""
|
| 51 |
+
ee_pts: np.ndarray # End-effector positions (N, 3)
|
| 52 |
+
ee_oris: np.ndarray # End-effector orientations (N, 3, 3) as rotation matrices
|
| 53 |
+
ee_widths: np.ndarray # Gripper widths (N,)
|
| 54 |
+
|
| 55 |
+
class ActionProcessor(BaseProcessor):
|
| 56 |
+
"""
|
| 57 |
+
Processor for converting hand motion capture data into robot-executable actions.
|
| 58 |
+
|
| 59 |
+
This class handles the complete pipeline from raw hand keypoints to refined robot actions.
|
| 60 |
+
It supports both single-arm and bimanual robotic setups, with intelligent handling of
|
| 61 |
+
missing hand detections and physically realistic constraints.
|
| 62 |
+
|
| 63 |
+
The processor can operate in different modes:
|
| 64 |
+
- Single arm: Processes only left or right hand data
|
| 65 |
+
- Bimanual: Processes both hands with union-based frame selection
|
| 66 |
+
|
| 67 |
+
Key processing steps:
|
| 68 |
+
1. Load hand sequences with 3D keypoints and detection flags
|
| 69 |
+
2. Transform keypoints from camera frame to robot frame
|
| 70 |
+
3. Fit hand model (optionally with physical constraints)
|
| 71 |
+
4. Extract end-effector poses and gripper states
|
| 72 |
+
5. Refine actions using last-valid-value interpolation
|
| 73 |
+
6. Generate neutral poses for undetected periods
|
| 74 |
+
|
| 75 |
+
Attributes:
|
| 76 |
+
dt (float): Time delta between frames (1/15 seconds for 15Hz processing)
|
| 77 |
+
bimanual_setup (str): Setup type ("single_arm", "shoulders", etc.)
|
| 78 |
+
target_hand (str): Which hand to process in single-arm mode ("left"/"right")
|
| 79 |
+
constrained_hand (bool): Whether to use physically constrained hand model
|
| 80 |
+
T_cam2robot (np.ndarray): 4x4 transformation matrix from camera to robot frame
|
| 81 |
+
"""
|
| 82 |
+
def __init__(self, args):
|
| 83 |
+
# Set processing frequency to 15Hz
|
| 84 |
+
self.dt = 1/15
|
| 85 |
+
super().__init__(args)
|
| 86 |
+
|
| 87 |
+
def process_one_demo(self, data_sub_folder: str) -> None:
|
| 88 |
+
"""
|
| 89 |
+
Process a single demonstration recording into robot actions.
|
| 90 |
+
|
| 91 |
+
This is the main entry point for processing one demo. It handles both
|
| 92 |
+
single-arm and bimanual processing modes, loading the raw hand data,
|
| 93 |
+
converting it to robot actions, and saving the results.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
data_sub_folder (str): Path to the folder containing this demo's data
|
| 97 |
+
"""
|
| 98 |
+
save_folder = self.get_save_folder(data_sub_folder)
|
| 99 |
+
paths = self.get_paths(save_folder)
|
| 100 |
+
|
| 101 |
+
# Load hand sequence data for both hands
|
| 102 |
+
left_sequence, right_sequence = self._load_sequences(paths)
|
| 103 |
+
|
| 104 |
+
# Handle single-arm processing mode
|
| 105 |
+
if self.bimanual_setup == "single_arm":
|
| 106 |
+
self._process_single_arm(left_sequence, right_sequence, paths)
|
| 107 |
+
else:
|
| 108 |
+
self._process_bimanual(left_sequence, right_sequence, paths)
|
| 109 |
+
|
| 110 |
+
def _process_single_arm(self, left_sequence: HandSequence, right_sequence: HandSequence, paths) -> None:
|
| 111 |
+
"""Process single-arm setup with one target hand."""
|
| 112 |
+
# Select target hand based on configuration
|
| 113 |
+
target_sequence = left_sequence if self.target_hand == "left" else right_sequence
|
| 114 |
+
|
| 115 |
+
# Process the selected hand sequence
|
| 116 |
+
target_actions = self._process_hand_sequence(target_sequence, self.T_cam2robot)
|
| 117 |
+
|
| 118 |
+
# Get indices where hand was detected for this sequence
|
| 119 |
+
union_indices = np.where(target_sequence.hand_detected)[0]
|
| 120 |
+
|
| 121 |
+
# Refine actions to handle missing detections
|
| 122 |
+
target_actions_refined = self._refine_actions(target_sequence, target_actions, union_indices, self.target_hand)
|
| 123 |
+
|
| 124 |
+
# Save results for the selected hand only
|
| 125 |
+
if self.target_hand == "left":
|
| 126 |
+
self._save_results(paths, union_indices=union_indices, left_actions=target_actions_refined)
|
| 127 |
+
else:
|
| 128 |
+
self._save_results(paths, union_indices=union_indices, right_actions=target_actions_refined)
|
| 129 |
+
|
| 130 |
+
def _process_bimanual(self, left_sequence: HandSequence, right_sequence: HandSequence, paths) -> None:
|
| 131 |
+
"""Process bimanual setup with both hands."""
|
| 132 |
+
# Process both hand sequences
|
| 133 |
+
left_actions = self._process_hand_sequence(left_sequence, self.T_cam2robot)
|
| 134 |
+
right_actions = self._process_hand_sequence(right_sequence, self.T_cam2robot)
|
| 135 |
+
|
| 136 |
+
# Combine detection results using OR logic - frame is valid if either hand detected
|
| 137 |
+
union_indices = np.where(left_sequence.hand_detected | right_sequence.hand_detected)[0]
|
| 138 |
+
|
| 139 |
+
# Refine actions for both hands using the union indices
|
| 140 |
+
left_actions_refined = self._refine_actions(left_sequence, left_actions, union_indices, "left")
|
| 141 |
+
right_actions_refined = self._refine_actions(right_sequence, right_actions, union_indices, "right")
|
| 142 |
+
|
| 143 |
+
# Save results for both hands
|
| 144 |
+
self._save_results(paths, union_indices, left_actions_refined, right_actions_refined)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _load_sequences(self, paths) -> Tuple[HandSequence, HandSequence]:
|
| 148 |
+
"""
|
| 149 |
+
Load hand sequences from disk for both left and right hands.
|
| 150 |
+
|
| 151 |
+
HandSequence objects contain the processed keypoint data, detection flags,
|
| 152 |
+
and other metadata needed for action processing.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
paths: Paths object containing file locations for hand data
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Tuple[HandSequence, HandSequence]: Left and right hand sequences
|
| 159 |
+
"""
|
| 160 |
+
return (
|
| 161 |
+
HandSequence.load(paths.hand_data_left),
|
| 162 |
+
HandSequence.load(paths.hand_data_right)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def _process_hand_sequence(
|
| 166 |
+
self,
|
| 167 |
+
sequence: HandSequence,
|
| 168 |
+
T_cam2robot: np.ndarray,
|
| 169 |
+
) -> EEActions:
|
| 170 |
+
"""
|
| 171 |
+
Process a single hand sequence into end-effector actions.
|
| 172 |
+
|
| 173 |
+
This method performs the following processing pipeline for one hand:
|
| 174 |
+
1. Transform keypoints from camera frame to robot frame
|
| 175 |
+
2. Fit a hand model to the keypoint sequence
|
| 176 |
+
3. Extract end-effector poses and gripper states
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
sequence (HandSequence): Hand keypoint sequence with detection flags
|
| 180 |
+
T_cam2robot (np.ndarray): 4x4 transformation matrix from camera to robot frame
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
EEActions: Processed end-effector positions, orientations, and gripper widths
|
| 184 |
+
"""
|
| 185 |
+
# Convert keypoints from camera frame to robot frame coordinates
|
| 186 |
+
kpts_3d_cf = sequence.kpts_3d # Camera frame keypoints
|
| 187 |
+
kpts_3d_rf = ActionProcessor._convert_pts_to_robot_frame(
|
| 188 |
+
kpts_3d_cf,
|
| 189 |
+
T_cam2robot
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Create and fit hand model to the keypoint sequence
|
| 193 |
+
hand_model = self._get_hand_model(kpts_3d_rf, sequence.hand_detected)
|
| 194 |
+
|
| 195 |
+
# Extract end-effector poses and gripper states from fitted model
|
| 196 |
+
kpts_3d, ee_pts, ee_oris = self._get_model_keypoints(hand_model)
|
| 197 |
+
|
| 198 |
+
# Compute gripper opening distances from fingertip positions
|
| 199 |
+
ee_widths = self._compute_gripper_distances(
|
| 200 |
+
kpts_3d,
|
| 201 |
+
sequence.hand_detected
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
return EEActions(
|
| 205 |
+
ee_pts=ee_pts,
|
| 206 |
+
ee_oris=ee_oris,
|
| 207 |
+
ee_widths=ee_widths,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def _get_hand_model(self, kpts_3d_rf: np.ndarray, hand_detected: np.ndarray) -> HandModel | PhysicallyConstrainedHandModel:
|
| 211 |
+
"""
|
| 212 |
+
Create and fit a hand model to the keypoint sequence.
|
| 213 |
+
|
| 214 |
+
The hand model can be either unconstrained (simple fitting) or physically
|
| 215 |
+
constrained (enforces realistic hand poses and robot constraints).
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
kpts_3d_rf (np.ndarray): Hand keypoints in robot frame, shape (N, 21, 3)
|
| 219 |
+
hand_detected (np.ndarray): Boolean array indicating valid detections, shape (N,)
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
HandModel | PhysicallyConstrainedHandModel: Fitted hand model with trajectory data
|
| 223 |
+
"""
|
| 224 |
+
# Choose hand model type based on configuration
|
| 225 |
+
if self.constrained_hand:
|
| 226 |
+
hand_model = PhysicallyConstrainedHandModel(self.robot)
|
| 227 |
+
else:
|
| 228 |
+
hand_model = HandModel(self.robot)
|
| 229 |
+
|
| 230 |
+
# Add each frame to the model for trajectory fitting
|
| 231 |
+
for t_idx in range(len(kpts_3d_rf)):
|
| 232 |
+
hand_model.add_frame(
|
| 233 |
+
kpts_3d_rf[t_idx],
|
| 234 |
+
t_idx * self.dt, # Convert frame index to time
|
| 235 |
+
hand_detected[t_idx]
|
| 236 |
+
)
|
| 237 |
+
return hand_model
|
| 238 |
+
|
| 239 |
+
def _get_model_keypoints(self, model: HandModel | PhysicallyConstrainedHandModel) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 240 |
+
"""
|
| 241 |
+
Extract keypoints and end-effector data from fitted hand model.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
model (HandModel | PhysicallyConstrainedHandModel): Fitted hand model
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
Tuple containing:
|
| 248 |
+
- kpts_3d (np.ndarray): Model keypoint positions, shape (N, 21, 3)
|
| 249 |
+
- ee_pts (np.ndarray): End-effector positions, shape (N, 3)
|
| 250 |
+
- ee_oris (np.ndarray): End-effector orientations, shape (N, 3, 3)
|
| 251 |
+
"""
|
| 252 |
+
kpts_3d = np.array(model.vertex_positions) # All hand keypoints
|
| 253 |
+
ee_pts = np.array(model.grasp_points) # End-effector positions (palm center)
|
| 254 |
+
ee_oris = np.array(model.grasp_oris) # End-effector orientations (rotation matrices)
|
| 255 |
+
return kpts_3d, ee_pts, ee_oris
|
| 256 |
+
|
| 257 |
+
def _compute_gripper_distances(
|
| 258 |
+
self,
|
| 259 |
+
kpts_3d_rf: np.ndarray,
|
| 260 |
+
hand_detected: np.ndarray
|
| 261 |
+
) -> np.ndarray:
|
| 262 |
+
"""
|
| 263 |
+
Compute gripper opening distances for all frames in the sequence.
|
| 264 |
+
|
| 265 |
+
The gripper distance is calculated as the Euclidean distance between
|
| 266 |
+
the thumb tip and index finger tip, providing a proxy for gripper state.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
kpts_3d_rf (np.ndarray): Hand keypoints in robot frame, shape (N, 21, 3)
|
| 270 |
+
hand_detected (np.ndarray): Boolean flags for valid detections, shape (N,)
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
np.ndarray: Gripper distances for each frame, shape (N,)
|
| 274 |
+
"""
|
| 275 |
+
gripper_dists = np.zeros(len(kpts_3d_rf))
|
| 276 |
+
|
| 277 |
+
for idx in range(len(kpts_3d_rf)):
|
| 278 |
+
if hand_detected[idx]:
|
| 279 |
+
# Only compute distance for frames with valid hand detection
|
| 280 |
+
gripper_dists[idx] = ActionProcessor._compute_gripper_opening(
|
| 281 |
+
kpts_3d_rf[idx]
|
| 282 |
+
)
|
| 283 |
+
# Note: Invalid frames remain at 0.0, will be refined later
|
| 284 |
+
return gripper_dists
|
| 285 |
+
|
| 286 |
+
def _refine_actions(
|
| 287 |
+
self,
|
| 288 |
+
sequence: HandSequence,
|
| 289 |
+
actions: EEActions,
|
| 290 |
+
union_indices: np.ndarray,
|
| 291 |
+
hand_side: str
|
| 292 |
+
) -> EEActions:
|
| 293 |
+
"""
|
| 294 |
+
Refine actions to handle missing hand detections using last-valid-value interpolation.
|
| 295 |
+
|
| 296 |
+
When hand detection fails, this method fills in missing values by carrying forward
|
| 297 |
+
the last valid pose and gripper state. This creates smooth, executable trajectories
|
| 298 |
+
even when the vision system temporarily loses tracking.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
sequence (HandSequence): Original hand sequence with detection flags
|
| 302 |
+
actions (EEActions): Raw actions from hand model
|
| 303 |
+
union_indices (np.ndarray): Frame indices to include in final trajectory
|
| 304 |
+
hand_side (str): "left" or "right" for neutral pose generation
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
EEActions: Refined actions with interpolated values for missing detections
|
| 308 |
+
"""
|
| 309 |
+
# Find frames where this hand was actually detected
|
| 310 |
+
hand_detected_indices = np.where(sequence.hand_detected)[0]
|
| 311 |
+
|
| 312 |
+
# If no valid detections, return neutral pose for entire sequence
|
| 313 |
+
if len(hand_detected_indices) == 0:
|
| 314 |
+
return self._get_neutral_actions(hand_side, len(union_indices))
|
| 315 |
+
|
| 316 |
+
# Apply carry-forward interpolation
|
| 317 |
+
return self._apply_carry_forward_interpolation(sequence, actions, union_indices, hand_detected_indices)
|
| 318 |
+
|
| 319 |
+
def _apply_carry_forward_interpolation(
|
| 320 |
+
self,
|
| 321 |
+
sequence: HandSequence,
|
| 322 |
+
actions: EEActions,
|
| 323 |
+
union_indices: np.ndarray,
|
| 324 |
+
hand_detected_indices: np.ndarray
|
| 325 |
+
) -> EEActions:
|
| 326 |
+
"""Apply last-valid-value interpolation to fill missing detections."""
|
| 327 |
+
# Initialize with first valid detection values
|
| 328 |
+
first_valid_idx = hand_detected_indices[0]
|
| 329 |
+
last_valid_pt = actions.ee_pts[first_valid_idx]
|
| 330 |
+
last_valid_ori = actions.ee_oris[first_valid_idx]
|
| 331 |
+
last_valid_width = actions.ee_widths[first_valid_idx]
|
| 332 |
+
|
| 333 |
+
# Process each frame in the union sequence
|
| 334 |
+
ee_pts_refined = []
|
| 335 |
+
ee_oris_refined = []
|
| 336 |
+
ee_widths_refined = []
|
| 337 |
+
|
| 338 |
+
for idx in union_indices:
|
| 339 |
+
if sequence.hand_detected[idx]:
|
| 340 |
+
# Update with new valid values when available
|
| 341 |
+
last_valid_pt = actions.ee_pts[idx]
|
| 342 |
+
last_valid_ori = actions.ee_oris[idx]
|
| 343 |
+
last_valid_width = actions.ee_widths[idx]
|
| 344 |
+
|
| 345 |
+
# Always append the last valid values (carry-forward for missing frames)
|
| 346 |
+
ee_pts_refined.append(last_valid_pt)
|
| 347 |
+
ee_oris_refined.append(last_valid_ori)
|
| 348 |
+
ee_widths_refined.append(last_valid_width)
|
| 349 |
+
|
| 350 |
+
return EEActions(
|
| 351 |
+
ee_pts=np.array(ee_pts_refined),
|
| 352 |
+
ee_oris=np.array(ee_oris_refined),
|
| 353 |
+
ee_widths=np.array(ee_widths_refined),
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
def _get_neutral_actions(self, hand_side: str, n_frames: int) -> EEActions:
|
| 357 |
+
"""
|
| 358 |
+
Generate neutral pose actions when no hand detection is available.
|
| 359 |
+
|
| 360 |
+
Neutral poses place the robot arms in out-of-frame positions.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
hand_side (str): "left" or "right" to determine which neutral pose to use
|
| 364 |
+
n_frames (int): Number of frames to generate
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
EEActions: Neutral pose actions for the specified number of frames
|
| 368 |
+
"""
|
| 369 |
+
# Define neutral pose configurations
|
| 370 |
+
neutral_configs = {
|
| 371 |
+
"single_arm": {
|
| 372 |
+
"right": {"pos": [0.2, -0.8, 0.3], "quat": [1, 0.0, 0.0, 0.0]},
|
| 373 |
+
"left": {"pos": [0.2, 0.8, 0.3], "quat": [1, 0.0, 0.0, 0.0]}
|
| 374 |
+
},
|
| 375 |
+
"shoulders": {
|
| 376 |
+
"right": {"pos": [0.4, -0.5, 0.3], "quat": [-0.7071, 0.0, 0.0, 0.7071]},
|
| 377 |
+
"left": {"pos": [0.4, 0.5, 0.3], "quat": [0.7071, 0.0, 0.0, 0.7071]}
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
# Get configuration for current setup and hand
|
| 382 |
+
config = neutral_configs[self.bimanual_setup][hand_side]
|
| 383 |
+
|
| 384 |
+
# Convert to numpy arrays and create rotation matrix
|
| 385 |
+
neutral_pos = np.array(config["pos"])
|
| 386 |
+
neutral_ori = Rotation.from_quat(config["quat"], scalar_first=False).as_matrix()
|
| 387 |
+
neutral_width = 0.085 # Standard gripper opening (8.5cm)
|
| 388 |
+
|
| 389 |
+
# Create arrays replicated for all frames
|
| 390 |
+
return EEActions(
|
| 391 |
+
ee_pts=np.repeat(neutral_pos.reshape(1, 3), n_frames, axis=0),
|
| 392 |
+
ee_oris=np.repeat(neutral_ori.reshape(1, 3, 3), n_frames, axis=0),
|
| 393 |
+
ee_widths=np.full(n_frames, neutral_width)
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
def _save_results(
|
| 397 |
+
self,
|
| 398 |
+
paths: Paths,
|
| 399 |
+
union_indices: np.ndarray,
|
| 400 |
+
left_actions: Optional[EEActions] = None,
|
| 401 |
+
right_actions: Optional[EEActions] = None,
|
| 402 |
+
) -> None:
|
| 403 |
+
"""
|
| 404 |
+
Save processed action results to disk in NPZ format.
|
| 405 |
+
|
| 406 |
+
The saved files contain all necessary data for robot execution:
|
| 407 |
+
- union_indices: Valid frame indices in the original sequence
|
| 408 |
+
- ee_pts: End-effector positions
|
| 409 |
+
- ee_oris: End-effector orientations (rotation matrices)
|
| 410 |
+
- ee_widths: Gripper opening widths
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
paths (Paths): File path configuration object
|
| 414 |
+
union_indices (np.ndarray): Valid frame indices
|
| 415 |
+
left_actions (Optional[EEActions]): Left hand actions to save
|
| 416 |
+
right_actions (Optional[EEActions]): Right hand actions to save
|
| 417 |
+
"""
|
| 418 |
+
# Create output directory if it doesn't exist
|
| 419 |
+
os.makedirs(paths.action_processor, exist_ok=True)
|
| 420 |
+
|
| 421 |
+
# Save actions for each hand if provided
|
| 422 |
+
if left_actions is not None:
|
| 423 |
+
self._save_hand_actions(paths.actions_left, union_indices, left_actions)
|
| 424 |
+
if right_actions is not None:
|
| 425 |
+
self._save_hand_actions(paths.actions_right, union_indices, right_actions)
|
| 426 |
+
|
| 427 |
+
def _save_hand_actions(self, base_path: str, union_indices: np.ndarray, actions: EEActions) -> None:
|
| 428 |
+
"""Save actions for a single hand to NPZ file."""
|
| 429 |
+
file_path = str(base_path).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
|
| 430 |
+
np.savez(
|
| 431 |
+
file_path,
|
| 432 |
+
union_indices=union_indices,
|
| 433 |
+
ee_pts=actions.ee_pts,
|
| 434 |
+
ee_oris=actions.ee_oris,
|
| 435 |
+
ee_widths=actions.ee_widths
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
@staticmethod
|
| 439 |
+
def _compute_gripper_opening(skeleton_pts: np.ndarray) -> float:
|
| 440 |
+
"""
|
| 441 |
+
Compute gripper opening distance from hand keypoints for a single frame.
|
| 442 |
+
|
| 443 |
+
The gripper distance is calculated as the Euclidean distance between
|
| 444 |
+
the thumb tip and index finger tip.
|
| 445 |
+
|
| 446 |
+
Args:
|
| 447 |
+
skeleton_pts (np.ndarray): Hand keypoints for one frame, shape (21, 3)
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
float: Distance between thumb tip and index finger tip in meters
|
| 451 |
+
"""
|
| 452 |
+
# Extract finger tip positions from the hand skeleton
|
| 453 |
+
finger_dict = get_list_finger_pts_from_skeleton(skeleton_pts)
|
| 454 |
+
|
| 455 |
+
# Compute distance between thumb tip and index finger tip
|
| 456 |
+
return np.linalg.norm(finger_dict["thumb"][-1] - finger_dict["index"][-1])
|
| 457 |
+
|
| 458 |
+
@staticmethod
|
| 459 |
+
def _convert_pts_to_robot_frame(skeleton_poses_cf: np.ndarray, T_cam2robot: np.ndarray) -> np.ndarray:
|
| 460 |
+
"""
|
| 461 |
+
Convert hand keypoints from camera frame to robot frame coordinates.
|
| 462 |
+
|
| 463 |
+
Args:
|
| 464 |
+
skeleton_poses_cf (np.ndarray): Hand poses in camera frame, shape (N, 21, 3)
|
| 465 |
+
T_cam2robot (np.ndarray): 4x4 transformation matrix from camera to robot frame
|
| 466 |
+
|
| 467 |
+
Returns:
|
| 468 |
+
np.ndarray: Hand poses in robot frame, shape (N, 21, 3)
|
| 469 |
+
"""
|
| 470 |
+
# Convert to homogeneous coordinates by adding ones
|
| 471 |
+
pts_h = np.ones((skeleton_poses_cf.shape[0], skeleton_poses_cf.shape[1], 1))
|
| 472 |
+
skeleton_poses_cf_h = np.concatenate([skeleton_poses_cf, pts_h], axis=-1)
|
| 473 |
+
|
| 474 |
+
# Apply transformation matrix to convert coordinate frames
|
| 475 |
+
skeleton_poses_rf_h0 = np.einsum('ij,bpj->bpi', T_cam2robot, skeleton_poses_cf_h)
|
| 476 |
+
|
| 477 |
+
# Remove homogeneous coordinate and return 3D points
|
| 478 |
+
return skeleton_poses_rf_h0[..., :3]
|
phantom/phantom/processors/base_processor.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import numpy as np
|
| 5 |
+
import shutil
|
| 6 |
+
import errno
|
| 7 |
+
from typing import Tuple
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from omegaconf import DictConfig
|
| 10 |
+
|
| 11 |
+
from phantom.utils.data_utils import get_parent_folder_of_package
|
| 12 |
+
from phantom.utils.image_utils import get_intrinsics_from_json, get_transformation_matrix_from_extrinsics
|
| 13 |
+
from phantom.processors.paths import Paths, PathsConfig
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class BaseProcessor:
|
| 18 |
+
def __init__(self, cfg: DictConfig):
|
| 19 |
+
# Store configuration for potential future use
|
| 20 |
+
self.cfg = cfg
|
| 21 |
+
|
| 22 |
+
# Apply configuration to instance attributes
|
| 23 |
+
self._apply_config(cfg)
|
| 24 |
+
|
| 25 |
+
# Validate configuration
|
| 26 |
+
self._validate_config(cfg)
|
| 27 |
+
|
| 28 |
+
# Set up paths and data folders
|
| 29 |
+
self._setup_paths_and_folders(cfg)
|
| 30 |
+
|
| 31 |
+
# Initialize camera parameters
|
| 32 |
+
self._init_camera_parameters()
|
| 33 |
+
|
| 34 |
+
def _apply_config(self, cfg: DictConfig) -> None:
|
| 35 |
+
"""Apply configuration to instance attributes."""
|
| 36 |
+
# Basic attributes
|
| 37 |
+
self.input_resolution = cfg.input_resolution
|
| 38 |
+
self.output_resolution = cfg.output_resolution
|
| 39 |
+
self.project_folder = get_parent_folder_of_package("phantom")
|
| 40 |
+
self.debug = cfg.debug
|
| 41 |
+
self.n_processes = cfg.n_processes
|
| 42 |
+
self.verbose = cfg.verbose
|
| 43 |
+
self.skip_existing = cfg.skip_existing
|
| 44 |
+
self.robot = cfg.robot
|
| 45 |
+
self.gripper = cfg.gripper
|
| 46 |
+
self.square = cfg.square
|
| 47 |
+
self.epic = cfg.epic
|
| 48 |
+
self.bimanual_setup = cfg.bimanual_setup
|
| 49 |
+
self.target_hand = cfg.target_hand
|
| 50 |
+
self.constrained_hand = cfg.constrained_hand
|
| 51 |
+
self.depth_for_overlay = cfg.depth_for_overlay
|
| 52 |
+
self.render = cfg.render
|
| 53 |
+
self.debug_cameras = getattr(cfg, 'debug_cameras', [])
|
| 54 |
+
|
| 55 |
+
# Apply bimanual setup logic
|
| 56 |
+
if self.bimanual_setup != "single_arm":
|
| 57 |
+
self.target_hand = "both"
|
| 58 |
+
|
| 59 |
+
def _validate_config(self, cfg: DictConfig) -> None:
|
| 60 |
+
"""Validate critical configuration parameters."""
|
| 61 |
+
if cfg.input_resolution <= 0 or cfg.output_resolution <= 0:
|
| 62 |
+
raise ValueError(f"Resolutions must be positive: input={cfg.input_resolution}, output={cfg.output_resolution}")
|
| 63 |
+
|
| 64 |
+
if not os.path.exists(cfg.data_root_dir):
|
| 65 |
+
raise FileNotFoundError(f"Data root directory not found: {cfg.data_root_dir}")
|
| 66 |
+
|
| 67 |
+
if not os.path.exists(cfg.camera_intrinsics):
|
| 68 |
+
raise FileNotFoundError(f"Camera intrinsics file not found: {cfg.camera_intrinsics}")
|
| 69 |
+
|
| 70 |
+
def _setup_paths_and_folders(self, cfg: DictConfig) -> None:
|
| 71 |
+
"""Set up paths configuration and create necessary directories."""
|
| 72 |
+
# Set up paths configuration
|
| 73 |
+
self.paths_config = PathsConfig()
|
| 74 |
+
self.paths_config.config['data_root'] = cfg.data_root_dir
|
| 75 |
+
self.paths_config.config['processed_root'] = cfg.processed_data_root_dir
|
| 76 |
+
|
| 77 |
+
self.data_folder = os.path.join(cfg.data_root_dir, cfg.demo_name)
|
| 78 |
+
self.processed_data_folder = os.path.join(cfg.processed_data_root_dir, cfg.demo_name)
|
| 79 |
+
|
| 80 |
+
# Validate that data folder exists
|
| 81 |
+
if not os.path.exists(self.data_folder):
|
| 82 |
+
raise FileNotFoundError(f"Data folder not found: {self.data_folder}")
|
| 83 |
+
|
| 84 |
+
os.makedirs(self.processed_data_folder, exist_ok=True)
|
| 85 |
+
|
| 86 |
+
# Get all folders in data_folder
|
| 87 |
+
try:
|
| 88 |
+
all_data_folders = [d1 for d1 in os.listdir(self.data_folder) if os.path.isdir(os.path.join(self.data_folder, d1))]
|
| 89 |
+
self.all_data_folders = sorted(all_data_folders, key=lambda x: int(x))
|
| 90 |
+
self.all_data_folders_idx = {x: idx for idx, x in enumerate(self.all_data_folders)}
|
| 91 |
+
except OSError as e:
|
| 92 |
+
if e.errno == errno.EACCES:
|
| 93 |
+
raise PermissionError(f"Permission denied accessing data folder: {self.data_folder}")
|
| 94 |
+
elif e.errno == errno.ENOENT:
|
| 95 |
+
raise FileNotFoundError(f"Data folder not found: {self.data_folder}")
|
| 96 |
+
else:
|
| 97 |
+
raise RuntimeError(f"OS error accessing data folder {self.data_folder}: {e}")
|
| 98 |
+
except ValueError as e:
|
| 99 |
+
raise ValueError(f"Invalid folder name format in {self.data_folder}. Folders should be numbered: {e}")
|
| 100 |
+
|
| 101 |
+
def _init_camera_parameters(self) -> None:
|
| 102 |
+
"""Initialize camera intrinsics and extrinsics."""
|
| 103 |
+
# Get camera intrinsics and extrinsics
|
| 104 |
+
self.intrinsics_dict, self.intrinsics_matrix = self.get_intrinsics(self.cfg.camera_intrinsics)
|
| 105 |
+
|
| 106 |
+
# Use camera_extrinsics from config if available, otherwise determine from bimanual_setup
|
| 107 |
+
if hasattr(self.cfg, 'camera_extrinsics') and self.cfg.camera_extrinsics:
|
| 108 |
+
camera_extrinsics_path = self.cfg.camera_extrinsics
|
| 109 |
+
else:
|
| 110 |
+
camera_extrinsics_path = self._get_camera_extrinsics_path()
|
| 111 |
+
|
| 112 |
+
self.T_cam2robot, self.extrinsics = self.get_extrinsics(camera_extrinsics_path)
|
| 113 |
+
|
| 114 |
+
def _get_camera_extrinsics_path(self) -> str:
|
| 115 |
+
"""Get the appropriate camera extrinsics path based on bimanual setup."""
|
| 116 |
+
if self.bimanual_setup == "shoulders":
|
| 117 |
+
return "camera/camera_extrinsics_ego_bimanual_shoulders.json"
|
| 118 |
+
elif self.bimanual_setup == "single_arm":
|
| 119 |
+
return "camera/camera_extrinsics.json"
|
| 120 |
+
else:
|
| 121 |
+
raise ValueError(f"Invalid bimanual setup: {self.bimanual_setup}. Must be 'single_arm' or 'shoulders'.")
|
| 122 |
+
|
| 123 |
+
def get_paths(self, data_path: str) -> Paths:
|
| 124 |
+
"""
|
| 125 |
+
Get all file paths for a demo.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
data_path: Path to the demo data
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Paths object containing all file paths
|
| 132 |
+
"""
|
| 133 |
+
paths = Paths(
|
| 134 |
+
data_path=Path(data_path),
|
| 135 |
+
robot_name=self.robot
|
| 136 |
+
)
|
| 137 |
+
paths.ensure_directories_exist()
|
| 138 |
+
return paths
|
| 139 |
+
|
| 140 |
+
def get_save_folder(self, data_sub_folder: str) -> str:
|
| 141 |
+
data_sub_folder_fullpath = os.path.join(self.data_folder, str(data_sub_folder))
|
| 142 |
+
save_folder = os.path.join(self.processed_data_folder, str(data_sub_folder))
|
| 143 |
+
# Check existing dirs using os.scandir
|
| 144 |
+
with os.scandir(self.processed_data_folder) as it:
|
| 145 |
+
existing_dirs = {entry.name for entry in it if entry.is_dir()}
|
| 146 |
+
if str(data_sub_folder) not in existing_dirs:
|
| 147 |
+
shutil.copytree(data_sub_folder_fullpath, save_folder)
|
| 148 |
+
return save_folder
|
| 149 |
+
|
| 150 |
+
def process_one_demo(self, data_sub_folder: str):
|
| 151 |
+
raise NotImplementedError
|
| 152 |
+
|
| 153 |
+
def get_intrinsics(self, intrinsics_path: str) -> Tuple[dict, np.ndarray]:
|
| 154 |
+
intrinsics_matrix, intrinsics_dict = get_intrinsics_from_json(intrinsics_path)
|
| 155 |
+
if self.square:
|
| 156 |
+
intrinsics_dict, intrinsics_matrix = self.update_intrinsics_for_square_image(self.input_resolution,
|
| 157 |
+
intrinsics_dict,
|
| 158 |
+
intrinsics_matrix)
|
| 159 |
+
return intrinsics_dict, intrinsics_matrix
|
| 160 |
+
|
| 161 |
+
def get_extrinsics(self, extrinsics_path: str) -> Tuple[np.ndarray, dict]:
|
| 162 |
+
"""Load and process camera extrinsics from JSON file.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
extrinsics_path: Path to the extrinsics JSON file
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Tuple of (transformation_matrix, extrinsics_dict)
|
| 169 |
+
|
| 170 |
+
Raises:
|
| 171 |
+
FileNotFoundError: If extrinsics file doesn't exist
|
| 172 |
+
json.JSONDecodeError: If extrinsics file is invalid JSON
|
| 173 |
+
ValueError: If extrinsics data is invalid
|
| 174 |
+
"""
|
| 175 |
+
if not os.path.exists(extrinsics_path):
|
| 176 |
+
raise FileNotFoundError(f"Camera extrinsics file not found: {extrinsics_path}")
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
with open(extrinsics_path, "r") as f:
|
| 180 |
+
camera_extrinsics = json.load(f)
|
| 181 |
+
except json.JSONDecodeError as e:
|
| 182 |
+
raise ValueError(f"Invalid JSON in extrinsics file {extrinsics_path}: {str(e)}")
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
T_cam2robot = get_transformation_matrix_from_extrinsics(camera_extrinsics)
|
| 186 |
+
except Exception as e:
|
| 187 |
+
raise ValueError(f"Failed to process extrinsics data from {extrinsics_path}: {str(e)}")
|
| 188 |
+
|
| 189 |
+
return T_cam2robot, camera_extrinsics
|
| 190 |
+
|
| 191 |
+
@staticmethod
|
| 192 |
+
def update_intrinsics_for_square_image(img_h: int, intrinsics_dict: dict,
|
| 193 |
+
intrinsics_matrix: np.ndarray) -> Tuple[dict, np.ndarray]:
|
| 194 |
+
"""
|
| 195 |
+
Adjusts camera intrinsic parameters for a square image by modifying the principal point offset.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
img_h (int): Height of the image (assumed to be square).
|
| 199 |
+
intrinsics_dict (dict): Dictionary of intrinsic parameters.
|
| 200 |
+
intrinsics_matrix (np.ndarray): Intrinsic matrix.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Tuple[dict, np.ndarray]: Updated intrinsic parameters and matrix.
|
| 204 |
+
"""
|
| 205 |
+
img_w = img_h * 16 // 9
|
| 206 |
+
offset = (img_w - img_h) // 2
|
| 207 |
+
intrinsics_dict["cx"] -= offset
|
| 208 |
+
intrinsics_matrix[0, 2] -= offset
|
| 209 |
+
return intrinsics_dict, intrinsics_matrix
|
phantom/phantom/processors/bbox_processor.py
ADDED
|
@@ -0,0 +1,851 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Bounding Box Processor Module
|
| 3 |
+
|
| 4 |
+
This module provides video processing capabilities for detecting and tracking hand bounding boxes
|
| 5 |
+
in demonstration videos. It serves as the first stage in the hand processing pipeline, providing
|
| 6 |
+
spatial localization data for downstream pose estimation and segmentation tasks.
|
| 7 |
+
|
| 8 |
+
Key Features:
|
| 9 |
+
- Multiple hand detection methods (DINO, EPIC-KITCHENS integration)
|
| 10 |
+
- Bimanual hand tracking with left/right classification
|
| 11 |
+
- Temporal consistency through outlier filtering and interpolation
|
| 12 |
+
- Spatial constraint validation (edge detection, center positioning)
|
| 13 |
+
- Visualization and annotation generation
|
| 14 |
+
|
| 15 |
+
Processing Pipeline:
|
| 16 |
+
1. Video loading and validation
|
| 17 |
+
2. Frame-by-frame hand detection using configured detectors
|
| 18 |
+
3. Bounding box classification (left/right) based on spatial positioning
|
| 19 |
+
4. Temporal filtering to remove outliers and large jumps
|
| 20 |
+
5. Gap interpolation for smooth trajectories
|
| 21 |
+
6. Edge distance calculation for quality assessment
|
| 22 |
+
7. Result visualization and storage
|
| 23 |
+
|
| 24 |
+
The processor supports multiple detection backends:
|
| 25 |
+
- DINO-based detection for general hand detection
|
| 26 |
+
- EPIC-KITCHENS pre-computed detections
|
| 27 |
+
- Configurable confidence thresholds and spatial constraints
|
| 28 |
+
|
| 29 |
+
Output Data:
|
| 30 |
+
- Hand detection flags per frame (boolean arrays)
|
| 31 |
+
- Bounding box coordinates [x1, y1, x2, y2] per frame
|
| 32 |
+
- Bounding box centers [x, y] per frame
|
| 33 |
+
- Distance metrics to image edges
|
| 34 |
+
- Annotated visualization videos
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import os
|
| 38 |
+
import pickle
|
| 39 |
+
import logging
|
| 40 |
+
import numpy as np
|
| 41 |
+
import mediapy as media
|
| 42 |
+
import cv2
|
| 43 |
+
import itertools
|
| 44 |
+
import time
|
| 45 |
+
import matplotlib.pyplot as plt
|
| 46 |
+
from typing import List, Tuple, Optional, Any, Dict
|
| 47 |
+
from typing_extensions import Literal
|
| 48 |
+
import numpy.typing as npt
|
| 49 |
+
from omegaconf import DictConfig
|
| 50 |
+
|
| 51 |
+
from phantom.processors.base_processor import BaseProcessor
|
| 52 |
+
from phantom.processors.paths import Paths
|
| 53 |
+
from phantom.processors.phantom_data import hand_side_dict
|
| 54 |
+
|
| 55 |
+
from phantom.utils.bbox_utils import get_bbox_center, get_bbox_center_min_dist_to_edge
|
| 56 |
+
|
| 57 |
+
logger = logging.getLogger(__name__)
|
| 58 |
+
|
| 59 |
+
# Type aliases for better readability
|
| 60 |
+
DetectionResults = Dict[str, npt.NDArray]
|
| 61 |
+
BBoxArray = npt.NDArray[np.float32] # [x1, y1, x2, y2]
|
| 62 |
+
CenterArray = npt.NDArray[np.float32] # [x, y]
|
| 63 |
+
DetectionFlagArray = npt.NDArray[np.bool_]
|
| 64 |
+
HandSide = Literal["left", "right"]
|
| 65 |
+
|
| 66 |
+
class BBoxProcessor(BaseProcessor):
|
| 67 |
+
# Detection configuration constants
|
| 68 |
+
HAND_SIDE_MARGIN = 50 # Pixel margin for hand side classification tolerance
|
| 69 |
+
OVERLAP_THRESHOLD = 0.3 # Threshold for considering bboxes as overlapping
|
| 70 |
+
MAX_INTERPOLATION_GAP = 10 # Maximum frames to interpolate over
|
| 71 |
+
MAX_SPATIAL_JUMP = 200.0 # Maximum allowed pixel jump between detections
|
| 72 |
+
MAX_JUMP_LOOKAHEAD = 10 # Maximum consecutive distant points to filter
|
| 73 |
+
DINO_CONFIDENCE_THRESH = 0.2 # Default confidence threshold
|
| 74 |
+
|
| 75 |
+
# Visualization constants
|
| 76 |
+
LEFT_HAND_COLOR = (0, 0, 255) # BGR format - Red for left hand
|
| 77 |
+
RIGHT_HAND_COLOR = (0, 255, 0) # BGR format - Green for right hand
|
| 78 |
+
BBOX_THICKNESS = 2 # Thickness of bounding box lines
|
| 79 |
+
|
| 80 |
+
"""
|
| 81 |
+
Bounding box detection and tracking processor for hand localization in videos.
|
| 82 |
+
|
| 83 |
+
This processor serves as the foundation of the hand processing pipeline by detecting
|
| 84 |
+
and tracking hand bounding boxes across video frames. It handles both single-arm
|
| 85 |
+
and bimanual setups.
|
| 86 |
+
|
| 87 |
+
The processor employs multiple strategies for reliable detection:
|
| 88 |
+
- Primary detection using DINO or pre-computed EPIC data
|
| 89 |
+
- Spatial reasoning for left/right hand classification
|
| 90 |
+
- Temporal filtering to maintain trajectory consistency
|
| 91 |
+
- Gap interpolation for handling missing detections
|
| 92 |
+
- Quality assessment through edge distance metrics
|
| 93 |
+
|
| 94 |
+
Attributes:
|
| 95 |
+
H (int): Video frame height (set during processing)
|
| 96 |
+
W (int): Video frame width (set during processing)
|
| 97 |
+
center (int): Horizontal center of the frame for left/right classification
|
| 98 |
+
margin (int): Pixel margin for hand side classification tolerance
|
| 99 |
+
confidence_threshold (float): Minimum confidence for valid detections
|
| 100 |
+
dino_detector: DINO-based hand detector (if not using EPIC data)
|
| 101 |
+
filtered_hand_detection_data (dict): Processed EPIC detection data
|
| 102 |
+
sorted_keys (list): Sorted frame indices for EPIC data processing
|
| 103 |
+
"""
|
| 104 |
+
def __init__(self, cfg: DictConfig) -> None:
|
| 105 |
+
"""
|
| 106 |
+
Initialize the bounding box processor with configuration parameters.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
cfg: Hydra configuration object containing processing configuration
|
| 110 |
+
including confidence thresholds, target hands, and dataset type
|
| 111 |
+
"""
|
| 112 |
+
super().__init__(cfg)
|
| 113 |
+
# Image dimensions (set when processing video)
|
| 114 |
+
self.H: int = 0
|
| 115 |
+
self.W: int = 0
|
| 116 |
+
|
| 117 |
+
# Initialize detection backend based on dataset type
|
| 118 |
+
if not self.epic:
|
| 119 |
+
from phantom.detectors.detector_dino import DetectorDino
|
| 120 |
+
self.dino_detector: DetectorDino = DetectorDino("IDEA-Research/grounding-dino-base")
|
| 121 |
+
else:
|
| 122 |
+
self.dino_detector: Optional[DetectorDino] = None
|
| 123 |
+
|
| 124 |
+
# EPIC-specific attributes
|
| 125 |
+
self.filtered_hand_detection_data: Dict[str, List[Any]] = {}
|
| 126 |
+
self.sorted_keys: List[str] = []
|
| 127 |
+
|
| 128 |
+
# ============================================================================
|
| 129 |
+
# COMMON/SHARED METHODS (Used by both Phantom and EPIC modes)
|
| 130 |
+
# ============================================================================
|
| 131 |
+
|
| 132 |
+
def process_one_demo(self, data_sub_folder: str) -> None:
|
| 133 |
+
"""
|
| 134 |
+
Process a single demonstration video to extract hand bounding boxes.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
data_sub_folder: Path to the demonstration data folder containing the video
|
| 138 |
+
and any pre-computed hand detection data.
|
| 139 |
+
|
| 140 |
+
The method performs the following steps:
|
| 141 |
+
1. Loads and validates input video and detection data
|
| 142 |
+
2. Processes each frame to detect and classify hand positions
|
| 143 |
+
3. Applies post-processing filters for temporal consistency
|
| 144 |
+
4. Generates quality metrics and visualizations
|
| 145 |
+
5. Saves all results in standardized format
|
| 146 |
+
|
| 147 |
+
Raises:
|
| 148 |
+
FileNotFoundError: If required input files (video, detection data) are not found
|
| 149 |
+
ValueError: If video frames or hand detection data are invalid
|
| 150 |
+
"""
|
| 151 |
+
# Setup and validation
|
| 152 |
+
save_folder = self.get_save_folder(data_sub_folder)
|
| 153 |
+
|
| 154 |
+
paths = self.get_paths(save_folder)
|
| 155 |
+
|
| 156 |
+
# Load and validate input data
|
| 157 |
+
imgs_rgb = self._load_video(paths)
|
| 158 |
+
|
| 159 |
+
# Process frames based on dataset type
|
| 160 |
+
if self.epic:
|
| 161 |
+
self._load_epic_hand_data(paths)
|
| 162 |
+
detection_results = self._process_epic_frames(imgs_rgb)
|
| 163 |
+
else:
|
| 164 |
+
detection_results = self._process_frames(imgs_rgb)
|
| 165 |
+
|
| 166 |
+
# Post-process results for temporal consistency
|
| 167 |
+
processed_results = self._post_process_detections(detection_results)
|
| 168 |
+
|
| 169 |
+
# Generate visualization for quality assessment
|
| 170 |
+
visualization_results = self._generate_visualization(imgs_rgb, processed_results)
|
| 171 |
+
|
| 172 |
+
# Save all results to disk
|
| 173 |
+
self._save_results(paths, processed_results, visualization_results)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _load_video(self, paths: Paths) -> np.ndarray:
|
| 177 |
+
"""
|
| 178 |
+
Load and validate video data from the specified path.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
paths: Paths object containing video file locations
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
RGB video frames as array
|
| 185 |
+
|
| 186 |
+
Raises:
|
| 187 |
+
FileNotFoundError: If video file doesn't exist
|
| 188 |
+
ValueError: If video is empty or corrupted
|
| 189 |
+
"""
|
| 190 |
+
if not os.path.exists(paths.video_left):
|
| 191 |
+
raise FileNotFoundError(f"Video file not found: {paths.video_left}")
|
| 192 |
+
|
| 193 |
+
imgs_rgb = media.read_video(getattr(paths, f"video_left"))
|
| 194 |
+
if len(imgs_rgb) == 0:
|
| 195 |
+
raise ValueError("Empty video file")
|
| 196 |
+
|
| 197 |
+
# Store video dimensions for coordinate calculations
|
| 198 |
+
self.H, self.W, _ = imgs_rgb[0].shape
|
| 199 |
+
self.center: int = self.W // 2 # Center line for left/right classification
|
| 200 |
+
return imgs_rgb
|
| 201 |
+
|
| 202 |
+
# ============================================================================
|
| 203 |
+
# PHANTOM-SPECIFIC METHODS (DINO Detection)
|
| 204 |
+
# ============================================================================
|
| 205 |
+
def _process_frames(self, imgs_rgb: np.ndarray) -> Dict[str, np.ndarray]:
|
| 206 |
+
"""
|
| 207 |
+
Process RGB frames using DINO detector for hand detection and classification.
|
| 208 |
+
|
| 209 |
+
This method handles the core detection pipeline for non-EPIC datasets,
|
| 210 |
+
using DINO for hand detection and implementing spatial reasoning for
|
| 211 |
+
left/right classification.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
imgs_rgb: Array of RGB images with shape (num_frames, height, width, 3)
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Dictionary containing:
|
| 218 |
+
- left/right_hand_detected: Boolean arrays indicating hand detection per frame
|
| 219 |
+
- left/right_bboxes: Bounding box coordinates [x1,y1,x2,y2] per frame
|
| 220 |
+
- left/right_bboxes_ctr: Bounding box centers [x,y] per frame
|
| 221 |
+
"""
|
| 222 |
+
num_frames = len(imgs_rgb)
|
| 223 |
+
|
| 224 |
+
detection_arrays = self._initialize_detection_arrays(num_frames)
|
| 225 |
+
|
| 226 |
+
for idx in range(num_frames):
|
| 227 |
+
try:
|
| 228 |
+
# Run DINO detection on current frame
|
| 229 |
+
bboxes, scores = self.dino_detector.get_bboxes(imgs_rgb[idx], "a hand", threshold=self.DINO_CONFIDENCE_THRESH, visualize=False)
|
| 230 |
+
if len(bboxes) == 0:
|
| 231 |
+
continue
|
| 232 |
+
|
| 233 |
+
bboxes = np.array(bboxes)
|
| 234 |
+
scores = np.array(scores)
|
| 235 |
+
|
| 236 |
+
# Process detections for current frame
|
| 237 |
+
self._process_frame_detections(idx, bboxes, scores, detection_arrays)
|
| 238 |
+
except Exception as e:
|
| 239 |
+
logger.warning(f"Frame {idx} processing failed: {str(e)}")
|
| 240 |
+
continue
|
| 241 |
+
|
| 242 |
+
return {
|
| 243 |
+
'left_hand_detected': detection_arrays['left_hand_detected'],
|
| 244 |
+
'right_hand_detected': detection_arrays['right_hand_detected'],
|
| 245 |
+
'left_bboxes': detection_arrays['left_bboxes'],
|
| 246 |
+
'right_bboxes': detection_arrays['right_bboxes'],
|
| 247 |
+
'left_bboxes_ctr': detection_arrays['left_bboxes_ctr'],
|
| 248 |
+
'right_bboxes_ctr': detection_arrays['right_bboxes_ctr'],
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
def _initialize_detection_arrays(self, num_frames: int) -> Dict[str, npt.NDArray]:
|
| 252 |
+
"""
|
| 253 |
+
Initialize arrays for storing detection results.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
num_frames: Number of frames in the video
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
Dictionary containing pre-allocated arrays for left/right hand detections,
|
| 260 |
+
bounding boxes, centers, and detection flags
|
| 261 |
+
"""
|
| 262 |
+
return {
|
| 263 |
+
'left_bboxes': np.zeros((num_frames, 4)),
|
| 264 |
+
'right_bboxes': np.zeros((num_frames, 4)),
|
| 265 |
+
'left_bboxes_ctr': np.zeros((num_frames, 2)),
|
| 266 |
+
'right_bboxes_ctr': np.zeros((num_frames, 2)),
|
| 267 |
+
'left_hand_detected': np.zeros(num_frames, dtype=bool),
|
| 268 |
+
'right_hand_detected': np.zeros(num_frames, dtype=bool)
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
def _process_frame_detections(self, idx: int, bboxes: npt.NDArray, scores: npt.NDArray,
|
| 272 |
+
detection_arrays: Dict[str, npt.NDArray]) -> None:
|
| 273 |
+
"""
|
| 274 |
+
Process detections for a single frame.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
idx: Frame index
|
| 278 |
+
bboxes: Array of detected bounding boxes
|
| 279 |
+
scores: Array of detection confidence scores
|
| 280 |
+
detection_arrays: Dictionary to store detection results
|
| 281 |
+
"""
|
| 282 |
+
if len(bboxes) == 0:
|
| 283 |
+
return
|
| 284 |
+
|
| 285 |
+
# Always select the bounding box with the highest score
|
| 286 |
+
best_idx = np.argmax(scores)
|
| 287 |
+
best_bbox = bboxes[best_idx]
|
| 288 |
+
best_bbox_ctr = get_bbox_center(best_bbox)
|
| 289 |
+
|
| 290 |
+
# Assign hand type directly based on self.target_hand
|
| 291 |
+
if self.target_hand == "left":
|
| 292 |
+
detection_arrays['left_bboxes'][idx] = best_bbox
|
| 293 |
+
detection_arrays['left_bboxes_ctr'][idx] = best_bbox_ctr
|
| 294 |
+
detection_arrays['left_hand_detected'][idx] = True
|
| 295 |
+
elif self.target_hand == "right":
|
| 296 |
+
detection_arrays['right_bboxes'][idx] = best_bbox
|
| 297 |
+
detection_arrays['right_bboxes_ctr'][idx] = best_bbox_ctr
|
| 298 |
+
detection_arrays['right_hand_detected'][idx] = True
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# ============================================================================
|
| 302 |
+
# EPIC-SPECIFIC METHODS (EPIC Dataset Processing)
|
| 303 |
+
# ============================================================================
|
| 304 |
+
|
| 305 |
+
def _validate_epic_data_structure(self, epic_data: List[Any]) -> bool:
|
| 306 |
+
"""Validate EPIC data structure before processing."""
|
| 307 |
+
if not epic_data:
|
| 308 |
+
return False
|
| 309 |
+
|
| 310 |
+
# Check if first item has required attributes
|
| 311 |
+
try:
|
| 312 |
+
first_item = epic_data[0]
|
| 313 |
+
if not hasattr(first_item, 'side') or not hasattr(first_item, 'bbox'):
|
| 314 |
+
logging.warning("EPIC data missing required attributes: 'side' or 'bbox'")
|
| 315 |
+
return False
|
| 316 |
+
|
| 317 |
+
# Check if bbox has required attributes
|
| 318 |
+
bbox = first_item.bbox
|
| 319 |
+
required_attrs = ['left', 'right', 'top', 'bottom']
|
| 320 |
+
if not all(hasattr(bbox, attr) for attr in required_attrs):
|
| 321 |
+
logging.warning("EPIC bbox missing required attributes: left, right, top, bottom")
|
| 322 |
+
return False
|
| 323 |
+
|
| 324 |
+
return True
|
| 325 |
+
except Exception as e:
|
| 326 |
+
logging.warning(f"Error validating EPIC data structure: {str(e)}")
|
| 327 |
+
return False
|
| 328 |
+
|
| 329 |
+
def _load_epic_hand_data(self, paths: Paths) -> Dict[str, Any]:
|
| 330 |
+
"""
|
| 331 |
+
Load and validate pre-computed hand detection data from EPIC-KITCHENS dataset.
|
| 332 |
+
|
| 333 |
+
EPIC-KITCHENS provides pre-computed hand detection annotations that we can
|
| 334 |
+
use directly instead of running our own detection. This method filters and
|
| 335 |
+
sorts the data for efficient frame-by-frame processing.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
paths: Paths object containing detection data file location
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
Dictionary of filtered and sorted hand detection data
|
| 342 |
+
|
| 343 |
+
Raises:
|
| 344 |
+
FileNotFoundError: If detection data file doesn't exist
|
| 345 |
+
"""
|
| 346 |
+
if not os.path.exists(paths.hand_detection_data):
|
| 347 |
+
raise FileNotFoundError(f"Hand detection data not found: {paths.hand_detection_data}")
|
| 348 |
+
|
| 349 |
+
with open(paths.hand_detection_data, 'rb') as f:
|
| 350 |
+
hand_detection_data = dict(pickle.load(f))
|
| 351 |
+
|
| 352 |
+
# Filter out detection objects without valid side information
|
| 353 |
+
filtered_data = {
|
| 354 |
+
key: [obj for obj in obj_list if hasattr(obj, 'side')]
|
| 355 |
+
for key, obj_list in hand_detection_data.items()
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
# Sort by frame index for sequential processing
|
| 359 |
+
self.filtered_hand_detection_data = dict(sorted(filtered_data.items(), key=lambda x: int(x[0])))
|
| 360 |
+
self.sorted_keys = sorted(self.filtered_hand_detection_data.keys(), key=lambda k: int(k))
|
| 361 |
+
|
| 362 |
+
return self.filtered_hand_detection_data
|
| 363 |
+
|
| 364 |
+
def _process_epic_frames(self, imgs_rgb: npt.NDArray[np.uint8]) -> DetectionResults:
|
| 365 |
+
"""
|
| 366 |
+
Process frames using pre-computed EPIC-KITCHENS hand detection data.
|
| 367 |
+
|
| 368 |
+
This method processes EPIC-KITCHENS dataset videos using their provided
|
| 369 |
+
hand detection annotations, converting them to our standard format while
|
| 370 |
+
applying spatial validation constraints.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
imgs_rgb: Array of RGB images for dimension reference
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
Dictionary containing detection results in the same format as _process_frames
|
| 377 |
+
"""
|
| 378 |
+
num_frames = len(imgs_rgb)
|
| 379 |
+
|
| 380 |
+
detection_arrays = self._initialize_detection_arrays(num_frames)
|
| 381 |
+
|
| 382 |
+
# Process each frame using EPIC detection data
|
| 383 |
+
for idx in range(num_frames):
|
| 384 |
+
try:
|
| 385 |
+
epic_data = self.filtered_hand_detection_data[self.sorted_keys[idx]]
|
| 386 |
+
|
| 387 |
+
if len(epic_data) == 0:
|
| 388 |
+
continue
|
| 389 |
+
|
| 390 |
+
# Process frame detections
|
| 391 |
+
self._process_epic_frame_detections(idx, epic_data, detection_arrays)
|
| 392 |
+
except KeyError:
|
| 393 |
+
logger.warning(f"Missing EPIC data for frame {idx}")
|
| 394 |
+
continue
|
| 395 |
+
except Exception as e:
|
| 396 |
+
logger.warning(f"EPIC frame {idx} processing failed: {str(e)}")
|
| 397 |
+
continue
|
| 398 |
+
|
| 399 |
+
return {
|
| 400 |
+
'left_hand_detected': detection_arrays['left_hand_detected'],
|
| 401 |
+
'right_hand_detected': detection_arrays['right_hand_detected'],
|
| 402 |
+
'left_bboxes': detection_arrays['left_bboxes'],
|
| 403 |
+
'right_bboxes': detection_arrays['right_bboxes'],
|
| 404 |
+
'left_bboxes_ctr': detection_arrays['left_bboxes_ctr'],
|
| 405 |
+
'right_bboxes_ctr': detection_arrays['right_bboxes_ctr']
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
def _process_epic_frame_detections(self, idx: int, epic_data: List[Any],
|
| 409 |
+
detection_arrays: Dict[str, npt.NDArray]) -> None:
|
| 410 |
+
"""Process EPIC detections for a single frame."""
|
| 411 |
+
# Process left and right hands separately
|
| 412 |
+
left_detected, left_bbox, left_bbox_ctr = self._process_epic_hand_detection(epic_data, "left")
|
| 413 |
+
right_detected, right_bbox, right_bbox_ctr = self._process_epic_hand_detection(epic_data, "right")
|
| 414 |
+
|
| 415 |
+
# Store results in pre-allocated arrays
|
| 416 |
+
detection_arrays['left_hand_detected'][idx] = left_detected
|
| 417 |
+
detection_arrays['right_hand_detected'][idx] = right_detected
|
| 418 |
+
if left_detected:
|
| 419 |
+
detection_arrays['left_bboxes'][idx] = left_bbox
|
| 420 |
+
detection_arrays['left_bboxes_ctr'][idx] = left_bbox_ctr
|
| 421 |
+
if right_detected:
|
| 422 |
+
detection_arrays['right_bboxes'][idx] = right_bbox
|
| 423 |
+
detection_arrays['right_bboxes_ctr'][idx] = right_bbox_ctr
|
| 424 |
+
|
| 425 |
+
# Quality check: If hands appear crossed (left hand on right side),
|
| 426 |
+
# mark both as invalid to avoid confusion
|
| 427 |
+
if left_detected and right_detected:
|
| 428 |
+
self._validate_hand_positions(idx, left_bbox_ctr, right_bbox_ctr, detection_arrays)
|
| 429 |
+
|
| 430 |
+
def _validate_hand_positions(self, idx: int, left_bbox_ctr: npt.NDArray, right_bbox_ctr: npt.NDArray,
|
| 431 |
+
detection_arrays: Dict[str, npt.NDArray]) -> None:
|
| 432 |
+
"""Validate that hands are on correct sides of the image."""
|
| 433 |
+
if left_bbox_ctr[0] > right_bbox_ctr[0]:
|
| 434 |
+
# Left hand appears to be on the right side - mark both as invalid
|
| 435 |
+
detection_arrays['left_hand_detected'][idx] = False
|
| 436 |
+
detection_arrays['right_hand_detected'][idx] = False
|
| 437 |
+
|
| 438 |
+
def _process_epic_hand_detection(self,
|
| 439 |
+
epic_data: List[Any],
|
| 440 |
+
hand_side: HandSide) -> Tuple[bool, BBoxArray, CenterArray]:
|
| 441 |
+
"""
|
| 442 |
+
Process EPIC hand detection data for a single frame and hand side.
|
| 443 |
+
|
| 444 |
+
This method extracts and validates hand detection data from EPIC annotations,
|
| 445 |
+
converting normalized coordinates to pixel coordinates and applying spatial
|
| 446 |
+
validation constraints.
|
| 447 |
+
|
| 448 |
+
Args:
|
| 449 |
+
epic_data: List of detection objects for the current frame
|
| 450 |
+
hand_side: Either "left" or "right" specifying which hand to process
|
| 451 |
+
|
| 452 |
+
Returns:
|
| 453 |
+
Tuple of (is_detected: bool, bbox: ndarray, bbox_center: ndarray)
|
| 454 |
+
"""
|
| 455 |
+
if hand_side not in hand_side_dict:
|
| 456 |
+
raise ValueError(f"Invalid hand side: {hand_side}")
|
| 457 |
+
|
| 458 |
+
# Default empty result for failed detections
|
| 459 |
+
empty_result = (False, np.array([0, 0, 0, 0]), np.array([0, 0]))
|
| 460 |
+
|
| 461 |
+
try:
|
| 462 |
+
# Filter and validate detection data
|
| 463 |
+
hand_data = self._filter_epic_hand_data(epic_data, hand_side)
|
| 464 |
+
if not hand_data:
|
| 465 |
+
return empty_result
|
| 466 |
+
|
| 467 |
+
# Validate data structure
|
| 468 |
+
if not self._validate_epic_data_structure(hand_data):
|
| 469 |
+
return empty_result
|
| 470 |
+
|
| 471 |
+
# Extract and process bounding box
|
| 472 |
+
bbox, bbox_center = self._extract_epic_bbox(hand_data[0])
|
| 473 |
+
|
| 474 |
+
# Validate bounding box coordinates
|
| 475 |
+
if not self._validate_bbox_coordinates(hand_data[0].bbox, hand_side):
|
| 476 |
+
return empty_result
|
| 477 |
+
|
| 478 |
+
# Apply spatial validation
|
| 479 |
+
is_valid = self._validate_spatial_position(bbox_center, hand_side)
|
| 480 |
+
return (is_valid, bbox, bbox_center) if is_valid else empty_result
|
| 481 |
+
|
| 482 |
+
except Exception as e:
|
| 483 |
+
logging.warning(f"Unexpected error processing {hand_side} hand detection: {str(e)}")
|
| 484 |
+
return empty_result
|
| 485 |
+
|
| 486 |
+
def _filter_epic_hand_data(self, epic_data: List[Any], hand_side: HandSide) -> List[Any]:
|
| 487 |
+
"""Filter EPIC detection data for the specified hand side."""
|
| 488 |
+
return [data for data in epic_data if data.side.value == hand_side_dict[hand_side]]
|
| 489 |
+
|
| 490 |
+
def _extract_epic_bbox(self, hand_data: Any) -> Tuple[BBoxArray, CenterArray]:
|
| 491 |
+
"""Extract bounding box and center from EPIC hand detection data."""
|
| 492 |
+
bbox_cls = hand_data.bbox
|
| 493 |
+
|
| 494 |
+
# Convert normalized coordinates to pixel coordinates
|
| 495 |
+
bbox = np.array([
|
| 496 |
+
bbox_cls.left * self.W,
|
| 497 |
+
bbox_cls.top * self.H,
|
| 498 |
+
bbox_cls.right * self.W,
|
| 499 |
+
bbox_cls.bottom * self.H
|
| 500 |
+
])
|
| 501 |
+
|
| 502 |
+
# Calculate center point for spatial validation
|
| 503 |
+
bbox_center = np.array([
|
| 504 |
+
(bbox[0] + bbox[2]) / 2,
|
| 505 |
+
(bbox[1] + bbox[3]) / 2
|
| 506 |
+
]).astype(np.int32)
|
| 507 |
+
|
| 508 |
+
return bbox, bbox_center
|
| 509 |
+
|
| 510 |
+
def _validate_spatial_position(self, bbox_center: CenterArray, hand_side: HandSide) -> bool:
|
| 511 |
+
"""Validate that hand center is on correct side of image."""
|
| 512 |
+
if hand_side == "left":
|
| 513 |
+
return bbox_center[0] <= (self.center + self.HAND_SIDE_MARGIN)
|
| 514 |
+
else: # right
|
| 515 |
+
return bbox_center[0] >= (self.center - self.HAND_SIDE_MARGIN)
|
| 516 |
+
|
| 517 |
+
def _validate_bbox_coordinates(self, bbox_cls: Any, hand_side: HandSide) -> bool:
|
| 518 |
+
"""Validate bounding box coordinates are within valid range [0,1]."""
|
| 519 |
+
if not (0 <= bbox_cls.left <= 1 and 0 <= bbox_cls.right <= 1 and
|
| 520 |
+
0 <= bbox_cls.top <= 1 and 0 <= bbox_cls.bottom <= 1):
|
| 521 |
+
logging.warning(f"Invalid bbox coordinates detected for {hand_side} hand: "
|
| 522 |
+
f"left={bbox_cls.left:.3f}, right={bbox_cls.right:.3f}, "
|
| 523 |
+
f"top={bbox_cls.top:.3f}, bottom={bbox_cls.bottom:.3f}")
|
| 524 |
+
return False
|
| 525 |
+
return True
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# ============================================================================
|
| 529 |
+
# UTILITY/HELPER METHODS (General utilities and post-processing)
|
| 530 |
+
# ============================================================================
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def _post_process_detections(self, detection_results: DetectionResults) -> DetectionResults:
|
| 534 |
+
"""
|
| 535 |
+
Apply post-processing to improve detection temporal consistency.
|
| 536 |
+
|
| 537 |
+
This method applies several filters and enhancements to the raw detection
|
| 538 |
+
results to improve their quality and temporal coherence:
|
| 539 |
+
1. Filter out large spatial jumps that indicate tracking errors
|
| 540 |
+
2. Interpolate short gaps in detection sequences
|
| 541 |
+
3. Calculate quality metrics (distance to image edges)
|
| 542 |
+
|
| 543 |
+
Args:
|
| 544 |
+
detection_results: Raw detection results from frame processing
|
| 545 |
+
|
| 546 |
+
Returns:
|
| 547 |
+
Enhanced detection results with improved temporal consistency
|
| 548 |
+
"""
|
| 549 |
+
# Filter out large jumps for both hands
|
| 550 |
+
left_results = self._filter_large_jumps(
|
| 551 |
+
detection_results['left_hand_detected'],
|
| 552 |
+
detection_results['left_bboxes'],
|
| 553 |
+
detection_results['left_bboxes_ctr'],
|
| 554 |
+
max_jump=self.MAX_SPATIAL_JUMP,
|
| 555 |
+
lookahead=self.MAX_JUMP_LOOKAHEAD
|
| 556 |
+
)
|
| 557 |
+
right_results = self._filter_large_jumps(
|
| 558 |
+
detection_results['right_hand_detected'],
|
| 559 |
+
detection_results['right_bboxes'],
|
| 560 |
+
detection_results['right_bboxes_ctr'],
|
| 561 |
+
max_jump=self.MAX_SPATIAL_JUMP,
|
| 562 |
+
lookahead=self.MAX_JUMP_LOOKAHEAD
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
# Interpolate missing detections for smooth trajectories
|
| 566 |
+
left_results = self._interpolate_detections(*left_results, max_gap=self.MAX_INTERPOLATION_GAP)
|
| 567 |
+
right_results = self._interpolate_detections(*right_results, max_gap=self.MAX_INTERPOLATION_GAP)
|
| 568 |
+
|
| 569 |
+
# Calculate quality metrics: minimum distance from bbox center to image edges
|
| 570 |
+
left_bbox_min_dist = get_bbox_center_min_dist_to_edge(left_results[1], self.W, self.H)
|
| 571 |
+
right_bbox_min_dist = get_bbox_center_min_dist_to_edge(right_results[1], self.W, self.H)
|
| 572 |
+
|
| 573 |
+
return {
|
| 574 |
+
'left_hand_detected': left_results[0],
|
| 575 |
+
'right_hand_detected': right_results[0],
|
| 576 |
+
'left_bboxes': left_results[1],
|
| 577 |
+
'right_bboxes': right_results[1],
|
| 578 |
+
'left_bboxes_ctr': left_results[2],
|
| 579 |
+
'right_bboxes_ctr': right_results[2],
|
| 580 |
+
'left_bbox_min_dist_to_edge': left_bbox_min_dist,
|
| 581 |
+
'right_bbox_min_dist_to_edge': right_bbox_min_dist
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
def _generate_visualization(self, imgs_rgb: np.ndarray, results: Dict[str, np.ndarray]) -> List[np.ndarray]:
|
| 585 |
+
"""
|
| 586 |
+
Generate visualization of detection results for quality assessment.
|
| 587 |
+
|
| 588 |
+
Creates annotated frames showing detected bounding boxes for visual
|
| 589 |
+
inspection of detection quality and temporal consistency.
|
| 590 |
+
|
| 591 |
+
Args:
|
| 592 |
+
imgs_rgb: Original RGB video frames
|
| 593 |
+
results: Processed detection results
|
| 594 |
+
|
| 595 |
+
Returns:
|
| 596 |
+
List of annotated images with bounding boxes drawn
|
| 597 |
+
"""
|
| 598 |
+
list_img_annot = []
|
| 599 |
+
for idx in range(len(imgs_rgb)):
|
| 600 |
+
left_bbox = None
|
| 601 |
+
right_bbox = None
|
| 602 |
+
|
| 603 |
+
# Prepare bounding boxes for visualization
|
| 604 |
+
if results['left_hand_detected'][idx] or results['right_hand_detected'][idx]:
|
| 605 |
+
left_bbox = results['left_bboxes'][idx] if results['left_hand_detected'][idx] else None
|
| 606 |
+
right_bbox = results['right_bboxes'][idx] if results['right_hand_detected'][idx] else None
|
| 607 |
+
|
| 608 |
+
# Generate annotated image
|
| 609 |
+
img_annot = self.visualize_detections(imgs_rgb[idx], left_bbox, right_bbox, show_image=False)
|
| 610 |
+
list_img_annot.append(img_annot)
|
| 611 |
+
return list_img_annot
|
| 612 |
+
|
| 613 |
+
def _save_results(self, paths: Paths, results: DetectionResults, visualization_results: List[npt.NDArray[np.uint8]]) -> None:
|
| 614 |
+
"""
|
| 615 |
+
Save all processed results to disk in standardized format.
|
| 616 |
+
|
| 617 |
+
Args:
|
| 618 |
+
paths: Paths object containing output file locations
|
| 619 |
+
results: Processed detection results
|
| 620 |
+
visualization_results: Generated visualization frames
|
| 621 |
+
"""
|
| 622 |
+
# Create output directory if it doesn't exist
|
| 623 |
+
if not os.path.exists(paths.bbox_processor):
|
| 624 |
+
os.makedirs(paths.bbox_processor)
|
| 625 |
+
|
| 626 |
+
# Save detection data in compressed NumPy format
|
| 627 |
+
np.savez(paths.bbox_data, **results)
|
| 628 |
+
|
| 629 |
+
# Save visualization video with lossless compression
|
| 630 |
+
media.write_video(paths.video_bboxes, visualization_results, fps=15, codec="ffv1")
|
| 631 |
+
|
| 632 |
+
def _interpolate_detections(self, detected: DetectionFlagArray,
|
| 633 |
+
bboxes: BBoxArray,
|
| 634 |
+
centers: CenterArray,
|
| 635 |
+
max_gap: int = 10) -> Tuple[DetectionFlagArray, BBoxArray, CenterArray]:
|
| 636 |
+
"""
|
| 637 |
+
Interpolate bounding boxes and detection status for short gaps in tracking.
|
| 638 |
+
|
| 639 |
+
This method fills in missing detections using linear interpolation when the
|
| 640 |
+
gap is small enough to reasonably assume continuous hand motion. This helps
|
| 641 |
+
create smoother trajectories for downstream processing.
|
| 642 |
+
|
| 643 |
+
Args:
|
| 644 |
+
detected: Boolean array of detection status per frame
|
| 645 |
+
bboxes: Array of bounding boxes [N, 4] format [x1, y1, x2, y2]
|
| 646 |
+
centers: Array of bbox centers [N, 2] format [x, y]
|
| 647 |
+
max_gap: Maximum gap size (in frames) to interpolate over
|
| 648 |
+
|
| 649 |
+
Returns:
|
| 650 |
+
Tuple of (interpolated detection status, interpolated bboxes, interpolated centers)
|
| 651 |
+
"""
|
| 652 |
+
detected = detected.copy()
|
| 653 |
+
bboxes = bboxes.copy()
|
| 654 |
+
centers = centers.copy()
|
| 655 |
+
|
| 656 |
+
# Handle single-frame gaps first (most common case)
|
| 657 |
+
for i in range(1, len(detected) - 1):
|
| 658 |
+
if not detected[i] and detected[i-1] and detected[i+1]:
|
| 659 |
+
# Get valid bboxes/centers before and after gap
|
| 660 |
+
start_bbox = bboxes[i-1]
|
| 661 |
+
end_bbox = bboxes[i+1]
|
| 662 |
+
start_center = centers[i-1]
|
| 663 |
+
end_center = centers[i+1]
|
| 664 |
+
|
| 665 |
+
# Linear interpolation with t = 0.5 for single frame
|
| 666 |
+
interpolated_bbox = 0.5 * (start_bbox + end_bbox)
|
| 667 |
+
interpolated_center = 0.5 * (start_center + end_center)
|
| 668 |
+
|
| 669 |
+
# Validate interpolated values are reasonable
|
| 670 |
+
if self._is_valid_bbox(interpolated_bbox) and self._is_valid_center(interpolated_center):
|
| 671 |
+
bboxes[i] = interpolated_bbox
|
| 672 |
+
centers[i] = interpolated_center
|
| 673 |
+
detected[i] = True
|
| 674 |
+
|
| 675 |
+
# Handle multi-frame gaps
|
| 676 |
+
non_detect_start = None
|
| 677 |
+
for i in range(1, len(detected) - 1):
|
| 678 |
+
# Start of non-detection sequence
|
| 679 |
+
if detected[i-1] and not detected[i]:
|
| 680 |
+
non_detect_start = i
|
| 681 |
+
# End of non-detection sequence
|
| 682 |
+
elif non_detect_start is not None and not detected[i] and detected[i+1]:
|
| 683 |
+
non_detect_end = i
|
| 684 |
+
gap_size = non_detect_end - non_detect_start + 1
|
| 685 |
+
|
| 686 |
+
# Only interpolate if gap is small enough and has valid detections on both sides
|
| 687 |
+
if gap_size <= max_gap:
|
| 688 |
+
# Get valid bboxes/centers before and after gap
|
| 689 |
+
start_bbox = bboxes[non_detect_start - 1]
|
| 690 |
+
end_bbox = bboxes[non_detect_end + 1]
|
| 691 |
+
start_center = centers[non_detect_start - 1]
|
| 692 |
+
end_center = centers[non_detect_end + 1]
|
| 693 |
+
|
| 694 |
+
# Generate interpolation steps
|
| 695 |
+
steps = gap_size + 1
|
| 696 |
+
for j in range(gap_size):
|
| 697 |
+
t = (j + 1) / steps # Interpolation factor
|
| 698 |
+
|
| 699 |
+
# Linear interpolation of bbox coordinates
|
| 700 |
+
bboxes[non_detect_start + j] = (1 - t) * start_bbox + t * end_bbox
|
| 701 |
+
|
| 702 |
+
# Linear interpolation of center coordinates
|
| 703 |
+
centers[non_detect_start + j] = (1 - t) * start_center + t * end_center
|
| 704 |
+
|
| 705 |
+
# Mark as detected
|
| 706 |
+
detected[non_detect_start + j] = True
|
| 707 |
+
|
| 708 |
+
non_detect_start = None
|
| 709 |
+
|
| 710 |
+
return detected, bboxes, centers
|
| 711 |
+
|
| 712 |
+
def _is_valid_bbox(self, bbox: BBoxArray) -> bool:
|
| 713 |
+
"""Validate that bbox coordinates are reasonable."""
|
| 714 |
+
if bbox is None or len(bbox) != 4:
|
| 715 |
+
return False
|
| 716 |
+
# Check for reasonable bounds (not negative, not too large)
|
| 717 |
+
return (bbox >= 0).all() and (bbox[:2] < bbox[2:]).all() and bbox.max() < max(self.W, self.H) * 2
|
| 718 |
+
|
| 719 |
+
def _is_valid_center(self, center: CenterArray) -> bool:
|
| 720 |
+
"""Validate that center coordinates are reasonable."""
|
| 721 |
+
if center is None or len(center) != 2:
|
| 722 |
+
return False
|
| 723 |
+
# Check for reasonable bounds
|
| 724 |
+
return (center >= 0).all() and center[0] < self.W * 2 and center[1] < self.H * 2
|
| 725 |
+
|
| 726 |
+
def visualize_detections(self, img: npt.NDArray[np.uint8],
|
| 727 |
+
left_bbox: Optional[npt.NDArray[np.float32]] = None,
|
| 728 |
+
right_bbox: Optional[npt.NDArray[np.float32]] = None,
|
| 729 |
+
show_image: bool = True) -> npt.NDArray[np.uint8]:
|
| 730 |
+
"""
|
| 731 |
+
Visualize hand detections by drawing bounding boxes on the image.
|
| 732 |
+
|
| 733 |
+
This method creates annotated images showing detected hand locations with
|
| 734 |
+
color-coded bounding boxes (red for left hand, green for right hand).
|
| 735 |
+
|
| 736 |
+
Args:
|
| 737 |
+
img: Input RGB image to annotate
|
| 738 |
+
left_bbox: Left hand bounding box [x1, y1, x2, y2] or None if not detected
|
| 739 |
+
right_bbox: Right hand bounding box [x1, y1, x2, y2] or None if not detected
|
| 740 |
+
show_image: Whether to display the image using cv2.imshow
|
| 741 |
+
|
| 742 |
+
Returns:
|
| 743 |
+
The annotated image
|
| 744 |
+
"""
|
| 745 |
+
# Work directly with the input image (assumed to be in BGR format)
|
| 746 |
+
img_bgr = img
|
| 747 |
+
|
| 748 |
+
# Draw left hand bounding box in red
|
| 749 |
+
if left_bbox is not None and not np.array_equal(left_bbox, np.array([0, 0, 0, 0])):
|
| 750 |
+
cv2.rectangle(
|
| 751 |
+
img_bgr,
|
| 752 |
+
(int(left_bbox[0]), int(left_bbox[1])),
|
| 753 |
+
(int(left_bbox[2]), int(left_bbox[3])),
|
| 754 |
+
self.LEFT_HAND_COLOR,
|
| 755 |
+
self.BBOX_THICKNESS
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
# Draw right hand bounding box in green
|
| 759 |
+
if right_bbox is not None and not np.array_equal(right_bbox, np.array([0, 0, 0, 0])):
|
| 760 |
+
cv2.rectangle(
|
| 761 |
+
img_bgr,
|
| 762 |
+
(int(right_bbox[0]), int(right_bbox[1])),
|
| 763 |
+
(int(right_bbox[2]), int(right_bbox[3])),
|
| 764 |
+
self.RIGHT_HAND_COLOR,
|
| 765 |
+
self.BBOX_THICKNESS
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
# Optionally display the image for debugging
|
| 769 |
+
if show_image:
|
| 770 |
+
cv2.imshow("Hand Detections", img_bgr)
|
| 771 |
+
cv2.waitKey(0)
|
| 772 |
+
cv2.destroyAllWindows()
|
| 773 |
+
|
| 774 |
+
return img_bgr
|
| 775 |
+
|
| 776 |
+
@staticmethod
|
| 777 |
+
def _filter_large_jumps(detected: DetectionFlagArray,
|
| 778 |
+
bboxes: BBoxArray,
|
| 779 |
+
centers: CenterArray,
|
| 780 |
+
max_jump: float = 200.0,
|
| 781 |
+
lookahead: int = 10) -> Tuple[DetectionFlagArray, BBoxArray, CenterArray]:
|
| 782 |
+
"""
|
| 783 |
+
Filter out small groups of detections that are spatially inconsistent with the trajectory.
|
| 784 |
+
|
| 785 |
+
This method identifies and removes isolated detections that are far from the
|
| 786 |
+
expected trajectory, which usually indicate false positives or tracking errors.
|
| 787 |
+
It helps maintain temporal consistency in hand tracking.
|
| 788 |
+
|
| 789 |
+
Args:
|
| 790 |
+
detected: Boolean array of detection status per frame
|
| 791 |
+
bboxes: Array of bounding boxes [N, 4] format [x1, y1, x2, y2]
|
| 792 |
+
centers: Array of bbox centers [N, 2] format [x, y]
|
| 793 |
+
max_jump: Maximum allowed distance (in pixels) between consecutive detections
|
| 794 |
+
lookahead: Maximum number of consecutive distant points to filter as a group
|
| 795 |
+
|
| 796 |
+
Returns:
|
| 797 |
+
Tuple of (filtered detection status, filtered bboxes, filtered centers)
|
| 798 |
+
"""
|
| 799 |
+
detected = detected.copy()
|
| 800 |
+
bboxes = bboxes.copy()
|
| 801 |
+
centers = centers.copy()
|
| 802 |
+
|
| 803 |
+
# Templates for clearing invalid detections
|
| 804 |
+
empty_bbox = np.array([0, 0, 0, 0])
|
| 805 |
+
empty_center = np.array([0, 0])
|
| 806 |
+
|
| 807 |
+
i = 0
|
| 808 |
+
while i < len(detected):
|
| 809 |
+
# Find next detected point to compare against
|
| 810 |
+
next_valid = i + 1
|
| 811 |
+
|
| 812 |
+
if next_valid >= len(detected):
|
| 813 |
+
break
|
| 814 |
+
|
| 815 |
+
# Calculate spatial distance to next detection
|
| 816 |
+
dist = np.linalg.norm(centers[next_valid] - centers[i])
|
| 817 |
+
|
| 818 |
+
if dist > max_jump:
|
| 819 |
+
# Large jump detected - check if it's part of a small group of outliers
|
| 820 |
+
distant_points = []
|
| 821 |
+
ref_center = centers[i] # Use current point as reference
|
| 822 |
+
|
| 823 |
+
# Look ahead to find consecutive distant points
|
| 824 |
+
for j in range(next_valid, len(detected)):
|
| 825 |
+
curr_dist = np.linalg.norm(centers[j] - ref_center)
|
| 826 |
+
if curr_dist > max_jump:
|
| 827 |
+
distant_points.append(j)
|
| 828 |
+
else:
|
| 829 |
+
break
|
| 830 |
+
|
| 831 |
+
# If we found a small group of distant points, filter them out
|
| 832 |
+
if len(distant_points) > 0 and len(distant_points) <= lookahead:
|
| 833 |
+
for idx in distant_points:
|
| 834 |
+
detected[idx] = False
|
| 835 |
+
bboxes[idx] = empty_bbox
|
| 836 |
+
centers[idx] = empty_center
|
| 837 |
+
logging.warning(f"Filtered out frame {idx} as part of small distant group")
|
| 838 |
+
|
| 839 |
+
i = next_valid
|
| 840 |
+
|
| 841 |
+
return detected, bboxes, centers
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
|
phantom/phantom/processors/hand_processor.py
ADDED
|
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hand Processor Module
|
| 3 |
+
|
| 4 |
+
This module converts detected hand bounding boxes into detailed 3D hand poses using
|
| 5 |
+
state-of-the-art pose estimation models, with optional depth-based refinement for improved accuracy.
|
| 6 |
+
|
| 7 |
+
Processing Pipeline:
|
| 8 |
+
1. Load video frames and bounding box data from previous stage
|
| 9 |
+
2. Apply HaMeR pose estimation within detected bounding boxes
|
| 10 |
+
3. Filter poses based on edge proximity and quality metrics
|
| 11 |
+
4. Optionally refine 3D poses using depth data and segmentation
|
| 12 |
+
5. Generate hand mesh models and extract keypoint trajectories
|
| 13 |
+
6. Save processed hand sequences for downstream tasks
|
| 14 |
+
|
| 15 |
+
The module supports multiple processing modes:
|
| 16 |
+
- Hand2DProcessor: 2D pose estimation only (faster, camera-based)
|
| 17 |
+
- Hand3DProcessor: Full 3D processing with depth alignment (more accurate, if depth is available)
|
| 18 |
+
|
| 19 |
+
Output Data:
|
| 20 |
+
- HandSequence objects containing pose trajectories
|
| 21 |
+
- 2D keypoint positions in image coordinates
|
| 22 |
+
- 3D keypoint positions in camera coordinates
|
| 23 |
+
- Hand detection flags per frame
|
| 24 |
+
- Annotated visualization videos
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
import glob
|
| 28 |
+
import os
|
| 29 |
+
import logging
|
| 30 |
+
from tqdm import tqdm
|
| 31 |
+
import numpy as np
|
| 32 |
+
import mediapy as media
|
| 33 |
+
import open3d as o3d # type: ignore
|
| 34 |
+
from typing import Tuple, Optional, Dict, Any
|
| 35 |
+
import trimesh
|
| 36 |
+
from collections import defaultdict
|
| 37 |
+
import argparse
|
| 38 |
+
|
| 39 |
+
from phantom.utils.pcd_utils import get_visible_points, get_pcd_from_points, icp_registration, get_point_cloud_of_segmask, get_3D_points_from_pixels, remove_outliers, get_bbox_of_3d_points, trim_pcd_to_bbox, visualize_pcds
|
| 40 |
+
from phantom.utils.transform_utils import transform_pts
|
| 41 |
+
from phantom.processors.base_processor import BaseProcessor
|
| 42 |
+
from phantom.detectors.detector_hamer import DetectorHamer
|
| 43 |
+
from phantom.processors.phantom_data import HandSequence, HandFrame, hand_side_dict
|
| 44 |
+
from phantom.processors.paths import Paths
|
| 45 |
+
from phantom.processors.segmentation_processor import HandSegmentationProcessor
|
| 46 |
+
|
| 47 |
+
logger = logging.getLogger(__name__)
|
| 48 |
+
|
| 49 |
+
class HandBaseProcessor(BaseProcessor):
|
| 50 |
+
"""
|
| 51 |
+
Base class for hand pose processing using HaMeR detection and optional depth refinement.
|
| 52 |
+
|
| 53 |
+
The processor operates on the output of BBoxProcessor, using detected hand bounding boxes
|
| 54 |
+
to guide pose estimation. It supports both 2D and 3D processing modes, with the 3D mode
|
| 55 |
+
providing enhanced accuracy through depth sensor integration.
|
| 56 |
+
|
| 57 |
+
Processing Workflow:
|
| 58 |
+
1. Load video frames and bounding box detection results
|
| 59 |
+
2. For each frame with detected hands:
|
| 60 |
+
- Apply HaMeR pose estimation within bounding box
|
| 61 |
+
- Validate pose quality (edge proximity, confidence)
|
| 62 |
+
- Optionally generate hand segmentation masks for depth refinement
|
| 63 |
+
- Optionally apply depth-based pose refinement
|
| 64 |
+
3. Generate temporal hand sequences with smooth trajectories
|
| 65 |
+
4. Save processed results and visualization videos
|
| 66 |
+
|
| 67 |
+
Attributes:
|
| 68 |
+
process_hand_masks (bool): Whether to generate hand segmentation masks
|
| 69 |
+
apply_depth_alignment (bool): Whether to use depth-based pose refinement
|
| 70 |
+
detector_hamer (DetectorHamer): HaMeR pose estimation model
|
| 71 |
+
hand_mask_processor: Segmentation processor for hand mask generation
|
| 72 |
+
H (int): Video frame height
|
| 73 |
+
W (int): Video frame width
|
| 74 |
+
imgs_depth (np.ndarray): Depth images for 3D refinement
|
| 75 |
+
left_masks (np.ndarray): Left hand segmentation masks
|
| 76 |
+
right_masks (np.ndarray): Right hand segmentation masks
|
| 77 |
+
"""
|
| 78 |
+
def __init__(self, args: argparse.Namespace) -> None:
|
| 79 |
+
"""
|
| 80 |
+
Initialize the hand processor with configuration parameters.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
args: Command line arguments containing processing configuration
|
| 84 |
+
including depth processing flags and model parameters
|
| 85 |
+
"""
|
| 86 |
+
super().__init__(args)
|
| 87 |
+
self.process_hand_masks: bool = False
|
| 88 |
+
self._initialize_detectors()
|
| 89 |
+
self.hand_mask_processor: Optional[HandSegmentationProcessor] = None
|
| 90 |
+
self.apply_depth_alignment: bool = False
|
| 91 |
+
|
| 92 |
+
def _initialize_detectors(self) -> None:
|
| 93 |
+
"""
|
| 94 |
+
Initialize all required detection models.
|
| 95 |
+
|
| 96 |
+
Sets up the HaMeR detector for hand pose estimation.
|
| 97 |
+
"""
|
| 98 |
+
self.detector_hamer = DetectorHamer()
|
| 99 |
+
|
| 100 |
+
def process_one_demo(self, data_sub_folder: str) -> None:
|
| 101 |
+
"""
|
| 102 |
+
Process a single demonstration video to extract hand poses and segmentation.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
data_sub_folder: Path to the demonstration data folder containing
|
| 106 |
+
video files, bounding box data, and optional depth data
|
| 107 |
+
"""
|
| 108 |
+
save_folder = self.get_save_folder(data_sub_folder)
|
| 109 |
+
|
| 110 |
+
paths = self.get_paths(save_folder)
|
| 111 |
+
|
| 112 |
+
# Load RGB video frames
|
| 113 |
+
imgs_rgb = media.read_video(getattr(paths, f"video_left"))
|
| 114 |
+
self.H, self.W, _ = imgs_rgb[0].shape
|
| 115 |
+
|
| 116 |
+
# Load depth data if available (for 3D processing)
|
| 117 |
+
if os.path.exists(paths.depth):
|
| 118 |
+
self.imgs_depth = np.load(paths.depth)
|
| 119 |
+
else:
|
| 120 |
+
self.imgs_depth = np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1]))
|
| 121 |
+
|
| 122 |
+
# Load hand segmentation masks if available
|
| 123 |
+
if os.path.exists(paths.masks_hand_left) and os.path.exists(paths.masks_hand_right):
|
| 124 |
+
self.left_masks = np.load(paths.masks_hand_left)
|
| 125 |
+
self.right_masks = np.load(paths.masks_hand_right)
|
| 126 |
+
else:
|
| 127 |
+
self.left_masks = np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1]))
|
| 128 |
+
self.right_masks = np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1]))
|
| 129 |
+
|
| 130 |
+
# Load bounding box detection results from previous stage
|
| 131 |
+
bbox_data = np.load(paths.bbox_data)
|
| 132 |
+
left_hand_detected = bbox_data["left_hand_detected"]
|
| 133 |
+
right_hand_detected = bbox_data["right_hand_detected"]
|
| 134 |
+
left_bboxes = bbox_data["left_bboxes"]
|
| 135 |
+
right_bboxes = bbox_data["right_bboxes"]
|
| 136 |
+
|
| 137 |
+
# Validate data consistency
|
| 138 |
+
assert len(left_hand_detected) == len(right_hand_detected)
|
| 139 |
+
assert len(left_hand_detected) == len(imgs_rgb)
|
| 140 |
+
|
| 141 |
+
# Process left and right hand sequences
|
| 142 |
+
left_sequence = self._process_all_frames(imgs_rgb, left_bboxes, left_hand_detected, "left")
|
| 143 |
+
right_sequence = self._process_all_frames(imgs_rgb, right_bboxes, right_hand_detected, "right")
|
| 144 |
+
|
| 145 |
+
# Generate hand segmentation masks if enabled
|
| 146 |
+
if self.process_hand_masks:
|
| 147 |
+
self._get_hand_masks(data_sub_folder, left_sequence, right_sequence)
|
| 148 |
+
self.left_masks = np.load(paths.masks_hand_left)
|
| 149 |
+
self.right_masks = np.load(paths.masks_hand_right)
|
| 150 |
+
|
| 151 |
+
# Apply depth-based pose refinement if enabled
|
| 152 |
+
if self.apply_depth_alignment:
|
| 153 |
+
left_sequence = self._process_all_frames_depth_alignment(imgs_rgb, left_hand_detected, "left", left_sequence)
|
| 154 |
+
right_sequence = self._process_all_frames_depth_alignment(imgs_rgb, right_hand_detected, "right", right_sequence)
|
| 155 |
+
|
| 156 |
+
# Save processed sequences and generate visualizations
|
| 157 |
+
self._save_results(paths, left_sequence, right_sequence)
|
| 158 |
+
|
| 159 |
+
def _process_all_frames(self, imgs_rgb: np.ndarray, bboxes: np.ndarray,
|
| 160 |
+
hand_detections: np.ndarray, hand_side: str) -> HandSequence:
|
| 161 |
+
"""
|
| 162 |
+
Process all frames in a video sequence to extract hand poses.
|
| 163 |
+
|
| 164 |
+
This method iterates through all video frames, applying pose estimation
|
| 165 |
+
where hands are detected and creating empty frames where they are not.
|
| 166 |
+
It maintains temporal consistency and provides quality filtering.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
imgs_rgb: RGB video frames, shape (num_frames, height, width, 3)
|
| 170 |
+
bboxes: Hand bounding boxes per frame, shape (num_frames, 4)
|
| 171 |
+
hand_detections: Boolean flags indicating valid detections per frame
|
| 172 |
+
hand_side: "left" or "right" to specify which hand is being processed
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
HandSequence object containing processed pose data for all frames
|
| 176 |
+
"""
|
| 177 |
+
sequence = HandSequence()
|
| 178 |
+
|
| 179 |
+
for img_idx in tqdm(range(len(imgs_rgb)), disable=False, leave=False):
|
| 180 |
+
if not hand_detections[img_idx]:
|
| 181 |
+
# Create empty frame for missing detections
|
| 182 |
+
sequence.add_frame(HandFrame.create_empty_frame(
|
| 183 |
+
frame_idx=img_idx,
|
| 184 |
+
img_rgb=imgs_rgb[img_idx],
|
| 185 |
+
))
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
# Process frame with detected hand
|
| 189 |
+
frame_data = self._process_frame(img_idx, imgs_rgb[img_idx], bboxes[img_idx],
|
| 190 |
+
hand_side)
|
| 191 |
+
sequence.add_frame(frame_data)
|
| 192 |
+
|
| 193 |
+
return sequence
|
| 194 |
+
|
| 195 |
+
def _process_frame(self, img_idx: int, img_rgb: np.ndarray, bbox: np.ndarray,
|
| 196 |
+
hand_side: str, view: bool = False) -> HandFrame:
|
| 197 |
+
"""
|
| 198 |
+
Process a single frame to extract hand pose and validate quality.
|
| 199 |
+
|
| 200 |
+
This method applies HaMeR pose estimation within the detected bounding box
|
| 201 |
+
and performs quality checks to ensure the pose is suitable for downstream
|
| 202 |
+
processing. Poor quality poses (e.g., hands too close to image edges) are
|
| 203 |
+
rejected to maintain data quality.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
img_idx: Index of the current frame
|
| 207 |
+
img_rgb: RGB image data for this frame
|
| 208 |
+
bbox: Hand bounding box coordinates [x1, y1, x2, y2]
|
| 209 |
+
hand_side: "left" or "right" specifying which hand is being processed
|
| 210 |
+
view: Whether to display debug visualizations
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
HandFrame object containing pose data or empty frame if quality is poor
|
| 214 |
+
"""
|
| 215 |
+
try:
|
| 216 |
+
# Apply HaMeR pose estimation within bounding box
|
| 217 |
+
processed_data = self._process_image_with_hamer(img_rgb, bbox[None,...], hand_side, img_idx, view=view)
|
| 218 |
+
|
| 219 |
+
# Quality check: reject poses where keypoints are too close to image edges
|
| 220 |
+
if self.are_kpts_too_close_to_margin(processed_data["kpts_2d"], self.W, self.H, margin=5, threshold=0.1):
|
| 221 |
+
logger.error(f"Error processing frame {img_idx}: Edge hand")
|
| 222 |
+
return HandFrame.create_empty_frame(
|
| 223 |
+
frame_idx=img_idx,
|
| 224 |
+
img_rgb=img_rgb,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Create frame with validated pose data
|
| 228 |
+
frame_data = HandFrame(
|
| 229 |
+
frame_idx=img_idx,
|
| 230 |
+
hand_detected=True,
|
| 231 |
+
img_rgb=img_rgb,
|
| 232 |
+
img_hamer=processed_data["img_hamer"],
|
| 233 |
+
kpts_2d=processed_data["kpts_2d"],
|
| 234 |
+
kpts_3d=processed_data["kpts_3d"],
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
return frame_data
|
| 238 |
+
|
| 239 |
+
except Exception as e:
|
| 240 |
+
logger.error(f"Error processing frame {img_idx}: {str(e)}")
|
| 241 |
+
return HandFrame.create_empty_frame(
|
| 242 |
+
frame_idx=img_idx,
|
| 243 |
+
img_rgb=img_rgb,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
def are_kpts_too_close_to_margin(self, kpts_2d: np.ndarray, img_width: int, img_height: int,
|
| 247 |
+
margin: int = 20, threshold: float = 0.5) -> bool:
|
| 248 |
+
"""
|
| 249 |
+
Filter hand keypoints based on proximity to image edges.
|
| 250 |
+
|
| 251 |
+
This quality check rejects hand poses where too many keypoints are near
|
| 252 |
+
the image boundaries, which typically indicates partial occlusion or
|
| 253 |
+
tracking errors that would lead to poor pose estimates.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
kpts_2d: 2D keypoint positions, shape (N, 2) where N is number of keypoints
|
| 257 |
+
img_width: Image width in pixels
|
| 258 |
+
img_height: Image height in pixels
|
| 259 |
+
margin: Distance from edge (in pixels) to consider "too close"
|
| 260 |
+
threshold: Fraction of keypoints that triggers rejection (e.g., 0.5 = 50%)
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
True if hand should be rejected due to edge proximity, False otherwise
|
| 264 |
+
"""
|
| 265 |
+
x = kpts_2d[:, 0]
|
| 266 |
+
y = kpts_2d[:, 1]
|
| 267 |
+
|
| 268 |
+
# Create boolean mask for keypoints near any image edge
|
| 269 |
+
near_edge = (
|
| 270 |
+
(x < margin) |
|
| 271 |
+
(y < margin) |
|
| 272 |
+
(x > img_width - margin) |
|
| 273 |
+
(y > img_height - margin)
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
frac_near_edge = np.mean(near_edge) # Fraction of keypoints near edge
|
| 277 |
+
return frac_near_edge > threshold
|
| 278 |
+
|
| 279 |
+
def _save_results(self, paths: Paths, left_sequence: HandSequence, right_sequence: HandSequence) -> None:
|
| 280 |
+
"""
|
| 281 |
+
Save processed hand sequences and generate visualization videos.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
paths: Paths object containing output file locations
|
| 285 |
+
left_sequence: Processed left hand pose sequence
|
| 286 |
+
right_sequence: Processed right hand pose sequence
|
| 287 |
+
"""
|
| 288 |
+
# Create output directory
|
| 289 |
+
if not os.path.exists(getattr(paths, f"hand_processor")):
|
| 290 |
+
os.makedirs(getattr(paths, f"hand_processor"))
|
| 291 |
+
|
| 292 |
+
# Save hand sequence data in compressed format
|
| 293 |
+
left_sequence.save(getattr(paths, f"hand_data_left"))
|
| 294 |
+
right_sequence.save(getattr(paths, f"hand_data_right"))
|
| 295 |
+
|
| 296 |
+
# Save RGB frames for reference
|
| 297 |
+
media.write_video(getattr(paths, f"video_rgb_imgs"), left_sequence.imgs_rgb, fps=10, codec="ffv1")
|
| 298 |
+
|
| 299 |
+
# Load additional visualization components
|
| 300 |
+
imgs_bbox = media.read_video(getattr(paths, f"video_bboxes"))
|
| 301 |
+
|
| 302 |
+
# Load segmentation visualization if available
|
| 303 |
+
if os.path.exists(getattr(paths, f"video_sam_arm")):
|
| 304 |
+
imgs_sam = media.read_video(getattr(paths, f"video_sam_arm"))
|
| 305 |
+
else:
|
| 306 |
+
imgs_sam = np.zeros((len(left_sequence.imgs_rgb), left_sequence.imgs_rgb[0].shape[0], left_sequence.imgs_rgb[0].shape[1], 3))
|
| 307 |
+
|
| 308 |
+
# Create comprehensive annotation video showing all processing stages
|
| 309 |
+
annot_imgs = []
|
| 310 |
+
for idx in range(len(left_sequence.imgs_rgb)):
|
| 311 |
+
img_hamer_left = left_sequence.imgs_hamer[idx]
|
| 312 |
+
img_hamer_right = right_sequence.imgs_hamer[idx]
|
| 313 |
+
img_bbox = imgs_bbox[idx]
|
| 314 |
+
img_sam = imgs_sam[idx]
|
| 315 |
+
|
| 316 |
+
# Combine visualizations in 2x2 grid: [bbox, sam] on top, [left_hand, right_hand] on bottom
|
| 317 |
+
annot_img = np.vstack((np.hstack((img_bbox, img_sam)), np.hstack((img_hamer_left, img_hamer_right)))).astype(np.uint8)
|
| 318 |
+
annot_imgs.append(annot_img)
|
| 319 |
+
|
| 320 |
+
# Save comprehensive visualization video
|
| 321 |
+
media.write_video(getattr(paths, f"video_annot"), np.array(annot_imgs), fps=10, codec="h264") # mp4
|
| 322 |
+
|
| 323 |
+
def _create_hand_mesh(self, hamer_out: Dict[str, Any]) -> trimesh.Trimesh:
|
| 324 |
+
"""
|
| 325 |
+
Create a 3D triangle mesh from HaMeR pose estimation output.
|
| 326 |
+
|
| 327 |
+
Args:
|
| 328 |
+
hamer_out: HaMeR output dictionary containing vertex positions
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
Trimesh object representing the hand mesh
|
| 332 |
+
"""
|
| 333 |
+
return trimesh.Trimesh(hamer_out["verts"].copy(), self.detector_hamer.faces_left.copy(), process=False)
|
| 334 |
+
|
| 335 |
+
def _get_hand_masks(self, data_sub_folder: str, hamer_data_left: HandSequence, hamer_data_right: HandSequence) -> None:
|
| 336 |
+
"""
|
| 337 |
+
Generate hand segmentation masks using processed pose data.
|
| 338 |
+
|
| 339 |
+
This method integrates with the segmentation processor to generate
|
| 340 |
+
detailed hand masks that can be used for depth-based pose refinement.
|
| 341 |
+
|
| 342 |
+
Args:
|
| 343 |
+
data_sub_folder: Path to demonstration data folder
|
| 344 |
+
hamer_data_left: Processed left hand sequence for guidance
|
| 345 |
+
hamer_data_right: Processed right hand sequence for guidance
|
| 346 |
+
"""
|
| 347 |
+
hamer_data = {
|
| 348 |
+
"left": hamer_data_left,
|
| 349 |
+
"right": hamer_data_right
|
| 350 |
+
}
|
| 351 |
+
self.hand_mask_processor.process_one_demo(data_sub_folder, hamer_data)
|
| 352 |
+
|
| 353 |
+
@staticmethod
|
| 354 |
+
def _get_visible_pts_from_hamer(detector_hamer: DetectorHamer, hamer_out: Dict[str, Any], mesh: trimesh.Trimesh,
|
| 355 |
+
img_depth: np.ndarray, cam_intrinsics: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
|
| 356 |
+
"""
|
| 357 |
+
Identify visible hand vertices and their corresponding depth points.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
detector_hamer: HaMeR detector instance for coordinate projections
|
| 361 |
+
hamer_out: HaMeR output containing pose estimates and camera parameters
|
| 362 |
+
mesh: 3D hand mesh generated from HaMeR output
|
| 363 |
+
img_depth: Depth image corresponding to the RGB frame
|
| 364 |
+
cam_intrinsics: Camera intrinsic parameters for 3D projection
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
Tuple of (visible_points_3d, visible_hamer_vertices):
|
| 368 |
+
- visible_points_3d: 3D points from depth image at visible mesh locations
|
| 369 |
+
- visible_hamer_vertices: Corresponding vertices from the HaMeR mesh
|
| 370 |
+
"""
|
| 371 |
+
# Perform ray-casting to identify visible mesh vertices
|
| 372 |
+
visible_hamer_vertices, _ = get_visible_points(mesh, origin=np.array([0,0,0]))
|
| 373 |
+
|
| 374 |
+
# Project 3D vertices to 2D image coordinates
|
| 375 |
+
visible_points_2d = detector_hamer.project_3d_kpt_to_2d(
|
| 376 |
+
(visible_hamer_vertices-hamer_out["T_cam_pred"].cpu().numpy()).astype(np.float32),
|
| 377 |
+
hamer_out["img_w"], hamer_out["img_h"], hamer_out["scaled_focal_length"],
|
| 378 |
+
hamer_out["camera_center"], hamer_out["T_cam_pred"])
|
| 379 |
+
|
| 380 |
+
# Filter out points that fall outside the depth image boundaries
|
| 381 |
+
original_visible_points_2d = visible_points_2d.copy()
|
| 382 |
+
|
| 383 |
+
# Create valid region mask (note: depth indexing is [y, x])
|
| 384 |
+
valid_mask = ((original_visible_points_2d[:, 0] < img_depth.shape[1]) &
|
| 385 |
+
(original_visible_points_2d[:, 1] < img_depth.shape[0]))
|
| 386 |
+
|
| 387 |
+
visible_points_2d = visible_points_2d[valid_mask]
|
| 388 |
+
visible_hamer_vertices = visible_hamer_vertices[valid_mask]
|
| 389 |
+
|
| 390 |
+
# Convert 2D depth pixels to 3D points using camera intrinsics
|
| 391 |
+
visible_points_3d = get_3D_points_from_pixels(visible_points_2d, img_depth, cam_intrinsics)
|
| 392 |
+
|
| 393 |
+
return visible_points_3d, visible_hamer_vertices
|
| 394 |
+
|
| 395 |
+
@staticmethod
|
| 396 |
+
def _get_transformation_estimate(visible_points_3d: np.ndarray,
|
| 397 |
+
visible_hamer_vertices: np.ndarray,
|
| 398 |
+
pcd: o3d.geometry.PointCloud) -> Tuple[np.ndarray, o3d.geometry.PointCloud]:
|
| 399 |
+
"""
|
| 400 |
+
Estimate transformation to align HaMeR mesh with observed point cloud.
|
| 401 |
+
|
| 402 |
+
This method uses Iterative Closest Point (ICP) registration to find the
|
| 403 |
+
optimal transformation that aligns the visible parts of the predicted
|
| 404 |
+
hand mesh with the point cloud extracted from depth and segmentation data.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
visible_points_3d: 3D points from depth image at mesh locations
|
| 408 |
+
visible_hamer_vertices: Corresponding vertices from HaMeR mesh
|
| 409 |
+
pcd: Point cloud from segmentation and depth data
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
Tuple of (transformation_matrix, aligned_mesh_pointcloud):
|
| 413 |
+
- transformation_matrix: 4x4 transformation to align mesh with depth
|
| 414 |
+
- aligned_mesh_pointcloud: Transformed mesh point cloud after alignment
|
| 415 |
+
"""
|
| 416 |
+
# Get initial transformation estimate using median translation
|
| 417 |
+
T_0 = HandBaseProcessor._get_initial_transformation_estimate(visible_points_3d, visible_hamer_vertices)
|
| 418 |
+
|
| 419 |
+
# Create point cloud from visible mesh vertices
|
| 420 |
+
visible_hamer_pcd = get_pcd_from_points(visible_hamer_vertices, colors=np.ones_like(visible_hamer_vertices) * [0, 1, 0])
|
| 421 |
+
|
| 422 |
+
try:
|
| 423 |
+
# Apply ICP registration for fine alignment
|
| 424 |
+
aligned_hamer_pcd, T = icp_registration(visible_hamer_pcd, pcd, voxel_size=0.005, init_transform=T_0)
|
| 425 |
+
except Exception as e:
|
| 426 |
+
logger.error(f"ICP registration failed: {e}")
|
| 427 |
+
return T_0, visible_hamer_pcd
|
| 428 |
+
|
| 429 |
+
return T, aligned_hamer_pcd
|
| 430 |
+
|
| 431 |
+
@staticmethod
|
| 432 |
+
def _get_initial_transformation_estimate(visible_points_3d: np.ndarray,
|
| 433 |
+
visible_hamer_vertices: np.ndarray) -> np.ndarray:
|
| 434 |
+
"""
|
| 435 |
+
Compute initial transformation estimate for mesh-to-depth alignment.
|
| 436 |
+
|
| 437 |
+
This method provides a coarse alignment between the HaMeR prediction and
|
| 438 |
+
the depth-based point cloud using median translation. It assumes that
|
| 439 |
+
orientation is approximately correct and only translation correction is needed.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
visible_points_3d: 3D points from depth image
|
| 443 |
+
visible_hamer_vertices: Corresponding HaMeR mesh vertices
|
| 444 |
+
|
| 445 |
+
Returns:
|
| 446 |
+
4x4 transformation matrix with estimated translation
|
| 447 |
+
"""
|
| 448 |
+
# Calculate median translation between corresponding point sets
|
| 449 |
+
translation = np.nanmedian(visible_points_3d - visible_hamer_vertices, axis=0)
|
| 450 |
+
|
| 451 |
+
# Create transformation matrix (identity rotation + translation)
|
| 452 |
+
T_0 = np.eye(4)
|
| 453 |
+
if not np.isnan(translation).any():
|
| 454 |
+
T_0[:3, 3] = translation
|
| 455 |
+
|
| 456 |
+
return T_0
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
class Hand2DProcessor(HandBaseProcessor):
|
| 460 |
+
"""
|
| 461 |
+
2D hand pose processor optimized for speed and RGB-only operation.
|
| 462 |
+
|
| 463 |
+
This processor focuses on extracting 2D hand poses and basic 3D estimates
|
| 464 |
+
without depth sensor integration. It's designed for applications where
|
| 465 |
+
depth sensors are not available.
|
| 466 |
+
"""
|
| 467 |
+
def __init__(self, args: argparse.Namespace) -> None:
|
| 468 |
+
"""
|
| 469 |
+
Initialize 2D hand processor with RGB-only configuration.
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
args: Command line arguments for processor configuration
|
| 473 |
+
"""
|
| 474 |
+
super().__init__(args)
|
| 475 |
+
|
| 476 |
+
def _process_image_with_hamer(self, img_rgb: np.ndarray, bboxes: np.ndarray, hand_side: str,
|
| 477 |
+
img_idx: int, view: bool = False) -> Dict[str, Any]:
|
| 478 |
+
"""
|
| 479 |
+
Process RGB image with HaMeR for 2D pose estimation.
|
| 480 |
+
|
| 481 |
+
Args:
|
| 482 |
+
img_rgb: RGB image to process
|
| 483 |
+
bboxes: Hand bounding boxes for pose estimation guidance
|
| 484 |
+
hand_side: "left" or "right" specifying which hand to process
|
| 485 |
+
img_idx: Frame index for debugging and logging
|
| 486 |
+
view: Whether to display debug visualizations
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
Dictionary containing:
|
| 490 |
+
- img_hamer: Annotated image with pose visualization
|
| 491 |
+
- kpts_3d: Estimated 3D keypoints
|
| 492 |
+
- kpts_2d: 2D keypoint projections in image coordinates
|
| 493 |
+
|
| 494 |
+
Raises:
|
| 495 |
+
ValueError: If no valid hand pose is detected in the image
|
| 496 |
+
"""
|
| 497 |
+
# Configure HaMeR for target hand side
|
| 498 |
+
is_right = np.array([hand_side_dict[str(hand_side)]*True]*len(bboxes))
|
| 499 |
+
|
| 500 |
+
# Apply HaMeR pose estimation
|
| 501 |
+
hamer_out = self.detector_hamer.detect_hand_keypoints(
|
| 502 |
+
img_rgb,
|
| 503 |
+
hand_side=hand_side,
|
| 504 |
+
bboxes=bboxes,
|
| 505 |
+
is_right=is_right,
|
| 506 |
+
camera_params=self.intrinsics_dict,
|
| 507 |
+
visualize=False
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if hamer_out is None or not hamer_out.get("success", False):
|
| 511 |
+
raise ValueError("No hand detected in image")
|
| 512 |
+
|
| 513 |
+
return {
|
| 514 |
+
"img_hamer": hamer_out["annotated_img"][:,:,::-1], # Convert BGR to RGB
|
| 515 |
+
"kpts_3d": hamer_out["kpts_3d"],
|
| 516 |
+
"kpts_2d": hamer_out['kpts_2d']
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
class Hand3DProcessor(HandBaseProcessor):
|
| 520 |
+
"""
|
| 521 |
+
3D hand pose processor with depth-based refinement capabilities.
|
| 522 |
+
|
| 523 |
+
This processor provides more accurate 3D hand poses by combining HaMeR
|
| 524 |
+
estimation with depth sensor data and hand segmentation. It uses point cloud
|
| 525 |
+
registration techniques to refine the initial pose estimates, resulting in
|
| 526 |
+
poses that are better aligned with the physical environment.
|
| 527 |
+
|
| 528 |
+
Processing Enhancements:
|
| 529 |
+
- Mesh generation from HaMeR output for visibility analysis
|
| 530 |
+
- Hand segmentation using SAM2 for accurate depth extraction
|
| 531 |
+
- ICP-based alignment between predicted mesh and observed point cloud
|
| 532 |
+
"""
|
| 533 |
+
def __init__(self, args: argparse.Namespace) -> None:
|
| 534 |
+
"""
|
| 535 |
+
Initialize 3D hand processor with depth refinement capabilities.
|
| 536 |
+
|
| 537 |
+
Args:
|
| 538 |
+
args: Command line arguments containing depth processing configuration
|
| 539 |
+
"""
|
| 540 |
+
super().__init__(args)
|
| 541 |
+
self.args = args
|
| 542 |
+
|
| 543 |
+
# Storage for HaMeR outputs needed for depth alignment
|
| 544 |
+
self.hamer_out_dict: Dict[str, Dict[int, Dict[str, Any]]] = {
|
| 545 |
+
"left": defaultdict(dict),
|
| 546 |
+
"right": defaultdict(dict)
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
# Enable advanced processing features
|
| 550 |
+
self.process_hand_masks = True
|
| 551 |
+
self.apply_depth_alignment = True
|
| 552 |
+
self.hand_mask_processor = HandSegmentationProcessor(self.args)
|
| 553 |
+
|
| 554 |
+
def _process_image_with_hamer(self, img_rgb: np.ndarray, bboxes: np.ndarray, hand_side: str,
|
| 555 |
+
img_idx: int, view: bool = False) -> Dict[str, Any]:
|
| 556 |
+
"""
|
| 557 |
+
Process RGB image with HaMeR optimized for subsequent depth refinement.
|
| 558 |
+
|
| 559 |
+
This method applies HaMeR pose estimation configured for 3D processing,
|
| 560 |
+
storing intermediate results needed for later depth-based refinement.
|
| 561 |
+
|
| 562 |
+
Args:
|
| 563 |
+
img_rgb: RGB image to process
|
| 564 |
+
bboxes: Hand bounding boxes for pose estimation guidance
|
| 565 |
+
hand_side: "left" or "right" specifying which hand to process
|
| 566 |
+
img_idx: Frame index for result storage and debugging
|
| 567 |
+
view: Whether to display debug visualizations
|
| 568 |
+
|
| 569 |
+
Returns:
|
| 570 |
+
Dictionary containing pose estimation results
|
| 571 |
+
|
| 572 |
+
Raises:
|
| 573 |
+
ValueError: If no valid hand pose is detected in the image
|
| 574 |
+
"""
|
| 575 |
+
# Configure HaMeR for target hand side
|
| 576 |
+
is_right = np.array([hand_side_dict[str(hand_side)]*True]*len(bboxes))
|
| 577 |
+
|
| 578 |
+
# Apply HaMeR with 2D keypoint focus (3D refinement happens later)
|
| 579 |
+
hamer_out = self.detector_hamer.detect_hand_keypoints(
|
| 580 |
+
img_rgb,
|
| 581 |
+
hand_side=hand_side,
|
| 582 |
+
bboxes=bboxes,
|
| 583 |
+
is_right=is_right,
|
| 584 |
+
kpts_2d_only=True, # Initial processing focuses on 2D
|
| 585 |
+
camera_params=self.intrinsics_dict
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
if hamer_out is None or not hamer_out.get("success", False):
|
| 589 |
+
raise ValueError("No hand detected in image")
|
| 590 |
+
|
| 591 |
+
# Store HaMeR output for later depth alignment processing
|
| 592 |
+
self.hamer_out_dict[hand_side][img_idx] = hamer_out
|
| 593 |
+
|
| 594 |
+
return {
|
| 595 |
+
"img_hamer": hamer_out["annotated_img"][:,:,::-1], # Convert BGR to RGB
|
| 596 |
+
"kpts_3d": hamer_out["kpts_3d"],
|
| 597 |
+
"kpts_2d": hamer_out['kpts_2d']
|
| 598 |
+
}
|
| 599 |
+
|
| 600 |
+
def _process_all_frames_depth_alignment(self, imgs_rgb: np.ndarray, hand_detections: np.ndarray,
|
| 601 |
+
hand_side: str, sequence: Optional[HandSequence] = None) -> HandSequence:
|
| 602 |
+
"""
|
| 603 |
+
Apply depth-based refinement to all frames in the sequence.
|
| 604 |
+
|
| 605 |
+
This method performs the depth alignment stage of processing, using
|
| 606 |
+
segmentation masks and depth data to refine the initial HaMeR pose
|
| 607 |
+
estimates for improved 3D accuracy.
|
| 608 |
+
|
| 609 |
+
Args:
|
| 610 |
+
imgs_rgb: RGB video frames for reference
|
| 611 |
+
hand_detections: Boolean flags indicating frames with valid detections
|
| 612 |
+
hand_side: "left" or "right" specifying which hand to process
|
| 613 |
+
sequence: HandSequence containing initial pose estimates to refine
|
| 614 |
+
|
| 615 |
+
Returns:
|
| 616 |
+
HandSequence with refined 3D poses aligned to depth data
|
| 617 |
+
"""
|
| 618 |
+
for img_idx in tqdm(range(len(imgs_rgb)), disable=False, leave=False):
|
| 619 |
+
if not hand_detections[img_idx]:
|
| 620 |
+
continue
|
| 621 |
+
|
| 622 |
+
# Apply depth-based refinement to this frame
|
| 623 |
+
frame_data = sequence.get_frame(img_idx)
|
| 624 |
+
frame_data.kpts_3d = self._depth_alignment(img_idx, hand_side, imgs_rgb[img_idx])
|
| 625 |
+
sequence.modify_frame(img_idx, frame_data)
|
| 626 |
+
|
| 627 |
+
return sequence
|
| 628 |
+
|
| 629 |
+
def _depth_alignment(self, img_idx: int, hand_side: str, img_rgb: np.ndarray) -> np.ndarray:
|
| 630 |
+
"""
|
| 631 |
+
Perform depth-based pose refinement for a single frame.
|
| 632 |
+
|
| 633 |
+
Algorithm Steps:
|
| 634 |
+
1. Extract depth image and segmentation mask for the frame
|
| 635 |
+
2. Obtain 3D hand mesh from HaMeR output
|
| 636 |
+
3. Create point cloud from segmented depth region
|
| 637 |
+
4. Identify visible mesh vertices through ray casting
|
| 638 |
+
5. Apply ICP registration between mesh and point cloud
|
| 639 |
+
6. Transform original keypoints using computed alignment
|
| 640 |
+
|
| 641 |
+
Args:
|
| 642 |
+
img_idx: Index of the frame to process
|
| 643 |
+
hand_side: "left" or "right" specifying which hand to process
|
| 644 |
+
img_rgb: RGB image for reference (used in point cloud generation)
|
| 645 |
+
|
| 646 |
+
Returns:
|
| 647 |
+
Refined 3D keypoint positions aligned with depth data
|
| 648 |
+
"""
|
| 649 |
+
# Load frame-specific data
|
| 650 |
+
img_depth = self.imgs_depth[img_idx]
|
| 651 |
+
mask = self.left_masks[img_idx] if hand_side == "left" else self.right_masks[img_idx]
|
| 652 |
+
hamer_out = self.hamer_out_dict[hand_side][img_idx]
|
| 653 |
+
|
| 654 |
+
# Create 3D hand mesh from HaMeR pose estimate
|
| 655 |
+
mesh = self._create_hand_mesh(hamer_out)
|
| 656 |
+
|
| 657 |
+
# Generate point cloud from depth image within segmented hand region
|
| 658 |
+
pcd = get_point_cloud_of_segmask(mask, img_depth, img_rgb, self.intrinsics_dict, visualize=False)
|
| 659 |
+
|
| 660 |
+
# Identify visible mesh vertices and corresponding depth points
|
| 661 |
+
visible_points_3d, visible_hamer_vertices = self._get_visible_pts_from_hamer(
|
| 662 |
+
self.detector_hamer,
|
| 663 |
+
hamer_out,
|
| 664 |
+
mesh,
|
| 665 |
+
img_depth,
|
| 666 |
+
self.intrinsics_dict
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
# Compute optimal transformation using ICP registration
|
| 670 |
+
T, _ = self._get_transformation_estimate(visible_points_3d, visible_hamer_vertices, pcd)
|
| 671 |
+
|
| 672 |
+
# Apply transformation to refine original keypoint positions
|
| 673 |
+
kpts_3d = transform_pts(hamer_out["kpts_3d"], T)
|
| 674 |
+
|
| 675 |
+
return kpts_3d
|
phantom/phantom/processors/handinpaint_processor.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hand Inpainting Processor Module
|
| 3 |
+
|
| 4 |
+
This module removes human hands from demonstration videos using the E2FGVI model.
|
| 5 |
+
|
| 6 |
+
Paper:
|
| 7 |
+
Towards An End-to-End Framework for Flow-Guided Video Inpainting
|
| 8 |
+
https://github.com/MCG-NKU/E2FGVI.git
|
| 9 |
+
|
| 10 |
+
Processing Pipeline:
|
| 11 |
+
1. Load pre-trained E2FGVI model and initialize GPU processing
|
| 12 |
+
2. Read input video frames and corresponding hand segmentation masks
|
| 13 |
+
3. Process frames in batches with neighboring temporal context
|
| 14 |
+
4. Apply mask-guided inpainting to remove hand regions
|
| 15 |
+
5. Verify complete processing and handle any missed frames
|
| 16 |
+
6. Save final hand-free video for robot learning applications
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import cv2
|
| 20 |
+
from PIL import Image
|
| 21 |
+
import numpy as np
|
| 22 |
+
import os
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
import torch
|
| 26 |
+
import mediapy as media
|
| 27 |
+
import logging
|
| 28 |
+
import gc
|
| 29 |
+
from typing import List, Tuple, Optional, Any, Union
|
| 30 |
+
|
| 31 |
+
from phantom.processors.base_processor import BaseProcessor
|
| 32 |
+
from phantom.utils.data_utils import get_parent_folder_of_package
|
| 33 |
+
from E2FGVI.model.e2fgvi_hq import InpaintGenerator # type: ignore
|
| 34 |
+
from E2FGVI.core.utils import to_tensors # type: ignore
|
| 35 |
+
|
| 36 |
+
DEFAULT_CHECKPOINT = 'E2FGVI/release_model/E2FGVI-HQ-CVPR22.pth'
|
| 37 |
+
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
|
| 40 |
+
class HandInpaintProcessor(BaseProcessor):
|
| 41 |
+
"""
|
| 42 |
+
Hand inpainting processor for removing human hands from demonstration videos.
|
| 43 |
+
|
| 44 |
+
Attributes:
|
| 45 |
+
model: E2FGVI neural network model for video inpainting
|
| 46 |
+
device: GPU/CPU device for model execution
|
| 47 |
+
ref_length (int): Spacing between reference frames for temporal consistency
|
| 48 |
+
num_ref (int): Number of reference frames to use (-1 for automatic)
|
| 49 |
+
neighbor_stride (int): Spacing between neighboring frames in temporal context
|
| 50 |
+
batch_size (int): Number of frame groups to process simultaneously
|
| 51 |
+
scale_factor (int): Resolution scaling factor for processing optimization
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, args: Any) -> None:
|
| 55 |
+
"""
|
| 56 |
+
Initialize the hand inpainting processor with E2FGVI model and parameters.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
args: Command line arguments containing processing configuration
|
| 60 |
+
including scale factor and other inpainting parameters
|
| 61 |
+
"""
|
| 62 |
+
super().__init__(args)
|
| 63 |
+
|
| 64 |
+
# Load pre-trained E2FGVI model
|
| 65 |
+
root_dir = get_parent_folder_of_package("E2FGVI")
|
| 66 |
+
checkpoint_path = Path(root_dir, DEFAULT_CHECKPOINT)
|
| 67 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 68 |
+
|
| 69 |
+
# Initialize and load the inpainting model
|
| 70 |
+
self.model = InpaintGenerator().to(self.device)
|
| 71 |
+
data = torch.load(checkpoint_path, map_location=self.device)
|
| 72 |
+
self.model.load_state_dict(data)
|
| 73 |
+
self.model.eval()
|
| 74 |
+
|
| 75 |
+
# Configure temporal processing parameters
|
| 76 |
+
self.ref_length: int = 20 # Spacing between reference frames
|
| 77 |
+
self.num_ref: int = -1 # Number of reference frames (-1 = automatic)
|
| 78 |
+
self.neighbor_stride: int = 5 # Stride for neighboring frame selection
|
| 79 |
+
|
| 80 |
+
# Configure batch processing parameters for memory optimization
|
| 81 |
+
self.batch_size: int = 10 # Number of frame groups per batch
|
| 82 |
+
self.scale_factor: int = getattr(args, 'scale_factor', 2) # Resolution scaling
|
| 83 |
+
|
| 84 |
+
def _clear_gpu_memory(self) -> None:
|
| 85 |
+
"""Clear GPU memory cache and trigger garbage collection."""
|
| 86 |
+
torch.cuda.empty_cache()
|
| 87 |
+
gc.collect()
|
| 88 |
+
|
| 89 |
+
def process_one_demo(self, data_sub_folder: str) -> None:
|
| 90 |
+
"""
|
| 91 |
+
Process a single demonstration video to remove hand regions.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
data_sub_folder: Path to demonstration data folder containing
|
| 95 |
+
input video and hand segmentation masks
|
| 96 |
+
"""
|
| 97 |
+
save_folder = self.get_save_folder(data_sub_folder)
|
| 98 |
+
paths = self.get_paths(save_folder)
|
| 99 |
+
if not os.path.exists(paths.inpaint_processor):
|
| 100 |
+
os.makedirs(paths.inpaint_processor)
|
| 101 |
+
|
| 102 |
+
self._process_frames(paths)
|
| 103 |
+
|
| 104 |
+
def _process_frames(self, paths: Any) -> None:
|
| 105 |
+
"""
|
| 106 |
+
Process all video frames to remove hand regions using E2FGVI inpainting.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
paths: Paths object containing input video and mask file locations
|
| 110 |
+
"""
|
| 111 |
+
# Load and prepare video frames
|
| 112 |
+
frames = self._load_and_prepare_frames(paths)
|
| 113 |
+
video_length = len(frames)
|
| 114 |
+
logger.info(f"Processing {video_length} frames")
|
| 115 |
+
|
| 116 |
+
# Initialize tracking arrays for processed frames
|
| 117 |
+
comp_frames: List[Optional[np.ndarray]] = [None] * video_length
|
| 118 |
+
processed_frame_mask: List[bool] = [False] * video_length
|
| 119 |
+
|
| 120 |
+
# Process frames in batches with temporal overlap for consistency
|
| 121 |
+
self._process_frames_in_batches(frames, paths, comp_frames, processed_frame_mask)
|
| 122 |
+
|
| 123 |
+
# Handle any missed frames
|
| 124 |
+
self._process_missed_frames(frames, paths, comp_frames, processed_frame_mask)
|
| 125 |
+
|
| 126 |
+
# Final verification and save
|
| 127 |
+
self._verify_and_save_results(comp_frames, paths)
|
| 128 |
+
|
| 129 |
+
def _load_and_prepare_frames(self, paths: Any) -> List[Image.Image]:
|
| 130 |
+
"""Load video frames and prepare them for processing."""
|
| 131 |
+
frames = self.read_frame_from_videos(paths.video_rgb_imgs)
|
| 132 |
+
|
| 133 |
+
# Calculate output dimensions based on configuration
|
| 134 |
+
h, w = frames[0].height, frames[0].width
|
| 135 |
+
|
| 136 |
+
if self.epic:
|
| 137 |
+
size = (w, h)
|
| 138 |
+
else:
|
| 139 |
+
if self.square:
|
| 140 |
+
output_resolution = np.array([self.output_resolution, self.output_resolution])
|
| 141 |
+
else:
|
| 142 |
+
output_resolution = np.array([int(w/h*self.output_resolution), self.output_resolution])
|
| 143 |
+
output_resolution = output_resolution.astype(np.int32)
|
| 144 |
+
size = output_resolution
|
| 145 |
+
frames, size = self.resize_frames(frames, size)
|
| 146 |
+
|
| 147 |
+
return frames
|
| 148 |
+
|
| 149 |
+
def _process_frames_in_batches(self, frames: List[Image.Image], paths: Any,
|
| 150 |
+
comp_frames: List[Optional[np.ndarray]],
|
| 151 |
+
processed_frame_mask: List[bool]) -> None:
|
| 152 |
+
"""Process frames in batches with temporal overlap."""
|
| 153 |
+
video_length = len(frames)
|
| 154 |
+
h, w = frames[0].height, frames[0].width
|
| 155 |
+
|
| 156 |
+
for batch_start in tqdm(range(0, video_length, self.batch_size * self.neighbor_stride),
|
| 157 |
+
desc="Processing batches"):
|
| 158 |
+
batch_end = min(batch_start + self.batch_size * self.neighbor_stride + self.neighbor_stride, video_length)
|
| 159 |
+
|
| 160 |
+
# Prepare batch data
|
| 161 |
+
batch_data = self._prepare_batch_data(frames, paths, batch_start, batch_end, h, w)
|
| 162 |
+
|
| 163 |
+
# Process frames within batch
|
| 164 |
+
self._process_batch_frames(frames, batch_data, batch_start, batch_end,
|
| 165 |
+
comp_frames, processed_frame_mask, h, w)
|
| 166 |
+
|
| 167 |
+
# Clean up batch memory
|
| 168 |
+
del batch_data['batch_imgs'], batch_data['batch_masks']
|
| 169 |
+
self._clear_gpu_memory()
|
| 170 |
+
|
| 171 |
+
def _prepare_batch_data(self, frames: List[Image.Image], paths: Any,
|
| 172 |
+
batch_start: int, batch_end: int, h: int, w: int) -> dict:
|
| 173 |
+
"""Prepare batch data including frames, masks, and binary masks."""
|
| 174 |
+
batch_frames = frames[batch_start:batch_end]
|
| 175 |
+
batch_imgs = to_tensors()(batch_frames).unsqueeze(0).to(self.device) * 2 - 1
|
| 176 |
+
|
| 177 |
+
batch_masks = self.read_mask(paths.masks_arm, (w, h))[batch_start:batch_end]
|
| 178 |
+
batch_masks = to_tensors()(batch_masks).unsqueeze(0).to(self.device)
|
| 179 |
+
|
| 180 |
+
binary_masks = self._create_binary_masks(paths.masks_arm, batch_start, batch_end, w, h)
|
| 181 |
+
|
| 182 |
+
return {
|
| 183 |
+
'batch_imgs': batch_imgs,
|
| 184 |
+
'batch_masks': batch_masks,
|
| 185 |
+
'binary_masks': binary_masks
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
def _create_binary_masks(self, mask_path: str, batch_start: int, batch_end: int,
|
| 189 |
+
w: int, h: int) -> List[np.ndarray]:
|
| 190 |
+
"""Create binary masks for the batch."""
|
| 191 |
+
masks = self.read_mask(mask_path, (w, h))[batch_start:batch_end]
|
| 192 |
+
binary_masks = []
|
| 193 |
+
|
| 194 |
+
for mask in masks:
|
| 195 |
+
mask_array = np.array(mask)
|
| 196 |
+
binary_mask = np.expand_dims((mask_array != 0).astype(np.uint8), 2)
|
| 197 |
+
binary_mask = cv2.resize(binary_mask, (w, h), interpolation=cv2.INTER_NEAREST)
|
| 198 |
+
binary_mask = np.expand_dims(binary_mask, 2)
|
| 199 |
+
binary_masks.append(binary_mask)
|
| 200 |
+
|
| 201 |
+
return binary_masks
|
| 202 |
+
|
| 203 |
+
def _process_batch_frames(self, frames: List[Image.Image], batch_data: dict,
|
| 204 |
+
batch_start: int, batch_end: int,
|
| 205 |
+
comp_frames: List[Optional[np.ndarray]],
|
| 206 |
+
processed_frame_mask: List[bool], h: int, w: int) -> None:
|
| 207 |
+
"""Process individual frames within a batch."""
|
| 208 |
+
stride = max(1, self.neighbor_stride if batch_start + self.batch_size * self.neighbor_stride < len(frames) else 1)
|
| 209 |
+
|
| 210 |
+
for frame_idx in range(batch_start, batch_end, stride):
|
| 211 |
+
neighbor_ids = self._get_neighbor_ids(frame_idx, batch_start, batch_end)
|
| 212 |
+
ref_ids = self.get_ref_index(frame_idx, neighbor_ids, batch_end)
|
| 213 |
+
|
| 214 |
+
if not neighbor_ids:
|
| 215 |
+
continue
|
| 216 |
+
|
| 217 |
+
# Convert to batch-relative indices
|
| 218 |
+
batch_neighbor_ids = [i - batch_start for i in neighbor_ids]
|
| 219 |
+
batch_ref_ids = [i - batch_start for i in ref_ids if batch_start <= i < batch_end]
|
| 220 |
+
|
| 221 |
+
# Process frame with temporal context
|
| 222 |
+
self._process_single_frame(frames, batch_data, neighbor_ids, batch_neighbor_ids,
|
| 223 |
+
batch_ref_ids, comp_frames, processed_frame_mask, h, w)
|
| 224 |
+
|
| 225 |
+
self._clear_gpu_memory()
|
| 226 |
+
|
| 227 |
+
def _get_neighbor_ids(self, frame_idx: int, batch_start: int, batch_end: int) -> List[int]:
|
| 228 |
+
"""Get neighboring frame indices for temporal context."""
|
| 229 |
+
return list(range(
|
| 230 |
+
max(batch_start, frame_idx - self.neighbor_stride),
|
| 231 |
+
min(batch_end, frame_idx + self.neighbor_stride + 1)
|
| 232 |
+
))
|
| 233 |
+
|
| 234 |
+
def _process_single_frame(self, frames: List[Image.Image], batch_data: dict,
|
| 235 |
+
neighbor_ids: List[int], batch_neighbor_ids: List[int],
|
| 236 |
+
batch_ref_ids: List[int], comp_frames: List[Optional[np.ndarray]],
|
| 237 |
+
processed_frame_mask: List[bool], h: int, w: int) -> None:
|
| 238 |
+
"""Process a single frame with its temporal context."""
|
| 239 |
+
batch_start = neighbor_ids[0] - batch_neighbor_ids[0]
|
| 240 |
+
|
| 241 |
+
# Select relevant frames and masks
|
| 242 |
+
selected_imgs = batch_data['batch_imgs'][:, batch_neighbor_ids + batch_ref_ids, :, :, :]
|
| 243 |
+
selected_masks = batch_data['batch_masks'][:, batch_neighbor_ids + batch_ref_ids, :, :]
|
| 244 |
+
|
| 245 |
+
with torch.no_grad():
|
| 246 |
+
# Apply masks and generate inpainted frames
|
| 247 |
+
masked_imgs = selected_imgs * (1 - selected_masks)
|
| 248 |
+
masked_imgs = self._pad_images(masked_imgs, h, w)
|
| 249 |
+
|
| 250 |
+
pred_imgs, _ = self.model(masked_imgs, len(batch_neighbor_ids))
|
| 251 |
+
pred_imgs = (pred_imgs[:, :, :h, :w] + 1) / 2
|
| 252 |
+
pred_imgs = (pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)
|
| 253 |
+
|
| 254 |
+
# Composite with original background
|
| 255 |
+
for i, idx in enumerate(neighbor_ids):
|
| 256 |
+
binary_mask = batch_data['binary_masks'][idx - batch_start]
|
| 257 |
+
original_frame = np.array(frames[idx])
|
| 258 |
+
|
| 259 |
+
inpainted_frame = (pred_imgs[i] * binary_mask +
|
| 260 |
+
original_frame * (1 - binary_mask))
|
| 261 |
+
|
| 262 |
+
# Average with previous results if frame was already processed
|
| 263 |
+
if comp_frames[idx] is None:
|
| 264 |
+
comp_frames[idx] = inpainted_frame
|
| 265 |
+
else:
|
| 266 |
+
comp_frames[idx] = ((comp_frames[idx].astype(np.float32) +
|
| 267 |
+
inpainted_frame.astype(np.float32)) / 2).astype(np.uint8)
|
| 268 |
+
processed_frame_mask[idx] = True
|
| 269 |
+
|
| 270 |
+
def _process_missed_frames(self, frames: List[Image.Image], paths: Any,
|
| 271 |
+
comp_frames: List[Optional[np.ndarray]],
|
| 272 |
+
processed_frame_mask: List[bool]) -> None:
|
| 273 |
+
"""Process any frames that were missed during batch processing."""
|
| 274 |
+
unprocessed_frames = [i for i, processed in enumerate(processed_frame_mask) if not processed]
|
| 275 |
+
|
| 276 |
+
if not unprocessed_frames:
|
| 277 |
+
return
|
| 278 |
+
|
| 279 |
+
logger.warning(f"Found {len(unprocessed_frames)} unprocessed frames at indices: {unprocessed_frames}")
|
| 280 |
+
|
| 281 |
+
# Determine processing context for missed frames
|
| 282 |
+
start_idx, end_idx = self._get_missed_frame_context(unprocessed_frames, processed_frame_mask, len(frames))
|
| 283 |
+
|
| 284 |
+
logger.info(f"Processing missed frames from {start_idx} to {end_idx}")
|
| 285 |
+
self._process_missed_frame_sequence(frames, paths, unprocessed_frames,
|
| 286 |
+
start_idx, end_idx, comp_frames, processed_frame_mask)
|
| 287 |
+
|
| 288 |
+
def _get_missed_frame_context(self, unprocessed_frames: List[int],
|
| 289 |
+
processed_frame_mask: List[bool], video_length: int) -> Tuple[int, int]:
|
| 290 |
+
"""Get the context range for processing missed frames."""
|
| 291 |
+
last_processed_idx = max([i for i, processed in enumerate(processed_frame_mask[:unprocessed_frames[0]])
|
| 292 |
+
if processed], default=-1)
|
| 293 |
+
if last_processed_idx == -1:
|
| 294 |
+
last_processed_idx = 0
|
| 295 |
+
|
| 296 |
+
next_processed_idx = min([i for i, processed in enumerate(processed_frame_mask[unprocessed_frames[-1]:],
|
| 297 |
+
start=unprocessed_frames[-1]) if processed], default=video_length)
|
| 298 |
+
|
| 299 |
+
start_idx = max(0, last_processed_idx - self.neighbor_stride)
|
| 300 |
+
end_idx = min(video_length, next_processed_idx + self.neighbor_stride)
|
| 301 |
+
|
| 302 |
+
return start_idx, end_idx
|
| 303 |
+
|
| 304 |
+
def _process_missed_frame_sequence(self, frames: List[Image.Image], paths: Any,
|
| 305 |
+
unprocessed_frames: List[int], start_idx: int, end_idx: int,
|
| 306 |
+
comp_frames: List[Optional[np.ndarray]],
|
| 307 |
+
processed_frame_mask: List[bool]) -> None:
|
| 308 |
+
"""Process the sequence containing missed frames."""
|
| 309 |
+
h, w = frames[0].height, frames[0].width
|
| 310 |
+
|
| 311 |
+
# Prepare sequence data
|
| 312 |
+
batch_frames = frames[start_idx:end_idx]
|
| 313 |
+
batch_imgs = to_tensors()(batch_frames).unsqueeze(0).to(self.device) * 2 - 1
|
| 314 |
+
|
| 315 |
+
batch_masks = self.read_mask(paths.masks_arm, (w, h))[start_idx:end_idx]
|
| 316 |
+
batch_masks = to_tensors()(batch_masks).unsqueeze(0).to(self.device)
|
| 317 |
+
|
| 318 |
+
binary_masks = self._create_binary_masks(paths.masks_arm, start_idx, end_idx, w, h)
|
| 319 |
+
|
| 320 |
+
# Process each missed frame
|
| 321 |
+
for idx in tqdm(unprocessed_frames, desc="Processing missed frames"):
|
| 322 |
+
self._process_missed_single_frame(frames, batch_imgs, batch_masks, binary_masks,
|
| 323 |
+
idx, start_idx, end_idx, comp_frames, processed_frame_mask, h, w)
|
| 324 |
+
|
| 325 |
+
del batch_imgs, batch_masks
|
| 326 |
+
self._clear_gpu_memory()
|
| 327 |
+
|
| 328 |
+
def _process_missed_single_frame(self, frames: List[Image.Image], batch_imgs: torch.Tensor,
|
| 329 |
+
batch_masks: torch.Tensor, binary_masks: List[np.ndarray],
|
| 330 |
+
frame_idx: int, start_idx: int, end_idx: int,
|
| 331 |
+
comp_frames: List[Optional[np.ndarray]],
|
| 332 |
+
processed_frame_mask: List[bool], h: int, w: int) -> None:
|
| 333 |
+
"""Process a single missed frame."""
|
| 334 |
+
relative_start = frame_idx - start_idx
|
| 335 |
+
neighbor_ids = list(range(
|
| 336 |
+
max(0, relative_start - self.neighbor_stride),
|
| 337 |
+
min(end_idx - start_idx, relative_start + self.neighbor_stride + 1)
|
| 338 |
+
))
|
| 339 |
+
ref_ids = self.get_ref_index(relative_start, neighbor_ids, end_idx - start_idx)
|
| 340 |
+
|
| 341 |
+
with torch.no_grad():
|
| 342 |
+
selected_imgs = batch_imgs[:, neighbor_ids + ref_ids, :, :, :]
|
| 343 |
+
selected_masks = batch_masks[:, neighbor_ids + ref_ids, :, :]
|
| 344 |
+
|
| 345 |
+
masked_imgs = selected_imgs * (1 - selected_masks)
|
| 346 |
+
masked_imgs = self._pad_images(masked_imgs, h, w)
|
| 347 |
+
|
| 348 |
+
pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids))
|
| 349 |
+
pred_imgs = (pred_imgs[:, :, :h, :w] + 1) / 2
|
| 350 |
+
pred_imgs = (pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255).astype(np.uint8)
|
| 351 |
+
|
| 352 |
+
relative_idx = frame_idx - start_idx - neighbor_ids[0]
|
| 353 |
+
binary_mask = binary_masks[frame_idx - start_idx]
|
| 354 |
+
original_frame = np.array(frames[frame_idx])
|
| 355 |
+
|
| 356 |
+
inpainted_frame = (pred_imgs[relative_idx] * binary_mask +
|
| 357 |
+
original_frame * (1 - binary_mask))
|
| 358 |
+
comp_frames[frame_idx] = inpainted_frame
|
| 359 |
+
processed_frame_mask[frame_idx] = True
|
| 360 |
+
|
| 361 |
+
def _verify_and_save_results(self, comp_frames: List[Optional[np.ndarray]], paths: Any) -> None:
|
| 362 |
+
"""Verify all frames were processed and save the final video."""
|
| 363 |
+
missing_frames = [i for i, frame in enumerate(comp_frames)
|
| 364 |
+
if frame is None or (isinstance(frame, np.ndarray) and frame.size == 0)]
|
| 365 |
+
|
| 366 |
+
if missing_frames:
|
| 367 |
+
raise RuntimeError(f"Still found unprocessed frames after cleanup: {missing_frames}")
|
| 368 |
+
|
| 369 |
+
logger.info("Successfully processed all frames")
|
| 370 |
+
|
| 371 |
+
# Save final inpainted video
|
| 372 |
+
media.write_video(paths.video_human_inpaint, comp_frames, fps=15, codec="ffv1")
|
| 373 |
+
|
| 374 |
+
def get_ref_index(self, f: int, neighbor_ids: List[int], length: int) -> List[int]:
|
| 375 |
+
"""
|
| 376 |
+
Select reference frame indices for temporal consistency.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
f: Current frame index
|
| 380 |
+
neighbor_ids: List of neighboring frame indices
|
| 381 |
+
length: Total length of the sequence
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
List of reference frame indices for temporal consistency
|
| 385 |
+
"""
|
| 386 |
+
if self.num_ref == -1:
|
| 387 |
+
# Automatic reference selection: every ref_length frames not in neighbors
|
| 388 |
+
ref_index = [
|
| 389 |
+
i for i in range(0, length, self.ref_length)
|
| 390 |
+
if i not in neighbor_ids
|
| 391 |
+
]
|
| 392 |
+
else:
|
| 393 |
+
# Limited reference selection: specific number around current frame
|
| 394 |
+
ref_index = []
|
| 395 |
+
for i in range(max(0, f - self.ref_length * (self.num_ref // 2)),
|
| 396 |
+
min(length, f + self.ref_length * (self.num_ref // 2)) + 1,
|
| 397 |
+
self.ref_length):
|
| 398 |
+
if i not in neighbor_ids and len(ref_index) < self.num_ref:
|
| 399 |
+
ref_index.append(i)
|
| 400 |
+
return ref_index
|
| 401 |
+
|
| 402 |
+
@staticmethod
|
| 403 |
+
def read_mask(mask_path: str, size: Tuple[int, int]) -> List[Image.Image]:
|
| 404 |
+
"""
|
| 405 |
+
Load and process hand segmentation masks for inpainting guidance.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
mask_path: Path to mask file containing hand segmentation data
|
| 409 |
+
size: Target size (width, height) for mask resizing
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
List of processed PIL Images containing binary hand masks
|
| 413 |
+
"""
|
| 414 |
+
masks = []
|
| 415 |
+
frames_media = np.load(mask_path, allow_pickle=True)
|
| 416 |
+
frames = [frame for frame in frames_media]
|
| 417 |
+
|
| 418 |
+
for mask_frame in frames:
|
| 419 |
+
# Convert to PIL Image and resize
|
| 420 |
+
mask_img = Image.fromarray(mask_frame)
|
| 421 |
+
mask_img = mask_img.resize(size, Image.NEAREST)
|
| 422 |
+
mask_array = np.array(mask_img.convert('L'))
|
| 423 |
+
|
| 424 |
+
# Create binary mask
|
| 425 |
+
binary_mask = np.array(mask_array > 0).astype(np.uint8)
|
| 426 |
+
|
| 427 |
+
# Apply morphological dilation to expand mask boundaries
|
| 428 |
+
# This helps ensure complete coverage of hand regions
|
| 429 |
+
dilated_mask = cv2.dilate(binary_mask,
|
| 430 |
+
cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
|
| 431 |
+
iterations=4)
|
| 432 |
+
masks.append(Image.fromarray(dilated_mask * 255))
|
| 433 |
+
return masks
|
| 434 |
+
|
| 435 |
+
@staticmethod
|
| 436 |
+
def read_frame_from_videos(video_path: str) -> List[Image.Image]:
|
| 437 |
+
"""
|
| 438 |
+
Load video frames and convert to PIL Images.
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
video_path: Path to video file
|
| 442 |
+
|
| 443 |
+
Returns:
|
| 444 |
+
List of PIL Images containing video frames
|
| 445 |
+
"""
|
| 446 |
+
return [Image.fromarray(frame) for frame in media.read_video(video_path)]
|
| 447 |
+
|
| 448 |
+
@staticmethod
|
| 449 |
+
def resize_frames(frames: List[Image.Image], size: Optional[Tuple[int, int]] = None) -> Tuple[List[Image.Image], Tuple[int, int]]:
|
| 450 |
+
"""
|
| 451 |
+
Resize video frames to target resolution.
|
| 452 |
+
|
| 453 |
+
Args:
|
| 454 |
+
frames: List of PIL Images to resize
|
| 455 |
+
size: Target size (width, height), or None to keep original
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
Tuple containing resized frames and final size
|
| 459 |
+
"""
|
| 460 |
+
return ([f.resize(size) for f in frames], size)
|
| 461 |
+
|
| 462 |
+
@staticmethod
|
| 463 |
+
def _pad_images(img_tensor: torch.Tensor, h: int, w: int) -> torch.Tensor:
|
| 464 |
+
"""
|
| 465 |
+
Pad image tensor to meet model input requirements.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
img_tensor: Input image tensor to pad
|
| 469 |
+
h: Original height
|
| 470 |
+
w: Original width
|
| 471 |
+
|
| 472 |
+
Returns:
|
| 473 |
+
Padded image tensor suitable for model input
|
| 474 |
+
"""
|
| 475 |
+
# Model requires specific dimension multiples
|
| 476 |
+
mod_size_h, mod_size_w = 60, 108
|
| 477 |
+
|
| 478 |
+
# Calculate required padding
|
| 479 |
+
h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
|
| 480 |
+
w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
|
| 481 |
+
|
| 482 |
+
# Apply reflection padding to avoid boundary artifacts
|
| 483 |
+
img_tensor = torch.cat([img_tensor, torch.flip(img_tensor, [3])], 3)[:, :, :, :h + h_pad, :]
|
| 484 |
+
return torch.cat([img_tensor, torch.flip(img_tensor, [4])], 4)[:, :, :, :, :w + w_pad]
|
| 485 |
+
|
phantom/phantom/processors/paths.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Path management for Phantom.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import List, Dict, Optional
|
| 8 |
+
import yaml
|
| 9 |
+
|
| 10 |
+
from phantom.utils.image_utils import convert_video_to_images
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class Paths:
|
| 14 |
+
"""Data class containing all file paths used by processors."""
|
| 15 |
+
data_path: Path
|
| 16 |
+
robot_name: str = "franka"
|
| 17 |
+
|
| 18 |
+
def __post_init__(self):
|
| 19 |
+
"""Compute derived paths based on base paths."""
|
| 20 |
+
# Convert string paths to Path objects if needed
|
| 21 |
+
if isinstance(self.data_path, str):
|
| 22 |
+
self.data_path = Path(self.data_path)
|
| 23 |
+
|
| 24 |
+
# Validate data path
|
| 25 |
+
if not self.data_path.exists():
|
| 26 |
+
raise FileNotFoundError(f"Data path does not exist: {self.data_path}")
|
| 27 |
+
|
| 28 |
+
# Videos
|
| 29 |
+
self.video_left = self.data_path / "video_L.mp4"
|
| 30 |
+
self.video_right = self.data_path / "video_R.mp4"
|
| 31 |
+
self.video_rgb_imgs = self.data_path / "video_rgb_imgs.mkv"
|
| 32 |
+
|
| 33 |
+
# Image folders
|
| 34 |
+
self.original_images_folder = self.data_path / "original_images"
|
| 35 |
+
# self._setup_original_images()
|
| 36 |
+
self.original_images_folder_reverse = self.data_path / "original_images_reverse"
|
| 37 |
+
# self._setup_original_images_reverse()
|
| 38 |
+
|
| 39 |
+
# Epic annotations
|
| 40 |
+
self.hand_detection_data = self.data_path / "hand_det.pkl"
|
| 41 |
+
self.cam_extrinsics_data = self.data_path / "extrinsics.npy"
|
| 42 |
+
|
| 43 |
+
# Depth
|
| 44 |
+
self.depth = self.data_path / "depth.npy"
|
| 45 |
+
|
| 46 |
+
# Bbox processor
|
| 47 |
+
self.bbox_processor = self.data_path / "bbox_processor"
|
| 48 |
+
self.bbox_data = self.bbox_processor / "bbox_data.npz"
|
| 49 |
+
self.video_bboxes = self.bbox_processor / "video_bboxes.mkv"
|
| 50 |
+
|
| 51 |
+
# Segmentation processor
|
| 52 |
+
self.segmentation_processor = self.data_path / "segmentation_processor"
|
| 53 |
+
self.masks_arm = self.segmentation_processor / "masks_arm.npy"
|
| 54 |
+
self.video_masks_arm = self.segmentation_processor / "video_masks_arm.mkv"
|
| 55 |
+
self.video_sam_arm = self.segmentation_processor / "video_sam_arm.mkv"
|
| 56 |
+
for side in ["left", "right"]:
|
| 57 |
+
setattr(self, f"masks_hand_{side}", self.segmentation_processor / f"masks_hand_{side}.npy")
|
| 58 |
+
setattr(self, f"video_masks_hand_{side}", self.segmentation_processor / f"video_masks_hand_{side}.mkv")
|
| 59 |
+
setattr(self, f"video_sam_hand_{side}", self.segmentation_processor / f"video_sam_hand_{side}.mkv")
|
| 60 |
+
|
| 61 |
+
# Hand Processor
|
| 62 |
+
self.hand_processor = self.data_path / f"hand_processor"
|
| 63 |
+
for side in ["left", "right"]:
|
| 64 |
+
setattr(self, f"hand_data_{side}", self.hand_processor / f"hand_data_{side}.npz")
|
| 65 |
+
setattr(self, f"hand_data_3d_{side}", self.hand_processor / f"hand_data_3d_{side}.npz")
|
| 66 |
+
self.video_annot = self.data_path / "video_annot.mp4"
|
| 67 |
+
|
| 68 |
+
# Action processor
|
| 69 |
+
self.action_processor = self.data_path / "action_processor"
|
| 70 |
+
for side in ["left", "right"]:
|
| 71 |
+
setattr(self, f"actions_{side}", self.action_processor / f"actions_{side}.npz")
|
| 72 |
+
|
| 73 |
+
# Smoothing processor
|
| 74 |
+
self.smoothing_processor = self.data_path / f"smoothing_processor"
|
| 75 |
+
for side in ["left", "right"]:
|
| 76 |
+
setattr(self, f"smoothed_actions_{side}", self.smoothing_processor / f"smoothed_actions_{side}.npz")
|
| 77 |
+
|
| 78 |
+
# Inpaint processor
|
| 79 |
+
self.inpaint_processor = self.data_path / "inpaint_processor"
|
| 80 |
+
self.video_overlay = self.data_path / "video_overlay.mkv"
|
| 81 |
+
self.video_human_inpaint = self.inpaint_processor / "video_human_inpaint.mkv"
|
| 82 |
+
self.video_inpaint_overlay = self.inpaint_processor / "video_inpaint_overlay.mkv"
|
| 83 |
+
self.video_birdview = self.inpaint_processor / "video_birdview.mkv"
|
| 84 |
+
self.training_data = self.inpaint_processor / "training_data.npz"
|
| 85 |
+
|
| 86 |
+
def _setup_original_images(self):
|
| 87 |
+
"""Set up original images paths."""
|
| 88 |
+
convert_video_to_images(self.video_left, self.original_images_folder, square=False)
|
| 89 |
+
image_paths = sorted(
|
| 90 |
+
list(self.original_images_folder.glob("*.jpg")),
|
| 91 |
+
key=lambda x: int(x.stem)
|
| 92 |
+
)
|
| 93 |
+
self.original_images = image_paths
|
| 94 |
+
|
| 95 |
+
def _setup_original_images_reverse(self):
|
| 96 |
+
"""Set up original images paths."""
|
| 97 |
+
convert_video_to_images(self.video_left, self.original_images_folder_reverse, square=False, reverse=True)
|
| 98 |
+
image_paths = sorted(
|
| 99 |
+
list(self.original_images_folder_reverse.glob("*.jpg")),
|
| 100 |
+
key=lambda x: int(x.stem)
|
| 101 |
+
)
|
| 102 |
+
self.original_images_reverse = image_paths
|
| 103 |
+
|
| 104 |
+
def ensure_directories_exist(self):
|
| 105 |
+
"""
|
| 106 |
+
Create necessary directories if they don't exist.
|
| 107 |
+
"""
|
| 108 |
+
# Create all necessary directories
|
| 109 |
+
directories = [
|
| 110 |
+
self.data_path,
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
for directory in directories:
|
| 114 |
+
if isinstance(directory, Path) and not directory.exists():
|
| 115 |
+
directory.mkdir(parents=True, exist_ok=True)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class PathsConfig:
|
| 120 |
+
"""
|
| 121 |
+
Configuration for paths used in the project.
|
| 122 |
+
|
| 123 |
+
This class handles loading and saving path configurations from files,
|
| 124 |
+
and provides methods for creating Paths objects.
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(self, config_file: Optional[str] = None) -> None:
|
| 128 |
+
"""
|
| 129 |
+
Initialize paths configuration.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
config_file: Path to configuration file. If None, use default config.
|
| 133 |
+
"""
|
| 134 |
+
self.config: dict[str, str] = {}
|
| 135 |
+
if config_file:
|
| 136 |
+
self.load_config(config_file)
|
| 137 |
+
else:
|
| 138 |
+
self.set_default_config()
|
| 139 |
+
|
| 140 |
+
def load_config(self, config_file: str) -> None:
|
| 141 |
+
"""
|
| 142 |
+
Load configuration from a YAML file.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
config_file: Path to configuration file
|
| 146 |
+
|
| 147 |
+
Raises:
|
| 148 |
+
FileNotFoundError: If config file doesn't exist
|
| 149 |
+
yaml.YAMLError: If config file is invalid YAML
|
| 150 |
+
"""
|
| 151 |
+
try:
|
| 152 |
+
with open(config_file, 'r') as f:
|
| 153 |
+
self.config = yaml.safe_load(f)
|
| 154 |
+
except FileNotFoundError:
|
| 155 |
+
raise FileNotFoundError(f"Configuration file not found: {config_file}")
|
| 156 |
+
except yaml.YAMLError as e:
|
| 157 |
+
raise yaml.YAMLError(f"Invalid YAML in configuration file {config_file}: {e}")
|
| 158 |
+
|
| 159 |
+
def save_config(self, config_file: str) -> None:
|
| 160 |
+
"""
|
| 161 |
+
Save configuration to a YAML file.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
config_file: Path to save configuration file
|
| 165 |
+
|
| 166 |
+
Raises:
|
| 167 |
+
OSError: If unable to write to the file
|
| 168 |
+
"""
|
| 169 |
+
with open(config_file, 'w') as f:
|
| 170 |
+
yaml.dump(self.config, f, default_flow_style=False)
|
| 171 |
+
|
| 172 |
+
def set_default_config(self) -> None:
|
| 173 |
+
"""Set default configuration values."""
|
| 174 |
+
self.config = {
|
| 175 |
+
'data_root': './data',
|
| 176 |
+
'processed_root': './processed_data',
|
| 177 |
+
'project_name': 'phantom',
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
def get_paths(self, demo_name: str, robot_name: str = "franka") -> Paths:
|
| 181 |
+
"""
|
| 182 |
+
Get Paths object for a specific demo.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
demo_name: Name of the demo
|
| 186 |
+
robot_name: Name of the robot
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Paths object for the demo
|
| 190 |
+
"""
|
| 191 |
+
data_path = os.path.join(self.config['data_root'], demo_name)
|
| 192 |
+
|
| 193 |
+
return Paths(
|
| 194 |
+
data_path=Path(data_path),
|
| 195 |
+
robot_name=robot_name
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
def get_all_demo_paths(self) -> List[str]:
|
| 199 |
+
"""
|
| 200 |
+
Get list of all demo paths in data root.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
List of demo paths
|
| 204 |
+
"""
|
| 205 |
+
data_root = self.config['data_root']
|
| 206 |
+
all_data_collection_folders = [
|
| 207 |
+
f for f in os.listdir(data_root)
|
| 208 |
+
if os.path.isdir(os.path.join(data_root, f))
|
| 209 |
+
]
|
| 210 |
+
|
| 211 |
+
all_data_folders = [
|
| 212 |
+
os.path.join(d1, d2)
|
| 213 |
+
for d1 in os.listdir(data_root)
|
| 214 |
+
if os.path.isdir(os.path.join(data_root, d1))
|
| 215 |
+
for d2 in os.listdir(os.path.join(data_root, d1))
|
| 216 |
+
if os.path.isdir(os.path.join(data_root, d1, d2))
|
| 217 |
+
]
|
| 218 |
+
|
| 219 |
+
return sorted(all_data_folders, key=lambda x: tuple(map(int, x.rsplit('/', 2)[-2:])))
|
phantom/phantom/processors/phantom_data.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Dict, List, Optional, Callable, Any
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
hand_side_dict = {
|
| 6 |
+
'left': 0,
|
| 7 |
+
'right': 1,
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
class LazyLoadingMixin:
|
| 11 |
+
"""Mixin to provide lazy loading functionality for cached properties."""
|
| 12 |
+
|
| 13 |
+
def _invalidate_cache(self) -> None:
|
| 14 |
+
"""Invalidate all cached properties. Override in subclasses."""
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
def _get_cached_property(self, cache_attr: str, compute_func: Callable[[], Any]) -> Any:
|
| 18 |
+
"""Generic lazy loading for cached properties."""
|
| 19 |
+
if getattr(self, cache_attr) is None:
|
| 20 |
+
setattr(self, cache_attr, compute_func())
|
| 21 |
+
return getattr(self, cache_attr)
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class TrainingData:
|
| 25 |
+
"""Container for processing results"""
|
| 26 |
+
frame_idx: int
|
| 27 |
+
valid: bool
|
| 28 |
+
action_pos_left: np.ndarray
|
| 29 |
+
action_orixyzw_left: np.ndarray
|
| 30 |
+
action_pos_right: np.ndarray
|
| 31 |
+
action_orixyzw_right: np.ndarray
|
| 32 |
+
action_gripper_left: np.ndarray
|
| 33 |
+
action_gripper_right: np.ndarray
|
| 34 |
+
gripper_width_left: np.ndarray
|
| 35 |
+
gripper_width_right: np.ndarray
|
| 36 |
+
|
| 37 |
+
@classmethod
|
| 38 |
+
def create_empty_frame(cls, frame_idx: int) -> 'TrainingData':
|
| 39 |
+
"""Create a frame with no hand detection"""
|
| 40 |
+
return cls(
|
| 41 |
+
frame_idx=frame_idx,
|
| 42 |
+
valid=False,
|
| 43 |
+
action_pos_left=np.zeros((3,)),
|
| 44 |
+
action_orixyzw_left=np.zeros((4,)),
|
| 45 |
+
action_pos_right=np.zeros((3,)),
|
| 46 |
+
action_orixyzw_right=np.zeros((4,)),
|
| 47 |
+
action_gripper_left=0,
|
| 48 |
+
action_gripper_right=0,
|
| 49 |
+
gripper_width_left=0,
|
| 50 |
+
gripper_width_right=0,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
class TrainingDataSequence(LazyLoadingMixin):
|
| 54 |
+
"""Container for a sequence of training data"""
|
| 55 |
+
def __init__(self):
|
| 56 |
+
self.frames: List[TrainingData] = []
|
| 57 |
+
self.metadata: Dict = {}
|
| 58 |
+
|
| 59 |
+
self._frame_indices: Optional[np.ndarray] = None
|
| 60 |
+
self._valid: Optional[np.ndarray] = None
|
| 61 |
+
self._action_pos_left: Optional[np.ndarray] = None
|
| 62 |
+
self._action_orixyzw_left: Optional[np.ndarray] = None
|
| 63 |
+
self._action_pos_right: Optional[np.ndarray] = None
|
| 64 |
+
self._action_orixyzw_right: Optional[np.ndarray] = None
|
| 65 |
+
self._action_gripper_left: Optional[np.ndarray] = None
|
| 66 |
+
self._action_gripper_right: Optional[np.ndarray] = None
|
| 67 |
+
self._gripper_width_left: Optional[np.ndarray] = None
|
| 68 |
+
self._gripper_width_right: Optional[np.ndarray] = None
|
| 69 |
+
|
| 70 |
+
def add_frame(self, frame: TrainingData) -> None:
|
| 71 |
+
"""Add a frame to the sequence and invalidate cached properties."""
|
| 72 |
+
self.frames.append(frame)
|
| 73 |
+
self._invalidate_cache()
|
| 74 |
+
|
| 75 |
+
def save(self, path: str) -> None:
|
| 76 |
+
"""Save the sequence to disk in both frame-wise and sequence-wise formats"""
|
| 77 |
+
|
| 78 |
+
sequence_data = {
|
| 79 |
+
'frame_indices': self.frame_indices,
|
| 80 |
+
'valid': self.valid,
|
| 81 |
+
'action_pos_left': self.action_pos_left,
|
| 82 |
+
'action_orixyzw_left': self.action_orixyzw_left,
|
| 83 |
+
'action_pos_right': self.action_pos_right,
|
| 84 |
+
'action_orixyzw_right': self.action_orixyzw_right,
|
| 85 |
+
'action_gripper_left': self.action_gripper_left,
|
| 86 |
+
'action_gripper_right': self.action_gripper_right,
|
| 87 |
+
'gripper_width_left': self.gripper_width_left,
|
| 88 |
+
'gripper_width_right': self.gripper_width_right,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
np.savez_compressed(
|
| 92 |
+
path,
|
| 93 |
+
**sequence_data
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def frame_indices(self) -> np.ndarray:
|
| 98 |
+
"""Lazy loading of all frame indices"""
|
| 99 |
+
return self._get_cached_property(
|
| 100 |
+
'_frame_indices',
|
| 101 |
+
lambda: np.arange(len(self.frames))
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def valid(self) -> np.ndarray:
|
| 106 |
+
"""Lazy loading of all valid flags"""
|
| 107 |
+
return self._get_cached_property(
|
| 108 |
+
'_valid',
|
| 109 |
+
lambda: np.stack([f.valid for f in self.frames])
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
@property
|
| 113 |
+
def action_pos_left(self) -> np.ndarray:
|
| 114 |
+
"""Lazy loading of all action positions"""
|
| 115 |
+
return self._get_cached_property(
|
| 116 |
+
'_action_pos_left',
|
| 117 |
+
lambda: np.stack([f.action_pos_left for f in self.frames])
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def action_orixyzw_left(self) -> np.ndarray:
|
| 122 |
+
"""Lazy loading of all action orientations"""
|
| 123 |
+
return self._get_cached_property(
|
| 124 |
+
'_action_orixyzw_left',
|
| 125 |
+
lambda: np.stack([f.action_orixyzw_left for f in self.frames])
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
@property
|
| 129 |
+
def action_pos_right(self) -> np.ndarray:
|
| 130 |
+
"""Lazy loading of all action positions"""
|
| 131 |
+
return self._get_cached_property(
|
| 132 |
+
'_action_pos_right',
|
| 133 |
+
lambda: np.stack([f.action_pos_right for f in self.frames])
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
@property
|
| 137 |
+
def action_orixyzw_right(self) -> np.ndarray:
|
| 138 |
+
"""Lazy loading of all action orientations"""
|
| 139 |
+
return self._get_cached_property(
|
| 140 |
+
'_action_orixyzw_right',
|
| 141 |
+
lambda: np.stack([f.action_orixyzw_right for f in self.frames])
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
@property
|
| 145 |
+
def action_gripper_left(self) -> np.ndarray:
|
| 146 |
+
"""Lazy loading of all action gripper distances"""
|
| 147 |
+
return self._get_cached_property(
|
| 148 |
+
'_action_gripper_left',
|
| 149 |
+
lambda: np.stack([f.action_gripper_left for f in self.frames])
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def action_gripper_right(self) -> np.ndarray:
|
| 154 |
+
"""Lazy loading of all action gripper distances"""
|
| 155 |
+
return self._get_cached_property(
|
| 156 |
+
'_action_gripper_right',
|
| 157 |
+
lambda: np.stack([f.action_gripper_right for f in self.frames])
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
@property
|
| 161 |
+
def gripper_width_left(self) -> np.ndarray:
|
| 162 |
+
"""Lazy loading of all gripper widths"""
|
| 163 |
+
return self._get_cached_property(
|
| 164 |
+
'_gripper_width_left',
|
| 165 |
+
lambda: np.stack([f.gripper_width_left for f in self.frames])
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
@property
|
| 169 |
+
def gripper_width_right(self) -> np.ndarray:
|
| 170 |
+
"""Lazy loading of all gripper widths"""
|
| 171 |
+
return self._get_cached_property(
|
| 172 |
+
'_gripper_width_right',
|
| 173 |
+
lambda: np.stack([f.gripper_width_right for f in self.frames])
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def _invalidate_cache(self):
|
| 177 |
+
"""Invalidate all cached properties."""
|
| 178 |
+
self._frame_indices = None
|
| 179 |
+
self._valid = None
|
| 180 |
+
self._action_pos_left = None
|
| 181 |
+
self._action_orixyzw_left = None
|
| 182 |
+
self._action_pos_right = None
|
| 183 |
+
self._action_orixyzw_right = None
|
| 184 |
+
self._action_gripper_left = None
|
| 185 |
+
self._action_gripper_right = None
|
| 186 |
+
self._gripper_width_left = None
|
| 187 |
+
self._gripper_width_right = None
|
| 188 |
+
|
| 189 |
+
@classmethod
|
| 190 |
+
def load(cls, path: str) -> 'TrainingDataSequence':
|
| 191 |
+
"""Load a sequence from disk"""
|
| 192 |
+
data = np.load(path, allow_pickle=True)
|
| 193 |
+
sequence = cls()
|
| 194 |
+
|
| 195 |
+
sequence._frame_indices = data['frame_indices']
|
| 196 |
+
sequence._valid = data['valid']
|
| 197 |
+
sequence._action_pos_left = data['action_pos_left']
|
| 198 |
+
sequence._action_orixyzw_left = data['action_orixyzw_left']
|
| 199 |
+
sequence._action_pos_right = data['action_pos_right']
|
| 200 |
+
sequence._action_orixyzw_right = data['action_orixyzw_right']
|
| 201 |
+
sequence._action_gripper_left = data['action_gripper_left']
|
| 202 |
+
sequence._action_gripper_right = data['action_gripper_right']
|
| 203 |
+
sequence._gripper_width_left = data['gripper_width_left']
|
| 204 |
+
sequence._gripper_width_right = data['gripper_width_right']
|
| 205 |
+
|
| 206 |
+
return sequence
|
| 207 |
+
|
| 208 |
+
@dataclass
|
| 209 |
+
class HandFrame:
|
| 210 |
+
"""Data structure for a single frame of hand data"""
|
| 211 |
+
frame_idx: int
|
| 212 |
+
hand_detected: bool
|
| 213 |
+
img_rgb: np.ndarray
|
| 214 |
+
img_hamer: np.ndarray
|
| 215 |
+
kpts_2d: np.ndarray # shape: (N, 2)
|
| 216 |
+
kpts_3d: np.ndarray # shape: (N, 3)
|
| 217 |
+
|
| 218 |
+
@classmethod
|
| 219 |
+
def create_empty_frame(cls, frame_idx: int, img_rgb: np.ndarray) -> 'HandFrame':
|
| 220 |
+
"""Create a frame with no hand detection"""
|
| 221 |
+
return cls(
|
| 222 |
+
frame_idx=frame_idx,
|
| 223 |
+
hand_detected=False,
|
| 224 |
+
img_rgb=img_rgb,
|
| 225 |
+
img_hamer=np.zeros_like(img_rgb),
|
| 226 |
+
kpts_2d=np.zeros((21, 2)),
|
| 227 |
+
kpts_3d=np.zeros((21, 3)),
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
class HandSequence(LazyLoadingMixin):
|
| 231 |
+
"""Container for a sequence of hand data"""
|
| 232 |
+
def __init__(self):
|
| 233 |
+
self.frames: List[HandFrame] = []
|
| 234 |
+
self.metadata: Dict = {}
|
| 235 |
+
|
| 236 |
+
self._frame_indices: Optional[np.ndarray] = None
|
| 237 |
+
self._hand_detected: Optional[np.ndarray] = None
|
| 238 |
+
self._img_rgb: Optional[np.ndarray] = None
|
| 239 |
+
self._img_hamer: Optional[np.ndarray] = None
|
| 240 |
+
self._kpts_2d: Optional[np.ndarray] = None
|
| 241 |
+
self._kpts_3d: Optional[np.ndarray] = None
|
| 242 |
+
|
| 243 |
+
def add_frame(self, frame: HandFrame) -> None:
|
| 244 |
+
"""Add a frame to the sequence and invalidate cached properties."""
|
| 245 |
+
self.frames.append(frame)
|
| 246 |
+
self._invalidate_cache()
|
| 247 |
+
|
| 248 |
+
def get_frame(self, frame_idx: int) -> HandFrame:
|
| 249 |
+
"""Get a frame by index."""
|
| 250 |
+
return self.frames[frame_idx]
|
| 251 |
+
|
| 252 |
+
def modify_frame(self, frame_idx: int, frame: HandFrame) -> None:
|
| 253 |
+
"""Modify a frame at the given index and invalidate cached properties."""
|
| 254 |
+
self.frames[frame_idx] = frame
|
| 255 |
+
self._invalidate_cache()
|
| 256 |
+
|
| 257 |
+
def save(self, path: str) -> None:
|
| 258 |
+
"""Save the sequence to disk in both frame-wise and sequence-wise formats"""
|
| 259 |
+
sequence_data = {
|
| 260 |
+
'hand_detected': self.hand_detected,
|
| 261 |
+
'kpts_2d': self.kpts_2d,
|
| 262 |
+
'kpts_3d': self.kpts_3d,
|
| 263 |
+
'frame_indices': self.frame_indices,
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
np.savez_compressed(
|
| 267 |
+
path,
|
| 268 |
+
**sequence_data
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
@property
|
| 272 |
+
def frame_indices(self) -> np.ndarray:
|
| 273 |
+
"""Lazy loading of all frame indices"""
|
| 274 |
+
return self._get_cached_property(
|
| 275 |
+
'_frame_indices',
|
| 276 |
+
lambda: np.arange(len(self.frames))
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
@property
|
| 280 |
+
def hand_detected(self) -> np.ndarray:
|
| 281 |
+
"""Lazy loading of all hand detection flags"""
|
| 282 |
+
return self._get_cached_property(
|
| 283 |
+
'_hand_detected',
|
| 284 |
+
lambda: np.stack([f.hand_detected for f in self.frames])
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
@property
|
| 288 |
+
def imgs_rgb(self) -> np.ndarray:
|
| 289 |
+
"""Lazy loading of all RGB images"""
|
| 290 |
+
return self._get_cached_property(
|
| 291 |
+
'_img_rgb',
|
| 292 |
+
lambda: np.stack([f.img_rgb for f in self.frames])
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
@property
|
| 296 |
+
def imgs_hamer(self) -> np.ndarray:
|
| 297 |
+
"""Lazy loading of all HAMER images"""
|
| 298 |
+
return self._get_cached_property(
|
| 299 |
+
'_img_hamer',
|
| 300 |
+
lambda: np.stack([f.img_hamer for f in self.frames])
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
@property
|
| 304 |
+
def kpts_2d(self) -> np.ndarray:
|
| 305 |
+
"""Lazy loading of all 2D keypoints"""
|
| 306 |
+
return self._get_cached_property(
|
| 307 |
+
'_kpts_2d',
|
| 308 |
+
lambda: np.stack([f.kpts_2d for f in self.frames])
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
@property
|
| 312 |
+
def kpts_3d(self) -> np.ndarray:
|
| 313 |
+
"""Lazy loading of all 3D keypoints"""
|
| 314 |
+
return self._get_cached_property(
|
| 315 |
+
'_kpts_3d',
|
| 316 |
+
lambda: np.stack([f.kpts_3d for f in self.frames])
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
@classmethod
|
| 320 |
+
def load(cls, path: str) -> 'HandSequence':
|
| 321 |
+
"""Load a sequence from disk"""
|
| 322 |
+
data = np.load(path, allow_pickle=True)
|
| 323 |
+
sequence = cls()
|
| 324 |
+
|
| 325 |
+
# Load pre-computed sequence-wise data
|
| 326 |
+
sequence._frame_indices = data['frame_indices']
|
| 327 |
+
sequence._hand_detected = data['hand_detected']
|
| 328 |
+
sequence._kpts_2d = data['kpts_2d']
|
| 329 |
+
sequence._kpts_3d = data['kpts_3d']
|
| 330 |
+
|
| 331 |
+
return sequence
|
| 332 |
+
|
| 333 |
+
def _invalidate_cache(self):
|
| 334 |
+
"""Invalidate all cached properties."""
|
| 335 |
+
self._frame_indices = None
|
| 336 |
+
self._hand_detected = None
|
| 337 |
+
self._img_rgb = None
|
| 338 |
+
self._img_hamer = None
|
| 339 |
+
self._kpts_2d = None
|
| 340 |
+
self._kpts_3d = None
|
phantom/phantom/processors/robotinpaint_processor.py
ADDED
|
@@ -0,0 +1,785 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Robot Inpainting Processor Module
|
| 3 |
+
|
| 4 |
+
This module uses MuJoCo to render robot models and overlay them onto human demonstration videos.
|
| 5 |
+
|
| 6 |
+
Processing Pipeline:
|
| 7 |
+
1. Load smoothed robot trajectories from previous processing stages
|
| 8 |
+
2. Initialize MuJoCo robot simulation with calibrated camera parameters
|
| 9 |
+
3. For each frame:
|
| 10 |
+
- Move simulated robot to target pose from human demonstration
|
| 11 |
+
- Render robot from calibrated camera viewpoint
|
| 12 |
+
- Apply depth-based occlusion handling (Optional)
|
| 13 |
+
- Create robot overlay on human demonstration video
|
| 14 |
+
4. Generate training data with robot state annotations
|
| 15 |
+
5. Save robot-inpainted videos and training data
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import pdb
|
| 20 |
+
import numpy as np
|
| 21 |
+
import cv2
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
import mediapy as media
|
| 24 |
+
from scipy.spatial.transform import Rotation
|
| 25 |
+
from typing import Tuple, Dict, List, Optional, Any, Union
|
| 26 |
+
import logging
|
| 27 |
+
from dataclasses import dataclass
|
| 28 |
+
|
| 29 |
+
from phantom.processors.phantom_data import TrainingData, TrainingDataSequence, HandSequence
|
| 30 |
+
from phantom.processors.base_processor import BaseProcessor
|
| 31 |
+
from phantom.twin_bimanual_robot import TwinBimanualRobot, MujocoCameraParams
|
| 32 |
+
from phantom.twin_robot import TwinRobot
|
| 33 |
+
from phantom.processors.paths import Paths
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class RobotState:
|
| 40 |
+
"""
|
| 41 |
+
Container for robot state data including pose and gripper configuration.
|
| 42 |
+
|
| 43 |
+
Attributes:
|
| 44 |
+
pos: 3D position coordinates in world frame
|
| 45 |
+
ori_xyzw: Quaternion orientation in XYZW format (scalar-last)
|
| 46 |
+
gripper_pos: Gripper opening distance or action value
|
| 47 |
+
"""
|
| 48 |
+
pos: np.ndarray
|
| 49 |
+
ori_xyzw: np.ndarray
|
| 50 |
+
gripper_pos: float
|
| 51 |
+
|
| 52 |
+
class RobotInpaintProcessor(BaseProcessor):
|
| 53 |
+
"""
|
| 54 |
+
Uses mujoco to overlay robot on human inpainted images.
|
| 55 |
+
"""
|
| 56 |
+
# Processing constants for quality control and output formatting
|
| 57 |
+
TRACKING_ERROR_THRESHOLD = 0.05 # Maximum tracking error in meters
|
| 58 |
+
DEFAULT_FPS = 15 # Standard frame rate for output videos
|
| 59 |
+
DEFAULT_CODEC = "ffv1" # Lossless codec for high-quality output
|
| 60 |
+
|
| 61 |
+
def __init__(self, args: Any) -> None:
|
| 62 |
+
"""
|
| 63 |
+
Initialize the robot inpainting processor with simulation parameters.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
args: Command line arguments containing robot configuration,
|
| 67 |
+
camera parameters, and processing options
|
| 68 |
+
"""
|
| 69 |
+
super().__init__(args)
|
| 70 |
+
self.use_depth = self.depth_for_overlay
|
| 71 |
+
self._initialize_robot()
|
| 72 |
+
|
| 73 |
+
def _initialize_robot(self) -> None:
|
| 74 |
+
"""
|
| 75 |
+
Initialize the twin robot simulation with calibrated camera parameters.
|
| 76 |
+
"""
|
| 77 |
+
# Generate MuJoCo camera parameters from real-world calibration
|
| 78 |
+
camera_params = self._get_mujoco_camera_params()
|
| 79 |
+
img_w, img_h = self._get_image_dimensions()
|
| 80 |
+
|
| 81 |
+
# Initialize appropriate robot configuration
|
| 82 |
+
if self.bimanual_setup == "single_arm":
|
| 83 |
+
self.twin_robot = TwinRobot(
|
| 84 |
+
self.robot,
|
| 85 |
+
self.gripper,
|
| 86 |
+
camera_params,
|
| 87 |
+
camera_height=img_h,
|
| 88 |
+
camera_width=img_w,
|
| 89 |
+
render=self.render,
|
| 90 |
+
n_steps_short=3,
|
| 91 |
+
n_steps_long=75,
|
| 92 |
+
debug_cameras=self.debug_cameras,
|
| 93 |
+
square=self.square,
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
self.twin_robot = TwinBimanualRobot(
|
| 97 |
+
self.robot,
|
| 98 |
+
self.gripper,
|
| 99 |
+
self.bimanual_setup,
|
| 100 |
+
camera_params,
|
| 101 |
+
camera_height=img_h,
|
| 102 |
+
camera_width=img_w,
|
| 103 |
+
render=self.render,
|
| 104 |
+
n_steps_short=10,
|
| 105 |
+
n_steps_long=75,
|
| 106 |
+
debug_cameras=self.debug_cameras,
|
| 107 |
+
epic=self.epic,
|
| 108 |
+
joint_controller=False, # Use operational-space control
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def __del__(self):
|
| 112 |
+
"""Clean up robot simulation resources."""
|
| 113 |
+
if hasattr(self, 'twin_robot'):
|
| 114 |
+
self.twin_robot.close()
|
| 115 |
+
|
| 116 |
+
def process_one_demo(self, data_sub_folder: str) -> None:
|
| 117 |
+
"""
|
| 118 |
+
Process a single demonstration to create robot-inpainted visualization.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
data_sub_folder: Path to demonstration data folder containing
|
| 122 |
+
smoothed trajectories and original video data
|
| 123 |
+
"""
|
| 124 |
+
save_folder = self.get_save_folder(data_sub_folder)
|
| 125 |
+
if self._should_skip_processing(save_folder):
|
| 126 |
+
return
|
| 127 |
+
paths = self.get_paths(save_folder)
|
| 128 |
+
|
| 129 |
+
# Reinitialize robot simulation for each demo to ensure clean state
|
| 130 |
+
self.__del__()
|
| 131 |
+
self._initialize_robot()
|
| 132 |
+
|
| 133 |
+
# Load and prepare demonstration data
|
| 134 |
+
data = self._load_data(paths)
|
| 135 |
+
images = self._load_images(paths, data["union_indices"])
|
| 136 |
+
gripper_actions, gripper_widths = self._process_gripper_widths(paths, data)
|
| 137 |
+
|
| 138 |
+
# Process all frames to generate robot overlays and training data
|
| 139 |
+
sequence, img_overlay, img_birdview = self._process_frames(images, data, gripper_actions, gripper_widths)
|
| 140 |
+
|
| 141 |
+
# Save comprehensive results
|
| 142 |
+
self._save_results(paths, sequence, img_overlay, img_birdview)
|
| 143 |
+
|
| 144 |
+
def _process_frames(self, images: Dict[str, np.ndarray], data: Dict[str, np.ndarray],
|
| 145 |
+
gripper_actions: Dict[str, np.ndarray], gripper_widths: Dict[str, np.ndarray]) -> Tuple[TrainingDataSequence, List[np.ndarray], Optional[List[np.ndarray]]]:
|
| 146 |
+
"""
|
| 147 |
+
Process each frame to generate robot overlays and training data.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
images: Dictionary containing human demonstration images and masks
|
| 151 |
+
data: Robot trajectory data (positions and orientations)
|
| 152 |
+
gripper_actions: Processed gripper action commands
|
| 153 |
+
gripper_widths: Gripper opening distances
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Tuple containing:
|
| 157 |
+
- TrainingDataSequence with robot state annotations
|
| 158 |
+
- List of robot overlay images
|
| 159 |
+
- Optional list of bird's eye view images (if debug cameras enabled)
|
| 160 |
+
"""
|
| 161 |
+
sequence = TrainingDataSequence()
|
| 162 |
+
img_overlay = []
|
| 163 |
+
img_birdview = None
|
| 164 |
+
if "birdview" in self.debug_cameras:
|
| 165 |
+
img_birdview = []
|
| 166 |
+
|
| 167 |
+
for idx in tqdm(range(len(images['human_imgs'])), desc="Processing frames"):
|
| 168 |
+
# Extract robot states for current frame
|
| 169 |
+
left_state = self._get_robot_state(
|
| 170 |
+
data['ee_pts_left'][idx],
|
| 171 |
+
data['ee_oris_left'][idx],
|
| 172 |
+
gripper_widths['left'][idx]
|
| 173 |
+
)
|
| 174 |
+
right_state = self._get_robot_state(
|
| 175 |
+
data['ee_pts_right'][idx],
|
| 176 |
+
data['ee_oris_right'][idx],
|
| 177 |
+
gripper_widths['right'][idx]
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Process individual frame with robot simulation
|
| 181 |
+
frame_results = self._process_single_frame(
|
| 182 |
+
images, left_state, right_state, idx
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# Handle failed processing (tracking errors, simulation issues)
|
| 186 |
+
if frame_results is None:
|
| 187 |
+
print(f"sdfsdfsTracking error too large at frame {idx}, skipping")
|
| 188 |
+
sequence.add_frame(TrainingData.create_empty_frame(
|
| 189 |
+
frame_idx=idx,
|
| 190 |
+
))
|
| 191 |
+
img_overlay.append(np.zeros_like(images['human_imgs'][idx]))
|
| 192 |
+
if "birdview" in self.debug_cameras:
|
| 193 |
+
img_birdview.append(np.zeros_like(images['human_imgs'][idx]))
|
| 194 |
+
else:
|
| 195 |
+
# Create comprehensive training data annotation
|
| 196 |
+
sequence.add_frame(TrainingData(
|
| 197 |
+
frame_idx=idx,
|
| 198 |
+
valid=True,
|
| 199 |
+
action_pos_left=left_state.pos,
|
| 200 |
+
action_orixyzw_left=left_state.ori_xyzw,
|
| 201 |
+
action_pos_right=right_state.pos,
|
| 202 |
+
action_orixyzw_right=right_state.ori_xyzw,
|
| 203 |
+
action_gripper_left=gripper_actions['left'][idx],
|
| 204 |
+
action_gripper_right=gripper_actions['right'][idx],
|
| 205 |
+
gripper_width_left=gripper_widths['left'][idx],
|
| 206 |
+
gripper_width_right=gripper_widths['right'][idx],
|
| 207 |
+
))
|
| 208 |
+
img_overlay.append(frame_results['rgb_robot_overlay'])
|
| 209 |
+
if "birdview" in self.debug_cameras:
|
| 210 |
+
img_birdview.append(frame_results['birdview_img'])
|
| 211 |
+
return sequence, img_overlay, img_birdview
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _process_single_frame(self, images: Dict[str, np.ndarray],
|
| 215 |
+
left_state: RobotState,
|
| 216 |
+
right_state: RobotState,
|
| 217 |
+
idx: int) -> Optional[Dict[str, np.ndarray]]:
|
| 218 |
+
"""
|
| 219 |
+
Process a single frame to generate robot overlay and validate tracking.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
images: Dictionary containing human images and segmentation data
|
| 223 |
+
left_state: Target state for left robot arm
|
| 224 |
+
right_state: Target state for right robot arm
|
| 225 |
+
idx: Frame index for initialization and logging
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
Dictionary containing rendered robot overlay and debug camera views,
|
| 229 |
+
or None if tracking error exceeds threshold
|
| 230 |
+
"""
|
| 231 |
+
# Prepare robot target state based on configuration
|
| 232 |
+
if self.bimanual_setup == "single_arm":
|
| 233 |
+
if self.target_hand == "left":
|
| 234 |
+
target_state = {
|
| 235 |
+
"pos": left_state.pos,
|
| 236 |
+
"ori_xyzw": left_state.ori_xyzw,
|
| 237 |
+
"gripper_pos": left_state.gripper_pos,
|
| 238 |
+
}
|
| 239 |
+
else:
|
| 240 |
+
target_state = {
|
| 241 |
+
"pos": right_state.pos,
|
| 242 |
+
"ori_xyzw": right_state.ori_xyzw,
|
| 243 |
+
"gripper_pos": right_state.gripper_pos,
|
| 244 |
+
}
|
| 245 |
+
else:
|
| 246 |
+
# Bimanual configuration requires coordinated control
|
| 247 |
+
target_state = {
|
| 248 |
+
"pos": [right_state.pos, left_state.pos],
|
| 249 |
+
"ori_xyzw": [right_state.ori_xyzw, left_state.ori_xyzw],
|
| 250 |
+
"gripper_pos": [right_state.gripper_pos, left_state.gripper_pos],
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
# Move robot to target state and get simulation results
|
| 254 |
+
robot_results = self.twin_robot.move_to_target_state(
|
| 255 |
+
target_state, init=(idx == 0) # Initialize on first frame
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# Validate tracking accuracy to ensure quality
|
| 259 |
+
if self.bimanual_setup == "single_arm":
|
| 260 |
+
if robot_results['pos_err'] > self.TRACKING_ERROR_THRESHOLD:
|
| 261 |
+
print(f"Tracking error too large at frame {idx}, skipping", robot_results['pos_err'])
|
| 262 |
+
logger.warning(f"Tracking error too large at frame {idx}, skipping")
|
| 263 |
+
return None
|
| 264 |
+
else:
|
| 265 |
+
if robot_results['left_pos_err'] > self.TRACKING_ERROR_THRESHOLD or robot_results['right_pos_err'] > self.TRACKING_ERROR_THRESHOLD:
|
| 266 |
+
logger.warning(f"Tracking error too large at frame {idx}, skipping")
|
| 267 |
+
return None
|
| 268 |
+
|
| 269 |
+
# Generate robot overlay using appropriate method
|
| 270 |
+
if self.use_depth:
|
| 271 |
+
rgb_robot_overlay = self._process_robot_overlay_with_depth(
|
| 272 |
+
images['human_imgs'][idx],
|
| 273 |
+
images['human_masks'][idx],
|
| 274 |
+
images['imgs_depth'][idx],
|
| 275 |
+
robot_results
|
| 276 |
+
)
|
| 277 |
+
else:
|
| 278 |
+
rgb_robot_overlay = self._process_robot_overlay(
|
| 279 |
+
images['human_imgs'][idx], robot_results
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Prepare output with main overlay and debug camera views
|
| 283 |
+
output = {
|
| 284 |
+
'rgb_robot_overlay': rgb_robot_overlay,
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
# Add debug camera views if requested
|
| 288 |
+
for cam in self.debug_cameras:
|
| 289 |
+
output[f"{cam}_img"] = (robot_results[f"{cam}_img"] * 255).astype(np.uint8)
|
| 290 |
+
|
| 291 |
+
return output
|
| 292 |
+
|
| 293 |
+
def _should_skip_processing(self, save_folder: str) -> bool:
|
| 294 |
+
"""
|
| 295 |
+
Check if processing should be skipped due to existing output files.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
save_folder: Directory where output files would be saved
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
True if processing should be skipped, False otherwise
|
| 302 |
+
"""
|
| 303 |
+
if self.skip_existing:
|
| 304 |
+
try:
|
| 305 |
+
with os.scandir(save_folder) as it:
|
| 306 |
+
existing_files = {entry.name for entry in it if entry.is_file()}
|
| 307 |
+
if str("video_overlay"+f"_{self.robot}_{self.bimanual_setup}.mkv") in existing_files:
|
| 308 |
+
print(f"Skipping existing demo {save_folder}")
|
| 309 |
+
return True
|
| 310 |
+
except FileNotFoundError:
|
| 311 |
+
return False
|
| 312 |
+
return False
|
| 313 |
+
|
| 314 |
+
def _load_data(self, paths: Paths) -> Dict[str, np.ndarray]:
|
| 315 |
+
"""
|
| 316 |
+
Load robot trajectory data from smoothed action files.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
paths: Paths object containing file locations
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
Dictionary containing robot trajectory data and frame indices
|
| 323 |
+
"""
|
| 324 |
+
if self.bimanual_setup == "single_arm":
|
| 325 |
+
# Get paths based on target hand for single-arm operation
|
| 326 |
+
smoothed_base = getattr(paths, f"smoothed_actions_{self.target_hand}")
|
| 327 |
+
actions_base = getattr(paths, f"actions_{self.target_hand}")
|
| 328 |
+
smoothed_actions_path = str(smoothed_base).replace(".npz", f"_{self.bimanual_setup}.npz")
|
| 329 |
+
actions_path = str(actions_base).replace(".npz", f"_{self.bimanual_setup}.npz")
|
| 330 |
+
|
| 331 |
+
# Load actual trajectory data for target hand
|
| 332 |
+
ee_pts = np.load(smoothed_actions_path)["ee_pts"]
|
| 333 |
+
ee_oris = np.load(smoothed_actions_path)["ee_oris"]
|
| 334 |
+
|
| 335 |
+
# Create dummy data for non-target hand
|
| 336 |
+
dummy_pts = np.zeros((len(ee_pts), 3))
|
| 337 |
+
dummy_oris = np.eye(3)[None, :, :].repeat(len(ee_oris), axis=0)
|
| 338 |
+
|
| 339 |
+
# Create data dictionary with target hand data and dummy data for other hand
|
| 340 |
+
other_hand = "right" if self.target_hand == "left" else "left"
|
| 341 |
+
return {
|
| 342 |
+
f'ee_pts_{self.target_hand}': ee_pts,
|
| 343 |
+
f'ee_oris_{self.target_hand}': ee_oris,
|
| 344 |
+
f'ee_pts_{other_hand}': dummy_pts,
|
| 345 |
+
f'ee_oris_{other_hand}': dummy_oris,
|
| 346 |
+
'union_indices': np.load(actions_path, allow_pickle=True)["union_indices"]
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
# Load bimanual trajectory data
|
| 350 |
+
smoothed_actions_left_path = str(paths.smoothed_actions_left).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
|
| 351 |
+
smoothed_actions_right_path = str(paths.smoothed_actions_right).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
|
| 352 |
+
actions_left_path = str(paths.actions_left).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
|
| 353 |
+
return {
|
| 354 |
+
'ee_pts_left': np.load(smoothed_actions_left_path)["ee_pts"],
|
| 355 |
+
'ee_oris_left': np.load(smoothed_actions_left_path)["ee_oris"],
|
| 356 |
+
'ee_pts_right': np.load(smoothed_actions_right_path)["ee_pts"],
|
| 357 |
+
'ee_oris_right': np.load(smoothed_actions_right_path)["ee_oris"],
|
| 358 |
+
'union_indices': np.load(actions_left_path, allow_pickle=True)["union_indices"]
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
def _load_images(self, paths: Paths, union_indices: np.ndarray) -> Dict[str, np.ndarray]:
|
| 362 |
+
"""
|
| 363 |
+
Load and index human demonstration images and associated data.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
paths: Paths object containing image file locations
|
| 367 |
+
union_indices: Frame indices to extract from full video sequences
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
Dictionary containing indexed human images, masks, and depth data
|
| 371 |
+
"""
|
| 372 |
+
return {
|
| 373 |
+
'human_masks': np.load(paths.masks_arm)[union_indices],
|
| 374 |
+
'human_imgs': np.array(media.read_video(paths.video_human_inpaint))[union_indices],
|
| 375 |
+
'imgs_depth': np.load(paths.depth)[union_indices] if self.use_depth else None
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
def _process_gripper_widths(self, paths: Paths, data: Dict[str, np.ndarray]) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
|
| 379 |
+
"""
|
| 380 |
+
Process gripper distance data into robot action commands.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
paths: Paths object containing smoothed action file locations
|
| 384 |
+
data: Dictionary containing trajectory data and frame indices
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
Tuple containing:
|
| 388 |
+
- Dictionary of gripper action commands for each hand
|
| 389 |
+
- Dictionary of gripper width values for each hand
|
| 390 |
+
"""
|
| 391 |
+
if self.bimanual_setup == "single_arm":
|
| 392 |
+
# Get the appropriate smoothed actions path based on target hand
|
| 393 |
+
base_path = getattr(paths, f"smoothed_actions_{self.target_hand}")
|
| 394 |
+
smoothed_actions_path = str(base_path).replace(".npz", f"_{self.bimanual_setup}.npz")
|
| 395 |
+
|
| 396 |
+
# Compute gripper actions and widths from smoothed data
|
| 397 |
+
actions, widths = self._compute_gripper_actions(
|
| 398 |
+
np.load(smoothed_actions_path)["ee_widths"]
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
# Create return dictionaries with actions for target hand, zeros for the other
|
| 402 |
+
num_indices = len(data['union_indices'])
|
| 403 |
+
other_hand = "right" if self.target_hand == "left" else "left"
|
| 404 |
+
|
| 405 |
+
return (
|
| 406 |
+
{self.target_hand: actions, other_hand: np.zeros(num_indices)},
|
| 407 |
+
{self.target_hand: widths, other_hand: np.zeros(num_indices)}
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# Process bimanual gripper data
|
| 411 |
+
smoothed_actions_left_path = str(paths.smoothed_actions_left).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
|
| 412 |
+
smoothed_actions_right_path = str(paths.smoothed_actions_right).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
|
| 413 |
+
left_actions, left_widths = self._compute_gripper_actions(
|
| 414 |
+
np.load(smoothed_actions_left_path)["ee_widths"]
|
| 415 |
+
)
|
| 416 |
+
right_actions, right_widths = self._compute_gripper_actions(
|
| 417 |
+
np.load(smoothed_actions_right_path)["ee_widths"]
|
| 418 |
+
)
|
| 419 |
+
return {'left': left_actions, 'right': right_actions}, {'left': left_widths, 'right': right_widths}
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def _compute_gripper_actions(self, list_gripper_dist: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 423 |
+
"""
|
| 424 |
+
Convert continuous gripper distances to discrete robot gripper actions.
|
| 425 |
+
Args:
|
| 426 |
+
list_gripper_dist: Array of gripper distances throughout trajectory
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
Tuple containing:
|
| 430 |
+
- Gripper action commands (0 for grasp, distance for open)
|
| 431 |
+
- Processed gripper width values
|
| 432 |
+
"""
|
| 433 |
+
try:
|
| 434 |
+
# Analyze gripper distance range and determine grasp threshold
|
| 435 |
+
min_val, max_val = np.min(list_gripper_dist), np.max(list_gripper_dist)
|
| 436 |
+
thresh = min_val + 0.2 * (max_val - min_val) # 20% above minimum
|
| 437 |
+
|
| 438 |
+
# Classify gripper states: 0 = closed/grasping, 1 = open
|
| 439 |
+
gripper_state = np.array([0 if dist < thresh else 1 for dist in list_gripper_dist])
|
| 440 |
+
|
| 441 |
+
# Find range of grasping action
|
| 442 |
+
min_idx_pos = np.where(gripper_state == 0)[0][0]
|
| 443 |
+
max_idx_pos = np.where(gripper_state == 0)[0][-1]
|
| 444 |
+
|
| 445 |
+
# Generate gripper action commands
|
| 446 |
+
list_gripper_actions = []
|
| 447 |
+
for idx in range(len(list_gripper_dist)):
|
| 448 |
+
if min_idx_pos <= idx <= max_idx_pos:
|
| 449 |
+
# During grasping phase: use grasp command (0) and limit distance
|
| 450 |
+
list_gripper_actions.append(0)
|
| 451 |
+
list_gripper_dist[idx] = np.min([list_gripper_dist[idx], thresh])
|
| 452 |
+
else:
|
| 453 |
+
# Outside grasping phase: use distance as action command
|
| 454 |
+
list_gripper_actions.append(list_gripper_dist[idx])
|
| 455 |
+
except:
|
| 456 |
+
# Fallback: use distances directly if processing fails
|
| 457 |
+
list_gripper_actions = list_gripper_dist.tolist()
|
| 458 |
+
|
| 459 |
+
return np.array(list_gripper_actions), list_gripper_dist
|
| 460 |
+
|
| 461 |
+
def _get_robot_state(self, ee_pt: np.ndarray, ori_matrix: np.ndarray, gripper_dist: float) -> RobotState:
|
| 462 |
+
"""
|
| 463 |
+
Convert trajectory data to robot state representation.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
ee_pt: End-effector position in 3D space
|
| 467 |
+
ori_matrix: 3x3 rotation matrix for end-effector orientation
|
| 468 |
+
gripper_dist: Gripper opening distance
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
RobotState object containing pose and gripper information
|
| 472 |
+
"""
|
| 473 |
+
# Convert rotation matrix to quaternion (XYZW format for robot control)
|
| 474 |
+
ori_xyzw = Rotation.from_matrix(ori_matrix).as_quat(scalar_first=False)
|
| 475 |
+
robot_state = RobotState(pos=ee_pt, ori_xyzw=ori_xyzw, gripper_pos=gripper_dist)
|
| 476 |
+
return robot_state
|
| 477 |
+
|
| 478 |
+
def _process_robot_overlay(self, img: np.ndarray, robot_results: Dict[str, Any]) -> np.ndarray:
|
| 479 |
+
"""
|
| 480 |
+
Create robot overlay on human image using segmentation masks.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
img: Original human demonstration image
|
| 484 |
+
robot_results: Dictionary containing robot rendering results
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
Image with robot overlay applied
|
| 488 |
+
"""
|
| 489 |
+
# Extract robot rendering and segmentation data
|
| 490 |
+
rgb_img_sim = (robot_results['rgb_img'] * 255).astype(np.uint8)
|
| 491 |
+
H, W = rgb_img_sim.shape[:2]
|
| 492 |
+
|
| 493 |
+
# Resize robot rendering and masks to match output resolution
|
| 494 |
+
if self.square:
|
| 495 |
+
rgb_img_sim = cv2.resize(rgb_img_sim, (self.output_resolution, self.output_resolution))
|
| 496 |
+
robot_mask = cv2.resize(robot_results['robot_mask'], (self.output_resolution, self.output_resolution))
|
| 497 |
+
robot_mask[robot_mask > 0] = 1
|
| 498 |
+
gripper_mask = cv2.resize(robot_results['gripper_mask'], (self.output_resolution, self.output_resolution))
|
| 499 |
+
gripper_mask[gripper_mask > 0] = 1
|
| 500 |
+
else:
|
| 501 |
+
rgb_img_sim = cv2.resize(rgb_img_sim, (int(W/H*self.output_resolution), self.output_resolution))
|
| 502 |
+
robot_mask = cv2.resize(robot_results['robot_mask'], (int(W/H*self.output_resolution), self.output_resolution))
|
| 503 |
+
robot_mask[robot_mask > 0] = 1
|
| 504 |
+
gripper_mask = cv2.resize(robot_results['gripper_mask'], (int(W/H*self.output_resolution), self.output_resolution))
|
| 505 |
+
gripper_mask[gripper_mask > 0] = 1
|
| 506 |
+
|
| 507 |
+
# Create overlay by compositing robot over human image
|
| 508 |
+
img_robot_overlay = img.copy()
|
| 509 |
+
overlay_mask = (robot_mask == 1) | (gripper_mask == 1)
|
| 510 |
+
img_robot_overlay[overlay_mask] = rgb_img_sim[overlay_mask]
|
| 511 |
+
|
| 512 |
+
return img_robot_overlay
|
| 513 |
+
|
| 514 |
+
def _process_robot_overlay_with_depth(self, img: np.ndarray, hand_mask: np.ndarray,
|
| 515 |
+
img_depth: np.ndarray, robot_results: Dict[str, Any]) -> np.ndarray:
|
| 516 |
+
"""
|
| 517 |
+
Create depth-aware robot overlay with realistic occlusion handling.
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
img: Original human demonstration image
|
| 521 |
+
hand_mask: Segmentation mask of human hand regions
|
| 522 |
+
img_depth: Depth image corresponding to the demonstration
|
| 523 |
+
robot_results: Dictionary containing robot rendering and depth results
|
| 524 |
+
|
| 525 |
+
Returns:
|
| 526 |
+
Image with depth-aware robot overlay applied
|
| 527 |
+
"""
|
| 528 |
+
# Extract robot rendering and depth data
|
| 529 |
+
robot_mask = robot_results['robot_mask']
|
| 530 |
+
gripper_mask = robot_results['gripper_mask']
|
| 531 |
+
rgb_img_sim = robot_results['rgb_img']
|
| 532 |
+
depth_img_sim = np.squeeze(robot_results['depth_img'])
|
| 533 |
+
H, W = rgb_img_sim.shape[:2]
|
| 534 |
+
|
| 535 |
+
# Create masked depth images for occlusion analysis
|
| 536 |
+
depth_sim_masked = self._create_masked_depth(depth_img_sim, robot_mask, gripper_mask)
|
| 537 |
+
depth_masked = self._create_masked_depth(img_depth, robot_mask, gripper_mask)
|
| 538 |
+
|
| 539 |
+
# Process hand mask for improved occlusion handling
|
| 540 |
+
hand_mask = self._dilate_mask(hand_mask.astype(np.uint8))
|
| 541 |
+
|
| 542 |
+
# Create overlay mask using depth-based occlusion
|
| 543 |
+
img_robot_overlay = img.copy()
|
| 544 |
+
overlay_mask = self._create_overlay_mask(
|
| 545 |
+
robot_mask, gripper_mask, depth_masked, depth_sim_masked, hand_mask
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
# Convert and resize robot rendering
|
| 549 |
+
rgb_img_sim = (rgb_img_sim * 255).astype(np.uint8)
|
| 550 |
+
|
| 551 |
+
if self.square:
|
| 552 |
+
resize_shape = (self.output_resolution, self.output_resolution)
|
| 553 |
+
else:
|
| 554 |
+
resize_shape = (int(W/H*self.output_resolution), self.output_resolution)
|
| 555 |
+
|
| 556 |
+
# Apply final overlay with depth-aware occlusion
|
| 557 |
+
rgb_img_sim = cv2.resize(rgb_img_sim, resize_shape)
|
| 558 |
+
overlay_mask = cv2.resize(overlay_mask.astype(np.uint8), resize_shape)
|
| 559 |
+
overlay_mask[overlay_mask > 0] = 1
|
| 560 |
+
overlay_mask = overlay_mask.astype(bool)
|
| 561 |
+
|
| 562 |
+
img_robot_overlay[overlay_mask] = rgb_img_sim[overlay_mask]
|
| 563 |
+
|
| 564 |
+
return img_robot_overlay
|
| 565 |
+
|
| 566 |
+
def _create_masked_depth(self, depth_img: np.ndarray, robot_mask: np.ndarray,
|
| 567 |
+
gripper_mask: np.ndarray) -> np.ndarray:
|
| 568 |
+
"""
|
| 569 |
+
Create depth image masked to robot regions for occlusion analysis.
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
depth_img: Input depth image
|
| 573 |
+
robot_mask: Binary mask indicating robot regions
|
| 574 |
+
gripper_mask: Binary mask indicating gripper regions
|
| 575 |
+
|
| 576 |
+
Returns:
|
| 577 |
+
Depth image with values only in robot/gripper regions
|
| 578 |
+
"""
|
| 579 |
+
masked_img = np.zeros_like(depth_img)
|
| 580 |
+
mask = (robot_mask == 1) | (gripper_mask == 1)
|
| 581 |
+
masked_img[mask] = depth_img[mask]
|
| 582 |
+
return masked_img
|
| 583 |
+
|
| 584 |
+
def _dilate_mask(self, mask: np.ndarray) -> np.ndarray:
|
| 585 |
+
"""
|
| 586 |
+
Apply morphological dilation to expand mask boundaries.
|
| 587 |
+
|
| 588 |
+
Args:
|
| 589 |
+
mask: Binary mask to dilate
|
| 590 |
+
|
| 591 |
+
Returns:
|
| 592 |
+
Dilated binary mask
|
| 593 |
+
"""
|
| 594 |
+
kernel = np.ones((5, 5), np.uint8)
|
| 595 |
+
return cv2.dilate(mask, kernel, iterations=1)
|
| 596 |
+
|
| 597 |
+
def _create_overlay_mask(self, robot_mask: np.ndarray, gripper_mask: np.ndarray,
|
| 598 |
+
depth_masked: np.ndarray, depth_sim_masked: np.ndarray,
|
| 599 |
+
hand_mask: np.ndarray) -> np.ndarray:
|
| 600 |
+
"""
|
| 601 |
+
Create sophisticated overlay mask using depth-based occlusion reasoning.
|
| 602 |
+
|
| 603 |
+
Args:
|
| 604 |
+
robot_mask: Binary mask for robot body regions
|
| 605 |
+
gripper_mask: Binary mask for robot gripper regions
|
| 606 |
+
depth_masked: Real depth image masked to robot regions
|
| 607 |
+
depth_sim_masked: Simulated robot depth masked to robot regions
|
| 608 |
+
hand_mask: Binary mask for human hand regions
|
| 609 |
+
|
| 610 |
+
Returns:
|
| 611 |
+
Binary mask indicating where robot overlay should be applied
|
| 612 |
+
"""
|
| 613 |
+
# Start with basic robot visibility mask
|
| 614 |
+
overlay_mask = (robot_mask == 1) | (gripper_mask == 1)
|
| 615 |
+
|
| 616 |
+
# Apply depth-based occlusion: hide robot when it's behind real objects
|
| 617 |
+
# and not in hand regions (where occlusion handling is more complex)
|
| 618 |
+
overlay_mask[(depth_masked < depth_sim_masked) & (hand_mask == 0)] = 0
|
| 619 |
+
|
| 620 |
+
return overlay_mask
|
| 621 |
+
|
| 622 |
+
def _save_results(self, paths: Paths, sequence: TrainingDataSequence, img_overlay: List[np.ndarray],
|
| 623 |
+
img_birdview: Optional[List[np.ndarray]] = None) -> None:
|
| 624 |
+
"""
|
| 625 |
+
Save comprehensive robot inpainting results to disk.
|
| 626 |
+
|
| 627 |
+
Args:
|
| 628 |
+
paths: Paths object containing output file locations
|
| 629 |
+
sequence: Training data sequence with robot state annotations
|
| 630 |
+
img_overlay: List of robot overlay images
|
| 631 |
+
img_birdview: Optional list of bird's eye view images for analysis
|
| 632 |
+
"""
|
| 633 |
+
# Create output directory
|
| 634 |
+
os.makedirs(paths.inpaint_processor, exist_ok=True)
|
| 635 |
+
|
| 636 |
+
if len(img_overlay) == 0:
|
| 637 |
+
print("No robot inpainted images, skipping")
|
| 638 |
+
return
|
| 639 |
+
|
| 640 |
+
# Save main robot-inpainted video
|
| 641 |
+
video_path = str(paths.video_overlay).split(".mkv")[0] + f"_{self.robot}_{self.bimanual_setup}.mkv"
|
| 642 |
+
self._save_video(video_path, img_overlay)
|
| 643 |
+
|
| 644 |
+
# Save bird's eye view video for analysis and debugging
|
| 645 |
+
if img_birdview is not None:
|
| 646 |
+
birdview_path = str(paths.video_birdview).split(".mkv")[0] + f"_{self.robot}_{self.bimanual_setup}.mkv"
|
| 647 |
+
self._save_video(birdview_path, np.array(img_birdview))
|
| 648 |
+
|
| 649 |
+
# Save comprehensive training data with robot state annotations
|
| 650 |
+
training_data_path = str(paths.training_data).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
|
| 651 |
+
sequence.save(training_data_path)
|
| 652 |
+
|
| 653 |
+
def _save_video(self, path: str, frames: List[np.ndarray]) -> None:
|
| 654 |
+
"""
|
| 655 |
+
Save video with consistent encoding parameters.
|
| 656 |
+
|
| 657 |
+
Args:
|
| 658 |
+
path: Output video file path
|
| 659 |
+
frames: List of video frames to save
|
| 660 |
+
"""
|
| 661 |
+
media.write_video(
|
| 662 |
+
path,
|
| 663 |
+
frames,
|
| 664 |
+
fps=self.DEFAULT_FPS,
|
| 665 |
+
codec=self.DEFAULT_CODEC
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
def _get_mujoco_camera_params(self) -> MujocoCameraParams:
|
| 669 |
+
"""
|
| 670 |
+
Generate MuJoCo camera parameters from real-world camera calibration.
|
| 671 |
+
|
| 672 |
+
Returns:
|
| 673 |
+
MujocoCameraParams object with calibrated camera settings
|
| 674 |
+
"""
|
| 675 |
+
# Extract real-world camera extrinsics and convert to MuJoCo format
|
| 676 |
+
extrinsics = self.extrinsics[0]
|
| 677 |
+
camera_ori_wxyz = self._convert_real_camera_ori_to_mujoco(
|
| 678 |
+
np.array(extrinsics["camera_base_ori"])
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
# Calculate image dimensions and camera intrinsics
|
| 682 |
+
img_w, img_h = self._get_image_dimensions()
|
| 683 |
+
offset = self._calculate_image_offset(img_w, img_h)
|
| 684 |
+
fx, fy, cx, cy = self._get_camera_intrinsics(offset)
|
| 685 |
+
sensor_width, sensor_height = self._calculate_sensor_size(img_w, img_h, fx, fy)
|
| 686 |
+
|
| 687 |
+
# Select appropriate camera name based on dataset
|
| 688 |
+
if self.epic:
|
| 689 |
+
camera_name = "zed"
|
| 690 |
+
else:
|
| 691 |
+
camera_name = "frontview"
|
| 692 |
+
|
| 693 |
+
return MujocoCameraParams(
|
| 694 |
+
name=camera_name,
|
| 695 |
+
pos=extrinsics["camera_base_pos"],
|
| 696 |
+
ori_wxyz=camera_ori_wxyz,
|
| 697 |
+
fov=self.intrinsics_dict["v_fov"],
|
| 698 |
+
resolution=(img_h, img_w),
|
| 699 |
+
sensorsize=np.array([sensor_width, sensor_height]),
|
| 700 |
+
principalpixel=np.array([img_w/2-cx, cy-img_h/2]),
|
| 701 |
+
focalpixel=np.array([fx, fy])
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
def _get_image_dimensions(self) -> Tuple[int, int]:
|
| 705 |
+
"""
|
| 706 |
+
Calculate image dimensions based on input resolution configuration.
|
| 707 |
+
|
| 708 |
+
Returns:
|
| 709 |
+
Tuple of (width, height) in pixels
|
| 710 |
+
"""
|
| 711 |
+
# Epic
|
| 712 |
+
if self.input_resolution == 256:
|
| 713 |
+
img_w = 456
|
| 714 |
+
# Phantom paper
|
| 715 |
+
elif self.input_resolution == 1080:
|
| 716 |
+
img_w = self.input_resolution * 16 // 9
|
| 717 |
+
img_h = self.input_resolution
|
| 718 |
+
return img_w, img_h
|
| 719 |
+
|
| 720 |
+
def _calculate_image_offset(self, img_w: int, img_h: int) -> int:
|
| 721 |
+
"""
|
| 722 |
+
Calculate horizontal image offset for square aspect ratio processing.
|
| 723 |
+
|
| 724 |
+
Args:
|
| 725 |
+
img_w: Image width in pixels
|
| 726 |
+
img_h: Image height in pixels
|
| 727 |
+
|
| 728 |
+
Returns:
|
| 729 |
+
Horizontal offset in pixels
|
| 730 |
+
"""
|
| 731 |
+
if self.square:
|
| 732 |
+
offset = (img_w - img_h) // 2
|
| 733 |
+
else:
|
| 734 |
+
offset = 0
|
| 735 |
+
return offset
|
| 736 |
+
|
| 737 |
+
def _get_camera_intrinsics(self, offset: int) -> Tuple[float, float, float, float]:
|
| 738 |
+
"""
|
| 739 |
+
Extract camera intrinsic parameters with offset correction.
|
| 740 |
+
|
| 741 |
+
Args:
|
| 742 |
+
offset: Horizontal offset for principal point adjustment
|
| 743 |
+
|
| 744 |
+
Returns:
|
| 745 |
+
Tuple of (fx, fy, cx, cy) camera intrinsic parameters
|
| 746 |
+
"""
|
| 747 |
+
return self.intrinsics_dict["fx"], self.intrinsics_dict["fy"], self.intrinsics_dict["cx"]+offset, self.intrinsics_dict["cy"]
|
| 748 |
+
|
| 749 |
+
def _calculate_sensor_size(self, img_w: int, img_h: int, fx: float, fy: float) -> Tuple[float, float]:
|
| 750 |
+
"""
|
| 751 |
+
Calculate physical sensor dimensions from image resolution and focal length.
|
| 752 |
+
|
| 753 |
+
Args:
|
| 754 |
+
img_w: Image width in pixels
|
| 755 |
+
img_h: Image height in pixels
|
| 756 |
+
fx: Focal length in x direction (pixels)
|
| 757 |
+
fy: Focal length in y direction (pixels)
|
| 758 |
+
|
| 759 |
+
Returns:
|
| 760 |
+
Tuple of (sensor_width, sensor_height) in meters
|
| 761 |
+
"""
|
| 762 |
+
sensor_width = img_w / fy / 1000
|
| 763 |
+
sensor_height = img_h / fx / 1000
|
| 764 |
+
return sensor_width, sensor_height
|
| 765 |
+
|
| 766 |
+
@staticmethod
|
| 767 |
+
def _convert_real_camera_ori_to_mujoco(camera_ori_matrix: np.ndarray) -> np.ndarray:
|
| 768 |
+
"""
|
| 769 |
+
Convert real-world camera orientation to MuJoCo coordinate system.
|
| 770 |
+
|
| 771 |
+
Args:
|
| 772 |
+
camera_ori_matrix: 3x3 rotation matrix in real-world coordinates
|
| 773 |
+
|
| 774 |
+
Returns:
|
| 775 |
+
Quaternion in WXYZ format for MuJoCo
|
| 776 |
+
"""
|
| 777 |
+
# Apply coordinate system transformation (flip Y and Z axes)
|
| 778 |
+
camera_ori_matrix[:, [1, 2]] = -camera_ori_matrix[:, [1, 2]]
|
| 779 |
+
|
| 780 |
+
# Convert to quaternion in MuJoCo's WXYZ format
|
| 781 |
+
r = Rotation.from_matrix(camera_ori_matrix)
|
| 782 |
+
camera_ori_wxyz = r.as_quat(scalar_first=True)
|
| 783 |
+
return camera_ori_wxyz
|
| 784 |
+
|
| 785 |
+
|
phantom/phantom/processors/segmentation_processor.py
ADDED
|
@@ -0,0 +1,1056 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Segmentation Processor Module
|
| 3 |
+
|
| 4 |
+
This module uses SAM2 to create masks of hands and arms in video sequences.
|
| 5 |
+
|
| 6 |
+
Processing Pipeline:
|
| 7 |
+
1. Load video frames and detection/pose data from previous stages
|
| 8 |
+
2. Initialize segmentation with highest-quality detection frame
|
| 9 |
+
3. Propagate segmentation bidirectionally (forward and reverse)
|
| 10 |
+
4. Combine temporal results for complete sequence coverage
|
| 11 |
+
5. Generate visualization videos and save segmentation masks
|
| 12 |
+
|
| 13 |
+
The module supports different segmentation modes:
|
| 14 |
+
- HandSegmentationProcessor: Precise hand-only segmentation
|
| 15 |
+
- ArmSegmentationProcessor: Combined hand + arm segmentation
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import logging
|
| 20 |
+
import shutil
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
import numpy as np
|
| 23 |
+
import mediapy as media
|
| 24 |
+
import argparse
|
| 25 |
+
from typing import Dict, Tuple, Optional, List
|
| 26 |
+
|
| 27 |
+
from phantom.processors.paths import Paths
|
| 28 |
+
from phantom.processors.base_processor import BaseProcessor
|
| 29 |
+
from phantom.detectors.detector_sam2 import DetectorSam2
|
| 30 |
+
from phantom.detectors.detector_detectron2 import DetectorDetectron2
|
| 31 |
+
from phantom.utils.bbox_utils import get_overlap_score
|
| 32 |
+
from phantom.processors.phantom_data import HandSequence
|
| 33 |
+
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
# Configuration constants for segmentation processing
|
| 37 |
+
DEFAULT_FPS = 10
|
| 38 |
+
DEFAULT_OVERLAP_THRESHOLD = 0.5
|
| 39 |
+
DEFAULT_CODEC = "ffv1"
|
| 40 |
+
ANNOTATION_CODEC = "h264"
|
| 41 |
+
|
| 42 |
+
class BaseSegmentationProcessor(BaseProcessor):
|
| 43 |
+
"""
|
| 44 |
+
Base class for video segmentation processing using SAM2.
|
| 45 |
+
|
| 46 |
+
The base processor establishes the framework for temporal segmentation processing,
|
| 47 |
+
where segmentation masks are propagated both forward and backward through time
|
| 48 |
+
to ensure temporal consistency and complete coverage of the video sequence.
|
| 49 |
+
|
| 50 |
+
Attributes:
|
| 51 |
+
detector_sam (DetectorSam2): SAM2 segmentation model instance
|
| 52 |
+
"""
|
| 53 |
+
def __init__(self, args: argparse.Namespace) -> None:
|
| 54 |
+
"""
|
| 55 |
+
Initialize the base segmentation processor.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
args: Command line arguments containing segmentation configuration
|
| 59 |
+
"""
|
| 60 |
+
super().__init__(args)
|
| 61 |
+
self.detector_sam = DetectorSam2()
|
| 62 |
+
|
| 63 |
+
def process_one_demo(self, data_sub_folder: str) -> None:
|
| 64 |
+
"""
|
| 65 |
+
Process a single demonstration - to be implemented by subclasses.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
data_sub_folder: Path to demonstration data folder
|
| 69 |
+
|
| 70 |
+
Raises:
|
| 71 |
+
NotImplementedError: Must be implemented by concrete subclasses
|
| 72 |
+
"""
|
| 73 |
+
raise NotImplementedError("Subclasses must implement this method")
|
| 74 |
+
|
| 75 |
+
def _load_hamer_data(self, paths: Paths) -> Dict[str, HandSequence]:
|
| 76 |
+
"""
|
| 77 |
+
Load hand pose estimation data from previous processing stage.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
paths: Paths object containing file locations
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Dictionary containing left and right hand sequences
|
| 84 |
+
"""
|
| 85 |
+
if self.bimanual_setup == "single_arm":
|
| 86 |
+
if self.target_hand == "left":
|
| 87 |
+
return {"left": HandSequence.load(paths.hand_data_left)}
|
| 88 |
+
elif self.target_hand == "right":
|
| 89 |
+
return {"right": HandSequence.load(paths.hand_data_right)}
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError(f"Invalid target hand: {self.target_hand}")
|
| 92 |
+
elif self.bimanual_setup == "shoulders":
|
| 93 |
+
return {
|
| 94 |
+
"left": HandSequence.load(paths.hand_data_left),
|
| 95 |
+
"right": HandSequence.load(paths.hand_data_right)
|
| 96 |
+
}
|
| 97 |
+
else:
|
| 98 |
+
raise ValueError(f"Invalid bimanual setup: {self.bimanual_setup}")
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def _load_video(video_path: str) -> np.ndarray:
|
| 102 |
+
"""
|
| 103 |
+
Load and validate video frames from disk.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
video_path: Path to video file
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Array of RGB video frames
|
| 110 |
+
|
| 111 |
+
Raises:
|
| 112 |
+
FileNotFoundError: If video file doesn't exist
|
| 113 |
+
ValueError: If video file is empty or corrupted
|
| 114 |
+
"""
|
| 115 |
+
if not os.path.exists(video_path):
|
| 116 |
+
raise FileNotFoundError(f"Video file not found: {video_path}")
|
| 117 |
+
|
| 118 |
+
imgs_rgb = media.read_video(video_path)
|
| 119 |
+
if len(imgs_rgb) == 0:
|
| 120 |
+
raise ValueError("Empty video file")
|
| 121 |
+
|
| 122 |
+
return imgs_rgb
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def _load_bbox_data(bbox_path: str) -> Dict[str, np.ndarray]:
|
| 126 |
+
"""
|
| 127 |
+
Load and validate bounding box detection data.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
bbox_path: Path to bounding box data file
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
Dictionary containing detection results from bounding box processor
|
| 134 |
+
|
| 135 |
+
Raises:
|
| 136 |
+
FileNotFoundError: If bounding box data file doesn't exist
|
| 137 |
+
"""
|
| 138 |
+
if not os.path.exists(bbox_path):
|
| 139 |
+
raise FileNotFoundError(f"Bbox data not found: {bbox_path}")
|
| 140 |
+
|
| 141 |
+
return np.load(bbox_path)
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def _combine_sam_images(
|
| 145 |
+
imgs_rgb: np.ndarray,
|
| 146 |
+
imgs_forward: Dict[int, np.ndarray],
|
| 147 |
+
imgs_reverse: Dict[int, np.ndarray]
|
| 148 |
+
) -> np.ndarray:
|
| 149 |
+
"""
|
| 150 |
+
Combine forward and reverse SAM visualization images.
|
| 151 |
+
|
| 152 |
+
This method merges the visualization results from bidirectional
|
| 153 |
+
processing to create a complete visualization sequence.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
imgs_rgb: Original RGB frames for shape reference
|
| 157 |
+
imgs_forward: Forward propagation visualization results
|
| 158 |
+
imgs_reverse: Reverse propagation visualization results
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Combined visualization array
|
| 162 |
+
"""
|
| 163 |
+
result = np.zeros_like(imgs_rgb)
|
| 164 |
+
# Fill in forward propagation results
|
| 165 |
+
for idx in imgs_forward:
|
| 166 |
+
result[idx] = imgs_forward[idx]
|
| 167 |
+
# Fill in reverse propagation results (may overwrite forward results)
|
| 168 |
+
for idx in imgs_reverse:
|
| 169 |
+
result[idx] = imgs_reverse[idx]
|
| 170 |
+
return result
|
| 171 |
+
|
| 172 |
+
@staticmethod
|
| 173 |
+
def _combine_masks(
|
| 174 |
+
imgs_rgb: np.ndarray,
|
| 175 |
+
masks_forward: Dict[int, np.ndarray],
|
| 176 |
+
masks_reverse: Dict[int, np.ndarray]
|
| 177 |
+
) -> np.ndarray:
|
| 178 |
+
"""
|
| 179 |
+
Combine forward and reverse segmentation masks.
|
| 180 |
+
|
| 181 |
+
This method merges segmentation masks from bidirectional processing
|
| 182 |
+
to ensure complete temporal coverage of the video sequence.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
imgs_rgb: Original RGB frames for shape reference
|
| 186 |
+
masks_forward: Forward propagation mask results
|
| 187 |
+
masks_reverse: Reverse propagation mask results
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
Combined mask array with shape (num_frames, height, width)
|
| 191 |
+
"""
|
| 192 |
+
result = np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1]))
|
| 193 |
+
for idx in masks_forward:
|
| 194 |
+
result[idx] = masks_forward[idx][0]
|
| 195 |
+
for idx in masks_reverse:
|
| 196 |
+
result[idx] = masks_reverse[idx][0]
|
| 197 |
+
return result
|
| 198 |
+
|
| 199 |
+
class ArmSegmentationProcessor(BaseSegmentationProcessor):
|
| 200 |
+
"""
|
| 201 |
+
Processor for segmenting combined hand and arm regions in video sequences.
|
| 202 |
+
|
| 203 |
+
Attributes:
|
| 204 |
+
detectron_detector (DetectorDetectron2): Detectron2 model for initial detection
|
| 205 |
+
"""
|
| 206 |
+
def __init__(self, args: argparse.Namespace) -> None:
|
| 207 |
+
"""
|
| 208 |
+
Initialize the arm segmentation processor with detection models.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
args: Command line arguments containing model configuration
|
| 212 |
+
"""
|
| 213 |
+
super().__init__(args)
|
| 214 |
+
|
| 215 |
+
# Initialize Detectron2 for initial hand/arm detection
|
| 216 |
+
root_dir = "../submodules/phantom-hamer/"
|
| 217 |
+
self.detectron_detector = DetectorDetectron2(root_dir)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def process_one_demo(self, data_sub_folder: str, hamer_data: Optional[Dict[str, HandSequence]] = None) -> None:
|
| 221 |
+
"""
|
| 222 |
+
Process a single video demonstration to generate combined hand + arm segmentation masks.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
data_sub_folder: Path to the subfolder containing the demo data
|
| 226 |
+
hamer_data: Optional pre-loaded hand pose data for segmentation guidance
|
| 227 |
+
|
| 228 |
+
Raises:
|
| 229 |
+
FileNotFoundError: If required input files are not found
|
| 230 |
+
ValueError: If video frames or bounding boxes are invalid
|
| 231 |
+
"""
|
| 232 |
+
# Setup and load all required data
|
| 233 |
+
save_folder, paths, imgs_rgb, bbox_data, det_bbox_data, hamer_data = self._setup_processing(
|
| 234 |
+
data_sub_folder, hamer_data
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Process based on setup type
|
| 238 |
+
if self.bimanual_setup == "single_arm":
|
| 239 |
+
masks = self._process_single_arm(imgs_rgb, bbox_data, det_bbox_data, hamer_data, paths)
|
| 240 |
+
elif self.bimanual_setup == "shoulders":
|
| 241 |
+
masks = self._process_bimanual(imgs_rgb, bbox_data, det_bbox_data, hamer_data, paths)
|
| 242 |
+
else:
|
| 243 |
+
raise ValueError(f"Invalid bimanual setup: {self.bimanual_setup}")
|
| 244 |
+
|
| 245 |
+
# Create visualization and save results
|
| 246 |
+
sam_imgs = self._create_visualization(imgs_rgb, masks)
|
| 247 |
+
self._validate_output_consistency(imgs_rgb, masks, sam_imgs)
|
| 248 |
+
self._save_results(paths, masks, sam_imgs)
|
| 249 |
+
|
| 250 |
+
def _setup_processing(
|
| 251 |
+
self,
|
| 252 |
+
data_sub_folder: str,
|
| 253 |
+
hamer_data: Optional[Dict[str, HandSequence]]
|
| 254 |
+
) -> Tuple[str, Paths, np.ndarray, Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, HandSequence]]:
|
| 255 |
+
"""
|
| 256 |
+
Setup processing environment and load all required data.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
data_sub_folder: Path to the subfolder containing the demo data
|
| 260 |
+
hamer_data: Optional pre-loaded hand pose data
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
Tuple containing: (save_folder, paths, imgs_rgb, bbox_data, det_bbox_data, hamer_data)
|
| 264 |
+
"""
|
| 265 |
+
save_folder = self.get_save_folder(data_sub_folder)
|
| 266 |
+
paths = self.get_paths(save_folder)
|
| 267 |
+
paths._setup_original_images()
|
| 268 |
+
paths._setup_original_images_reverse()
|
| 269 |
+
|
| 270 |
+
# Load and validate all input data
|
| 271 |
+
imgs_rgb = self._load_video(paths.video_left)
|
| 272 |
+
bbox_data = self._load_bbox_data(paths.bbox_data)
|
| 273 |
+
det_bbox_data = self.get_detectron_bboxes(imgs_rgb, bbox_data)
|
| 274 |
+
if hamer_data is None:
|
| 275 |
+
hamer_data = self._load_hamer_data(paths)
|
| 276 |
+
|
| 277 |
+
return save_folder, paths, imgs_rgb, bbox_data, det_bbox_data, hamer_data
|
| 278 |
+
|
| 279 |
+
def _process_single_arm(
|
| 280 |
+
self,
|
| 281 |
+
imgs_rgb: np.ndarray,
|
| 282 |
+
bbox_data: Dict[str, np.ndarray],
|
| 283 |
+
det_bbox_data: Dict[str, np.ndarray],
|
| 284 |
+
hamer_data: Dict[str, HandSequence],
|
| 285 |
+
paths: Paths
|
| 286 |
+
) -> np.ndarray:
|
| 287 |
+
"""
|
| 288 |
+
Process single arm setup (left or right hand only).
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
imgs_rgb: RGB video frames
|
| 292 |
+
bbox_data: Bounding box detection data
|
| 293 |
+
det_bbox_data: Detectron2 refined bounding boxes
|
| 294 |
+
hamer_data: Hand pose estimation data
|
| 295 |
+
paths: Paths object for file management
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
Boolean segmentation masks
|
| 299 |
+
"""
|
| 300 |
+
if self.target_hand == "left":
|
| 301 |
+
hand_data = self._process_hand_data(
|
| 302 |
+
imgs_rgb,
|
| 303 |
+
bbox_data["left_bboxes"],
|
| 304 |
+
bbox_data["left_bbox_min_dist_to_edge"],
|
| 305 |
+
bbox_data["left_hand_detected"],
|
| 306 |
+
det_bbox_data["left_det_bboxes"],
|
| 307 |
+
hamer_data["left"],
|
| 308 |
+
paths,
|
| 309 |
+
"left"
|
| 310 |
+
)
|
| 311 |
+
masks = hand_data["left_masks"].astype(np.bool_)
|
| 312 |
+
elif self.target_hand == "right":
|
| 313 |
+
hand_data = self._process_hand_data(
|
| 314 |
+
imgs_rgb,
|
| 315 |
+
bbox_data["right_bboxes"],
|
| 316 |
+
bbox_data["right_bbox_min_dist_to_edge"],
|
| 317 |
+
bbox_data["right_hand_detected"],
|
| 318 |
+
det_bbox_data["right_det_bboxes"],
|
| 319 |
+
hamer_data["right"],
|
| 320 |
+
paths,
|
| 321 |
+
"right"
|
| 322 |
+
)
|
| 323 |
+
masks = hand_data["right_masks"].astype(np.bool_)
|
| 324 |
+
else:
|
| 325 |
+
raise ValueError(f"Invalid target hand: {self.target_hand}")
|
| 326 |
+
|
| 327 |
+
return masks.astype(np.bool_)
|
| 328 |
+
|
| 329 |
+
def _process_bimanual(
|
| 330 |
+
self,
|
| 331 |
+
imgs_rgb: np.ndarray,
|
| 332 |
+
bbox_data: Dict[str, np.ndarray],
|
| 333 |
+
det_bbox_data: Dict[str, np.ndarray],
|
| 334 |
+
hamer_data: Dict[str, HandSequence],
|
| 335 |
+
paths: Paths
|
| 336 |
+
) -> np.ndarray:
|
| 337 |
+
"""
|
| 338 |
+
Process bimanual setup (both hands combined).
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
imgs_rgb: RGB video frames
|
| 342 |
+
bbox_data: Bounding box detection data
|
| 343 |
+
det_bbox_data: Detectron2 refined bounding boxes
|
| 344 |
+
hamer_data: Hand pose estimation data
|
| 345 |
+
paths: Paths object for file management
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
Combined boolean segmentation masks
|
| 349 |
+
"""
|
| 350 |
+
# Process left hand with arm segmentation
|
| 351 |
+
left_data = self._process_hand_data(
|
| 352 |
+
imgs_rgb,
|
| 353 |
+
bbox_data["left_bboxes"],
|
| 354 |
+
bbox_data["left_bbox_min_dist_to_edge"],
|
| 355 |
+
bbox_data["left_hand_detected"],
|
| 356 |
+
det_bbox_data["left_det_bboxes"],
|
| 357 |
+
hamer_data["left"],
|
| 358 |
+
paths,
|
| 359 |
+
"left"
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# Process right hand with arm segmentation
|
| 363 |
+
right_data = self._process_hand_data(
|
| 364 |
+
imgs_rgb,
|
| 365 |
+
bbox_data["right_bboxes"],
|
| 366 |
+
bbox_data["right_bbox_min_dist_to_edge"],
|
| 367 |
+
bbox_data["right_hand_detected"],
|
| 368 |
+
det_bbox_data["right_det_bboxes"],
|
| 369 |
+
hamer_data["right"],
|
| 370 |
+
paths,
|
| 371 |
+
"right"
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
# Convert to boolean masks and combine
|
| 375 |
+
left_masks = left_data["left_masks"].astype(np.bool_)
|
| 376 |
+
right_masks = right_data["right_masks"].astype(np.bool_)
|
| 377 |
+
|
| 378 |
+
# Generate combined video masks by taking the union of left and right masks
|
| 379 |
+
masks = np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1]))
|
| 380 |
+
for idx in range(len(imgs_rgb)):
|
| 381 |
+
masks[idx] = left_masks[idx] | right_masks[idx]
|
| 382 |
+
|
| 383 |
+
return masks.astype(np.bool_)
|
| 384 |
+
|
| 385 |
+
def _create_visualization(self, imgs_rgb: np.ndarray, masks: np.ndarray) -> np.ndarray:
|
| 386 |
+
"""
|
| 387 |
+
Create visualization by masking out segmented regions.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
imgs_rgb: Original RGB video frames
|
| 391 |
+
masks: Boolean segmentation masks
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
Visualization images with masked regions set to black
|
| 395 |
+
"""
|
| 396 |
+
sam_imgs = []
|
| 397 |
+
for idx in range(len(imgs_rgb)):
|
| 398 |
+
img = imgs_rgb[idx].copy() # Create copy to avoid modifying original
|
| 399 |
+
mask = masks[idx]
|
| 400 |
+
img[mask] = 0 # Set masked regions to black
|
| 401 |
+
sam_imgs.append(img)
|
| 402 |
+
return np.array(sam_imgs)
|
| 403 |
+
|
| 404 |
+
def _validate_output_consistency(
|
| 405 |
+
self,
|
| 406 |
+
imgs_rgb: np.ndarray,
|
| 407 |
+
masks: np.ndarray,
|
| 408 |
+
sam_imgs: np.ndarray
|
| 409 |
+
) -> None:
|
| 410 |
+
"""
|
| 411 |
+
Validate that output arrays have consistent dimensions.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
imgs_rgb: Original RGB video frames
|
| 415 |
+
masks: Segmentation masks
|
| 416 |
+
sam_imgs: Visualization images
|
| 417 |
+
|
| 418 |
+
Raises:
|
| 419 |
+
AssertionError: If dimensions don't match
|
| 420 |
+
"""
|
| 421 |
+
assert len(sam_imgs) == len(imgs_rgb), "Visualization length doesn't match input"
|
| 422 |
+
assert len(masks) == len(imgs_rgb), "Masks length doesn't match input"
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def _process_hand_data(
|
| 426 |
+
self,
|
| 427 |
+
imgs_rgb: np.ndarray,
|
| 428 |
+
bboxes: np.ndarray,
|
| 429 |
+
bbox_min_dist: np.ndarray,
|
| 430 |
+
hand_detected: np.ndarray,
|
| 431 |
+
det_bboxes: np.ndarray,
|
| 432 |
+
hamer_data: HandSequence,
|
| 433 |
+
paths: Paths,
|
| 434 |
+
hand_side: str
|
| 435 |
+
) -> Dict[str, np.ndarray]:
|
| 436 |
+
"""
|
| 437 |
+
Process segmentation data for a single hand (left or right) with arm inclusion.
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
imgs_rgb: RGB video frames
|
| 441 |
+
bboxes: Hand bounding boxes from detection stage
|
| 442 |
+
bbox_min_dist: Minimum distances to image edges (quality metric)
|
| 443 |
+
hand_detected: Boolean flags indicating valid hand detections
|
| 444 |
+
det_bboxes: Refined bounding boxes from Detectron2
|
| 445 |
+
hamer_data: Hand pose data for segmentation guidance
|
| 446 |
+
paths: Paths object for file management
|
| 447 |
+
hand_side: "left" or "right" specifying which hand to process
|
| 448 |
+
|
| 449 |
+
Returns:
|
| 450 |
+
Dictionary containing segmentation masks and visualization images
|
| 451 |
+
"""
|
| 452 |
+
# Handle cases with no valid detections
|
| 453 |
+
if not hand_detected.any() or max(bbox_min_dist) == 0:
|
| 454 |
+
return {
|
| 455 |
+
f"{hand_side}_masks": np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1])),
|
| 456 |
+
f"{hand_side}_sam_imgs": np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1], 3))
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
# Extract hand pose keypoints for segmentation guidance
|
| 460 |
+
kpts_2d = hamer_data.kpts_2d
|
| 461 |
+
|
| 462 |
+
# Find the frame with highest quality (furthest from edges)
|
| 463 |
+
max_dist_idx = np.argmax(bbox_min_dist)
|
| 464 |
+
points = np.expand_dims(kpts_2d[max_dist_idx], axis=1)
|
| 465 |
+
bbox_dets = det_bboxes[max_dist_idx]
|
| 466 |
+
|
| 467 |
+
# Use original bounding box if Detectron2 detection failed
|
| 468 |
+
if bbox_dets.sum() == 0:
|
| 469 |
+
bbox_dets = bboxes[max_dist_idx]
|
| 470 |
+
|
| 471 |
+
# Process segmentation in both temporal directions
|
| 472 |
+
masks_forward, sam_imgs_forward = self._run_sam_segmentation(
|
| 473 |
+
paths, bbox_dets, points, max_dist_idx, reverse=False
|
| 474 |
+
)
|
| 475 |
+
masks_reverse, sam_imgs_reverse = self._run_sam_segmentation(
|
| 476 |
+
paths, bbox_dets, points, max_dist_idx, reverse=True
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# Combine bidirectional results
|
| 480 |
+
sam_imgs = self._combine_sam_images(imgs_rgb, sam_imgs_forward, sam_imgs_reverse)
|
| 481 |
+
masks = self._combine_masks(imgs_rgb, masks_forward, masks_reverse)
|
| 482 |
+
|
| 483 |
+
return {
|
| 484 |
+
f"{hand_side}_masks": masks,
|
| 485 |
+
f"{hand_side}_sam_imgs": sam_imgs
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
def _run_sam_segmentation(
|
| 489 |
+
self,
|
| 490 |
+
paths: Paths,
|
| 491 |
+
bbox_dets: np.ndarray,
|
| 492 |
+
points: np.ndarray,
|
| 493 |
+
max_dist_idx: int,
|
| 494 |
+
reverse: bool
|
| 495 |
+
) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]:
|
| 496 |
+
"""
|
| 497 |
+
Process video segmentation in either forward or reverse temporal direction.
|
| 498 |
+
|
| 499 |
+
Args:
|
| 500 |
+
paths: Paths object for file management
|
| 501 |
+
bbox_dets: Detectron2 bounding box for initialization
|
| 502 |
+
points: Hand keypoints for segmentation guidance
|
| 503 |
+
max_dist_idx: Index of highest-quality frame for initialization
|
| 504 |
+
reverse: Whether to process in reverse temporal order
|
| 505 |
+
|
| 506 |
+
Returns:
|
| 507 |
+
Tuple of (segmentation_masks, visualization_images)
|
| 508 |
+
"""
|
| 509 |
+
return self.detector_sam.segment_video(
|
| 510 |
+
paths.original_images_folder,
|
| 511 |
+
bbox_dets,
|
| 512 |
+
points,
|
| 513 |
+
[max_dist_idx],
|
| 514 |
+
reverse=reverse
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
def get_detectron_bboxes(self, imgs_rgb: np.ndarray, bbox_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
| 518 |
+
"""
|
| 519 |
+
Generate enhanced bounding boxes using Detectron2 for improved segmentation.
|
| 520 |
+
|
| 521 |
+
Args:
|
| 522 |
+
imgs_rgb: Array of RGB frames with shape (N, H, W, 3)
|
| 523 |
+
bbox_data: Initial bounding box data from hand detection stage containing:
|
| 524 |
+
- left_bboxes: Left hand bounding boxes
|
| 525 |
+
- right_bboxes: Right hand bounding boxes
|
| 526 |
+
- left_hand_detected: Boolean flags for left hand detection
|
| 527 |
+
- right_hand_detected: Boolean flags for right hand detection
|
| 528 |
+
- left_bbox_min_dist_to_edge: Quality metrics for left hand
|
| 529 |
+
- right_bbox_min_dist_to_edge: Quality metrics for right hand
|
| 530 |
+
|
| 531 |
+
Returns:
|
| 532 |
+
Dictionary containing refined bounding boxes:
|
| 533 |
+
- left_det_bboxes: Enhanced left hand bounding boxes
|
| 534 |
+
- right_det_bboxes: Enhanced right hand bounding boxes
|
| 535 |
+
|
| 536 |
+
Raises:
|
| 537 |
+
ValueError: If input array is empty or has incorrect shape
|
| 538 |
+
"""
|
| 539 |
+
self._validate_detectron_input(imgs_rgb)
|
| 540 |
+
|
| 541 |
+
# Extract detection data and initialize output arrays
|
| 542 |
+
detection_data = self._extract_detection_data(bbox_data)
|
| 543 |
+
left_det_bboxes, right_det_bboxes = self._initialize_bbox_arrays(imgs_rgb)
|
| 544 |
+
|
| 545 |
+
# Process only highest-quality frames for efficiency
|
| 546 |
+
idx_list = self._get_quality_frame_indices(bbox_data)
|
| 547 |
+
|
| 548 |
+
for idx in tqdm(idx_list, desc="Processing frames"):
|
| 549 |
+
try:
|
| 550 |
+
self._process_detectron_frame(
|
| 551 |
+
idx, imgs_rgb, detection_data, left_det_bboxes, right_det_bboxes
|
| 552 |
+
)
|
| 553 |
+
except Exception as e:
|
| 554 |
+
logging.error(f"Error processing frame {idx}: {str(e)}")
|
| 555 |
+
|
| 556 |
+
return {"left_det_bboxes": left_det_bboxes, "right_det_bboxes": right_det_bboxes}
|
| 557 |
+
|
| 558 |
+
def _validate_detectron_input(self, imgs_rgb: np.ndarray) -> None:
|
| 559 |
+
"""
|
| 560 |
+
Validate input array for Detectron2 processing.
|
| 561 |
+
|
| 562 |
+
Args:
|
| 563 |
+
imgs_rgb: Array of RGB frames
|
| 564 |
+
|
| 565 |
+
Raises:
|
| 566 |
+
ValueError: If input array is empty or has incorrect shape
|
| 567 |
+
"""
|
| 568 |
+
if len(imgs_rgb) == 0:
|
| 569 |
+
raise ValueError("Empty input array - no video frames provided")
|
| 570 |
+
|
| 571 |
+
if len(imgs_rgb.shape) != 4 or imgs_rgb.shape[-1] != 3:
|
| 572 |
+
raise ValueError(f"Expected input shape (N, H, W, 3), got {imgs_rgb.shape}. "
|
| 573 |
+
f"Input should be RGB video frames.")
|
| 574 |
+
|
| 575 |
+
def _extract_detection_data(self, bbox_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
|
| 576 |
+
"""
|
| 577 |
+
Extract detection data from bounding box data.
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
bbox_data: Bounding box detection data
|
| 581 |
+
|
| 582 |
+
Returns:
|
| 583 |
+
Dictionary containing extracted detection data
|
| 584 |
+
"""
|
| 585 |
+
return {
|
| 586 |
+
"left_bboxes": bbox_data["left_bboxes"],
|
| 587 |
+
"right_bboxes": bbox_data["right_bboxes"],
|
| 588 |
+
"left_hand_detected": bbox_data["left_hand_detected"],
|
| 589 |
+
"right_hand_detected": bbox_data["right_hand_detected"]
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
def _initialize_bbox_arrays(self, imgs_rgb: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 593 |
+
"""
|
| 594 |
+
Initialize output bounding box arrays.
|
| 595 |
+
|
| 596 |
+
Args:
|
| 597 |
+
imgs_rgb: RGB video frames for shape reference
|
| 598 |
+
|
| 599 |
+
Returns:
|
| 600 |
+
Tuple of (left_det_bboxes, right_det_bboxes) initialized arrays
|
| 601 |
+
"""
|
| 602 |
+
left_det_bboxes = np.zeros((len(imgs_rgb), 4))
|
| 603 |
+
right_det_bboxes = np.zeros((len(imgs_rgb), 4))
|
| 604 |
+
return left_det_bboxes, right_det_bboxes
|
| 605 |
+
|
| 606 |
+
def _get_quality_frame_indices(self, bbox_data: Dict[str, np.ndarray]) -> List[int]:
|
| 607 |
+
"""
|
| 608 |
+
Get indices of highest-quality frames for processing.
|
| 609 |
+
|
| 610 |
+
Args:
|
| 611 |
+
bbox_data: Bounding box detection data
|
| 612 |
+
|
| 613 |
+
Returns:
|
| 614 |
+
List of frame indices to process
|
| 615 |
+
"""
|
| 616 |
+
idx_left = np.argmax(bbox_data["left_bbox_min_dist_to_edge"])
|
| 617 |
+
idx_right = np.argmax(bbox_data["right_bbox_min_dist_to_edge"])
|
| 618 |
+
return [idx_left, idx_right]
|
| 619 |
+
|
| 620 |
+
def _process_detectron_frame(
|
| 621 |
+
self,
|
| 622 |
+
idx: int,
|
| 623 |
+
imgs_rgb: np.ndarray,
|
| 624 |
+
detection_data: Dict[str, np.ndarray],
|
| 625 |
+
left_det_bboxes: np.ndarray,
|
| 626 |
+
right_det_bboxes: np.ndarray
|
| 627 |
+
) -> None:
|
| 628 |
+
"""
|
| 629 |
+
Process a single frame with Detectron2 detection.
|
| 630 |
+
|
| 631 |
+
Args:
|
| 632 |
+
idx: Frame index to process
|
| 633 |
+
imgs_rgb: RGB video frames
|
| 634 |
+
detection_data: Extracted detection data
|
| 635 |
+
left_det_bboxes: Left hand bounding box output array
|
| 636 |
+
right_det_bboxes: Right hand bounding box output array
|
| 637 |
+
"""
|
| 638 |
+
left_hand_detected = detection_data["left_hand_detected"]
|
| 639 |
+
right_hand_detected = detection_data["right_hand_detected"]
|
| 640 |
+
|
| 641 |
+
# Skip frames without any hand detections
|
| 642 |
+
if not left_hand_detected[idx] and not right_hand_detected[idx]:
|
| 643 |
+
left_det_bboxes[idx] = np.array([0, 0, 0, 0])
|
| 644 |
+
right_det_bboxes[idx] = np.array([0, 0, 0, 0])
|
| 645 |
+
return
|
| 646 |
+
|
| 647 |
+
# Apply Detectron2 detection
|
| 648 |
+
img = imgs_rgb[idx]
|
| 649 |
+
det_bboxes, det_scores = self.detectron_detector.get_bboxes(img, visualize=False)
|
| 650 |
+
|
| 651 |
+
if len(det_bboxes) == 0:
|
| 652 |
+
return
|
| 653 |
+
|
| 654 |
+
# Match left hand detection with Detectron2 results
|
| 655 |
+
if left_hand_detected[idx]:
|
| 656 |
+
self._match_hand_detection(
|
| 657 |
+
idx, "left", detection_data, det_bboxes, left_det_bboxes
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
# Match right hand detection with Detectron2 results
|
| 661 |
+
if right_hand_detected[idx]:
|
| 662 |
+
self._match_hand_detection(
|
| 663 |
+
idx, "right", detection_data, det_bboxes, right_det_bboxes
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
def _match_hand_detection(
|
| 667 |
+
self,
|
| 668 |
+
idx: int,
|
| 669 |
+
hand_side: str,
|
| 670 |
+
detection_data: Dict[str, np.ndarray],
|
| 671 |
+
det_bboxes: np.ndarray,
|
| 672 |
+
output_bboxes: np.ndarray
|
| 673 |
+
) -> None:
|
| 674 |
+
"""
|
| 675 |
+
Match hand detection with Detectron2 results using overlap scores.
|
| 676 |
+
|
| 677 |
+
Args:
|
| 678 |
+
idx: Frame index
|
| 679 |
+
hand_side: "left" or "right" hand
|
| 680 |
+
detection_data: Extracted detection data
|
| 681 |
+
det_bboxes: Detectron2 detection results
|
| 682 |
+
output_bboxes: Output bounding box array to update
|
| 683 |
+
"""
|
| 684 |
+
bbox = detection_data[f"{hand_side}_bboxes"][idx]
|
| 685 |
+
overlap_scores = []
|
| 686 |
+
|
| 687 |
+
for det_bbox in det_bboxes:
|
| 688 |
+
overlap_score = get_overlap_score(bbox, det_bbox)
|
| 689 |
+
overlap_scores.append(overlap_score)
|
| 690 |
+
|
| 691 |
+
if np.max(overlap_scores) > DEFAULT_OVERLAP_THRESHOLD:
|
| 692 |
+
best_idx = np.argmax(overlap_scores)
|
| 693 |
+
output_bboxes[idx] = det_bboxes[best_idx].astype(np.int32)
|
| 694 |
+
|
| 695 |
+
@staticmethod
|
| 696 |
+
def _save_results(
|
| 697 |
+
paths: Paths,
|
| 698 |
+
masks: np.ndarray,
|
| 699 |
+
sam_imgs: np.ndarray,
|
| 700 |
+
fps: int = DEFAULT_FPS
|
| 701 |
+
) -> None:
|
| 702 |
+
"""
|
| 703 |
+
Save arm segmentation results to disk.
|
| 704 |
+
|
| 705 |
+
Args:
|
| 706 |
+
paths: Paths object containing output file locations
|
| 707 |
+
masks: Combined arm segmentation masks
|
| 708 |
+
sam_imgs: SAM visualization images
|
| 709 |
+
fps: Frames per second for output videos (default: 10)
|
| 710 |
+
"""
|
| 711 |
+
ArmSegmentationProcessor._create_output_directory(paths)
|
| 712 |
+
|
| 713 |
+
try:
|
| 714 |
+
ArmSegmentationProcessor._save_mask_data(paths, masks)
|
| 715 |
+
ArmSegmentationProcessor._create_videos(paths, masks, sam_imgs, fps)
|
| 716 |
+
except Exception as e:
|
| 717 |
+
logging.error(f"Error saving results: {str(e)}")
|
| 718 |
+
raise
|
| 719 |
+
|
| 720 |
+
ArmSegmentationProcessor._cleanup_temp_files(paths)
|
| 721 |
+
ArmSegmentationProcessor._update_annotation_video(paths, masks, sam_imgs, fps)
|
| 722 |
+
|
| 723 |
+
@staticmethod
|
| 724 |
+
def _create_output_directory(paths: Paths) -> None:
|
| 725 |
+
"""
|
| 726 |
+
Create output directory for segmentation results.
|
| 727 |
+
|
| 728 |
+
Args:
|
| 729 |
+
paths: Paths object containing output directory location
|
| 730 |
+
"""
|
| 731 |
+
if not os.path.exists(paths.segmentation_processor):
|
| 732 |
+
os.makedirs(paths.segmentation_processor)
|
| 733 |
+
|
| 734 |
+
@staticmethod
|
| 735 |
+
def _save_mask_data(paths: Paths, masks: np.ndarray) -> None:
|
| 736 |
+
"""
|
| 737 |
+
Save mask data to disk.
|
| 738 |
+
|
| 739 |
+
Args:
|
| 740 |
+
paths: Paths object containing output file locations
|
| 741 |
+
masks: Segmentation masks to save
|
| 742 |
+
"""
|
| 743 |
+
np.save(paths.masks_arm, masks)
|
| 744 |
+
|
| 745 |
+
@staticmethod
|
| 746 |
+
def _create_videos(paths: Paths, masks: np.ndarray, sam_imgs: np.ndarray, fps: int) -> None:
|
| 747 |
+
"""
|
| 748 |
+
Create visualization videos from masks and SAM images.
|
| 749 |
+
|
| 750 |
+
Args:
|
| 751 |
+
paths: Paths object containing output file locations
|
| 752 |
+
masks: Segmentation masks
|
| 753 |
+
sam_imgs: SAM visualization images
|
| 754 |
+
fps: Frames per second for output videos
|
| 755 |
+
"""
|
| 756 |
+
for name, data in [
|
| 757 |
+
("video_masks_arm", masks),
|
| 758 |
+
("video_sam_arm", sam_imgs),
|
| 759 |
+
]:
|
| 760 |
+
output_path = getattr(paths, name)
|
| 761 |
+
media.write_video(output_path, data, fps=fps, codec=DEFAULT_CODEC)
|
| 762 |
+
|
| 763 |
+
@staticmethod
|
| 764 |
+
def _cleanup_temp_files(paths: Paths) -> None:
|
| 765 |
+
"""
|
| 766 |
+
Clean up temporary directories created during processing.
|
| 767 |
+
|
| 768 |
+
Args:
|
| 769 |
+
paths: Paths object containing temporary directory locations
|
| 770 |
+
"""
|
| 771 |
+
if os.path.exists(paths.original_images_folder):
|
| 772 |
+
shutil.rmtree(paths.original_images_folder)
|
| 773 |
+
if os.path.exists(paths.original_images_folder_reverse):
|
| 774 |
+
shutil.rmtree(paths.original_images_folder_reverse)
|
| 775 |
+
|
| 776 |
+
@staticmethod
|
| 777 |
+
def _update_annotation_video(paths: Paths, masks: np.ndarray, sam_imgs: np.ndarray, fps: int) -> None:
|
| 778 |
+
"""
|
| 779 |
+
Update existing annotation video with segmentation results.
|
| 780 |
+
|
| 781 |
+
Args:
|
| 782 |
+
paths: Paths object containing annotation video location
|
| 783 |
+
masks: Segmentation masks
|
| 784 |
+
sam_imgs: SAM visualization images
|
| 785 |
+
fps: Frames per second for output video
|
| 786 |
+
"""
|
| 787 |
+
if os.path.exists(paths.video_annot):
|
| 788 |
+
annot_imgs = media.read_video(paths.video_annot)
|
| 789 |
+
for idx in range(len(annot_imgs)):
|
| 790 |
+
annot_img = annot_imgs[idx]
|
| 791 |
+
h = masks[idx].shape[0]
|
| 792 |
+
w = masks[idx].shape[1]
|
| 793 |
+
# Insert segmentation visualization in the top-right quadrant
|
| 794 |
+
annot_img[:h, w:, :] = sam_imgs[idx]
|
| 795 |
+
media.write_video(paths.video_annot, annot_imgs, fps=fps, codec=ANNOTATION_CODEC)
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
class HandSegmentationProcessor(BaseSegmentationProcessor):
|
| 800 |
+
"""
|
| 801 |
+
Processor for precise hand-only segmentation in video sequences.
|
| 802 |
+
|
| 803 |
+
Attributes:
|
| 804 |
+
Inherits detector_sam from BaseSegmentationProcessor
|
| 805 |
+
"""
|
| 806 |
+
def __init__(self, args: argparse.Namespace) -> None:
|
| 807 |
+
"""
|
| 808 |
+
Initialize the hand segmentation processor.
|
| 809 |
+
|
| 810 |
+
Args:
|
| 811 |
+
args: Command line arguments containing segmentation configuration
|
| 812 |
+
"""
|
| 813 |
+
super().__init__(args)
|
| 814 |
+
|
| 815 |
+
def process_one_demo(self, data_sub_folder: str, hamer_data: Optional[Dict[str, HandSequence]] = None) -> None:
|
| 816 |
+
"""
|
| 817 |
+
Process a single video demonstration to generate precise hand segmentation masks.
|
| 818 |
+
|
| 819 |
+
Args:
|
| 820 |
+
data_sub_folder: Path to the subfolder containing the demo data
|
| 821 |
+
hamer_data: Optional pre-loaded hand pose data for segmentation guidance
|
| 822 |
+
|
| 823 |
+
Raises:
|
| 824 |
+
FileNotFoundError: If required input files are not found
|
| 825 |
+
ValueError: If video frames or bounding boxes are invalid
|
| 826 |
+
"""
|
| 827 |
+
save_folder = self.get_save_folder(data_sub_folder)
|
| 828 |
+
|
| 829 |
+
paths = self.get_paths(save_folder)
|
| 830 |
+
paths._setup_original_images()
|
| 831 |
+
paths._setup_original_images_reverse()
|
| 832 |
+
|
| 833 |
+
# Load and validate input data
|
| 834 |
+
imgs_rgb = self._load_video(paths.video_left)
|
| 835 |
+
bbox_data = self._load_bbox_data(paths.bbox_data)
|
| 836 |
+
if hamer_data is None:
|
| 837 |
+
hamer_data = self._load_hamer_data(paths)
|
| 838 |
+
|
| 839 |
+
# Process left and right hands separately for precise segmentation
|
| 840 |
+
left_data = self._process_hand_data(
|
| 841 |
+
imgs_rgb,
|
| 842 |
+
bbox_data["left_bboxes"],
|
| 843 |
+
bbox_data["left_bbox_min_dist_to_edge"],
|
| 844 |
+
bbox_data["left_hand_detected"],
|
| 845 |
+
hamer_data["left"],
|
| 846 |
+
paths,
|
| 847 |
+
"left"
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
right_data = self._process_hand_data(
|
| 851 |
+
imgs_rgb,
|
| 852 |
+
bbox_data["right_bboxes"],
|
| 853 |
+
bbox_data["right_bbox_min_dist_to_edge"],
|
| 854 |
+
bbox_data["right_hand_detected"],
|
| 855 |
+
hamer_data["right"],
|
| 856 |
+
paths,
|
| 857 |
+
"right"
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
# Convert to boolean masks
|
| 861 |
+
left_masks = left_data["left_masks"].astype(np.bool_)
|
| 862 |
+
left_sam_imgs = left_data["left_sam_imgs"]
|
| 863 |
+
right_masks = right_data["right_masks"].astype(np.bool_)
|
| 864 |
+
right_sam_imgs = right_data["right_sam_imgs"]
|
| 865 |
+
|
| 866 |
+
# Save results with separate left/right hand data
|
| 867 |
+
self._save_results(paths, left_masks, left_sam_imgs, right_masks, right_sam_imgs)
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
def _process_hand_data(
|
| 871 |
+
self,
|
| 872 |
+
imgs_rgb: np.ndarray,
|
| 873 |
+
bboxes: np.ndarray,
|
| 874 |
+
bbox_min_dist: np.ndarray,
|
| 875 |
+
hand_detected: np.ndarray,
|
| 876 |
+
hamer_data: HandSequence,
|
| 877 |
+
paths: Paths,
|
| 878 |
+
hand_side: str
|
| 879 |
+
) -> Dict[str, np.ndarray]:
|
| 880 |
+
"""
|
| 881 |
+
Process hand segmentation data for a single hand (left or right).
|
| 882 |
+
|
| 883 |
+
Args:
|
| 884 |
+
imgs_rgb: RGB video frames
|
| 885 |
+
bboxes: Hand bounding boxes from detection stage
|
| 886 |
+
bbox_min_dist: Minimum distances to image edges (quality metric)
|
| 887 |
+
hand_detected: Boolean flags indicating valid hand detections
|
| 888 |
+
hamer_data: Hand pose data for segmentation guidance
|
| 889 |
+
paths: Paths object for file management
|
| 890 |
+
hand_side: "left" or "right" specifying which hand to process
|
| 891 |
+
|
| 892 |
+
Returns:
|
| 893 |
+
Dictionary containing segmentation masks and visualization images
|
| 894 |
+
"""
|
| 895 |
+
# Handle cases with no valid detections
|
| 896 |
+
if not hand_detected.any() or max(bbox_min_dist) == 0:
|
| 897 |
+
return {
|
| 898 |
+
f"{hand_side}_masks": np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1])),
|
| 899 |
+
f"{hand_side}_sam_imgs": np.zeros((len(imgs_rgb), imgs_rgb[0].shape[0], imgs_rgb[0].shape[1], 3))
|
| 900 |
+
}
|
| 901 |
+
|
| 902 |
+
# Extract hand pose keypoints for segmentation guidance
|
| 903 |
+
kpts_2d = hamer_data.kpts_2d
|
| 904 |
+
|
| 905 |
+
# Find the frame with highest quality (furthest from edges)
|
| 906 |
+
max_dist_idx = np.argmax(bbox_min_dist)
|
| 907 |
+
bbox = bboxes[max_dist_idx]
|
| 908 |
+
points = np.expand_dims(kpts_2d[max_dist_idx], axis=1)
|
| 909 |
+
|
| 910 |
+
# Process segmentation in both temporal directions
|
| 911 |
+
masks_forward, sam_imgs_forward = self._run_sam_segmentation(
|
| 912 |
+
paths, bbox, points, max_dist_idx, reverse=False, output_bboxes=bboxes
|
| 913 |
+
)
|
| 914 |
+
masks_reverse, sam_imgs_reverse = self._run_sam_segmentation(
|
| 915 |
+
paths, bbox, points, max_dist_idx, reverse=True, output_bboxes=bboxes
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
# Combine bidirectional results
|
| 919 |
+
sam_imgs = self._combine_sam_images(imgs_rgb, sam_imgs_forward, sam_imgs_reverse)
|
| 920 |
+
masks = self._combine_masks(imgs_rgb, masks_forward, masks_reverse)
|
| 921 |
+
|
| 922 |
+
return {
|
| 923 |
+
f"{hand_side}_masks": masks,
|
| 924 |
+
f"{hand_side}_sam_imgs": sam_imgs
|
| 925 |
+
}
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
def _run_sam_segmentation(
|
| 929 |
+
self,
|
| 930 |
+
paths: Paths,
|
| 931 |
+
bbox: np.ndarray,
|
| 932 |
+
points: np.ndarray,
|
| 933 |
+
max_dist_idx: int,
|
| 934 |
+
reverse: bool,
|
| 935 |
+
output_bboxes: np.ndarray
|
| 936 |
+
) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]:
|
| 937 |
+
"""
|
| 938 |
+
Process video segmentation in either forward or reverse temporal direction.
|
| 939 |
+
|
| 940 |
+
Args:
|
| 941 |
+
paths: Paths object for file management
|
| 942 |
+
bbox: Initial bounding box for segmentation
|
| 943 |
+
points: Hand keypoints for segmentation guidance
|
| 944 |
+
max_dist_idx: Index of highest-quality frame for initialization
|
| 945 |
+
reverse: Whether to process in reverse temporal order
|
| 946 |
+
output_bboxes: All bounding boxes for the sequence
|
| 947 |
+
|
| 948 |
+
Returns:
|
| 949 |
+
Tuple of (segmentation_masks, visualization_images)
|
| 950 |
+
"""
|
| 951 |
+
return self.detector_sam.segment_video(
|
| 952 |
+
paths.original_images_folder,
|
| 953 |
+
bbox,
|
| 954 |
+
points,
|
| 955 |
+
[max_dist_idx],
|
| 956 |
+
reverse=reverse,
|
| 957 |
+
output_bboxes=output_bboxes
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
@staticmethod
|
| 961 |
+
def _save_results(
|
| 962 |
+
paths: Paths,
|
| 963 |
+
left_masks: np.ndarray,
|
| 964 |
+
left_sam_imgs: np.ndarray,
|
| 965 |
+
right_masks: np.ndarray,
|
| 966 |
+
right_sam_imgs: np.ndarray,
|
| 967 |
+
fps: int = DEFAULT_FPS
|
| 968 |
+
) -> None:
|
| 969 |
+
"""
|
| 970 |
+
Save hand segmentation results to disk.
|
| 971 |
+
|
| 972 |
+
Args:
|
| 973 |
+
paths: Paths object containing output file locations
|
| 974 |
+
left_masks: Left hand segmentation masks
|
| 975 |
+
left_sam_imgs: Left hand SAM visualization images
|
| 976 |
+
right_masks: Right hand segmentation masks
|
| 977 |
+
right_sam_imgs: Right hand SAM visualization images
|
| 978 |
+
fps: Frames per second for output videos (default: 10)
|
| 979 |
+
"""
|
| 980 |
+
HandSegmentationProcessor._create_output_directory(paths)
|
| 981 |
+
|
| 982 |
+
try:
|
| 983 |
+
HandSegmentationProcessor._save_hand_mask_data(paths, left_masks, right_masks)
|
| 984 |
+
HandSegmentationProcessor._create_hand_videos(paths, left_masks, left_sam_imgs, right_masks, right_sam_imgs, fps)
|
| 985 |
+
except Exception as e:
|
| 986 |
+
logging.error(f"Error saving results: {str(e)}")
|
| 987 |
+
raise
|
| 988 |
+
|
| 989 |
+
HandSegmentationProcessor._cleanup_temp_files(paths)
|
| 990 |
+
|
| 991 |
+
@staticmethod
|
| 992 |
+
def _create_output_directory(paths: Paths) -> None:
|
| 993 |
+
"""
|
| 994 |
+
Create output directory for segmentation results.
|
| 995 |
+
|
| 996 |
+
Args:
|
| 997 |
+
paths: Paths object containing output directory location
|
| 998 |
+
"""
|
| 999 |
+
if not os.path.exists(paths.segmentation_processor):
|
| 1000 |
+
os.makedirs(paths.segmentation_processor)
|
| 1001 |
+
|
| 1002 |
+
@staticmethod
|
| 1003 |
+
def _save_hand_mask_data(paths: Paths, left_masks: np.ndarray, right_masks: np.ndarray) -> None:
|
| 1004 |
+
"""
|
| 1005 |
+
Save hand mask data to disk.
|
| 1006 |
+
|
| 1007 |
+
Args:
|
| 1008 |
+
paths: Paths object containing output file locations
|
| 1009 |
+
left_masks: Left hand segmentation masks
|
| 1010 |
+
right_masks: Right hand segmentation masks
|
| 1011 |
+
"""
|
| 1012 |
+
np.save(paths.masks_hand_left, left_masks)
|
| 1013 |
+
np.save(paths.masks_hand_right, right_masks)
|
| 1014 |
+
|
| 1015 |
+
@staticmethod
|
| 1016 |
+
def _create_hand_videos(
|
| 1017 |
+
paths: Paths,
|
| 1018 |
+
left_masks: np.ndarray,
|
| 1019 |
+
left_sam_imgs: np.ndarray,
|
| 1020 |
+
right_masks: np.ndarray,
|
| 1021 |
+
right_sam_imgs: np.ndarray,
|
| 1022 |
+
fps: int
|
| 1023 |
+
) -> None:
|
| 1024 |
+
"""
|
| 1025 |
+
Create visualization videos for hand segmentation.
|
| 1026 |
+
|
| 1027 |
+
Args:
|
| 1028 |
+
paths: Paths object containing output file locations
|
| 1029 |
+
left_masks: Left hand segmentation masks
|
| 1030 |
+
left_sam_imgs: Left hand SAM visualization images
|
| 1031 |
+
right_masks: Right hand segmentation masks
|
| 1032 |
+
right_sam_imgs: Right hand SAM visualization images
|
| 1033 |
+
fps: Frames per second for output videos
|
| 1034 |
+
"""
|
| 1035 |
+
for name, data in [
|
| 1036 |
+
("video_masks_hand_left", left_masks),
|
| 1037 |
+
("video_masks_hand_right", right_masks),
|
| 1038 |
+
("video_sam_hand_left", left_sam_imgs),
|
| 1039 |
+
("video_sam_hand_right", right_sam_imgs),
|
| 1040 |
+
]:
|
| 1041 |
+
output_path = getattr(paths, name)
|
| 1042 |
+
media.write_video(output_path, data, fps=fps, codec=DEFAULT_CODEC)
|
| 1043 |
+
|
| 1044 |
+
@staticmethod
|
| 1045 |
+
def _cleanup_temp_files(paths: Paths) -> None:
|
| 1046 |
+
"""
|
| 1047 |
+
Clean up temporary directories created during processing.
|
| 1048 |
+
|
| 1049 |
+
Args:
|
| 1050 |
+
paths: Paths object containing temporary directory locations
|
| 1051 |
+
"""
|
| 1052 |
+
if os.path.exists(paths.original_images_folder):
|
| 1053 |
+
shutil.rmtree(paths.original_images_folder)
|
| 1054 |
+
if os.path.exists(paths.original_images_folder_reverse):
|
| 1055 |
+
shutil.rmtree(paths.original_images_folder_reverse)
|
| 1056 |
+
|
phantom/phantom/processors/smoothing_processor.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Trajectory Smoothing Processor Module
|
| 3 |
+
|
| 4 |
+
This module does trajectory smoothing for end-effector positions, orientations, and gripper states
|
| 5 |
+
extracted from human demonstrations.
|
| 6 |
+
|
| 7 |
+
Processing Pipeline:
|
| 8 |
+
1. Load processed action data from previous pipeline stages
|
| 9 |
+
2. Apply Gaussian Process smoothing to 3D position trajectories
|
| 10 |
+
3. Apply SLERP-based smoothing to rotation matrix trajectories
|
| 11 |
+
4. Apply Gaussian Process smoothing to gripper distance trajectories
|
| 12 |
+
5. Save smoothed trajectories for robot execution
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from typing import Optional
|
| 17 |
+
import argparse
|
| 18 |
+
import numpy as np
|
| 19 |
+
import logging
|
| 20 |
+
from sklearn.gaussian_process import GaussianProcessRegressor # type: ignore
|
| 21 |
+
from sklearn.gaussian_process.kernels import RBF, WhiteKernel # type: ignore
|
| 22 |
+
from scipy.spatial.transform import Rotation, Slerp
|
| 23 |
+
|
| 24 |
+
from phantom.processors.base_processor import BaseProcessor
|
| 25 |
+
from phantom.processors.paths import Paths
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
def gaussian_kernel(size: int, sigma: float) -> np.ndarray:
|
| 30 |
+
"""
|
| 31 |
+
Generate a centered Gaussian kernel for local smoothing operations.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
size: Size of the kernel (should be odd for proper centering)
|
| 35 |
+
sigma: Standard deviation of the Gaussian distribution
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
Normalized Gaussian kernel array
|
| 39 |
+
|
| 40 |
+
Raises:
|
| 41 |
+
ValueError: If size is not positive
|
| 42 |
+
"""
|
| 43 |
+
if size <= 0:
|
| 44 |
+
raise ValueError("Kernel size must be positive")
|
| 45 |
+
|
| 46 |
+
x = np.arange(size) - size // 2
|
| 47 |
+
kernel = np.exp(-0.5 * (x / sigma) ** 2)
|
| 48 |
+
return kernel / kernel.sum()
|
| 49 |
+
|
| 50 |
+
class SmoothingProcessor(BaseProcessor):
|
| 51 |
+
"""
|
| 52 |
+
This processor takes raw trajectory data extracted from human demonstrations
|
| 53 |
+
and applies smoothing techniques to create executable robot trajectories.
|
| 54 |
+
|
| 55 |
+
Attributes:
|
| 56 |
+
bimanual_setup (str): Configuration mode ("single_arm" or bimanual type)
|
| 57 |
+
target_hand (str): Target hand for single-arm processing ("left" or "right")
|
| 58 |
+
"""
|
| 59 |
+
def __init__(self, args: argparse.Namespace) -> None:
|
| 60 |
+
"""
|
| 61 |
+
Initialize the smoothing processor with configuration parameters.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
args: Command line arguments containing smoothing configuration
|
| 65 |
+
including bimanual setup and target hand specification
|
| 66 |
+
"""
|
| 67 |
+
super().__init__(args)
|
| 68 |
+
|
| 69 |
+
def process_one_demo(self, data_sub_folder: str) -> None:
|
| 70 |
+
"""
|
| 71 |
+
Process and smooth trajectories for a single demonstration.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
data_sub_folder: Path to demonstration data folder containing
|
| 75 |
+
processed action trajectories from previous stages
|
| 76 |
+
"""
|
| 77 |
+
save_folder = self.get_save_folder(data_sub_folder)
|
| 78 |
+
paths = self.get_paths(save_folder)
|
| 79 |
+
|
| 80 |
+
# Handle single-arm processing mode
|
| 81 |
+
if self.bimanual_setup == "single_arm":
|
| 82 |
+
self._process_single_arm_demo(paths)
|
| 83 |
+
else:
|
| 84 |
+
self._process_bimanual_demo(paths)
|
| 85 |
+
|
| 86 |
+
def _process_single_arm_demo(self, paths: Paths) -> None:
|
| 87 |
+
"""
|
| 88 |
+
Process single-arm demonstration data.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
paths: Paths object containing file locations
|
| 92 |
+
"""
|
| 93 |
+
# Load action data for target hand
|
| 94 |
+
actions_path = self._get_actions_path(paths)
|
| 95 |
+
actions = np.load(actions_path, allow_pickle=True)
|
| 96 |
+
|
| 97 |
+
# Apply smoothing to each trajectory component
|
| 98 |
+
smoothed_ee_pts = self.gaussian_process_smoothing(actions["ee_pts"])
|
| 99 |
+
|
| 100 |
+
# Apply rotation smoothing with configuration-specific parameters
|
| 101 |
+
if self.constrained_hand:
|
| 102 |
+
smoothed_ee_oris = self.gaussian_slerp_smoothing(
|
| 103 |
+
actions["ee_oris"], sigma=10.0, kernel_size=41
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
smoothed_ee_oris = self.gaussian_slerp_smoothing(
|
| 107 |
+
actions["ee_oris"], sigma=10.0
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
smoothed_ee_widths = self.gaussian_process_smoothing(actions["ee_widths"])
|
| 111 |
+
|
| 112 |
+
# Save results based on target hand
|
| 113 |
+
if self.target_hand == "left":
|
| 114 |
+
self._save_results(paths, smoothed_ee_pts_left=smoothed_ee_pts,
|
| 115 |
+
smoothed_ee_oris_left=smoothed_ee_oris,
|
| 116 |
+
smoothed_ee_widths_left=smoothed_ee_widths)
|
| 117 |
+
else:
|
| 118 |
+
self._save_results(paths, smoothed_ee_pts_right=smoothed_ee_pts,
|
| 119 |
+
smoothed_ee_oris_right=smoothed_ee_oris,
|
| 120 |
+
smoothed_ee_widths_right=smoothed_ee_widths)
|
| 121 |
+
|
| 122 |
+
def _process_bimanual_demo(self, paths: Paths) -> None:
|
| 123 |
+
"""
|
| 124 |
+
Process bimanual demonstration data.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
paths: Paths object containing file locations
|
| 128 |
+
"""
|
| 129 |
+
# Load data for both hands
|
| 130 |
+
actions_left_path = str(paths.actions_left).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
|
| 131 |
+
actions_right_path = str(paths.actions_right).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
|
| 132 |
+
actions_left = np.load(actions_left_path, allow_pickle=True)
|
| 133 |
+
actions_right = np.load(actions_right_path, allow_pickle=True)
|
| 134 |
+
|
| 135 |
+
# Apply position smoothing using Gaussian Process regression
|
| 136 |
+
smoothed_ee_pts_left = self.gaussian_process_smoothing(actions_left["ee_pts"])
|
| 137 |
+
smoothed_ee_pts_right = self.gaussian_process_smoothing(actions_right["ee_pts"])
|
| 138 |
+
|
| 139 |
+
# Apply rotation smoothing using SLERP with optimized parameters for bimanual coordination
|
| 140 |
+
smoothed_ee_oris_left = self.gaussian_slerp_smoothing(
|
| 141 |
+
actions_left["ee_oris"], sigma=10.0, kernel_size=21
|
| 142 |
+
)
|
| 143 |
+
smoothed_ee_oris_right = self.gaussian_slerp_smoothing(
|
| 144 |
+
actions_right["ee_oris"], sigma=10.0, kernel_size=21
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Apply gripper distance smoothing
|
| 148 |
+
smoothed_ee_widths_left = self.gaussian_process_smoothing(actions_left["ee_widths"])
|
| 149 |
+
smoothed_ee_widths_right = self.gaussian_process_smoothing(actions_right["ee_widths"])
|
| 150 |
+
|
| 151 |
+
# Save all smoothed trajectories
|
| 152 |
+
self._save_results(paths, smoothed_ee_pts_left, smoothed_ee_oris_left, smoothed_ee_widths_left,
|
| 153 |
+
smoothed_ee_pts_right, smoothed_ee_oris_right, smoothed_ee_widths_right)
|
| 154 |
+
|
| 155 |
+
def _get_actions_path(self, paths: Paths) -> str:
|
| 156 |
+
"""
|
| 157 |
+
Get the appropriate actions file path based on target hand.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
paths: Paths object containing file locations
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Path to the actions file for the target hand
|
| 164 |
+
"""
|
| 165 |
+
if self.target_hand == "left":
|
| 166 |
+
base_path = str(paths.actions_left)
|
| 167 |
+
else:
|
| 168 |
+
base_path = str(paths.actions_right)
|
| 169 |
+
return base_path.split(".npz")[0] + f"_{self.bimanual_setup}.npz"
|
| 170 |
+
|
| 171 |
+
def _save_results(self, paths: Paths, smoothed_ee_pts_left: Optional[np.ndarray] = None,
|
| 172 |
+
smoothed_ee_oris_left: Optional[np.ndarray] = None,
|
| 173 |
+
smoothed_ee_widths_left: Optional[np.ndarray] = None,
|
| 174 |
+
smoothed_ee_pts_right: Optional[np.ndarray] = None,
|
| 175 |
+
smoothed_ee_oris_right: Optional[np.ndarray] = None,
|
| 176 |
+
smoothed_ee_widths_right: Optional[np.ndarray] = None) -> None:
|
| 177 |
+
"""
|
| 178 |
+
Save smoothed trajectory results to disk.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
paths: Paths object containing output file locations
|
| 182 |
+
smoothed_ee_pts_left: Smoothed left hand position trajectory
|
| 183 |
+
smoothed_ee_oris_left: Smoothed left hand orientation trajectory
|
| 184 |
+
smoothed_ee_widths_left: Smoothed left hand gripper trajectory
|
| 185 |
+
smoothed_ee_pts_right: Smoothed right hand position trajectory
|
| 186 |
+
smoothed_ee_oris_right: Smoothed right hand orientation trajectory
|
| 187 |
+
smoothed_ee_widths_right: Smoothed right hand gripper trajectory
|
| 188 |
+
"""
|
| 189 |
+
# Create output directory
|
| 190 |
+
os.makedirs(paths.smoothing_processor, exist_ok=True)
|
| 191 |
+
|
| 192 |
+
# Save left hand trajectories if provided
|
| 193 |
+
if smoothed_ee_pts_left is not None:
|
| 194 |
+
smoothed_actions_left_path = str(paths.smoothed_actions_left).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
|
| 195 |
+
np.savez(smoothed_actions_left_path,
|
| 196 |
+
ee_pts=smoothed_ee_pts_left,
|
| 197 |
+
ee_oris=smoothed_ee_oris_left,
|
| 198 |
+
ee_widths=smoothed_ee_widths_left)
|
| 199 |
+
|
| 200 |
+
# Save right hand trajectories if provided
|
| 201 |
+
if smoothed_ee_pts_right is not None:
|
| 202 |
+
smoothed_actions_right_path = str(paths.smoothed_actions_right).split(".npz")[0] + f"_{self.bimanual_setup}.npz"
|
| 203 |
+
np.savez(smoothed_actions_right_path,
|
| 204 |
+
ee_pts=smoothed_ee_pts_right,
|
| 205 |
+
ee_oris=smoothed_ee_oris_right,
|
| 206 |
+
ee_widths=smoothed_ee_widths_right)
|
| 207 |
+
|
| 208 |
+
@staticmethod
|
| 209 |
+
def gaussian_slerp_smoothing(rot_mats: np.ndarray, sigma: float = 2, kernel_size: int = 9) -> np.ndarray:
|
| 210 |
+
"""
|
| 211 |
+
Apply Gaussian-weighted SLERP smoothing to rotation matrices.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
rot_mats: Array of rotation matrices to smooth, shape (N, 3, 3)
|
| 215 |
+
sigma: Standard deviation for Gaussian kernel
|
| 216 |
+
kernel_size: Size of the smoothing kernel (should be odd)
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Array of smoothed rotation matrices, shape (N, 3, 3)
|
| 220 |
+
|
| 221 |
+
Raises:
|
| 222 |
+
ValueError: If kernel_size is not odd
|
| 223 |
+
"""
|
| 224 |
+
if kernel_size % 2 != 1:
|
| 225 |
+
raise ValueError("Kernel size must be odd for proper centering")
|
| 226 |
+
|
| 227 |
+
half_k = kernel_size // 2
|
| 228 |
+
N = len(rot_mats)
|
| 229 |
+
|
| 230 |
+
# Step 1: Convert rotation matrices to quaternions for interpolation
|
| 231 |
+
quats = Rotation.from_matrix(rot_mats).as_quat()
|
| 232 |
+
|
| 233 |
+
# Step 2: Apply hemisphere correction to ensure quaternion continuity
|
| 234 |
+
quats_fixed = [quats[0]]
|
| 235 |
+
for i in range(1, N):
|
| 236 |
+
q = quats[i]
|
| 237 |
+
# Choose quaternion hemisphere that minimizes distance to previous quaternion
|
| 238 |
+
if np.dot(q, quats_fixed[-1]) < 0:
|
| 239 |
+
q = -q
|
| 240 |
+
quats_fixed.append(q)
|
| 241 |
+
quats_fixed = np.array(quats_fixed)
|
| 242 |
+
|
| 243 |
+
# Step 3: Prepare normalized Gaussian weights for local smoothing
|
| 244 |
+
weights = gaussian_kernel(kernel_size, sigma)
|
| 245 |
+
|
| 246 |
+
# Step 4: Apply weighted SLERP averaging for each time point
|
| 247 |
+
smoothed_rots = []
|
| 248 |
+
for i in range(N):
|
| 249 |
+
# Define local neighborhood around current time point
|
| 250 |
+
start = max(0, i - half_k)
|
| 251 |
+
end = min(N, i + half_k + 1)
|
| 252 |
+
|
| 253 |
+
# Extract local quaternions and corresponding weights
|
| 254 |
+
local_quats = quats_fixed[start:end]
|
| 255 |
+
local_weights = weights[half_k - (i - start): half_k + (end - i)]
|
| 256 |
+
|
| 257 |
+
# Normalize weights for current neighborhood
|
| 258 |
+
local_weights /= local_weights.sum()
|
| 259 |
+
|
| 260 |
+
# Initialize weighted average with first quaternion
|
| 261 |
+
q_avg = local_quats[0]
|
| 262 |
+
r_avg = Rotation.from_quat(q_avg)
|
| 263 |
+
|
| 264 |
+
# Iteratively apply weighted SLERP interpolation
|
| 265 |
+
for j in range(1, len(local_quats)):
|
| 266 |
+
r_next = Rotation.from_quat(local_quats[j])
|
| 267 |
+
# Use SLERP with weight proportional to current quaternion's contribution
|
| 268 |
+
r_avg = Slerp([0, 1], Rotation.concatenate([r_avg, r_next]))([local_weights[j] / (local_weights[:j+1].sum())])[0]
|
| 269 |
+
|
| 270 |
+
smoothed_rots.append(r_avg.as_matrix())
|
| 271 |
+
|
| 272 |
+
return np.stack(smoothed_rots)
|
| 273 |
+
|
| 274 |
+
@staticmethod
|
| 275 |
+
def gaussian_process_smoothing(pts: np.ndarray) -> np.ndarray:
|
| 276 |
+
"""
|
| 277 |
+
Apply Gaussian process smoothing to trajectory points.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
pts: Trajectory points to smooth, shape (N,) for 1D or (N, D) for multi-dimensional
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
Smoothed trajectory points with same shape as input
|
| 284 |
+
|
| 285 |
+
Raises:
|
| 286 |
+
ValueError: If pts is empty
|
| 287 |
+
"""
|
| 288 |
+
if len(pts) == 0:
|
| 289 |
+
raise ValueError("Cannot smooth empty trajectory")
|
| 290 |
+
|
| 291 |
+
# Create time indices as features for GP regression
|
| 292 |
+
time = np.arange(len(pts))[:, None] # Time as a single feature
|
| 293 |
+
|
| 294 |
+
# Configure GP kernel: RBF for smoothness + White noise for robustness
|
| 295 |
+
kernel = RBF(length_scale=1) + WhiteKernel(noise_level=1)
|
| 296 |
+
gpr = GaussianProcessRegressor(kernel=kernel, normalize_y=True)
|
| 297 |
+
|
| 298 |
+
# Handle 1D trajectory case
|
| 299 |
+
if pts.ndim == 1:
|
| 300 |
+
return gpr.fit(time, pts).predict(time)
|
| 301 |
+
|
| 302 |
+
# Handle multi-dimensional trajectory case by processing each dimension independently
|
| 303 |
+
return np.column_stack([gpr.fit(time, pts[:, i]).predict(time) for i in range(pts.shape[1])])
|
phantom/phantom/twin_bimanual_robot.py
ADDED
|
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Virtual twin bimanual robot implementation for MuJoCo simulation.
|
| 3 |
+
|
| 4 |
+
This module provides a TwinBimanualRobot class that creates a virtual representation
|
| 5 |
+
of a bimanual (two-arm) robot system in MuJoCo using the robosuite framework.
|
| 6 |
+
The twin robot can be controlled via end-effector poses or joint positions and
|
| 7 |
+
provides observation data including RGB images, depth maps, and robot masks.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from collections import deque
|
| 11 |
+
import re
|
| 12 |
+
import cv2
|
| 13 |
+
import pdb
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import numpy as np
|
| 16 |
+
from scipy.spatial.transform import Rotation
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Tuple, Union, Any
|
| 19 |
+
|
| 20 |
+
from robosuite.controllers import load_controller_config # type: ignore
|
| 21 |
+
from robosuite.utils.camera_utils import get_real_depth_map # type: ignore
|
| 22 |
+
from robomimic.envs.env_robosuite import EnvRobosuite # type: ignore
|
| 23 |
+
import robomimic.utils.obs_utils as ObsUtils # type: ignore
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class MujocoCameraParams:
|
| 28 |
+
"""
|
| 29 |
+
Camera parameters for MuJoCo simulation.
|
| 30 |
+
|
| 31 |
+
Attributes:
|
| 32 |
+
name: Camera name identifier
|
| 33 |
+
pos: 3D position of camera in world coordinates
|
| 34 |
+
ori_wxyz: Camera orientation as quaternion (w, x, y, z)
|
| 35 |
+
fov: Field of view in degrees
|
| 36 |
+
resolution: Image resolution as (width, height)
|
| 37 |
+
sensorsize: Physical sensor size in mm
|
| 38 |
+
principalpixel: Principal point coordinates in pixels
|
| 39 |
+
focalpixel: Focal length in pixels
|
| 40 |
+
"""
|
| 41 |
+
name: str
|
| 42 |
+
pos: np.ndarray
|
| 43 |
+
ori_wxyz: np.ndarray
|
| 44 |
+
fov: float
|
| 45 |
+
resolution: Tuple[int, int]
|
| 46 |
+
sensorsize: np.ndarray
|
| 47 |
+
principalpixel: np.ndarray
|
| 48 |
+
focalpixel: np.ndarray
|
| 49 |
+
|
| 50 |
+
# Color constants for visualization (RGBA format)
|
| 51 |
+
THUMB_COLOR = [0, 1, 0, 1] # Green for thumb
|
| 52 |
+
INDEX_COLOR = [1, 0, 0, 1] # Red for index finger
|
| 53 |
+
HAND_EE_COLOR = [0, 0, 1, 1] # Blue for hand end-effector
|
| 54 |
+
|
| 55 |
+
# Transformation matrix for Epic Kitchen setup - converts from base frame to robot frame
|
| 56 |
+
BASE_T_1 = np.array([[0.0, -1.0, 0.0, 0.0],
|
| 57 |
+
[ 0.5, 0.0, 0.866, 0.2],
|
| 58 |
+
[-0.866, 0.0, 0.5, 1.50],
|
| 59 |
+
[ 0.0, 0.0, 0.0, 1.0]])
|
| 60 |
+
|
| 61 |
+
def convert_real_camera_ori_to_mujoco(camera_ori_matrix: np.ndarray) -> np.ndarray:
|
| 62 |
+
"""
|
| 63 |
+
Convert camera orientation from real world to MuJoCo XML format.
|
| 64 |
+
|
| 65 |
+
MuJoCo uses a different coordinate system convention, so we need to
|
| 66 |
+
flip the Y and Z axes of the rotation matrix before converting to quaternion.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
camera_ori_matrix: 3x3 rotation matrix in real-world coordinates
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Camera orientation as quaternion in MuJoCo format (w, x, y, z)
|
| 73 |
+
"""
|
| 74 |
+
camera_ori_matrix[:, [1, 2]] = -camera_ori_matrix[:, [1, 2]]
|
| 75 |
+
r = Rotation.from_matrix(camera_ori_matrix)
|
| 76 |
+
camera_ori_wxyz = r.as_quat(scalar_first=True)
|
| 77 |
+
return camera_ori_wxyz
|
| 78 |
+
|
| 79 |
+
class TwinBimanualRobot:
|
| 80 |
+
"""
|
| 81 |
+
Virtual twin of a bimanual robot system in MuJoCo simulation.
|
| 82 |
+
|
| 83 |
+
This class creates a simulated bimanual robot that can be controlled via
|
| 84 |
+
end-effector poses or joint positions. It provides functionality for:
|
| 85 |
+
- Robot pose control (OSC or joint-level)
|
| 86 |
+
- Camera observation collection (RGB, depth, segmentation)
|
| 87 |
+
- Robot and gripper mask generation
|
| 88 |
+
- Observation history management
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, robot_name: str, gripper_name: str, bimanual_setup: str,
|
| 92 |
+
camera_params: MujocoCameraParams, camera_height: int, camera_width: int,
|
| 93 |
+
render: bool, n_steps_short: int, n_steps_long: int, square: bool = False,
|
| 94 |
+
debug_cameras: list[str] = [], epic: bool = False, joint_controller: bool = False):
|
| 95 |
+
"""
|
| 96 |
+
Initialize the bimanual robot twin.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
robot_name: Type of robot (e.g., "Kinova3")
|
| 100 |
+
gripper_name: Type of gripper (e.g., "Robotiq85")
|
| 101 |
+
bimanual_setup: Configuration for bimanual setup
|
| 102 |
+
camera_params: Camera configuration parameters
|
| 103 |
+
camera_height: Height of camera images in pixels
|
| 104 |
+
camera_width: Width of camera images in pixels
|
| 105 |
+
render: Whether to render the simulation visually
|
| 106 |
+
n_steps_short: Number of simulation steps for quick movements
|
| 107 |
+
n_steps_long: Number of simulation steps for initial/slow movements
|
| 108 |
+
square: Whether to crop images to square aspect ratio
|
| 109 |
+
debug_cameras: Additional camera names for debugging views
|
| 110 |
+
epic: Whether to use Epic Kitchen coordinate system
|
| 111 |
+
joint_controller: Whether to use joint-level control instead of OSC
|
| 112 |
+
"""
|
| 113 |
+
# Store configuration parameters
|
| 114 |
+
self.robot_name = robot_name
|
| 115 |
+
self.gripper_name = gripper_name
|
| 116 |
+
self.bimanual_setup = bimanual_setup
|
| 117 |
+
self.camera_params = camera_params
|
| 118 |
+
self.render = render
|
| 119 |
+
self.n_steps_long = n_steps_long
|
| 120 |
+
self.n_steps_short= n_steps_short
|
| 121 |
+
self.num_frames = 2 # Number of frames to keep in observation history
|
| 122 |
+
self.camera_height = camera_height
|
| 123 |
+
self.camera_width = camera_width
|
| 124 |
+
self.camera_name = "zed" # Main camera name
|
| 125 |
+
self.square = square
|
| 126 |
+
self.debug_cameras = list(debug_cameras) if debug_cameras else []
|
| 127 |
+
self.epic = epic # Epic Kitchen mode flag
|
| 128 |
+
self.joint_controller = joint_controller # Control mode flag
|
| 129 |
+
|
| 130 |
+
# Configure observation specifications for robomimic
|
| 131 |
+
obs_spec = dict(
|
| 132 |
+
obs=dict(
|
| 133 |
+
low_dim=["robot0_eef_pos"], # End-effector position observations
|
| 134 |
+
rgb=[f"{self.camera_params.name}_image"] + [f"{cam}_image" for cam in self.debug_cameras],
|
| 135 |
+
),
|
| 136 |
+
)
|
| 137 |
+
ObsUtils.initialize_obs_utils_with_obs_specs(
|
| 138 |
+
obs_modality_specs=obs_spec)
|
| 139 |
+
|
| 140 |
+
# Configure robosuite environment options
|
| 141 |
+
options: dict[str, Union[str, list[str], dict[str, Any], bool, int, np.ndarray]] = {}
|
| 142 |
+
options["env_name"] = "PhantomBimanual"
|
| 143 |
+
options["bimanual_setup"] = bimanual_setup
|
| 144 |
+
options["robots"] = [self.robot_name, self.robot_name] # Two identical robots
|
| 145 |
+
if self.robot_name == "Kinova3":
|
| 146 |
+
options["gripper_types"] = [f"{self.gripper_name}GripperRealKinova", f"{self.gripper_name}GripperRealKinova"]
|
| 147 |
+
else:
|
| 148 |
+
options["gripper_types"] = [f"{self.gripper_name}Gripper", f"{self.gripper_name}Gripper"]
|
| 149 |
+
|
| 150 |
+
# Configure controller (OSC pose control by default)
|
| 151 |
+
controller_config = load_controller_config(default_controller="OSC_POSE")
|
| 152 |
+
controller_config["control_delta"] = False # Use absolute positioning
|
| 153 |
+
controller_config["uncouple_pos_ori"] = False # Couple position and orientation
|
| 154 |
+
options["controller_configs"] = controller_config
|
| 155 |
+
|
| 156 |
+
# Override with joint controller if specified
|
| 157 |
+
if self.joint_controller:
|
| 158 |
+
controller_config = load_controller_config(default_controller="JOINT_POSITION")
|
| 159 |
+
controller_config["input_type"] = "absolute"
|
| 160 |
+
controller_config["input_max"] = 10
|
| 161 |
+
controller_config["input_min"] = -10
|
| 162 |
+
controller_config["output_max"] = 10
|
| 163 |
+
controller_config["output_min"] = -10
|
| 164 |
+
controller_config["kd"] = 200 # Derivative gain
|
| 165 |
+
controller_config["kv"] = 200 # Velocity gain
|
| 166 |
+
controller_config["kp"] = 1000 # Proportional gain
|
| 167 |
+
controller_config["kp_limits"] = [0, 1000] # Proportional gain limits
|
| 168 |
+
options["controller_configs"] = controller_config
|
| 169 |
+
|
| 170 |
+
# Camera and observation settings
|
| 171 |
+
options["camera_heights"] = self.camera_height
|
| 172 |
+
options["camera_widths"] = self.camera_width
|
| 173 |
+
options["camera_segmentations"] = "instance" # Instance segmentation masks
|
| 174 |
+
options["direct_gripper_control"] = True
|
| 175 |
+
options["use_depth_obs"] = True
|
| 176 |
+
|
| 177 |
+
# Apply Epic Kitchen coordinate transformation if enabled
|
| 178 |
+
if self.epic:
|
| 179 |
+
self.base_T_1 = BASE_T_1
|
| 180 |
+
# Transform camera position and orientation to Epic Kitchen frame
|
| 181 |
+
self.camera_params.pos = self.base_T_1[:3, :3] @ self.camera_params.pos + self.base_T_1[:3, 3]
|
| 182 |
+
camera_ori_matrix = self.base_T_1[:3, :3] @ Rotation.from_quat(self.camera_params.ori_wxyz, scalar_first=True).as_matrix()
|
| 183 |
+
self.camera_params.ori_wxyz = Rotation.from_matrix(camera_ori_matrix).as_quat(scalar_first=True)
|
| 184 |
+
|
| 185 |
+
# Set camera parameters
|
| 186 |
+
options["camera_pos"] = self.camera_params.pos
|
| 187 |
+
options["camera_quat_wxyz"] = self.camera_params.ori_wxyz
|
| 188 |
+
options["camera_sensorsize"] = self.camera_params.sensorsize
|
| 189 |
+
options["camera_principalpixel"] = self.camera_params.principalpixel
|
| 190 |
+
options["camera_focalpixel"] = self.camera_params.focalpixel
|
| 191 |
+
|
| 192 |
+
# Create the robosuite environment
|
| 193 |
+
self.env = EnvRobosuite(
|
| 194 |
+
**options,
|
| 195 |
+
render=render,
|
| 196 |
+
render_offscreen=True, # Enable offscreen rendering for image capture
|
| 197 |
+
use_image_obs=True,
|
| 198 |
+
camera_names=[self.camera_params.name] + self.debug_cameras,
|
| 199 |
+
control_freq=20, # 20 Hz control frequency
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Initialize environment and compute robot base position
|
| 203 |
+
self.reset()
|
| 204 |
+
self.robot_base_pos = np.array([0, 0, self.env.env.robot_base_height+self.env.env.robot_base_offset])
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def reset(self):
|
| 208 |
+
"""Reset environment and clear observation history."""
|
| 209 |
+
self.env.reset()
|
| 210 |
+
self.obs_history = deque()
|
| 211 |
+
|
| 212 |
+
def close(self):
|
| 213 |
+
"""Close the simulation environment."""
|
| 214 |
+
self.env.env.close()
|
| 215 |
+
|
| 216 |
+
def get_action_from_ee_pose(self, ee_pos: np.ndarray, ee_quat_xyzw: np.ndarray, gripper_action: float,
|
| 217 |
+
use_base_offset: bool = False) -> np.ndarray:
|
| 218 |
+
"""
|
| 219 |
+
Convert end-effector pose to robot action vector.
|
| 220 |
+
|
| 221 |
+
This method transforms the desired end-effector position and orientation
|
| 222 |
+
into the action format expected by the robot controller.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
ee_pos: End-effector position as 3D array
|
| 226 |
+
ee_quat_xyzw: End-effector orientation as quaternion (x, y, z, w)
|
| 227 |
+
gripper_action: Gripper action value
|
| 228 |
+
use_base_offset: Whether to add robot base offset to position
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
Action vector [position(3), rotation(3), gripper(1)]
|
| 232 |
+
"""
|
| 233 |
+
# Handle batch inputs by taking the last element
|
| 234 |
+
if ee_pos.ndim > 1:
|
| 235 |
+
ee_pos = ee_pos[-1]
|
| 236 |
+
ee_quat_xyzw = ee_quat_xyzw[-1]
|
| 237 |
+
|
| 238 |
+
# Add base offset if requested and not in Epic mode
|
| 239 |
+
if use_base_offset and not self.epic:
|
| 240 |
+
ee_pos = ee_pos + self.robot_base_pos
|
| 241 |
+
|
| 242 |
+
# Apply coordinate transformations based on mode
|
| 243 |
+
if self.epic:
|
| 244 |
+
# Transform position and orientation to Epic Kitchen coordinate frame
|
| 245 |
+
ee_pos = self.base_T_1[:3, 3] + self.base_T_1[:3, :3] @ ee_pos
|
| 246 |
+
axis_angle = Rotation.from_matrix(self.base_T_1[:3, :3] @ Rotation.from_quat(ee_quat_xyzw).as_matrix()).as_rotvec()
|
| 247 |
+
elif not self.epic:
|
| 248 |
+
# Apply 135-degree Z rotation for standard setup
|
| 249 |
+
rot = Rotation.from_quat(ee_quat_xyzw)
|
| 250 |
+
rot_135deg = Rotation.from_euler('z', 135, degrees=True)
|
| 251 |
+
new_rot = rot * rot_135deg
|
| 252 |
+
axis_angle = new_rot.as_rotvec()
|
| 253 |
+
|
| 254 |
+
# Combine into action vector
|
| 255 |
+
action = np.concatenate([ee_pos, axis_angle, [gripper_action]])
|
| 256 |
+
|
| 257 |
+
return action
|
| 258 |
+
|
| 259 |
+
def _get_initial_obs_history(self, state: dict) -> deque:
|
| 260 |
+
"""
|
| 261 |
+
Initialize observation history by repeating the first observation.
|
| 262 |
+
|
| 263 |
+
This creates a history buffer filled with the initial robot state,
|
| 264 |
+
which is useful for algorithms that require temporal context.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
state: Initial robot state dictionary
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
Deque containing repeated initial observations
|
| 271 |
+
"""
|
| 272 |
+
obs_history = deque(
|
| 273 |
+
[self.move_to_target_state(state, init=True)],
|
| 274 |
+
maxlen=self.num_frames,
|
| 275 |
+
)
|
| 276 |
+
# Fill remaining slots with copies of the initial observation
|
| 277 |
+
for _ in range(self.num_frames-1):
|
| 278 |
+
obs_history.append(self.move_to_target_state(state))
|
| 279 |
+
return obs_history
|
| 280 |
+
|
| 281 |
+
def get_obs_history(self, state: dict) -> list:
|
| 282 |
+
"""
|
| 283 |
+
Get observation history with specified length.
|
| 284 |
+
|
| 285 |
+
Maintains a rolling buffer of recent observations for temporal context.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
state: Current robot state dictionary
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
List of recent observations (length = self.num_frames)
|
| 292 |
+
"""
|
| 293 |
+
if len(self.obs_history) == 0:
|
| 294 |
+
# Initialize history if empty
|
| 295 |
+
self.obs_history = self._get_initial_obs_history(state)
|
| 296 |
+
else:
|
| 297 |
+
# Add new observation to history
|
| 298 |
+
self.obs_history.append(self.move_to_target_state(state))
|
| 299 |
+
return list(self.obs_history)
|
| 300 |
+
|
| 301 |
+
def move_to_target_state(self, state: dict, init=False) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 302 |
+
"""
|
| 303 |
+
Move robot to target state and collect observation data.
|
| 304 |
+
|
| 305 |
+
This is the main method for controlling the robot and collecting observations.
|
| 306 |
+
It handles both pose and joint control modes, and collects RGB, depth,
|
| 307 |
+
and segmentation data along with tracking errors.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
state: Target state containing positions, orientations, and gripper states
|
| 311 |
+
init: Whether this is an initialization step (uses longer movement time)
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
Dictionary containing observation data:
|
| 315 |
+
- robot_mask: Binary mask showing robot pixels
|
| 316 |
+
- gripper_mask: Binary mask showing gripper pixels
|
| 317 |
+
- rgb_img: RGB camera image
|
| 318 |
+
- depth_img: Depth camera image
|
| 319 |
+
- robot_pos: Robot end-effector position
|
| 320 |
+
- left_pos_err: Left arm position tracking error
|
| 321 |
+
- right_pos_err: Right arm position tracking error
|
| 322 |
+
- {cam}_img: Additional camera images if debug_cameras specified
|
| 323 |
+
"""
|
| 324 |
+
# Convert gripper positions to actions based on controller type
|
| 325 |
+
if not self.joint_controller:
|
| 326 |
+
# Use pose controller with gripper position mapping
|
| 327 |
+
gripper_action_0 = self._convert_handgripper_pos_to_action(state["gripper_pos"][0])
|
| 328 |
+
gripper_action_1 = self._convert_handgripper_pos_to_action(state["gripper_pos"][1])
|
| 329 |
+
gripper_action = [gripper_action_0, gripper_action_1]
|
| 330 |
+
else:
|
| 331 |
+
# Use joint controller with direct gripper control
|
| 332 |
+
gripper_action = [state["gripper_pos"][0]*255, state["gripper_pos"][1]*255]
|
| 333 |
+
|
| 334 |
+
# Choose movement duration based on whether this is initialization
|
| 335 |
+
n_steps = self.n_steps_long if init else self.n_steps_short
|
| 336 |
+
|
| 337 |
+
# Execute movement based on controller type
|
| 338 |
+
if not self.joint_controller:
|
| 339 |
+
# Move using pose control
|
| 340 |
+
obs = self.move_to_pose(state["pos"], state["ori_xyzw"], gripper_action, n_steps)
|
| 341 |
+
else:
|
| 342 |
+
# Move using joint control
|
| 343 |
+
obs = self.move_to_pose(state["pos"], state["ori_xyzw"], gripper_action, n_steps, state["q0"], state["q1"])
|
| 344 |
+
|
| 345 |
+
# Extract observation data from simulation
|
| 346 |
+
robot_mask = np.squeeze(self.get_robot_mask(obs))
|
| 347 |
+
gripper_mask = np.squeeze(self.get_gripper_mask(obs))
|
| 348 |
+
rgb_img = self.get_image(obs)
|
| 349 |
+
depth_img = self.get_depth_image(obs)
|
| 350 |
+
robot_pos = obs["robot0_eef_pos"] - self.robot_base_pos
|
| 351 |
+
|
| 352 |
+
# Calculate end-effector tracking errors for both arms
|
| 353 |
+
if not self.epic:
|
| 354 |
+
# Standard coordinate frame
|
| 355 |
+
right_pos_error = np.linalg.norm(obs['robot0_eef_pos']-self.robot_base_pos - state["pos"][0])
|
| 356 |
+
left_pos_error = np.linalg.norm(obs['robot1_eef_pos']-self.robot_base_pos - state["pos"][1])
|
| 357 |
+
else:
|
| 358 |
+
# Epic Kitchen coordinate frame
|
| 359 |
+
right_pos_error = np.linalg.norm(obs['robot0_eef_pos']-self.base_T_1[:3, 3] - self.base_T_1[:3, :3] @ state["pos"][0])
|
| 360 |
+
left_pos_error = np.linalg.norm(obs['robot1_eef_pos']-self.base_T_1[:3, 3] - self.base_T_1[:3, :3] @ state["pos"][1])
|
| 361 |
+
|
| 362 |
+
# Compile output dictionary
|
| 363 |
+
output = {
|
| 364 |
+
"robot_mask": robot_mask,
|
| 365 |
+
"gripper_mask": gripper_mask,
|
| 366 |
+
"rgb_img": rgb_img,
|
| 367 |
+
"depth_img": depth_img,
|
| 368 |
+
"robot_pos": robot_pos,
|
| 369 |
+
"left_pos_err": left_pos_error,
|
| 370 |
+
"right_pos_err": right_pos_error,
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
# Add debug camera images if specified
|
| 374 |
+
for cam in self.debug_cameras:
|
| 375 |
+
cam_img = self.get_camera_image(obs, cam)
|
| 376 |
+
output[f"{cam}_img"] = cam_img
|
| 377 |
+
|
| 378 |
+
return output
|
| 379 |
+
|
| 380 |
+
def _convert_handgripper_pos_to_action(self, gripper_pos: float) -> np.ndarray:
|
| 381 |
+
"""
|
| 382 |
+
Convert hand gripper position to robot gripper action.
|
| 383 |
+
|
| 384 |
+
Maps from physical gripper opening distance to robot action values.
|
| 385 |
+
Different gripper types may have different mappings.
|
| 386 |
+
|
| 387 |
+
Args:
|
| 388 |
+
gripper_pos: Gripper opening distance in meters
|
| 389 |
+
|
| 390 |
+
Returns:
|
| 391 |
+
Robot gripper action value (0-255 for Robotiq85)
|
| 392 |
+
|
| 393 |
+
Raises:
|
| 394 |
+
ValueError: If gripper type is not supported
|
| 395 |
+
"""
|
| 396 |
+
if self.gripper_name == "Robotiq85":
|
| 397 |
+
# Robotiq85 gripper specifications
|
| 398 |
+
min_gripper_pos, max_gripper_pos = 0.0, 0.085 # 0 to 8.5cm opening
|
| 399 |
+
gripper_pos = np.clip(gripper_pos, min_gripper_pos, max_gripper_pos)
|
| 400 |
+
open_gripper_action, closed_gripper_action = 0, 255 # 0=open, 255=closed
|
| 401 |
+
# Linear interpolation between open and closed states
|
| 402 |
+
return np.interp(gripper_pos, [min_gripper_pos, max_gripper_pos], [closed_gripper_action, open_gripper_action])
|
| 403 |
+
else:
|
| 404 |
+
raise ValueError(f"Gripper name {self.gripper_name} not supported")
|
| 405 |
+
|
| 406 |
+
def move_to_pose(self, ee_pos: dict, ee_ori: dict, gripper_action: dict, n_steps: int, q0=None, q1=None) -> dict:
|
| 407 |
+
"""
|
| 408 |
+
Execute robot movement to target pose.
|
| 409 |
+
|
| 410 |
+
Sends action commands to the simulation for the specified number of steps.
|
| 411 |
+
Handles both pose control (OSC) and joint control modes.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
ee_pos: End-effector positions for both arms {0: pos0, 1: pos1}
|
| 415 |
+
ee_ori: End-effector orientations for both arms {0: ori0, 1: ori1}
|
| 416 |
+
gripper_action: Gripper actions for both arms {0: grip0, 1: grip1}
|
| 417 |
+
n_steps: Number of simulation steps to execute
|
| 418 |
+
q0: Joint positions for arm 0 (only for joint controller)
|
| 419 |
+
q1: Joint positions for arm 1 (only for joint controller)
|
| 420 |
+
|
| 421 |
+
Returns:
|
| 422 |
+
Final observation dictionary from simulation
|
| 423 |
+
"""
|
| 424 |
+
if not self.joint_controller:
|
| 425 |
+
# Pose control mode: convert poses to actions
|
| 426 |
+
action_0 = self.get_action_from_ee_pose(ee_pos[0], ee_ori[0], gripper_action[0], use_base_offset=True)
|
| 427 |
+
action_1 = self.get_action_from_ee_pose(ee_pos[1], ee_ori[1], gripper_action[1], use_base_offset=True)
|
| 428 |
+
action = np.concatenate([action_0, action_1])
|
| 429 |
+
else:
|
| 430 |
+
# Joint control mode: convert joint angles from degrees to radians
|
| 431 |
+
q0_new = []
|
| 432 |
+
for rot_q in q0:
|
| 433 |
+
if rot_q >= 180:
|
| 434 |
+
q0_new.append((rot_q/180*np.pi-2*np.pi)) # Handle angle wrapping
|
| 435 |
+
else:
|
| 436 |
+
q0_new.append(rot_q/180*np.pi)
|
| 437 |
+
q1_new = []
|
| 438 |
+
for rot_q in q1:
|
| 439 |
+
if rot_q >= 180:
|
| 440 |
+
q1_new.append((rot_q/180*np.pi-2*np.pi)) # Handle angle wrapping
|
| 441 |
+
else:
|
| 442 |
+
q1_new.append(rot_q/180*np.pi)
|
| 443 |
+
|
| 444 |
+
# Combine joint positions and gripper actions
|
| 445 |
+
action_0 = q0_new
|
| 446 |
+
action_1 = q1_new
|
| 447 |
+
action = np.concatenate([action_0, np.array(gripper_action[0]).reshape(1,), action_1, np.array(gripper_action[1]).reshape(1,)])
|
| 448 |
+
|
| 449 |
+
# Execute action for specified number of steps
|
| 450 |
+
for _ in range(n_steps):
|
| 451 |
+
obs, _, _, _ = self.env.step(action)
|
| 452 |
+
if self.render:
|
| 453 |
+
self.env.render()
|
| 454 |
+
return obs
|
| 455 |
+
|
| 456 |
+
def get_proprioception(self, obs: dict) -> np.ndarray:
|
| 457 |
+
"""
|
| 458 |
+
Get proprioceptive information (robot's internal state).
|
| 459 |
+
|
| 460 |
+
Args:
|
| 461 |
+
obs: Observation dictionary from simulation
|
| 462 |
+
|
| 463 |
+
Returns:
|
| 464 |
+
End-effector position of first robot
|
| 465 |
+
"""
|
| 466 |
+
pos = obs["robot0_eef_pos"]
|
| 467 |
+
return pos
|
| 468 |
+
|
| 469 |
+
def get_image(self, obs: dict) -> np.ndarray:
|
| 470 |
+
"""
|
| 471 |
+
Extract RGB image from observation.
|
| 472 |
+
|
| 473 |
+
Handles image format conversion and optional square cropping.
|
| 474 |
+
|
| 475 |
+
Args:
|
| 476 |
+
obs: Observation dictionary containing image data
|
| 477 |
+
|
| 478 |
+
Returns:
|
| 479 |
+
RGB image as numpy array (H, W, 3)
|
| 480 |
+
"""
|
| 481 |
+
img = obs[f"{self.camera_name}_image"]
|
| 482 |
+
img = img.transpose(1, 2, 0) # Convert from CHW to HWC format
|
| 483 |
+
height = img.shape[0]
|
| 484 |
+
width = img.shape[1]
|
| 485 |
+
|
| 486 |
+
# Crop to square if requested
|
| 487 |
+
if self.square:
|
| 488 |
+
n_remove = int((width - height)/2)
|
| 489 |
+
img = img[:,n_remove:-n_remove,:]
|
| 490 |
+
return img
|
| 491 |
+
|
| 492 |
+
def get_camera_image(self, obs: dict, camera_name: str) -> np.ndarray:
|
| 493 |
+
"""
|
| 494 |
+
Extract RGB image from specific camera.
|
| 495 |
+
|
| 496 |
+
Args:
|
| 497 |
+
obs: Observation dictionary containing image data
|
| 498 |
+
camera_name: Name of the camera to extract image from
|
| 499 |
+
|
| 500 |
+
Returns:
|
| 501 |
+
RGB image as numpy array (H, W, 3)
|
| 502 |
+
"""
|
| 503 |
+
img = obs[f"{camera_name}_image"]
|
| 504 |
+
img = img.transpose(1, 2, 0) # Convert from CHW to HWC format
|
| 505 |
+
height = img.shape[0]
|
| 506 |
+
width = img.shape[1]
|
| 507 |
+
|
| 508 |
+
# Crop to square if requested
|
| 509 |
+
if self.square:
|
| 510 |
+
n_remove = int((width - height)/2)
|
| 511 |
+
img = img[:,n_remove:-n_remove,:]
|
| 512 |
+
return img
|
| 513 |
+
|
| 514 |
+
def get_seg_image(self, obs: dict) -> np.ndarray:
|
| 515 |
+
"""
|
| 516 |
+
Extract instance segmentation image.
|
| 517 |
+
|
| 518 |
+
Args:
|
| 519 |
+
obs: Observation dictionary containing segmentation data
|
| 520 |
+
|
| 521 |
+
Returns:
|
| 522 |
+
Segmentation image as uint8 array where each pixel value
|
| 523 |
+
represents a different object instance ID
|
| 524 |
+
"""
|
| 525 |
+
img = obs[f"{self.camera_name}_segmentation_instance"]
|
| 526 |
+
height = img.shape[0]
|
| 527 |
+
width = img.shape[1]
|
| 528 |
+
|
| 529 |
+
# Crop to square if requested
|
| 530 |
+
if self.square:
|
| 531 |
+
n_remove = int((width - height)/2)
|
| 532 |
+
img = img[:,n_remove:-n_remove,:]
|
| 533 |
+
img = img.astype(np.uint8)
|
| 534 |
+
return img
|
| 535 |
+
|
| 536 |
+
def get_depth_image(self, obs: dict) -> np.ndarray:
|
| 537 |
+
"""
|
| 538 |
+
Extract and process depth image.
|
| 539 |
+
|
| 540 |
+
Converts raw depth buffer to real-world depth values using
|
| 541 |
+
robosuite's depth processing utilities.
|
| 542 |
+
|
| 543 |
+
Args:
|
| 544 |
+
obs: Observation dictionary containing depth data
|
| 545 |
+
|
| 546 |
+
Returns:
|
| 547 |
+
Depth image as numpy array where values represent
|
| 548 |
+
distance in meters
|
| 549 |
+
"""
|
| 550 |
+
img = obs[f"{self.camera_name}_depth"]
|
| 551 |
+
img = get_real_depth_map(sim=self.env.env.sim, depth_map=img)
|
| 552 |
+
height = img.shape[0]
|
| 553 |
+
width = img.shape[1]
|
| 554 |
+
|
| 555 |
+
# Crop to square if requested
|
| 556 |
+
if self.square:
|
| 557 |
+
n_remove = int((width - height)/2)
|
| 558 |
+
img = img[:,n_remove:-n_remove,:]
|
| 559 |
+
return img
|
| 560 |
+
|
| 561 |
+
def get_robot_mask(self, obs: dict) -> np.ndarray:
|
| 562 |
+
"""
|
| 563 |
+
Generate binary mask for robot pixels.
|
| 564 |
+
|
| 565 |
+
Uses instance segmentation to identify which pixels belong to
|
| 566 |
+
the robot arms (instance IDs 1 and 4).
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
obs: Observation dictionary containing segmentation data
|
| 570 |
+
|
| 571 |
+
Returns:
|
| 572 |
+
Binary mask where 1 indicates robot pixels, 0 otherwise
|
| 573 |
+
"""
|
| 574 |
+
seg_img = self.get_seg_image(obs)
|
| 575 |
+
mask = np.zeros_like(seg_img)
|
| 576 |
+
mask[seg_img == 1] = 1 # First robot arm
|
| 577 |
+
mask[seg_img == 4] = 1 # Second robot arm
|
| 578 |
+
return mask
|
| 579 |
+
|
| 580 |
+
def get_gripper_mask(self, obs: dict) -> np.ndarray:
|
| 581 |
+
"""
|
| 582 |
+
Generate binary mask for gripper pixels.
|
| 583 |
+
|
| 584 |
+
Uses instance segmentation to identify which pixels belong to
|
| 585 |
+
the robot grippers (instance IDs 3 and 6).
|
| 586 |
+
|
| 587 |
+
Args:
|
| 588 |
+
obs: Observation dictionary containing segmentation data
|
| 589 |
+
|
| 590 |
+
Returns:
|
| 591 |
+
Binary mask where 1 indicates gripper pixels, 0 otherwise
|
| 592 |
+
"""
|
| 593 |
+
seg_img = self.get_seg_image(obs)
|
| 594 |
+
mask = np.zeros_like(seg_img)
|
| 595 |
+
mask[seg_img == 3] = 1 # First gripper
|
| 596 |
+
mask[seg_img == 6] = 1 # Second gripper
|
| 597 |
+
return mask
|
phantom/phantom/twin_robot.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Virtual twin single-arm robot implementation for MuJoCo simulation.
|
| 3 |
+
|
| 4 |
+
This module provides a TwinRobot class that creates a virtual representation
|
| 5 |
+
of a single-arm robot system in MuJoCo using the robosuite framework.
|
| 6 |
+
The twin robot can be controlled via end-effector poses and provides
|
| 7 |
+
observation data including RGB images, depth maps, and robot masks.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from collections import deque
|
| 11 |
+
import cv2
|
| 12 |
+
import numpy as np
|
| 13 |
+
from scipy.spatial.transform import Rotation
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Tuple, Union, Any
|
| 16 |
+
|
| 17 |
+
from robosuite.controllers import load_controller_config # type: ignore
|
| 18 |
+
from robosuite.utils.camera_utils import get_real_depth_map # type: ignore
|
| 19 |
+
from robomimic.envs.env_robosuite import EnvRobosuite # type: ignore
|
| 20 |
+
import robomimic.utils.obs_utils as ObsUtils # type: ignore
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class MujocoCameraParams:
|
| 25 |
+
"""
|
| 26 |
+
Camera parameters for MuJoCo simulation.
|
| 27 |
+
|
| 28 |
+
Attributes:
|
| 29 |
+
name: Camera name identifier
|
| 30 |
+
pos: 3D position of camera in world coordinates
|
| 31 |
+
ori_wxyz: Camera orientation as quaternion (w, x, y, z)
|
| 32 |
+
fov: Field of view in degrees
|
| 33 |
+
resolution: Image resolution as (width, height)
|
| 34 |
+
sensorsize: Physical sensor size in mm
|
| 35 |
+
principalpixel: Principal point coordinates in pixels
|
| 36 |
+
focalpixel: Focal length in pixels
|
| 37 |
+
"""
|
| 38 |
+
name: str
|
| 39 |
+
pos: np.ndarray
|
| 40 |
+
ori_wxyz: np.ndarray
|
| 41 |
+
fov: float
|
| 42 |
+
resolution: Tuple[int, int]
|
| 43 |
+
sensorsize: np.ndarray
|
| 44 |
+
principalpixel: np.ndarray
|
| 45 |
+
focalpixel: np.ndarray
|
| 46 |
+
|
| 47 |
+
# Color constants for visualization (RGBA format)
|
| 48 |
+
THUMB_COLOR = [0, 1, 0, 1] # Green for thumb
|
| 49 |
+
INDEX_COLOR = [1, 0, 0, 1] # Red for index finger
|
| 50 |
+
HAND_EE_COLOR = [0, 0, 1, 1] # Blue for hand end-effector
|
| 51 |
+
|
| 52 |
+
def convert_real_camera_ori_to_mujoco(camera_ori_matrix: np.ndarray) -> np.ndarray:
|
| 53 |
+
"""
|
| 54 |
+
Convert camera orientation from real world to MuJoCo XML format.
|
| 55 |
+
|
| 56 |
+
MuJoCo uses a different coordinate system convention, so we need to
|
| 57 |
+
flip the Y and Z axes of the rotation matrix before converting to quaternion.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
camera_ori_matrix: 3x3 rotation matrix in real-world coordinates
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Camera orientation as quaternion in MuJoCo format (w, x, y, z)
|
| 64 |
+
"""
|
| 65 |
+
camera_ori_matrix[:, [1, 2]] = -camera_ori_matrix[:, [1, 2]]
|
| 66 |
+
r = Rotation.from_matrix(camera_ori_matrix)
|
| 67 |
+
camera_ori_wxyz = r.as_quat(scalar_first=True)
|
| 68 |
+
return camera_ori_wxyz
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TwinRobot:
|
| 72 |
+
"""
|
| 73 |
+
Virtual twin of a single-arm robot system in MuJoCo simulation.
|
| 74 |
+
|
| 75 |
+
This class creates a simulated single-arm robot that can be controlled via
|
| 76 |
+
end-effector poses. It provides functionality for:
|
| 77 |
+
- Robot pose control using OSC (Operational Space Control)
|
| 78 |
+
- Camera observation collection (RGB, depth, segmentation)
|
| 79 |
+
- Robot and gripper mask generation
|
| 80 |
+
- Observation history management
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
# Robot configuration constants
|
| 84 |
+
DEFAULT_ROBOT_BASE_POS = np.array([-0.56, 0, 0.912])
|
| 85 |
+
|
| 86 |
+
def __init__(self, robot_name: str, gripper_name: str, camera_params: MujocoCameraParams, camera_height: int, camera_width: int,
|
| 87 |
+
render: bool, n_steps_short: int, n_steps_long: int, debug_cameras: list[str] = [],
|
| 88 |
+
square: bool = False):
|
| 89 |
+
"""
|
| 90 |
+
Initialize the single-arm robot twin.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
robot_name: Type of robot (e.g., "Kinova3")
|
| 94 |
+
gripper_name: Type of gripper (e.g., "Robotiq85")
|
| 95 |
+
camera_params: Camera configuration parameters
|
| 96 |
+
camera_height: Height of camera images in pixels
|
| 97 |
+
camera_width: Width of camera images in pixels
|
| 98 |
+
render: Whether to render the simulation visually
|
| 99 |
+
n_steps_short: Number of simulation steps for quick movements
|
| 100 |
+
n_steps_long: Number of simulation steps for initial/slow movements
|
| 101 |
+
debug_cameras: Additional camera names for debugging views
|
| 102 |
+
square: Whether to crop images to square aspect ratio
|
| 103 |
+
"""
|
| 104 |
+
# Store configuration parameters
|
| 105 |
+
self.robot_name = robot_name
|
| 106 |
+
self.gripper_name = gripper_name
|
| 107 |
+
self.camera_params = camera_params
|
| 108 |
+
self.render = render
|
| 109 |
+
self.n_steps_long = n_steps_long
|
| 110 |
+
self.n_steps_short= n_steps_short
|
| 111 |
+
self.num_frames = 2 # Number of frames to keep in observation history
|
| 112 |
+
self.camera_height = camera_height
|
| 113 |
+
self.camera_width = camera_width
|
| 114 |
+
self.camera_name = "frontview" # Main camera name for single-arm setup
|
| 115 |
+
self.square = square
|
| 116 |
+
self.debug_cameras = list(debug_cameras) if debug_cameras else []
|
| 117 |
+
|
| 118 |
+
# Configure observation specifications for robomimic
|
| 119 |
+
obs_spec = dict(
|
| 120 |
+
obs=dict(
|
| 121 |
+
low_dim=["robot0_eef_pos"], # End-effector position observations
|
| 122 |
+
rgb=[f"{self.camera_params.name}_image"] + [f"{cam}_image" for cam in self.debug_cameras],
|
| 123 |
+
),
|
| 124 |
+
)
|
| 125 |
+
ObsUtils.initialize_obs_utils_with_obs_specs(
|
| 126 |
+
obs_modality_specs=obs_spec)
|
| 127 |
+
|
| 128 |
+
# Configure robosuite environment options
|
| 129 |
+
options: dict[str, Union[str, list[str], dict[str, Any], bool, int, np.ndarray]] = {}
|
| 130 |
+
options["env_name"] = "Phantom" # Single-arm environment
|
| 131 |
+
options["robots"] = [self.robot_name] # Single robot
|
| 132 |
+
options["gripper_types"] = [f"{self.gripper_name}Gripper"] # Single gripper
|
| 133 |
+
|
| 134 |
+
# Configure OSC pose controller
|
| 135 |
+
controller_config = load_controller_config(default_controller="OSC_POSE")
|
| 136 |
+
controller_config["control_delta"] = False # Use absolute positioning
|
| 137 |
+
controller_config["uncouple_pos_ori"] = False # Couple position and orientation
|
| 138 |
+
options["controller_configs"] = controller_config
|
| 139 |
+
|
| 140 |
+
# Camera and observation settings
|
| 141 |
+
options["camera_heights"] = self.camera_height
|
| 142 |
+
options["camera_widths"] = self.camera_width
|
| 143 |
+
options["camera_segmentations"] = "instance" # Instance segmentation masks
|
| 144 |
+
options["direct_gripper_control"] = True
|
| 145 |
+
options["use_depth_obs"] = True
|
| 146 |
+
|
| 147 |
+
# Set camera parameters
|
| 148 |
+
options["camera_pos"] = self.camera_params.pos
|
| 149 |
+
options["camera_quat_wxyz"] = self.camera_params.ori_wxyz
|
| 150 |
+
options["camera_sensorsize"] = self.camera_params.sensorsize
|
| 151 |
+
options["camera_principalpixel"] = self.camera_params.principalpixel
|
| 152 |
+
options["camera_focalpixel"] = self.camera_params.focalpixel
|
| 153 |
+
|
| 154 |
+
# Create the robosuite environment
|
| 155 |
+
self.env = EnvRobosuite(
|
| 156 |
+
**options,
|
| 157 |
+
render=render,
|
| 158 |
+
render_offscreen=True, # Enable offscreen rendering for image capture
|
| 159 |
+
use_image_obs=True,
|
| 160 |
+
camera_names=[self.camera_params.name] + self.debug_cameras,
|
| 161 |
+
control_freq=20, # 20 Hz control frequency
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Initialize environment and set robot base position
|
| 165 |
+
self.reset()
|
| 166 |
+
self.robot_base_pos = self.DEFAULT_ROBOT_BASE_POS # Fixed base position for single-arm setup
|
| 167 |
+
|
| 168 |
+
def reset(self):
|
| 169 |
+
"""Reset environment and clear observation history."""
|
| 170 |
+
self.env.reset()
|
| 171 |
+
self.obs_history = deque()
|
| 172 |
+
|
| 173 |
+
def close(self):
|
| 174 |
+
"""Close the simulation environment."""
|
| 175 |
+
self.env.env.close()
|
| 176 |
+
|
| 177 |
+
def get_action_from_ee_pose(self, ee_pos: np.ndarray, ee_quat_xyzw: np.ndarray, gripper_action: float,
|
| 178 |
+
use_base_offset: bool = False) -> np.ndarray:
|
| 179 |
+
"""
|
| 180 |
+
Convert end-effector pose to robot action vector.
|
| 181 |
+
|
| 182 |
+
This method transforms the desired end-effector position and orientation
|
| 183 |
+
into the action format expected by the robot controller.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
ee_pos: End-effector position as 3D array
|
| 187 |
+
ee_quat_xyzw: End-effector orientation as quaternion (x, y, z, w)
|
| 188 |
+
gripper_action: Gripper action value
|
| 189 |
+
use_base_offset: Whether to add robot base offset to position
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Action vector [position(3), rotation(3), gripper(1)]
|
| 193 |
+
"""
|
| 194 |
+
# Handle batch inputs by taking the last element
|
| 195 |
+
if ee_pos.ndim > 1:
|
| 196 |
+
ee_pos = ee_pos[-1]
|
| 197 |
+
ee_quat_xyzw = ee_quat_xyzw[-1]
|
| 198 |
+
|
| 199 |
+
# Add base offset if requested
|
| 200 |
+
if use_base_offset:
|
| 201 |
+
ee_pos = ee_pos + self.robot_base_pos
|
| 202 |
+
|
| 203 |
+
# Apply -135 degree Z rotation for single-arm setup coordinate conversion
|
| 204 |
+
rot = Rotation.from_quat(ee_quat_xyzw)
|
| 205 |
+
rot_135deg = Rotation.from_euler('z', -135, degrees=True)
|
| 206 |
+
new_rot = rot * rot_135deg
|
| 207 |
+
|
| 208 |
+
# Convert rotation to axis-angle representation
|
| 209 |
+
# Note: commented lines show alternative approach using quaternion directly
|
| 210 |
+
# quat_rotated = rot_rotated135.as_quat()
|
| 211 |
+
# axis_angle = Rotation.from_quat(quat_rotated).as_rotvec()
|
| 212 |
+
axis_angle = new_rot.as_rotvec()
|
| 213 |
+
|
| 214 |
+
# Combine position, rotation, and gripper action into action vector
|
| 215 |
+
action = np.concatenate([ee_pos, axis_angle, [gripper_action]])
|
| 216 |
+
|
| 217 |
+
return action
|
| 218 |
+
|
| 219 |
+
def _get_initial_obs_history(self, state: dict) -> deque:
|
| 220 |
+
"""
|
| 221 |
+
Initialize observation history by repeating the first observation.
|
| 222 |
+
|
| 223 |
+
This creates a history buffer filled with the initial robot state,
|
| 224 |
+
which is useful for algorithms that require temporal context.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
state: Initial robot state dictionary
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
Deque containing repeated initial observations
|
| 231 |
+
"""
|
| 232 |
+
obs_history = deque(
|
| 233 |
+
[self.move_to_target_state(state, init=True)],
|
| 234 |
+
maxlen=self.num_frames,
|
| 235 |
+
)
|
| 236 |
+
# Fill remaining slots with copies of the initial observation
|
| 237 |
+
for _ in range(self.num_frames-1):
|
| 238 |
+
obs_history.append(self.move_to_target_state(state))
|
| 239 |
+
return obs_history
|
| 240 |
+
|
| 241 |
+
def get_obs_history(self, state: dict) -> list:
|
| 242 |
+
"""
|
| 243 |
+
Get observation history with specified length.
|
| 244 |
+
|
| 245 |
+
Maintains a rolling buffer of recent observations for temporal context.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
state: Current robot state dictionary
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
List of recent observations (length = self.num_frames)
|
| 252 |
+
"""
|
| 253 |
+
if len(self.obs_history) == 0:
|
| 254 |
+
# Initialize history if empty
|
| 255 |
+
self.obs_history = self._get_initial_obs_history(state)
|
| 256 |
+
else:
|
| 257 |
+
# Add new observation to history
|
| 258 |
+
self.obs_history.append(self.move_to_target_state(state))
|
| 259 |
+
return list(self.obs_history)
|
| 260 |
+
|
| 261 |
+
def move_to_target_state(self, state: dict, init=False) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 262 |
+
"""
|
| 263 |
+
Move robot to target state and collect observation data.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
state: Target state containing position, orientation, and gripper state
|
| 267 |
+
init: Whether this is an initialization step (uses longer movement time)
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
Dictionary containing observation data:
|
| 271 |
+
- robot_mask: Binary mask showing robot pixels
|
| 272 |
+
- gripper_mask: Binary mask showing gripper pixels
|
| 273 |
+
- rgb_img: RGB camera image
|
| 274 |
+
- depth_img: Depth camera image
|
| 275 |
+
- robot_pos: Robot end-effector position relative to base
|
| 276 |
+
- pos_err: Position tracking error magnitude
|
| 277 |
+
- {cam}_img: Additional camera images if debug_cameras specified
|
| 278 |
+
"""
|
| 279 |
+
# Convert gripper position to robot action
|
| 280 |
+
gripper_action = self._convert_handgripper_pos_to_action(state["gripper_pos"])
|
| 281 |
+
|
| 282 |
+
# Choose movement duration based on whether this is initialization
|
| 283 |
+
n_steps = self.n_steps_long if init else self.n_steps_short
|
| 284 |
+
|
| 285 |
+
# Execute movement to target pose
|
| 286 |
+
obs = self.move_to_pose(state["pos"], state["ori_xyzw"], float(gripper_action), n_steps)
|
| 287 |
+
|
| 288 |
+
# Extract observation data from simulation
|
| 289 |
+
robot_mask = np.squeeze(self.get_robot_mask(obs))
|
| 290 |
+
gripper_mask = np.squeeze(self.get_gripper_mask(obs))
|
| 291 |
+
rgb_img = self.get_image(obs)
|
| 292 |
+
depth_img = self.get_depth_image(obs)
|
| 293 |
+
robot_pos = obs["robot0_eef_pos"] - self.robot_base_pos
|
| 294 |
+
pos_error = np.linalg.norm(robot_pos - state["pos"])
|
| 295 |
+
|
| 296 |
+
# Compile output dictionary
|
| 297 |
+
output = {
|
| 298 |
+
"robot_mask": robot_mask,
|
| 299 |
+
"gripper_mask": gripper_mask,
|
| 300 |
+
"rgb_img": rgb_img,
|
| 301 |
+
"depth_img": depth_img,
|
| 302 |
+
"robot_pos": robot_pos,
|
| 303 |
+
"pos_err": pos_error,
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
# Add debug camera images if specified
|
| 307 |
+
for cam in self.debug_cameras:
|
| 308 |
+
cam_img = self.get_cam_image(obs, cam)
|
| 309 |
+
output[f"{cam}_img"] = cam_img
|
| 310 |
+
|
| 311 |
+
return output
|
| 312 |
+
|
| 313 |
+
def _convert_handgripper_pos_to_action(self, gripper_pos: float) -> np.ndarray:
|
| 314 |
+
"""
|
| 315 |
+
Convert hand gripper position to robot gripper action.
|
| 316 |
+
|
| 317 |
+
Maps from physical gripper opening distance to robot action values.
|
| 318 |
+
Different gripper types may have different mappings.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
gripper_pos: Gripper opening distance in meters
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
Robot gripper action value (0-255 for Robotiq85)
|
| 325 |
+
|
| 326 |
+
Raises:
|
| 327 |
+
ValueError: If gripper type is not supported
|
| 328 |
+
"""
|
| 329 |
+
if self.gripper_name == "Robotiq85":
|
| 330 |
+
# Robotiq85 gripper specifications
|
| 331 |
+
min_gripper_pos, max_gripper_pos = 0.0, 0.085 # 0 to 8.5cm opening
|
| 332 |
+
gripper_pos = np.clip(gripper_pos, min_gripper_pos, max_gripper_pos)
|
| 333 |
+
open_gripper_action, closed_gripper_action = 0, 255 # 0=open, 255=closed
|
| 334 |
+
# Linear interpolation between open and closed states
|
| 335 |
+
return np.interp(gripper_pos, [min_gripper_pos, max_gripper_pos], [closed_gripper_action, open_gripper_action])
|
| 336 |
+
else:
|
| 337 |
+
raise ValueError(f"Gripper name {self.gripper_name} not supported")
|
| 338 |
+
|
| 339 |
+
def move_to_pose(self, ee_pos: np.ndarray, ee_ori: np.ndarray, gripper_action: float, n_steps: int) -> dict:
|
| 340 |
+
"""
|
| 341 |
+
Execute robot movement to target pose.
|
| 342 |
+
|
| 343 |
+
Sends action commands to the simulation for the specified number of steps.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
ee_pos: End-effector position as 3D array
|
| 347 |
+
ee_ori: End-effector orientation as quaternion (x, y, z, w)
|
| 348 |
+
gripper_action: Gripper action value
|
| 349 |
+
n_steps: Number of simulation steps to execute
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
Final observation dictionary from simulation
|
| 353 |
+
"""
|
| 354 |
+
# Convert pose to action vector
|
| 355 |
+
action = self.get_action_from_ee_pose(ee_pos, ee_ori, gripper_action, use_base_offset=True)
|
| 356 |
+
|
| 357 |
+
# Execute action for specified number of steps
|
| 358 |
+
for _ in range(n_steps):
|
| 359 |
+
obs, _, _, _ = self.env.step(action)
|
| 360 |
+
if self.render:
|
| 361 |
+
self.env.render()
|
| 362 |
+
return obs
|
| 363 |
+
|
| 364 |
+
def get_image(self, obs: dict) -> np.ndarray:
|
| 365 |
+
"""
|
| 366 |
+
Extract RGB image from observation.
|
| 367 |
+
|
| 368 |
+
Handles image format conversion and optional square cropping.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
obs: Observation dictionary containing image data
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
RGB image as numpy array (H, W, 3)
|
| 375 |
+
"""
|
| 376 |
+
img = obs[f"{self.camera_name}_image"]
|
| 377 |
+
img = img.transpose(1, 2, 0) # Convert from CHW to HWC format
|
| 378 |
+
height = img.shape[0]
|
| 379 |
+
width = img.shape[1]
|
| 380 |
+
|
| 381 |
+
# Crop to square if requested
|
| 382 |
+
if self.square:
|
| 383 |
+
n_remove = int((width - height)/2)
|
| 384 |
+
img = img[:,n_remove:-n_remove,:]
|
| 385 |
+
return img
|
| 386 |
+
|
| 387 |
+
def get_cam_image(self, obs: dict, camera_name: str) -> np.ndarray:
|
| 388 |
+
"""
|
| 389 |
+
Extract RGB image from specific camera.
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
obs: Observation dictionary containing image data
|
| 393 |
+
camera_name: Name of the camera to extract image from
|
| 394 |
+
|
| 395 |
+
Returns:
|
| 396 |
+
RGB image as numpy array (H, W, 3)
|
| 397 |
+
"""
|
| 398 |
+
img = obs[f"{camera_name}_image"]
|
| 399 |
+
img = img.transpose(1, 2, 0) # Convert from CHW to HWC format
|
| 400 |
+
height = img.shape[0]
|
| 401 |
+
width = img.shape[1]
|
| 402 |
+
|
| 403 |
+
# Crop to square if requested
|
| 404 |
+
if self.square:
|
| 405 |
+
n_remove = int((width - height)/2)
|
| 406 |
+
img = img[:,n_remove:-n_remove,:]
|
| 407 |
+
return img
|
| 408 |
+
|
| 409 |
+
def get_seg_image(self, obs: dict) -> np.ndarray:
|
| 410 |
+
"""
|
| 411 |
+
Extract instance segmentation image.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
obs: Observation dictionary containing segmentation data
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
Segmentation image as uint8 array where each pixel value
|
| 418 |
+
represents a different object instance ID
|
| 419 |
+
"""
|
| 420 |
+
img = obs["frontview_segmentation_instance"] # Fixed camera name for single-arm
|
| 421 |
+
height = img.shape[0]
|
| 422 |
+
width = img.shape[1]
|
| 423 |
+
|
| 424 |
+
# Crop to square if requested
|
| 425 |
+
if self.square:
|
| 426 |
+
n_remove = int((width - height)/2)
|
| 427 |
+
img = img[:,n_remove:-n_remove,:]
|
| 428 |
+
img = img.astype(np.uint8)
|
| 429 |
+
return img
|
| 430 |
+
|
| 431 |
+
def get_depth_image(self, obs: dict) -> np.ndarray:
|
| 432 |
+
"""
|
| 433 |
+
Extract and process depth image.
|
| 434 |
+
|
| 435 |
+
Converts raw depth buffer to real-world depth values using
|
| 436 |
+
robosuite's depth processing utilities.
|
| 437 |
+
|
| 438 |
+
Args:
|
| 439 |
+
obs: Observation dictionary containing depth data
|
| 440 |
+
|
| 441 |
+
Returns:
|
| 442 |
+
Depth image as numpy array where values represent
|
| 443 |
+
distance in meters
|
| 444 |
+
"""
|
| 445 |
+
img = obs["frontview_depth"] # Fixed camera name for single-arm
|
| 446 |
+
img = get_real_depth_map(sim=self.env.env.sim, depth_map=img)
|
| 447 |
+
height = img.shape[0]
|
| 448 |
+
width = img.shape[1]
|
| 449 |
+
|
| 450 |
+
# Crop to square if requested
|
| 451 |
+
if self.square:
|
| 452 |
+
n_remove = int((width - height)/2)
|
| 453 |
+
img = img[:,n_remove:-n_remove,:]
|
| 454 |
+
return img
|
| 455 |
+
|
| 456 |
+
def get_robot_mask(self, obs: dict) -> np.ndarray:
|
| 457 |
+
"""
|
| 458 |
+
Generate binary mask for robot pixels.
|
| 459 |
+
|
| 460 |
+
Uses instance segmentation to identify which pixels belong to
|
| 461 |
+
the robot arm (instance ID 1).
|
| 462 |
+
|
| 463 |
+
Args:
|
| 464 |
+
obs: Observation dictionary containing segmentation data
|
| 465 |
+
|
| 466 |
+
Returns:
|
| 467 |
+
Binary mask where 1 indicates robot pixels, 0 otherwise
|
| 468 |
+
"""
|
| 469 |
+
seg_img = self.get_seg_image(obs)
|
| 470 |
+
mask = np.zeros_like(seg_img)
|
| 471 |
+
mask[seg_img == 1] = 1 # Robot arm
|
| 472 |
+
return mask
|
| 473 |
+
|
| 474 |
+
def get_gripper_mask(self, obs: dict) -> np.ndarray:
|
| 475 |
+
"""
|
| 476 |
+
Generate binary mask for gripper pixels.
|
| 477 |
+
|
| 478 |
+
Uses instance segmentation to identify which pixels belong to
|
| 479 |
+
the robot gripper (instance ID 3).
|
| 480 |
+
|
| 481 |
+
Args:
|
| 482 |
+
obs: Observation dictionary containing segmentation data
|
| 483 |
+
|
| 484 |
+
Returns:
|
| 485 |
+
Binary mask where 1 indicates gripper pixels, 0 otherwise
|
| 486 |
+
"""
|
| 487 |
+
seg_img = self.get_seg_image(obs)
|
| 488 |
+
mask = np.zeros_like(seg_img)
|
| 489 |
+
mask[seg_img == 3] = 1 # Gripper
|
| 490 |
+
return mask
|
phantom/phantom/utils/__init__.py
ADDED
|
File without changes
|
phantom/phantom/utils/bbox_utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import numpy.typing as npt
|
| 3 |
+
|
| 4 |
+
def get_bbox_center(bbox: np.ndarray) -> np.ndarray:
|
| 5 |
+
"""Calculate center point of bounding box."""
|
| 6 |
+
return np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2])
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_bbox_area(bbox: np.ndarray) -> float:
|
| 10 |
+
"""Get the area of a bounding box."""
|
| 11 |
+
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_overlap_score(bbox1: np.ndarray, bbox2: np.ndarray) -> float:
|
| 15 |
+
""" Get the overlap area between two boxes divided by the area of the smaller box """
|
| 16 |
+
area1 = get_bbox_area(bbox1)
|
| 17 |
+
area2 = get_bbox_area(bbox2)
|
| 18 |
+
overlap_area = get_overlap_area(bbox1, bbox2)
|
| 19 |
+
return overlap_area / min(area1, area2)
|
| 20 |
+
|
| 21 |
+
def get_overlap_area(bbox1: np.ndarray, bbox2: np.ndarray) -> float:
|
| 22 |
+
""" Get the overlap area between two boxes """
|
| 23 |
+
return max(0, min(bbox1[2], bbox2[2]) - max(bbox1[0], bbox2[0])) * max(0, min(bbox1[3], bbox2[3]) - max(bbox1[1], bbox2[1]))
|
| 24 |
+
|
| 25 |
+
def get_bbox_center_min_dist_to_edge(bboxes: npt.NDArray[np.float32], W: int, H: int) -> npt.NDArray[np.float32]:
|
| 26 |
+
"""
|
| 27 |
+
Get the minimum distance of the bbox center to the edge of the image.
|
| 28 |
+
"""
|
| 29 |
+
center_min_dist_to_edge_list = []
|
| 30 |
+
for bbox in bboxes:
|
| 31 |
+
x1, y1, x2, y2 = bbox
|
| 32 |
+
center = np.array([(x1 + x2) / 2, (y1 + y2) / 2])
|
| 33 |
+
min_dist_to_edge = min(center[0], center[1], W - center[0], H - center[1])
|
| 34 |
+
center_min_dist_to_edge_list.append(min_dist_to_edge)
|
| 35 |
+
return np.array(center_min_dist_to_edge_list)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
phantom/phantom/utils/data_utils.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
def get_finger_poses_from_pkl(path: Path) -> dict:
|
| 8 |
+
"""Get human finger poses from pkl file."""
|
| 9 |
+
finger_poses = pd.read_pickle(path)
|
| 10 |
+
thumb_poses = np.vstack(finger_poses["thumb"])
|
| 11 |
+
index_poses = np.vstack(finger_poses["index"])
|
| 12 |
+
hand_ee_poses = np.vstack(finger_poses["hand_ee"])
|
| 13 |
+
skeleton_poses = np.stack(finger_poses["skeleton"], axis=0)
|
| 14 |
+
hand_poses = np.stack(finger_poses["hand_pose"], axis=0)
|
| 15 |
+
all_global_orient = np.vstack(finger_poses["global_orient"])
|
| 16 |
+
data = {
|
| 17 |
+
"thumb": thumb_poses,
|
| 18 |
+
"index": index_poses,
|
| 19 |
+
"hand_ee": hand_ee_poses,
|
| 20 |
+
"skeleton": skeleton_poses,
|
| 21 |
+
"hand_pose": hand_poses,
|
| 22 |
+
"global_orient": all_global_orient
|
| 23 |
+
}
|
| 24 |
+
return data
|
| 25 |
+
|
| 26 |
+
def get_parent_folder_of_package(package_name: str) -> str:
|
| 27 |
+
# Import the package
|
| 28 |
+
package = __import__(package_name)
|
| 29 |
+
|
| 30 |
+
# Get the absolute path of the imported package
|
| 31 |
+
package_path = package.__file__
|
| 32 |
+
if package_path is None:
|
| 33 |
+
raise ValueError(f"Package {package_name} does not have a valid __file__ attribute")
|
| 34 |
+
package_path = os.path.abspath(package_path)
|
| 35 |
+
|
| 36 |
+
# Get the parent directory of the package directory
|
| 37 |
+
return os.path.dirname(os.path.dirname(package_path))
|
| 38 |
+
|
phantom/phantom/utils/image_utils.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import os
|
| 5 |
+
import mediapy as media
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Dict, List, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class BoundingBox:
|
| 11 |
+
xmin: int
|
| 12 |
+
ymin: int
|
| 13 |
+
xmax: int
|
| 14 |
+
ymax: int
|
| 15 |
+
|
| 16 |
+
@property
|
| 17 |
+
def xyxy(self) -> List[float]:
|
| 18 |
+
return [self.xmin, self.ymin, self.xmax, self.ymax]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class DetectionResult:
|
| 23 |
+
score: float
|
| 24 |
+
label: str
|
| 25 |
+
box: BoundingBox
|
| 26 |
+
mask: Optional[np.ndarray] = None
|
| 27 |
+
|
| 28 |
+
@classmethod
|
| 29 |
+
def from_dict(cls, detection_dict: Dict) -> "DetectionResult":
|
| 30 |
+
return cls(
|
| 31 |
+
score=detection_dict["score"],
|
| 32 |
+
label=detection_dict["label"],
|
| 33 |
+
box=BoundingBox(
|
| 34 |
+
xmin=detection_dict["box"]["xmin"],
|
| 35 |
+
ymin=detection_dict["box"]["ymin"],
|
| 36 |
+
xmax=detection_dict["box"]["xmax"],
|
| 37 |
+
ymax=detection_dict["box"]["ymax"],
|
| 38 |
+
),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def get_transformation_matrix_from_extrinsics(camera_extrinsics: List[Dict]) -> np.ndarray:
|
| 42 |
+
"""Get homogeneous transformation matrix from camera extrinsics."""
|
| 43 |
+
cam_base_pos = np.array(camera_extrinsics[0]["camera_base_pos"])
|
| 44 |
+
cam_base_ori = np.array(camera_extrinsics[0]["camera_base_ori"])
|
| 45 |
+
T_cam2robot = np.eye(4)
|
| 46 |
+
T_cam2robot[:3, 3] = cam_base_pos
|
| 47 |
+
T_cam2robot[:3, :3] = np.array(cam_base_ori).reshape(3, 3)
|
| 48 |
+
return T_cam2robot
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_intrinsics_from_json(json_path: str) -> Tuple[np.ndarray, dict]:
|
| 52 |
+
with open(json_path, "r") as f:
|
| 53 |
+
camera_intrinsics = json.load(f)
|
| 54 |
+
|
| 55 |
+
# Get camera matrix
|
| 56 |
+
fx = camera_intrinsics["left"]["fx"]
|
| 57 |
+
fy = camera_intrinsics["left"]["fy"]
|
| 58 |
+
cx = camera_intrinsics["left"]["cx"]
|
| 59 |
+
cy = camera_intrinsics["left"]["cy"]
|
| 60 |
+
v_fov = camera_intrinsics["left"]["v_fov"]
|
| 61 |
+
intrinsics_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
|
| 62 |
+
|
| 63 |
+
intrinsics_dict = {
|
| 64 |
+
"fx": fx,
|
| 65 |
+
"fy": fy,
|
| 66 |
+
"cx": cx,
|
| 67 |
+
"cy": cy,
|
| 68 |
+
"v_fov": v_fov,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
return intrinsics_matrix, intrinsics_dict
|
| 72 |
+
|
| 73 |
+
def resize_binary_image(image: np.ndarray, new_size: int) -> np.ndarray:
|
| 74 |
+
max_value = np.max(image)
|
| 75 |
+
|
| 76 |
+
# Resize the image
|
| 77 |
+
resized_image = cv2.resize(image, (new_size, new_size), interpolation=cv2.INTER_NEAREST)
|
| 78 |
+
|
| 79 |
+
if max_value == 1:
|
| 80 |
+
_, binary_image = cv2.threshold(resized_image, 0.5, 1, cv2.THRESH_BINARY)
|
| 81 |
+
else:
|
| 82 |
+
_, binary_image = cv2.threshold(resized_image, 127, 255, cv2.THRESH_BINARY)
|
| 83 |
+
|
| 84 |
+
return binary_image
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def convert_video_to_images(video_path: str, save_folder: str, square=False, reverse=False):
|
| 88 |
+
"""Save each frame of video as an image in save_folder."""
|
| 89 |
+
if not os.path.exists(save_folder):
|
| 90 |
+
os.makedirs(save_folder)
|
| 91 |
+
|
| 92 |
+
imgs = np.array(media.read_video(str(video_path)))
|
| 93 |
+
n_imgs = len(imgs)
|
| 94 |
+
if reverse:
|
| 95 |
+
imgs = imgs[::-1]
|
| 96 |
+
for idx in range(n_imgs):
|
| 97 |
+
img = imgs[idx]
|
| 98 |
+
if square:
|
| 99 |
+
delta = (img.shape[1] - img.shape[0]) // 2
|
| 100 |
+
img = img[:, delta:-delta, :]
|
| 101 |
+
media.write_image(f"{save_folder}/{idx:05d}.jpg", img)
|
| 102 |
+
|
| 103 |
+
|
phantom/phantom/utils/pcd_utils.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from typing import Tuple, Optional
|
| 3 |
+
import open3d as o3d # type: ignore
|
| 4 |
+
import trimesh
|
| 5 |
+
from sklearn.neighbors import NearestNeighbors # type: ignore
|
| 6 |
+
|
| 7 |
+
def preprocess_point_cloud(pcd: o3d.geometry.PointCloud,
|
| 8 |
+
voxel_size: float) -> Tuple[o3d.geometry.PointCloud, o3d.pipelines.registration.Feature]:
|
| 9 |
+
"""
|
| 10 |
+
Downsample point cloud to desired voxel resolution and compute FPFH features.
|
| 11 |
+
"""
|
| 12 |
+
pcd_down = pcd.voxel_down_sample(voxel_size)
|
| 13 |
+
radius_normal = voxel_size * 2
|
| 14 |
+
pcd_down.estimate_normals(o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30))
|
| 15 |
+
radius_feature = voxel_size * 5
|
| 16 |
+
pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature(
|
| 17 |
+
pcd_down, o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100))
|
| 18 |
+
return pcd_down, pcd_fpfh
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def global_registration(source_pcd: o3d.geometry.PointCloud, target_pcd: o3d.geometry.PointCloud,
|
| 22 |
+
voxel_size: float) -> o3d.pipelines.registration.RegistrationResult:
|
| 23 |
+
"""
|
| 24 |
+
Register two point clouds using global registration with RANSAC.
|
| 25 |
+
"""
|
| 26 |
+
source_down, source_fpfh = preprocess_point_cloud(source_pcd, voxel_size)
|
| 27 |
+
target_down, target_fpfh = preprocess_point_cloud(target_pcd, voxel_size)
|
| 28 |
+
|
| 29 |
+
distance_threshold = voxel_size * 1.5
|
| 30 |
+
result_ransac = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
|
| 31 |
+
source_down, target_down, source_fpfh, target_fpfh, True,
|
| 32 |
+
distance_threshold,
|
| 33 |
+
o3d.pipelines.registration.TransformationEstimationPointToPoint(),
|
| 34 |
+
4, # RANSAC iterations
|
| 35 |
+
[o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
|
| 36 |
+
o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold)],
|
| 37 |
+
o3d.pipelines.registration.RANSACConvergenceCriteria(4000000, 500))
|
| 38 |
+
|
| 39 |
+
return result_ransac
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def icp_registration(source_pcd: o3d.geometry.PointCloud, target_pcd: o3d.geometry.PointCloud,
|
| 43 |
+
voxel_size: float=0.05, use_global_registration:bool=True,
|
| 44 |
+
init_transform:Optional[np.ndarray]=None) -> Tuple[o3d.geometry.PointCloud, np.ndarray]:
|
| 45 |
+
"""
|
| 46 |
+
Register two point clouds using ICP algorithm.
|
| 47 |
+
"""
|
| 48 |
+
# Optional global registration using RANSAC
|
| 49 |
+
if use_global_registration:
|
| 50 |
+
if init_transform is None:
|
| 51 |
+
result_ransac = global_registration(source_pcd, target_pcd, voxel_size)
|
| 52 |
+
init_transform = result_ransac.transformation
|
| 53 |
+
else:
|
| 54 |
+
init_transform = np.eye(4)
|
| 55 |
+
|
| 56 |
+
# Refine alignment using ICP
|
| 57 |
+
max_correspondence_distance = voxel_size * 5
|
| 58 |
+
result_icp = o3d.pipelines.registration.registration_icp(
|
| 59 |
+
source=source_pcd, target=target_pcd, max_correspondence_distance=max_correspondence_distance,
|
| 60 |
+
init=init_transform,
|
| 61 |
+
estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint())
|
| 62 |
+
|
| 63 |
+
if np.array_equal(init_transform, result_icp.transformation):
|
| 64 |
+
result_ransac = global_registration(source_pcd, target_pcd, voxel_size)
|
| 65 |
+
init_transform = result_ransac.transformation
|
| 66 |
+
result_icp = o3d.pipelines.registration.registration_icp(
|
| 67 |
+
source=source_pcd, target=target_pcd, max_correspondence_distance=max_correspondence_distance,
|
| 68 |
+
init=init_transform,
|
| 69 |
+
estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint())
|
| 70 |
+
|
| 71 |
+
aligned_source_pcd = source_pcd.transform(result_icp.transformation)
|
| 72 |
+
|
| 73 |
+
return aligned_source_pcd, result_icp.transformation
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_visible_points(mesh, origin: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 77 |
+
"""
|
| 78 |
+
Return list of points in mesh that are visible from origin.
|
| 79 |
+
"""
|
| 80 |
+
intersector = trimesh.ray.ray_triangle.RayMeshIntersector(mesh)
|
| 81 |
+
pts = mesh.vertices
|
| 82 |
+
vectors = pts - origin
|
| 83 |
+
directions = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)
|
| 84 |
+
visible_triangle_indices = intersector.intersects_first(np.tile(origin, (pts.shape[0], 1)), directions)
|
| 85 |
+
visible_triangles = mesh.faces[visible_triangle_indices]
|
| 86 |
+
visible_vertex_indices = np.unique(visible_triangles)
|
| 87 |
+
visible_points = pts[visible_vertex_indices]
|
| 88 |
+
return np.array(visible_points).astype(np.float32), np.array(visible_vertex_indices)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_pcd_from_points(points: np.ndarray, colors: Optional[np.ndarray]=None) -> o3d.geometry.PointCloud:
|
| 92 |
+
"""
|
| 93 |
+
Convert a list of points to an Open3D point cloud.
|
| 94 |
+
"""
|
| 95 |
+
pcd = o3d.geometry.PointCloud()
|
| 96 |
+
pcd.points = o3d.utility.Vector3dVector(points)
|
| 97 |
+
if colors is not None:
|
| 98 |
+
pcd.colors = o3d.utility.Vector3dVector(colors)
|
| 99 |
+
pcd.remove_non_finite_points()
|
| 100 |
+
return pcd
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def visualize_pcds(list_pcds: list, visible: bool=True) -> np.ndarray:
|
| 104 |
+
"""
|
| 105 |
+
Visualize a list of point clouds.
|
| 106 |
+
"""
|
| 107 |
+
visualization_image = None
|
| 108 |
+
vis = o3d.visualization.Visualizer()
|
| 109 |
+
vis.create_window(visible=visible)
|
| 110 |
+
opt = vis.get_render_option()
|
| 111 |
+
opt.background_color = np.asarray([0.2, 0.2, 0.2])
|
| 112 |
+
for pcd in list_pcds:
|
| 113 |
+
if pcd is not None:
|
| 114 |
+
vis.add_geometry(pcd)
|
| 115 |
+
vis.poll_events()
|
| 116 |
+
vis.update_renderer()
|
| 117 |
+
if not visible:
|
| 118 |
+
visualization_image = vis.capture_screen_float_buffer(do_render=True)
|
| 119 |
+
visualization_image = (255.0 * np.asarray(visualization_image)).astype(np.uint8)
|
| 120 |
+
if visible:
|
| 121 |
+
vis.run()
|
| 122 |
+
vis.destroy_window()
|
| 123 |
+
if visualization_image is None:
|
| 124 |
+
visualization_image = np.array([])
|
| 125 |
+
return visualization_image
|
| 126 |
+
|
| 127 |
+
def radius_outlier_detection(points: np.ndarray, radius: float=5,
|
| 128 |
+
min_neighbors: int=5) -> Tuple[np.ndarray, np.ndarray]:
|
| 129 |
+
"""
|
| 130 |
+
Detect outliers in a point cloud using radius-based outlier detection.
|
| 131 |
+
"""
|
| 132 |
+
# Fit the NearestNeighbors model
|
| 133 |
+
nbrs = NearestNeighbors(radius=radius).fit(points)
|
| 134 |
+
|
| 135 |
+
# Get the number of neighbors for each point within the specified radius
|
| 136 |
+
distances, indices = nbrs.radius_neighbors(points)
|
| 137 |
+
|
| 138 |
+
# Detect points with fewer neighbors than the minimum threshold
|
| 139 |
+
outliers_mask = np.array([len(neigh) < min_neighbors for neigh in indices])
|
| 140 |
+
|
| 141 |
+
outlier_pts = points[outliers_mask]
|
| 142 |
+
|
| 143 |
+
return outliers_mask, outlier_pts
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def remove_outliers(pcd: o3d.geometry.PointCloud, radius: float=5,
|
| 147 |
+
min_neighbors: int=5) -> Tuple[o3d.geometry.PointCloud, np.ndarray]:
|
| 148 |
+
"""
|
| 149 |
+
Remove outliers from a point cloud using radius-based outlier detection.
|
| 150 |
+
"""
|
| 151 |
+
outlier_indices, outlier_pts = radius_outlier_detection(np.asarray(pcd.points),
|
| 152 |
+
radius=radius, min_neighbors=min_neighbors)
|
| 153 |
+
filtered_pts = np.asarray(pcd.points)[~outlier_indices]
|
| 154 |
+
filtered_colors = np.asarray(pcd.colors)[~outlier_indices]
|
| 155 |
+
filtered_pcd = get_pcd_from_points(filtered_pts, colors=filtered_colors)
|
| 156 |
+
return filtered_pcd, outlier_indices
|
| 157 |
+
|
| 158 |
+
def get_3D_points_from_pixels(pixels_2d: np.ndarray, depth_map: np.ndarray, intrinsics: dict) -> np.ndarray:
|
| 159 |
+
"""
|
| 160 |
+
Convert an array of pixel coordinates and depth map to 3D points.
|
| 161 |
+
"""
|
| 162 |
+
px = pixels_2d[:, 0]
|
| 163 |
+
py = pixels_2d[:, 1]
|
| 164 |
+
|
| 165 |
+
x = (px - intrinsics["cx"]) / intrinsics["fx"]
|
| 166 |
+
y = (py - intrinsics["cy"]) / intrinsics["fy"]
|
| 167 |
+
|
| 168 |
+
if len(depth_map.shape) == 3:
|
| 169 |
+
depth_map = depth_map[:, :, 0]
|
| 170 |
+
|
| 171 |
+
depth = depth_map[py, px]
|
| 172 |
+
|
| 173 |
+
X = x * depth
|
| 174 |
+
Y = y * depth
|
| 175 |
+
|
| 176 |
+
points_3d = np.stack((X, Y, depth), axis=1)
|
| 177 |
+
return points_3d
|
| 178 |
+
|
| 179 |
+
def get_point_cloud_of_segmask(mask: np.ndarray, depth_img: np.ndarray, img: np.ndarray,
|
| 180 |
+
intrinsics: dict, visualize: bool=False) -> o3d.geometry.PointCloud:
|
| 181 |
+
"""
|
| 182 |
+
Return the point cloud that corresponds to the segmentation mask in the depth image.
|
| 183 |
+
"""
|
| 184 |
+
idxs_y, idxs_x = mask.nonzero()
|
| 185 |
+
pixels_2d = np.stack((idxs_x, idxs_y), axis=1)
|
| 186 |
+
seg_points = get_3D_points_from_pixels(pixels_2d, depth_img, intrinsics)
|
| 187 |
+
seg_colors = img[idxs_y, idxs_x, :] / 255.0 # Normalize to [0,1] for cv2
|
| 188 |
+
|
| 189 |
+
pcd = get_pcd_from_points(seg_points, colors=seg_colors)
|
| 190 |
+
|
| 191 |
+
if visualize:
|
| 192 |
+
visualize_pcds([pcd])
|
| 193 |
+
|
| 194 |
+
return pcd
|
| 195 |
+
|
| 196 |
+
def get_bbox_of_3d_points(points: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 197 |
+
"""
|
| 198 |
+
Return the bounding box of 3D points.
|
| 199 |
+
"""
|
| 200 |
+
min_xyz = np.min(points, axis=0)
|
| 201 |
+
max_xyz = np.max(points, axis=0)
|
| 202 |
+
return min_xyz, max_xyz
|
| 203 |
+
|
| 204 |
+
def trim_pcd_to_bbox(pcd: o3d.geometry.PointCloud, bbox: Tuple[np.ndarray, np.ndarray]) -> o3d.geometry.PointCloud:
|
| 205 |
+
"""
|
| 206 |
+
Trim a point cloud to the specified bounding box.
|
| 207 |
+
"""
|
| 208 |
+
min_xyz, max_xyz = bbox
|
| 209 |
+
trimmed_pcd = pcd.crop(o3d.geometry.AxisAlignedBoundingBox(min_xyz, max_xyz))
|
| 210 |
+
return trimmed_pcd
|
phantom/phantom/utils/transform_utils.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
EPS = np.finfo(float).eps * 4.0
|
| 5 |
+
|
| 6 |
+
def transform_pts(pts: np.ndarray, T: np.ndarray) -> np.ndarray:
|
| 7 |
+
pts = np.hstack([pts, np.ones((len(pts), 1))])
|
| 8 |
+
pts = np.dot(T, pts.T).T
|
| 9 |
+
return pts[:, :3]
|
| 10 |
+
|
| 11 |
+
def project_point_to_plane(point: np.ndarray, plane_coeffs: np.ndarray) -> np.ndarray:
|
| 12 |
+
"""
|
| 13 |
+
Projects a 3D point onto a plane defined by its coefficients.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
point (array-like): Coordinates of the point to be projected (x0, y0, z0).
|
| 17 |
+
plane_coeffs (array-like): Coefficients of the plane (a, b, c, d) for ax + by + cz + d = 0.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
numpy.ndarray: The projected point's coordinates on the plane.
|
| 21 |
+
"""
|
| 22 |
+
# Convert inputs to numpy arrays
|
| 23 |
+
point = np.array(point)
|
| 24 |
+
plane_coeffs = np.array(plane_coeffs)
|
| 25 |
+
|
| 26 |
+
# Extract the plane normal vector and constant term
|
| 27 |
+
normal = plane_coeffs[:3] # [a, b, c]
|
| 28 |
+
d = plane_coeffs[3]
|
| 29 |
+
|
| 30 |
+
# Normalize the plane normal vector
|
| 31 |
+
normal_magnitude = np.linalg.norm(normal)
|
| 32 |
+
if normal_magnitude == 0:
|
| 33 |
+
raise ValueError("Invalid plane coefficients: normal vector cannot have zero magnitude.")
|
| 34 |
+
normal /= normal_magnitude
|
| 35 |
+
|
| 36 |
+
# Calculate the signed distance from the point to the plane
|
| 37 |
+
distance = np.dot(normal, point) + d / normal_magnitude
|
| 38 |
+
|
| 39 |
+
# Project the point onto the plane
|
| 40 |
+
projected_point = point - distance * normal
|
| 41 |
+
|
| 42 |
+
return projected_point
|
| 43 |
+
|
phantom/setup.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import setuptools
|
| 2 |
+
|
| 3 |
+
setuptools.setup(
|
| 4 |
+
name="phantom",
|
| 5 |
+
version="0.1",
|
| 6 |
+
packages=setuptools.find_packages(exclude=["submodules", "submodules.*"]),
|
| 7 |
+
)
|
phantom/submodules/phantom-E2FGVI/.gitignore
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Customized
|
| 2 |
+
*.pth
|
| 3 |
+
*.pt
|
| 4 |
+
keys.txt
|
| 5 |
+
results/
|
| 6 |
+
.vscode/
|
| 7 |
+
|
| 8 |
+
# Byte-compiled / optimized / DLL files
|
| 9 |
+
__pycache__/
|
| 10 |
+
*.py[cod]
|
| 11 |
+
*$py.class
|
| 12 |
+
|
| 13 |
+
# C extensions
|
| 14 |
+
*.so
|
| 15 |
+
|
| 16 |
+
# Distribution / packaging
|
| 17 |
+
.Python
|
| 18 |
+
build/
|
| 19 |
+
develop-eggs/
|
| 20 |
+
dist/
|
| 21 |
+
downloads/
|
| 22 |
+
eggs/
|
| 23 |
+
.eggs/
|
| 24 |
+
lib/
|
| 25 |
+
lib64/
|
| 26 |
+
parts/
|
| 27 |
+
sdist/
|
| 28 |
+
var/
|
| 29 |
+
wheels/
|
| 30 |
+
pip-wheel-metadata/
|
| 31 |
+
share/python-wheels/
|
| 32 |
+
*.egg-info/
|
| 33 |
+
.installed.cfg
|
| 34 |
+
*.egg
|
| 35 |
+
MANIFEST
|
| 36 |
+
|
| 37 |
+
# PyInstaller
|
| 38 |
+
# Usually these files are written by a python script from a template
|
| 39 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 40 |
+
*.manifest
|
| 41 |
+
*.spec
|
| 42 |
+
|
| 43 |
+
# Installer logs
|
| 44 |
+
pip-log.txt
|
| 45 |
+
pip-delete-this-directory.txt
|
| 46 |
+
|
| 47 |
+
# Unit test / coverage reports
|
| 48 |
+
htmlcov/
|
| 49 |
+
.tox/
|
| 50 |
+
.nox/
|
| 51 |
+
.coverage
|
| 52 |
+
.coverage.*
|
| 53 |
+
.cache
|
| 54 |
+
nosetests.xml
|
| 55 |
+
coverage.xml
|
| 56 |
+
*.cover
|
| 57 |
+
*.py,cover
|
| 58 |
+
.hypothesis/
|
| 59 |
+
.pytest_cache/
|
| 60 |
+
|
| 61 |
+
# Translations
|
| 62 |
+
*.mo
|
| 63 |
+
*.pot
|
| 64 |
+
|
| 65 |
+
# Django stuff:
|
| 66 |
+
*.log
|
| 67 |
+
local_settings.py
|
| 68 |
+
db.sqlite3
|
| 69 |
+
db.sqlite3-journal
|
| 70 |
+
|
| 71 |
+
# Flask stuff:
|
| 72 |
+
instance/
|
| 73 |
+
.webassets-cache
|
| 74 |
+
|
| 75 |
+
# Scrapy stuff:
|
| 76 |
+
.scrapy
|
| 77 |
+
|
| 78 |
+
# Sphinx documentation
|
| 79 |
+
docs/_build/
|
| 80 |
+
|
| 81 |
+
# PyBuilder
|
| 82 |
+
target/
|
| 83 |
+
|
| 84 |
+
# Jupyter Notebook
|
| 85 |
+
.ipynb_checkpoints
|
| 86 |
+
|
| 87 |
+
# IPython
|
| 88 |
+
profile_default/
|
| 89 |
+
ipython_config.py
|
| 90 |
+
|
| 91 |
+
# pyenv
|
| 92 |
+
.python-version
|
| 93 |
+
|
| 94 |
+
# pipenv
|
| 95 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 96 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 97 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 98 |
+
# install all needed dependencies.
|
| 99 |
+
#Pipfile.lock
|
| 100 |
+
|
| 101 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 102 |
+
__pypackages__/
|
| 103 |
+
|
| 104 |
+
# Celery stuff
|
| 105 |
+
celerybeat-schedule
|
| 106 |
+
celerybeat.pid
|
| 107 |
+
|
| 108 |
+
# SageMath parsed files
|
| 109 |
+
*.sage.py
|
| 110 |
+
|
| 111 |
+
# Environments
|
| 112 |
+
.env
|
| 113 |
+
.venv
|
| 114 |
+
env/
|
| 115 |
+
venv/
|
| 116 |
+
ENV/
|
| 117 |
+
env.bak/
|
| 118 |
+
venv.bak/
|
| 119 |
+
|
| 120 |
+
# Spyder project settings
|
| 121 |
+
.spyderproject
|
| 122 |
+
.spyproject
|
| 123 |
+
|
| 124 |
+
# Rope project settings
|
| 125 |
+
.ropeproject
|
| 126 |
+
|
| 127 |
+
# mkdocs documentation
|
| 128 |
+
/site
|
| 129 |
+
|
| 130 |
+
# mypy
|
| 131 |
+
.mypy_cache/
|
| 132 |
+
.dmypy.json
|
| 133 |
+
dmypy.json
|
| 134 |
+
|
| 135 |
+
# Pyre type checker
|
| 136 |
+
.pyre/
|
phantom/submodules/phantom-E2FGVI/E2FGVI/__init__.py
ADDED
|
File without changes
|
phantom/submodules/phantom-E2FGVI/E2FGVI/configs/train_e2fgvi.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"seed": 2021,
|
| 3 |
+
"save_dir": "release_model/",
|
| 4 |
+
"train_data_loader": {
|
| 5 |
+
"name": "youtube-vos",
|
| 6 |
+
"data_root": "datasets",
|
| 7 |
+
"w": 432,
|
| 8 |
+
"h": 240,
|
| 9 |
+
"num_local_frames": 5,
|
| 10 |
+
"num_ref_frames": 3
|
| 11 |
+
},
|
| 12 |
+
"losses": {
|
| 13 |
+
"hole_weight": 1,
|
| 14 |
+
"valid_weight": 1,
|
| 15 |
+
"flow_weight": 1,
|
| 16 |
+
"adversarial_weight": 0.01,
|
| 17 |
+
"GAN_LOSS": "hinge"
|
| 18 |
+
},
|
| 19 |
+
"model": {
|
| 20 |
+
"net": "e2fgvi",
|
| 21 |
+
"no_dis": 0
|
| 22 |
+
},
|
| 23 |
+
"trainer": {
|
| 24 |
+
"type": "Adam",
|
| 25 |
+
"beta1": 0,
|
| 26 |
+
"beta2": 0.99,
|
| 27 |
+
"lr": 1e-4,
|
| 28 |
+
"batch_size": 8,
|
| 29 |
+
"num_workers": 2,
|
| 30 |
+
"log_freq": 100,
|
| 31 |
+
"save_freq": 5e3,
|
| 32 |
+
"iterations": 50e4,
|
| 33 |
+
"scheduler": {
|
| 34 |
+
"type": "MultiStepLR",
|
| 35 |
+
"milestones": [
|
| 36 |
+
40e4
|
| 37 |
+
],
|
| 38 |
+
"gamma": 0.1
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
}
|
phantom/submodules/phantom-E2FGVI/E2FGVI/configs/train_e2fgvi_hq.json
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"seed": 2021,
|
| 3 |
+
"save_dir": "release_model/",
|
| 4 |
+
"train_data_loader": {
|
| 5 |
+
"name": "youtube-vos",
|
| 6 |
+
"data_root": "datasets",
|
| 7 |
+
"w": 432,
|
| 8 |
+
"h": 240,
|
| 9 |
+
"num_local_frames": 5,
|
| 10 |
+
"num_ref_frames": 3
|
| 11 |
+
},
|
| 12 |
+
"losses": {
|
| 13 |
+
"hole_weight": 1,
|
| 14 |
+
"valid_weight": 1,
|
| 15 |
+
"flow_weight": 1,
|
| 16 |
+
"adversarial_weight": 0.01,
|
| 17 |
+
"GAN_LOSS": "hinge"
|
| 18 |
+
},
|
| 19 |
+
"model": {
|
| 20 |
+
"net": "e2fgvi_hq",
|
| 21 |
+
"no_dis": 0
|
| 22 |
+
},
|
| 23 |
+
"trainer": {
|
| 24 |
+
"type": "Adam",
|
| 25 |
+
"beta1": 0,
|
| 26 |
+
"beta2": 0.99,
|
| 27 |
+
"lr": 1e-4,
|
| 28 |
+
"batch_size": 8,
|
| 29 |
+
"num_workers": 2,
|
| 30 |
+
"log_freq": 100,
|
| 31 |
+
"save_freq": 5e3,
|
| 32 |
+
"iterations": 50e4,
|
| 33 |
+
"scheduler": {
|
| 34 |
+
"type": "MultiStepLR",
|
| 35 |
+
"milestones": [
|
| 36 |
+
40e4
|
| 37 |
+
],
|
| 38 |
+
"gamma": 0.1
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
}
|